78 lines
1.8 KiB
C++
78 lines
1.8 KiB
C++
#pragma once
|
|
#include "Log.h"
|
|
#include "arch/CPU.h"
|
|
#include "thread/Scheduler.h"
|
|
#include "thread/Thread.h"
|
|
#include <luna/CircularQueue.h>
|
|
|
|
template <usize ConcurrentThreads> class KMutex
|
|
{
|
|
public:
|
|
void lock()
|
|
{
|
|
int expected = 0;
|
|
while (!m_lock.compare_exchange_strong(expected, 1))
|
|
{
|
|
expected = 0;
|
|
auto* current = Scheduler::current();
|
|
|
|
// We cannot be interrupted between these functions, otherwise we might never exit the loop
|
|
CPU::disable_interrupts();
|
|
bool ok = m_blocked_threads.try_push(current);
|
|
if (!ok) kernel_sleep(10);
|
|
else
|
|
kernel_wait_for_event();
|
|
CPU::enable_interrupts();
|
|
}
|
|
};
|
|
|
|
void unlock()
|
|
{
|
|
int expected = 1;
|
|
if (!m_lock.compare_exchange_strong(expected, 0))
|
|
{
|
|
kwarnln("KMutex::unlock() called on an unlocked lock with value %d", expected);
|
|
}
|
|
|
|
Thread* blocked;
|
|
if (m_blocked_threads.try_pop(blocked)) blocked->wake_up();
|
|
}
|
|
|
|
bool try_lock()
|
|
{
|
|
int expected = 0;
|
|
return m_lock.compare_exchange_strong(expected, 1);
|
|
}
|
|
|
|
private:
|
|
CircularQueue<Thread*, ConcurrentThreads> m_blocked_threads;
|
|
Atomic<int> m_lock;
|
|
};
|
|
|
|
template <usize ConcurrentThreads> class ScopedKMutexLock
|
|
{
|
|
public:
|
|
ScopedKMutexLock(KMutex<ConcurrentThreads>& lock) : m_lock(lock)
|
|
{
|
|
m_lock.lock();
|
|
}
|
|
|
|
~ScopedKMutexLock()
|
|
{
|
|
if (!m_taken_over) m_lock.unlock();
|
|
}
|
|
|
|
ScopedKMutexLock(const ScopedKMutexLock&) = delete;
|
|
ScopedKMutexLock(ScopedKMutexLock&&) = delete;
|
|
|
|
KMutex<ConcurrentThreads>& take_over()
|
|
{
|
|
m_taken_over = true;
|
|
return m_lock;
|
|
}
|
|
|
|
private:
|
|
KMutex<ConcurrentThreads>& m_lock;
|
|
bool m_taken_over { false };
|
|
};
|