#include <alloca.h>
#include <assert.h>
#include <fcntl.h>
#include <luna/Heap.h>
#include <os/ArgumentParser.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <time.h>
#include <unistd.h>

struct Cell
{
    bool state;
    bool new_state;
};

static int g_num_rows = 76;
static int g_num_columns = 102;

static int g_fb_width;
static int g_fb_height;

static int g_fd;

static Cell* g_cells;
static char* g_fb;

static Result<void> fill_cells()
{
    g_cells = (Cell*)TRY(calloc_impl(g_num_rows, g_num_columns * sizeof(Cell), false));

    for (isize i = 0; i < (g_num_rows * g_num_columns); i++)
    {
        auto value = rand() % 2;
        g_cells[i].state = g_cells[i].new_state = value;
    }

    return {};
}

static Cell& find_cell(int row, int column)
{
    assert(row < g_num_rows);
    assert(column < g_num_columns);
    return g_cells[row * g_num_columns + column];
}

static constexpr int BYTES_PER_PIXEL = 4;

static void draw_cells()
{
    const int CELL_WIDTH = g_fb_width / g_num_columns;
    const int CELL_HEIGHT = g_fb_height / g_num_rows;

    for (int i = 0; i < g_num_rows; i++)
    {

        for (int j = 0; j < g_num_columns; j++)
        {
            char* buf = g_fb + (i * g_fb_width * CELL_HEIGHT * BYTES_PER_PIXEL);

            auto& cell = find_cell(i, j);
            u8 color = cell.state ? 0xff : 0x00;
            for (int k = 0; k < CELL_HEIGHT; k++)
            {
                memset(buf + (j * CELL_WIDTH * BYTES_PER_PIXEL), color, CELL_WIDTH * BYTES_PER_PIXEL);
                buf += g_fb_width * BYTES_PER_PIXEL;
            }
        }
    }

    msync(g_fb, g_fb_height * g_fb_width * BYTES_PER_PIXEL, MS_SYNC);
}

static int find_neighbors(int row, int column)
{
    int sum = 0;

    if (row > 0 && column > 0) sum += find_cell(row - 1, column - 1).state;
    if (row > 0) sum += find_cell(row - 1, column).state;
    if (row > 0 && (column + 1) < g_num_columns) sum += find_cell(row - 1, column + 1).state;
    if (column > 0) sum += find_cell(row, column - 1).state;
    if ((column + 1) < g_num_columns) sum += find_cell(row, column + 1).state;
    if ((row + 1) < g_num_rows && column > 0) sum += find_cell(row + 1, column - 1).state;
    if ((row + 1) < g_num_rows) sum += find_cell(row + 1, column).state;
    if ((row + 1) < g_num_rows && (column + 1) < g_num_columns) sum += find_cell(row + 1, column + 1).state;

    return sum;
}

static void next_generation()
{
    for (int i = 0; i < g_num_rows; i++)
    {
        for (int j = 0; j < g_num_columns; j++)
        {
            auto& cell = find_cell(i, j);
            int neighbors = find_neighbors(i, j);
            if (!cell.state && neighbors == 3) cell.new_state = true;
            else if (cell.state && (neighbors < 2 || neighbors > 3))
                cell.new_state = false;
        }
    }

    for (isize i = 0; i < (g_num_rows * g_num_columns); i++) g_cells[i].state = g_cells[i].new_state;
}

Result<int> luna_main(int argc, char** argv)
{
    u64 delay_between_iterations = 250;
    u64 delay_at_end = 3000;
    u64 num_iterations = 100;

    StringView columns;
    StringView rows;
    StringView delay;
    StringView end_delay;
    StringView iterations;
    StringView seed;

    os::ArgumentParser parser;
    parser.add_description("A framebuffer-based implementation for Conway's Game of Life.");
    parser.add_system_program_info("gol"_sv);
    parser.add_positional_argument(rows, "rows"_sv, "76"_sv);
    parser.add_positional_argument(columns, "columns"_sv, "102"_sv);
    parser.add_value_argument(delay, 'd', "delay"_sv, "the delay between generations (in ms)");
    parser.add_value_argument(end_delay, 'e', "end-delay"_sv,
                              "after finishing, how much to wait before returning to the shell (in ms)");
    parser.add_value_argument(iterations, 'i', "iterations"_sv, "how many generations to show (default: 100)");
    parser.add_value_argument(seed, 's', "seed"_sv, "the seed for the random number generator");
    parser.parse(argc, argv);

    g_num_columns = (int)TRY(columns.to_uint());
    g_num_rows = (int)TRY(rows.to_uint());
    if (!delay.is_empty()) delay_between_iterations = TRY(delay.to_uint());
    if (!end_delay.is_empty()) delay_at_end = TRY(end_delay.to_uint());
    if (!iterations.is_empty()) num_iterations = TRY(iterations.to_uint());
    if (!seed.is_empty()) srand((unsigned)TRY(seed.to_uint()));
    else
        srand((unsigned)time(NULL));

    g_fd = open("/dev/fb0", O_RDWR);
    if (g_fd < 0)
    {
        perror("gol: cannot open framebuffer for writing");
        return 1;
    }

    g_fb_height = ioctl(g_fd, FB_GET_HEIGHT);
    g_fb_width = ioctl(g_fd, FB_GET_WIDTH);

    TRY(fill_cells());

    g_fb =
        (char*)mmap(nullptr, g_fb_height * g_fb_width * BYTES_PER_PIXEL, PROT_READ | PROT_WRITE, MAP_SHARED, g_fd, 0);
    if (g_fb == MAP_FAILED)
    {
        perror("gol: cannot map framebuffer into memory");
        return 1;
    }

    draw_cells();

    while (num_iterations--)
    {
        usleep(delay_between_iterations * 1000);
        next_generation();
        draw_cells();
    }

    usleep(delay_at_end * 1000);

    munmap(g_fb, g_fb_height * g_fb_width * BYTES_PER_PIXEL);
    free(g_cells);

    return 0;
}