#include <bits/errno-return.h>
#include <fcntl.h>
#include <luna/Common.h>
#include <luna/Format.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/syscall.h>
#include <sys/wait.h>
#include <unistd.h>

FILE* stdin = nullptr;
FILE* stderr = nullptr;
FILE* stdout = nullptr;

FILE* s_open_files[FOPEN_MAX];

enum FileStatusFlags
{
    BufferIsMalloced = (1 << 0),
    LastRead = (1 << 1),
    LastWrite = (1 << 2),
};

static const char* read_tmpdir()
{
    const char* tmpdir = getenv("TMPDIR");
    if (!tmpdir) return "/tmp";
    return tmpdir;
}

static int fopen_parse_mode(const char* mode)
{
    int result = 0;
    switch (*mode)
    {
    case 'r': result |= O_RDONLY; break;
    case 'w': result |= (O_WRONLY | O_CREAT | O_TRUNC); break;
    case 'a': result |= (O_WRONLY | O_CREAT | O_APPEND); break;

    default: errno = EINVAL; return -1;
    }

    if (strchr(mode, '+')) result |= O_RDWR;

    return result;
}

static int fdopen_check_compatible_mode(int fd, int new_flags)
{
    int old_flags = fcntl(fd, F_GETFL);
    if (old_flags < 0) return -1;

    int old_mode = old_flags & O_ACCMODE;
    int new_mode = new_flags & O_ACCMODE;
    if ((old_mode & new_mode) != new_mode)
    {
        errno = EINVAL;
        return -1;
    }

    return 0;
}

static int flush_write_buffer(FILE* stream)
{
    if (stream->_buf.mode == _IONBF) return 0;

    ssize_t result = write(stream->_fd, stream->_buf.buffer, stream->_buf.size);

    stream->_buf.index = 0;
    stream->_buf.size = 0;

    stream->_buf.status &= ~FileStatusFlags::LastWrite;

    return result < 0 ? EOF : 0;
}

static int flush_read_buffer(FILE* stream)
{
    if (stream->_buf.mode == _IONBF) return 0;

    // Reset the stream to its expected position.
    ssize_t unread_bytes = stream->_buf.size - stream->_buf.index;
    lseek(stream->_fd, -unread_bytes, SEEK_CUR);

    stream->_buf.index = 0;
    stream->_buf.size = 0;

    stream->_buf.status &= ~FileStatusFlags::LastRead;

    return 0;
}

static ssize_t write_into_buffer(FILE* stream, const u8* data, ssize_t size)
{
    ssize_t total_written = 0;

    while (size > 0)
    {
        ssize_t nwritten;
        if (stream->_buf.mode != _IONBF)
        {
            if (stream->_buf.status & FileStatusFlags::LastRead) flush_read_buffer(stream);

            if ((stream->_buf.size + size) > stream->_buf.capacity)
            {
                if (flush_write_buffer(stream) < 0) return -1;
            }

            ssize_t size_remaining = stream->_buf.capacity - stream->_buf.size;
            nwritten = min(size_remaining, size);
            memcpy(stream->_buf.buffer + stream->_buf.size, data, nwritten);

            stream->_buf.status |= FileStatusFlags::LastWrite;

            stream->_buf.size += nwritten;

            if (stream->_buf.mode == _IOLBF && memchr(data, '\n', nwritten))
            {
                if (flush_write_buffer(stream) < 0) return -1;
            }
        }
        else
            nwritten = write(stream->_fd, data, min(size, BUFSIZ));
        if (nwritten < 0) return nwritten;
        size -= nwritten;
        data += nwritten;
        total_written += nwritten;
    }

    return total_written;
}

static ssize_t read_data_into_buffer(FILE* stream)
{
    stream->_buf.index = 0;
    ssize_t nread = read(stream->_fd, stream->_buf.buffer, stream->_buf.capacity);
    if (nread >= 0) stream->_buf.size = nread;
    else
        stream->_buf.size = 0;
    stream->_buf.status |= FileStatusFlags::LastRead;
    return nread;
}

