diff --git a/apps/su.cpp b/apps/su.cpp index bc2a84f3..6d25892c 100644 --- a/apps/su.cpp +++ b/apps/su.cpp @@ -1,16 +1,63 @@ +#include #include #include #include #include +#include #include +static struct termios orig; + +void restore_terminal() +{ + ioctl(fileno(stdin), TCSETS, &orig); +} + +char* getpass() +{ + fputs("Password: ", stdout); + + if (ioctl(fileno(stdin), TCGETS, &orig) < 0) + { + perror("ioctl(TCGETS)"); + return nullptr; + } + + atexit(restore_terminal); + + struct termios tc = orig; + tc.c_lflag &= ~ECHO; + if (ioctl(fileno(stdin), TCSETS, &tc) < 0) + { + perror("ioctl(TCSETS)"); + return nullptr; + } + + static char buf[1024]; + char* rc = fgets(buf, sizeof(buf), stdin); + + restore_terminal(); + putchar('\n'); + + if (!rc) + { + perror("fgets"); + return nullptr; + } + + char* newline = strrchr(rc, '\n'); + if (newline) *newline = 0; + + return buf; +} + int main(int argc, char** argv) { StringView name; if (geteuid() != 0) { - fprintf(stderr, "su must be run as root!\n"); + fprintf(stderr, "su must be setuid root!\n"); return 1; } @@ -25,10 +72,18 @@ int main(int argc, char** argv) return 1; } - if (getuid() != geteuid()) + if (getuid() != geteuid() && *entry->pw_passwd) { - fprintf(stderr, "FIXME: you have to enter %s's password first!\n", name.chars()); - return 1; + char* pass = getpass(); + if (!pass) return 1; + + if (strcmp(pass, entry->pw_passwd)) + { + fprintf(stderr, "Wrong password!\n"); + return 1; + } + + memset(pass, 0, strlen(pass)); } setgid(entry->pw_gid);