diff --git a/libc/CMakeLists.txt b/libc/CMakeLists.txt index 46834ba4..307e1523 100644 --- a/libc/CMakeLists.txt +++ b/libc/CMakeLists.txt @@ -20,6 +20,7 @@ set(SOURCES src/pwd.cpp src/grp.cpp src/locale.cpp + src/scanf.cpp src/sys/stat.cpp src/sys/mman.cpp src/sys/wait.cpp diff --git a/libc/include/stdio.h b/libc/include/stdio.h index aef67d83..fe3cc3ae 100644 --- a/libc/include/stdio.h +++ b/libc/include/stdio.h @@ -122,16 +122,34 @@ extern "C" int snprintf(char* buf, size_t max, const char* format, ...); /* Write formatted output into a buffer. */ - int vsprintf(char*, const char*, va_list); + int vsprintf(char* buf, const char* format, va_list ap); /* Write up to max bytes of formatted output into a buffer. */ - int vsnprintf(char*, size_t, const char*, va_list); + int vsnprintf(char* buf, size_t max, const char* format, va_list ap); /* Write formatted output to standard output. */ - int vprintf(const char*, va_list ap); + int vprintf(const char* format, va_list ap); /* Write formatted output to standard output. */ - int printf(const char*, ...); + int printf(const char* format, ...); + + /* Scan formatted input from a string. */ + int vsscanf(const char* str, const char* format, va_list ap); + + /* Scan formatted input from a string. */ + int sscanf(const char* str, const char* format, ...); + + /* Scan formatted input from a file. */ + int vfscanf(FILE* stream, const char* format, va_list ap); + + /* Scan formatted input from a file. */ + int fscanf(FILE* stream, const char* format, ...); + + /* Scan formatted input from standard input. */ + int vscanf(const char* format, va_list ap); + + /* Scan formatted input from standard input. */ + int scanf(const char* format, ...); /* Write a string followed by a newline to standard output. */ int puts(const char* s); diff --git a/libc/src/scanf.cpp b/libc/src/scanf.cpp new file mode 100644 index 00000000..1a7d3f72 --- /dev/null +++ b/libc/src/scanf.cpp @@ -0,0 +1,137 @@ +#include +#include +#include +#include +#include + +#define FLAG_DISCARD (1 << 0) +#define FLAG_ALLOC (1 << 1) +#define FLAG_WIDTH (1 << 2) + +static int parse_flags(const char** format) +{ + int result = 0; + + while (true) + { + switch (**format) + { + case '*': + result |= FLAG_DISCARD; + (*format)++; + break; + case 'm': + result |= FLAG_ALLOC; + (*format)++; + break; + default: return result; + } + } +} + +static size_t parse_width(const char** format, int& flags) +{ + size_t result = 0; + + if (_isdigit(**format)) + { + result = scan_unsigned_integer(format); + flags |= FLAG_WIDTH; + } + + return result; +} + +extern "C" +{ + int vsscanf(const char* str, const char* format, va_list ap) + { + int parsed = 0; + + if (*str == 0) return EOF; + + while (*format) + { + if (*format != '%') + { + normal: + if (!_isspace(*format)) + { + if (*str != *format) return parsed; + str++; + format++; + if (*str == 0) return parsed; + continue; + } + + format += strspn(format, " \t\f\r\n\v"); + str += strspn(str, " \t\f\r\n\v"); + if (*str == 0) return parsed; + continue; + } + else + { + format++; + if (*format == '%') goto normal; + + int flags = parse_flags(&format); + size_t width = parse_width(&format, flags); + char specifier = *format++; + if (!specifier) return parsed; + + switch (specifier) + { + case 's': { + str += strspn(str, " \t\f\r\n\v"); + size_t chars = strcspn(str, " \t\f\r\n\v"); + if (!chars) return parsed; + if ((flags & FLAG_WIDTH) && chars > width) chars = width; + if (!(flags & FLAG_DISCARD)) + { + char* ptr; + if (flags & FLAG_ALLOC) + { + ptr = (char*)malloc(chars + 1); + if (!ptr) return parsed; + *va_arg(ap, char**) = ptr; + } + else + ptr = va_arg(ap, char*); + memcpy(ptr, str, chars); + ptr[chars] = 0; + } + str += chars; + parsed++; + break; + } + case 'c': { + if (strlen(str) < width) return parsed; + if (!(flags & FLAG_WIDTH)) width = 1; + if (!(flags & FLAG_DISCARD)) + { + char* ptr; + if (flags & FLAG_ALLOC) + { + ptr = (char*)malloc(width); + if (!ptr) return parsed; + *va_arg(ap, char**) = ptr; + } + else + ptr = va_arg(ap, char*); + memcpy(ptr, str, width); + } + str += width; + parsed++; + break; + } + default: { + fprintf(stderr, "vsscanf: unknown conversion specifier: %%%c\n", specifier); + return parsed; + } + } + } + } + + return parsed; + } +} diff --git a/libc/src/stdio.cpp b/libc/src/stdio.cpp index 0112746b..16c509ad 100644 --- a/libc/src/stdio.cpp +++ b/libc/src/stdio.cpp @@ -381,6 +381,54 @@ extern "C" return rc; } + int sscanf(const char* str, const char* format, ...) + { + va_list ap; + va_start(ap, format); + + int rc = vsscanf(str, format, ap); + + va_end(ap); + + return rc; + } + + int vfscanf(FILE* stream, const char* format, va_list ap) + { + char buf[BUFSIZ]; + if (!fgets(buf, sizeof(buf), stream)) return EOF; + return vsscanf(buf, format, ap); + } + + int fscanf(FILE* stream, const char* format, ...) + { + va_list ap; + va_start(ap, format); + + int rc = vfscanf(stream, format, ap); + + va_end(ap); + + return rc; + } + + int vscanf(const char* format, va_list ap) + { + return vfscanf(stdin, format, ap); + } + + int scanf(const char* format, ...) + { + va_list ap; + va_start(ap, format); + + int rc = vfscanf(stdin, format, ap); + + va_end(ap); + + return rc; + } + int puts(const char* s) { if (fputs(s, stdout) < 0) return -1; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d7dabaef..dd452813 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -19,6 +19,7 @@ luna_test(libluna/TestUtf8.cpp TestUtf8) luna_test(libluna/TestFormat.cpp TestFormat) luna_test(libluna/TestHashTable.cpp TestHashTable) luna_test(libluna/TestCPath.cpp TestCPath) +luna_test(libc/TestScanf.cpp TestScanf) luna_app(run-tests.cpp run-tests) endif() diff --git a/tests/libc/TestScanf.cpp b/tests/libc/TestScanf.cpp new file mode 100644 index 00000000..98a1e27d --- /dev/null +++ b/tests/libc/TestScanf.cpp @@ -0,0 +1,42 @@ +#include +#include +#include + +// FIXME: Add more tests. + +TestResult test_basic_scanf() +{ + char hello[21]; + char world[21]; + + int parsed = sscanf("hello world", "%20s %20s", hello, world); + validate(parsed == 2); + + validate(!strcmp(hello, "hello")); + validate(!strcmp(world, "world")); + + test_success; +} + +TestResult test_incomplete_scanf() +{ + char hello[21]; + char world[21]; + + int parsed = sscanf("hello ", "%20s %20s", hello, world); + validate(parsed == 1); + + validate(!strcmp(hello, "hello")); + + test_success; +} + +Result test_main() +{ + test_prelude; + + run_test(test_basic_scanf); + run_test(test_incomplete_scanf); + + return {}; +}