static ssize_t read_from_buffer(FILE* stream, u8* data, ssize_t size)
{
    ssize_t total_read = 0;

    while (size > 0)
    {
        ssize_t nread;
        if (stream->_buf.mode != _IONBF)
        {
            if (stream->_buf.status & FileStatusFlags::LastWrite) flush_write_buffer(stream);

            if (stream->_buf.size == stream->_buf.index)
            {
                ssize_t rc;
                if ((rc = read_data_into_buffer(stream)) < 0) return -1;
                if (rc == 0) return total_read;
            }

            ssize_t size_remaining = stream->_buf.size - stream->_buf.index;
            nread = min(size_remaining, size);
            memcpy(data, stream->_buf.buffer + stream->_buf.index, nread);

            stream->_buf.index += nread;
        }
        else
            nread = read(stream->_fd, data, min(size, BUFSIZ));
        if (nread < 0) return nread;
        if (nread == 0) return total_read;
        size -= nread;
        data += nread;
        total_read += nread;
    }

    return total_read;
}

extern "C"
{
    void _init_stdio()
    {
        memset(&s_open_files, 0, sizeof(s_open_files));
    }

    int fflush(FILE* stream)
    {
        if (stream && stream->_buf.mode != _IONBF)
        {
            if (stream->_buf.status & FileStatusFlags::LastWrite) flush_write_buffer(stream);
            else if (stream->_buf.status & FileStatusFlags::LastRead)
                flush_read_buffer(stream);
        }
        else if (!stream)
        {
            for (int i = 0; i < FOPEN_MAX; i++)
            {
                if (s_open_files[i]) fflush(s_open_files[i]);
            }
        }
        return 0;
    }

    FILE* fopen(const char* path, const char* mode)
    {
        int flags;

        if ((flags = fopen_parse_mode(mode)) < 0) return nullptr;

        FILE* f = (FILE*)malloc(sizeof(FILE));
        if (!f) { return nullptr; }

        int fd = open(path, flags, 0666);
        if (fd < 0)
        {
            free(f);
            return nullptr;
        }

        f->_fd = fd;
        clearerr(f);

        f->_flags = flags;
        f->_buf.status = 0;
        f->_buf.mode = isatty(fd) ? _IOLBF : _IOFBF;
        f->_buf.size = f->_buf.index = 0;
        f->_buf.buffer = nullptr;
        setvbuf(f, NULL, f->_buf.mode, 0);

        s_open_files[fd] = f;

        return f;
    }

    FILE* _fdopen_impl(int fd, const char* mode, int buffering_mode)
    {
        int flags;

        if ((flags = fopen_parse_mode(mode)) < 0) return nullptr;

        if (fdopen_check_compatible_mode(fd, flags) < 0) return nullptr;

        FILE* f = (FILE*)malloc(sizeof(FILE));
        if (!f) { return nullptr; }

        f->_fd = fd;
        clearerr(f);

        f->_flags = flags;
        f->_buf.status = 0;
        f->_buf.mode = buffering_mode < 0 ? (isatty(fd) ? _IOLBF : _IOFBF) : buffering_mode;
        f->_buf.size = f->_buf.index = 0;
        f->_buf.buffer = nullptr;
        setvbuf(f, NULL, f->_buf.mode, 0);

        s_open_files[fd] = f;

        return f;
    }

    FILE* fdopen(int fd, const char* mode)
    {
        return _fdopen_impl(fd, mode, -1);
    }

    FILE* freopen(const char* path, const char* mode, FILE* stream)
    {
        int flags;

        if ((flags = fopen_parse_mode(mode)) < 0) return nullptr;

        close(stream->_fd);

        s_open_files[stream->_fd] = nullptr;

        if (stream->_buf.buffer && (stream->_buf.status & FileStatusFlags::BufferIsMalloced)) free(stream->_buf.buffer);

        if (!path) { fail("FIXME: freopen() called with path=nullptr"); }

        int fd = open(path, flags, 0666);
        if (fd < 0) { return nullptr; }

        stream->_fd = fd;
        clearerr(stream);

        stream->_flags = flags;
        stream->_buf.status = 0;
        stream->_buf.mode = isatty(fd) ? _IOLBF : _IOFBF;
        stream->_buf.size = stream->_buf.index = 0;
        stream->_buf.buffer = nullptr;
        setvbuf(stream, NULL, stream->_buf.mode, 0);

        s_open_files[fd] = stream;

        return stream;
    }

    int fclose(FILE* stream)
    {
        if (fflush(stream) < 0) return EOF;

        if (close(stream->_fd) < 0) return EOF;

        if (stream->_buf.buffer && (stream->_buf.status & FileStatusFlags::BufferIsMalloced)) free(stream->_buf.buffer);

        s_open_files[stream->_fd] = nullptr;

        free(stream);

        return 0;
    }

    int fileno(FILE* stream)
    {
        return stream->_fd;
    }

    size_t fread(void* buf, size_t size, size_t nmemb, FILE* stream)
    {
        if (size * nmemb == 0) return 0;

        ssize_t nread = read_from_buffer(stream, (u8*)buf, size * nmemb);

        if (nread < 0)
        {
            stream->_err = 1;
            return 0;
        }
        else if (nread == 0)
        {
            stream->_eof = 1;
            return 0;
        }
        else
            return (size_t)nread / size;
    }

    size_t fwrite(const void* buf, size_t size, size_t nmemb, FILE* stream)
    {
        if (size * nmemb == 0) return 0;

        ssize_t nwrite = write_into_buffer(stream, (const u8*)buf, size * nmemb);
        if (nwrite < 0)
        {
            stream->_err = 1;
            return 0;
        }

        return (size_t)nwrite / size;
    }

    int fseek(FILE* stream, long offset, int whence)
    {
        fflush(stream);

        long result = lseek(stream->_fd, offset, whence);
        if (result < 0) return -1;

        // man fseek(3): A successful call to the fseek() function clears the end-of-file indicator for the stream.
        stream->_eof = 0;

        return 0;
    }

    long ftell(FILE* stream)
    {
        fflush(stream);

        return lseek(stream->_fd, 0, SEEK_CUR);
    }

    void rewind(FILE* stream)
    {
        fflush(stream);

        lseek(stream->_fd, 0, SEEK_SET);

        clearerr(stream);
    }

    int fgetpos(FILE* stream, fpos_t* pos)
    {
        long offset = ftell(stream);
        if (offset < 0) return -1;

        *pos = offset;

        return 0;
    }

    int fsetpos(FILE* stream, const fpos_t* pos)
    {
        fflush(stream);

        return fseek(stream, *pos, SEEK_SET);
    }

    int ferror(FILE* stream)
    {
        return stream->_err;
    }

    int feof(FILE* stream)
    {
        return stream->_eof;
    }

    int fputc(int c, FILE* stream)
    {
        u8 value = (u8)c;
        ssize_t rc = write_into_buffer(stream, &value, 1);
        if (rc <= 0) return EOF;
        return c;
    }

    int putc(int c, FILE* stream)
    {
        return fputc(c, stream);
    }

    int putchar(int c)
    {
        return fputc(c, stdout);
    }

    int fputs(const char* str, FILE* stream)
    {
        ssize_t rc = write_into_buffer(stream, (const u8*)str, strlen(str));
        return (rc < 0) ? -1 : 0;
    }

    int fgetc(FILE* stream)
    {
        u8 value;
        ssize_t rc = read_from_buffer(stream, &value, 1);
        if (rc < 0)
        {
            stream->_err = 1;
            return EOF;
        }
        else if (rc == 0) { return EOF; }
        return value;
    }

    int getc(FILE* stream)
    {
        return fgetc(stream);
    }

    int getchar()
    {
        return fgetc(stdin);
    }

    char* fgets(char* buf, size_t size, FILE* stream)
    {
        size_t i = 0;
        while (i + 1 < size)
        {
            int c = fgetc(stream);
            if (c == EOF) break;
            buf[i++] = (char)c;
            if (c == '\n') break;
        }

        if (i == 0) return NULL;

        buf[i] = 0;

        return buf;
    }

    ssize_t getline(char** linep, size_t* n, FILE* stream)
    {
        return getdelim(linep, n, '\n', stream);
    }

    ssize_t getdelim(char** linep, size_t* n, int delim, FILE* stream)
    {
        if (!n || !linep)
        {
            errno = EINVAL;
            return -1;
        }

        char* buf = *linep;
        size_t size = *n;
        size_t len = 0;

        if (!buf)
        {
            buf = (char*)malloc(BUFSIZ);
            size = BUFSIZ;
            if (!buf) return -1;
            *linep = buf;
            *n = size;
        }

        while (1)
        {
            int c = fgetc(stream);
            if (c == EOF) break;

            if (len == size)
            {
                buf = (char*)realloc(buf, size + 64);
                size += 64;
                if (!buf) return -1;
                *linep = buf;
                *n = size;
            }

            buf[len++] = (char)c;
            if (c == delim) break;
        }

        if (len == 0) return -1;

        if (len == size)
        {
            buf = (char*)realloc(buf, size + 16);
            size += 16;
            if (!buf) return -1;
            *linep = buf;
            *n = size;
        }

        buf[len] = '\0';

        return (ssize_t)len;
    }

    void clearerr(FILE* stream)
    {
        stream->_eof = stream->_err = 0;
    }

    int vsnprintf(char* buf, size_t max, const char* format, va_list ap)
    {
        return (int)vstring_format(buf, max, format, ap);
    }

    int snprintf(char* buf, size_t max, const char* format, ...)
    {
        va_list ap;
        va_start(ap, format);

        int rc = vsnprintf(buf, max, format, ap);

        va_end(ap);

        return rc;
    }

    int vsprintf(char* buf, const char* format, va_list ap)
    {
        return vsnprintf(buf, (size_t)-1, format, ap);
    }

    int sprintf(char* buf, const char* format, ...)
    {
        va_list ap;
        va_start(ap, format);

        int rc = vsnprintf(buf, (size_t)-1, format, ap);

        va_end(ap);

        return rc;
    }

    int vfprintf(FILE* stream, const char* format, va_list ap)
    {
        usize count = cstyle_format(
                          format,
                          [](char c, void* f) -> Result<void> {
                              int rc = fputc(c, (FILE*)f);
                              if (rc == EOF) return err(errno);
                              return {};
                          },
                          stream, ap)
                          .value_or(-1);

        if (count == (usize)-1) return -1;

        return (int)count;
    }

    int fprintf(FILE* stream, const char* format, ...)
    {
        va_list ap;
        va_start(ap, format);

        int rc = vfprintf(stream, format, ap);

        va_end(ap);

        return rc;
    }

    int vprintf(const char* format, va_list ap)
    {
        return vfprintf(stdout, format, ap);
    }

    int printf(const char* format, ...)
    {
        va_list ap;
        va_start(ap, format);

        int rc = vfprintf(stdout, format, ap);

        va_end(ap);

        return rc;
    }

    int sscanf(const char* str, const char* format, ...)
    {
        va_list ap;
        va_start(ap, format);

        int rc = vsscanf(str, format, ap);

        va_end(ap);

        return rc;
    }

    int vfscanf(FILE* stream, const char* format, va_list ap)
    {
        char buf[BUFSIZ];
        if (!fgets(buf, sizeof(buf), stream)) return EOF;
        return vsscanf(buf, format, ap);
    }

    int fscanf(FILE* stream, const char* format, ...)
    {
        va_list ap;
        va_start(ap, format);

        int rc = vfscanf(stream, format, ap);

        va_end(ap);

        return rc;
    }

    int vscanf(const char* format, va_list ap)
    {
        return vfscanf(stdin, format, ap);
    }

    int scanf(const char* format, ...)
    {
        va_list ap;
        va_start(ap, format);

        int rc = vfscanf(stdin, format, ap);

        va_end(ap);

        return rc;
    }

    int puts(const char* s)
    {
        if (fputs(s, stdout) < 0) return -1;
        if (putchar('\n') == EOF) return -1;

        return 0;
    }

    void perror(const char* s)
    {
        int err = errno;
        if (s && *s) fprintf(stderr, "%s: ", s);
        fprintf(stderr, "%s\n", strerror(err));
    }

    int remove(const char* path)
    {
        // On Luna, unlink() allows removal of directories.
        return unlink(path);
    }

    FILE* tmpfile()
    {
        int fd = open(read_tmpdir(), O_RDWR | O_TMPFILE, 0600);
        if (fd < 0) return nullptr;

        FILE* f = fdopen(fd, "w+b");
        if (!f) close(fd);
        return f;
    }

    int ungetc(int c, FILE* stream)
    {
        if (stream->_buf.index == 0)
            return EOF; // No data currently in the read buffer, or no data has been read from it.

        if (stream->_buf.mode == _IONBF)
            return EOF; // FIXME: C doesn't state that ungetc() should only work on buffered streams.

        stream->_buf.index--;
        stream->_buf.buffer[stream->_buf.index] = (char)c;

        return 0;
    }

    int setvbuf(FILE* stream, char* buf, int mode, size_t size)
    {
        int status = 0;
        if (mode < 0 || mode > _IOFBF) return errno = EINVAL, -1;
        if (stream->_buf.size != 0 || stream->_buf.index != 0) return -1; // Buffer is already in use.
        if (mode != _IONBF && buf == NULL)
        {
            size = BUFSIZ;
            buf = (char*)calloc(size, 1);
            if (!buf) return -1;
            status = FileStatusFlags::BufferIsMalloced;
        }
        else if (mode == _IONBF)
        {
            buf = NULL;
            size = 0;
        }

        if (stream->_buf.buffer && (stream->_buf.status & FileStatusFlags::BufferIsMalloced)) free(stream->_buf.buffer);

        stream->_buf.buffer = buf;
        stream->_buf.capacity = size;
        stream->_buf.mode = mode;
        stream->_buf.status = status;

        return 0;
    }

    void setbuf(FILE* stream, char* buf)
    {
        setvbuf(stream, buf, buf ? _IOFBF : _IONBF, BUFSIZ);
    }

    void setbuffer(FILE* stream, char* buf, size_t size)
    {
        setvbuf(stream, buf, buf ? _IOFBF : _IONBF, size);
    }

    void setlinebuf(FILE* stream)
    {
        setvbuf(stream, NULL, _IOLBF, 0);
    }

    int rename(const char* oldpath, const char* newpath)
    {
        // FIXME: Implement this atomically in-kernel.
        unlink(newpath);
        if (link(oldpath, newpath) < 0) return -1;
        unlink(oldpath);
        return 0;
    }

    FILE* popen(const char* command, const char* type)
    {
        int pfds[2];
        if (pipe(pfds) < 0) return nullptr;

        if (*type != 'r' && *type != 'w')
        {
            errno = EINVAL;
            return nullptr;
        }

        pid_t child = fork();
        if (child < 0)
        {
            close(pfds[0]);
            close(pfds[1]);
            return nullptr;
        }
        if (child == 0)
        {
            if (*type == 'r')
            {
                close(pfds[0]);
                dup2(pfds[1], STDOUT_FILENO);
            }
            else
            {
                close(pfds[1]);
                dup2(pfds[0], STDIN_FILENO);
            }

            execl("/bin/sh", "sh", "-c", command, nullptr);
            _exit(127);
        }

        int fd;
        if (*type == 'r')
        {
            close(pfds[1]);
            fd = pfds[0];
        }
        else
        {
            close(pfds[0]);
            fd = pfds[1];
        }

        int err = errno;
        FILE* f = (FILE*)malloc(sizeof(FILE));
        if (!f)
        {
            errno = err;
            close(fd);
            return nullptr;
        }

        f->_fd = fd;
        f->_pid = child;
        clearerr(f);

        f->_flags = *type == 'r' ? O_RDONLY : O_WRONLY;
        f->_buf.status = 0;
        f->_buf.mode = _IOFBF;
        f->_buf.size = f->_buf.index = 0;
        f->_buf.buffer = nullptr;
        setvbuf(f, NULL, f->_buf.mode, 0);

        s_open_files[fd] = f;

        return f;
    }

    int pclose(FILE* stream)
    {
        pid_t pid = stream->_pid;
        fclose(stream);

        int status;
        if (waitpid(pid, &status, 0) < 0) return -1;

        return status;
    }
}