#include <errno.h>
#include <fcntl.h>
#include <grp.h>
#include <os/ArgumentParser.h>
#include <pwd.h>
#include <shadow.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <termios.h>
#include <unistd.h>

static struct termios orig;
static int fd = -1;

void restore_terminal()
{
    tcsetattr(fd, TCSANOW, &orig);
}

void signal_handler(int signo)
{
    restore_terminal();
    raise(signo);
}

char* getpass()
{
    char ctty[L_ctermid];
    ctermid(ctty);

    FILE* f = fopen(ctty, "r");
    if (!f)
    {
        perror("Failed to open controlling terminal");
        return nullptr;
    }

    fd = fileno(f);

    tcsetpgrp(fd, getpgid(0));

    fputs("Password: ", stdout);
    fflush(stdout);

    if (tcgetattr(fd, &orig) < 0)
    {
        perror("tcgetattr");
        fclose(f);
        fd = -1;
        return nullptr;
    }

    struct sigaction sa;
    sa.sa_handler = signal_handler;
    sigemptyset(&sa.sa_mask);
    sa.sa_flags = SA_RESETHAND;
    sigaction(SIGINT, &sa, NULL);
    sigaction(SIGTERM, &sa, NULL);
    sigaction(SIGQUIT, &sa, NULL);

    atexit(restore_terminal);

    struct termios tc = orig;
    tc.c_lflag &= ~ECHO;
    if (tcsetattr(fd, TCSANOW, &tc) < 0)
    {
        perror("tcsetattr");
        fclose(f);
        fd = -1;
        return nullptr;
    }

    static char buf[BUFSIZ];
    char* rc = fgets(buf, sizeof(buf), f);

    restore_terminal();
    putchar('\n');

    fclose(f);
    fd = -1;

    if (!rc)
    {
        perror("fgets");
        return nullptr;
    }

    char* newline = strrchr(rc, '\n');
    if (newline) *newline = 0;

    return buf;
}

Result<void> set_supplementary_groups(const char* name)
{
    Vector<gid_t> extra_groups;

    setgrent();
    group* grp;
    while ((grp = getgrent()))
    {
        for (char** user = grp->gr_mem; *user; user++)
        {
            if (!strcmp(*user, name))
            {
                TRY(extra_groups.try_append(grp->gr_gid));
                break;
            }
        }
    }
    endgrent();

    if (setgroups(static_cast<int>(extra_groups.size()), extra_groups.data()) < 0) return err(errno);

    return {};
}

Result<int> luna_main(int argc, char** argv)
{
    StringView name;
    bool prompt_password;
    bool login;

    if (geteuid() != 0)
    {
        fprintf(stderr, "%s must be setuid root!\n", argv[0]);
        return 1;
    }

    os::ArgumentParser parser;
    parser.add_description("Switch to a different user (by default, root)."_sv);
    parser.add_system_program_info("su"_sv);
    parser.add_positional_argument(name, "name"_sv, "root"_sv);
    parser.add_switch_argument(prompt_password, 'p', "prompt", "prompt for a password even if running as root");
    parser.add_switch_argument(login, 'l', "login"_sv, "change directory to the user's home and start a login shell");
    parser.parse(argc, argv);

    struct passwd* entry = getpwnam(name.chars());
    if (!entry)
    {
        fprintf(stderr, "%s: user %s not found!\n", argv[0], name.chars());
        return 1;
    }

    endpwent();

    if ((prompt_password || getuid() != geteuid()) && *entry->pw_passwd)
    {
        signal(SIGTTOU, SIG_IGN);

        const char* passwd = entry->pw_passwd;

        // If the user's password entry is 'x', read their password from the shadow file instead.
        if (!strcmp(entry->pw_passwd, "x"))
        {
            struct spwd* sp = getspnam(name.chars());

            if (!sp)
            {
                fprintf(stderr, "%s: user %s not found in shadow file!\n", argv[0], name.chars());
                return 1;
            }

            endspent();

            passwd = sp->sp_pwdp;
        }

        if (!strcmp(passwd, "!"))
        {
            fprintf(stderr, "%s: %s's password is disabled!\n", argv[0], entry->pw_name);
            return 1;
        }

        char* pass = getpass();
        if (!pass) return 1;

        if (strcmp(pass, passwd))
        {
            fprintf(stderr, "%s: wrong password!\n", argv[0]);
            return 1;
        }

        memset(pass, 0, strlen(pass));
    }

    TRY(set_supplementary_groups(name.chars()));

    setgid(entry->pw_gid);
    setuid(entry->pw_uid);

    if (login)
    {
        chdir(entry->pw_dir);
        clearenv();
        setenv("PATH", "/usr/bin:/usr/local/bin", 1);
        setpgid(0, 0);
    }

    if (login || entry->pw_uid != 0) setenv("USER", entry->pw_name, 1);

    setenv("HOME", entry->pw_dir, 1);
    setenv("SHELL", entry->pw_shell, 1);

    execl(entry->pw_shell, entry->pw_shell, NULL);

    perror("execl");
    return 1;
}