#include <luna/Alignment.h>
#include <luna/Sort.h>

static void swap_sized(void* ptr1, void* ptr2, usize size)
{
    char* x = (char*)ptr1;
    char* y = (char*)ptr2;

    while (size--)
    {
        char t = *x;
        *x = *y;
        *y = t;
        x += 1;
        y += 1;
    }
}

static usize partition(void* base, usize start, usize end, usize size, compar_t compar)
{
    auto atindex = [&base, &size](usize index) { return offset_ptr(base, index * size); };

    void* pivot = atindex(end);
    usize i = (start - 1);

    for (usize j = start; j <= end - 1; j++)
    {
        if (compar(atindex(j), pivot) < 0)
        {
            i++;
            swap_sized(atindex(i), atindex(j), size);
        }
    }

    swap_sized(atindex(i + 1), pivot, size);
    return i + 1;
}

static void quicksort_impl(void* base, usize start, usize end, usize size, compar_t compar)
{
    if (start < end)
    {
        usize pivot = partition(base, start, end, size, compar);
        if ((end - start) < 2) return;
        quicksort_impl(base, start, pivot - 1, size, compar);
        quicksort_impl(base, pivot + 1, end, size, compar);
    }
}

void c_quicksort(void* base, usize nmemb, usize size, compar_t compar)
{
    quicksort_impl(base, 0, nmemb - 1, size, compar);
}