#include #include #include #include #include #include void run_program(char** argv) { execv(argv[0], argv); perror("execv"); exit(EXIT_FAILURE); } void strip_newline(char* str) { size_t len = strlen(str); if (str[len - 1] == '\n') str[len - 1] = 0; } static const char* collect_password() { static char buf[BUFSIZ]; printf("Password: "); fgets(buf, BUFSIZ, stdin); strip_newline(buf); putchar('\n'); return buf; } int main(int argc, char** argv) { const char* username; if (argc == 1) username = "root"; else username = argv[1]; if (getuid() != 0) { fprintf(stderr, "%s must be setuid root", argv[0]); return EXIT_FAILURE; } struct passwd* user = getpwnam(username); endpwent(); if (!user) { if (errno) perror("getpwnam"); else fprintf(stderr, "Unknown user %s\n", username); return EXIT_FAILURE; } if (getuid() != geteuid()) // we were started from a non-root user { const char* pw = collect_password(); if (strcmp(pw, user->pw_passwd) != 0) { fprintf(stderr, "Invalid password\n"); return EXIT_FAILURE; } } if (setuid(user->pw_uid) < 0) { perror("setuid"); return EXIT_FAILURE; } if (setgid(user->pw_gid) < 0) { perror("setgid"); return EXIT_FAILURE; } char* default_argv[] = {user->pw_shell, NULL}; if (argc < 3) run_program(default_argv); else run_program(argv + 2); }