diff options
-rw-r--r-- | includes/Local.h | 6 | ||||
-rw-r--r-- | includes/Main.h | 25 | ||||
-rw-r--r-- | includes/Task.h | 44 | ||||
-rw-r--r-- | includes/TaskMan.h | 17 | ||||
-rw-r--r-- | src/Local.cc | 2 | ||||
-rw-r--r-- | src/Main.cc | 2 | ||||
-rw-r--r-- | src/Task.cc | 64 | ||||
-rw-r--r-- | src/TaskMan.cc | 62 | ||||
-rw-r--r-- | tests/test-Sanity.cc | 3 | ||||
-rw-r--r-- | tests/test-String.cc | 3 | ||||
-rw-r--r-- | tests/test-Tasks.cc | 14 |
11 files changed, 163 insertions, 79 deletions
diff --git a/includes/Local.h b/includes/Local.h index e3d08a4..79862a6 100644 --- a/includes/Local.h +++ b/includes/Local.h @@ -11,7 +11,7 @@ class TLSManager { void * createTLS(); }; -extern TLSManager * tlsManager; +extern TLSManager * g_tlsManager; class Local : public AtStart { public: @@ -27,8 +27,8 @@ class Local : public AtStart { int getIndex() { return m_idx; } private: static void * create() { void * r = malloc(s_size * sizeof(void *)); return r; } - static void * getTLS() { return tlsManager->getTLS(); } - static void * setTLS(void * val) { return tlsManager->setTLS(val); } + static void * getTLS() { return g_tlsManager->getTLS(); } + static void * setTLS(void * val) { return g_tlsManager->setTLS(val); } virtual void doStart(); int m_idx; static int s_size; diff --git a/includes/Main.h b/includes/Main.h index eee7ff7..44764a3 100644 --- a/includes/Main.h +++ b/includes/Main.h @@ -37,9 +37,17 @@ class Exit : public GeneralException { }; #include <Printer.h> +#include <Task.h> +#include <TaskMan.h> namespace Balau { +class MainTask : public Task { + public: + virtual const char * getName() { return "Main Task"; } + virtual void Do(); +}; + class Main { public: enum Status { @@ -49,11 +57,10 @@ class Main { STOPPING, STOPPED, }; - Main() : m_status(UNKNOWN) { application = this; } - virtual int startup() throw (GeneralException) = 0; - static Status status() { return application->m_status; } + Main() : m_status(UNKNOWN) { Assert(s_application == 0); s_application = this; } + static Status status() { return s_application->m_status; } int bootstrap(int _argc, char ** _argv) { - int r; + int r = 0; m_status = STARTING; argc = _argc; @@ -65,7 +72,8 @@ class Main { try { m_status = RUNNING; - r = startup(); + MainTask * mainTask = new MainTask(); + TaskMan::getTaskMan()->mainLoop(); m_status = STOPPING; } catch (Exit e) { @@ -96,7 +104,7 @@ class Main { char ** enve; private: Status m_status; - static Main * application; + static Main * s_application; }; #define BALAU_STARTUP \ @@ -106,12 +114,11 @@ class Application : public Balau::Main { \ virtual int startup() throw (Balau::GeneralException); \ }; \ \ -static Application application; \ -\ extern "C" { \ int main(int argc, char ** argv) { \ setlocale(LC_ALL, ""); \ - return application.bootstrap(argc, argv); \ + Balau::Main mainClass; \ + return mainClass.bootstrap(argc, argv); \ } \ } diff --git a/includes/Task.h b/includes/Task.h index fb210a2..c2777fe 100644 --- a/includes/Task.h +++ b/includes/Task.h @@ -2,10 +2,37 @@ #include <stdlib.h> #include <coro.h> +#include <Exceptions.h> +#include <vector> namespace Balau { class TaskMan; +class Task; + +namespace Events { + +class BaseEvent { + public: + BaseEvent() : m_signal(false), m_task(NULL) { } + bool gotSignal() { return m_signal; } + void doSignal() { m_signal = true; } + Task * taskWaiting() { Assert(m_task); return m_task; } + void registerOwner(Task * task) { Assert(m_task == NULL); m_task = task; } + private: + bool m_signal; + Task * m_task; +}; + +class TaskEvent : public BaseEvent { + public: + TaskEvent(Task * taskWaited); + Task * taskWaited() { return m_taskWaited; } + private: + Task * m_taskWaited; +}; + +}; class Task { public: @@ -19,20 +46,25 @@ class Task { Task(); virtual ~Task(); virtual const char * getName() = 0; - Status getStatus() { return status; } + Status getStatus() { return m_status; } + static Task * getCurrentTask(); protected: void suspend(); virtual void Do() = 0; + void waitFor(Events::BaseEvent * event); private: size_t stackSize() { return 128 * 1024; } void switchTo(); static void coroutine(void *); - void * stack; - coro_context ctx; - TaskMan * taskMan; - Status status; - void * tls; + void * m_stack; + coro_context m_ctx; + TaskMan * m_taskMan; + Status m_status; + void * m_tls; friend class TaskMan; + friend class Events::TaskEvent; + typedef std::vector<Events::TaskEvent *> waitedByList_t; + waitedByList_t m_waitedBy; }; }; diff --git a/includes/TaskMan.h b/includes/TaskMan.h index ac95f71..585fb7f 100644 --- a/includes/TaskMan.h +++ b/includes/TaskMan.h @@ -1,8 +1,9 @@ #pragma once +#include <stdint.h> #include <coro.h> #include <ext/hash_set> -#include <stdint.h> +#include <vector> namespace gnu = __gnu_cxx; @@ -15,18 +16,20 @@ class TaskMan { TaskMan(); ~TaskMan(); void mainLoop(); - void stop() { stopped = true; } + void stop() { m_stopped = true; } static TaskMan * getTaskMan(); private: void registerTask(Task * t); void unregisterTask(Task * t); - coro_context returnContext; + coro_context m_returnContext; friend class Task; - struct taskHash { size_t operator()(const Task * t) const { return reinterpret_cast<uintptr_t>(t); } }; - typedef gnu::hash_set<Task *, taskHash> taskList; - taskList tasks, pendingAdd; - volatile bool stopped; + struct taskHasher { size_t operator()(const Task * t) const { return reinterpret_cast<uintptr_t>(t); } }; + typedef gnu::hash_set<Task *, taskHasher> taskHash_t; + typedef std::vector<Task *> taskList_t; + taskHash_t m_tasks; + taskList_t m_pendingAdd; + volatile bool m_stopped; }; }; diff --git a/src/Local.cc b/src/Local.cc index c354db9..1619557 100644 --- a/src/Local.cc +++ b/src/Local.cc @@ -18,7 +18,7 @@ void * Balau::TLSManager::createTLS() { } static Balau::TLSManager dummyTLSManager; -Balau::TLSManager * Balau::tlsManager = &dummyTLSManager; +Balau::TLSManager * Balau::g_tlsManager = &dummyTLSManager; int Balau::Local::s_size = 0; void ** Balau::Local::m_globals = 0; diff --git a/src/Main.cc b/src/Main.cc index 04d867f..eb5e589 100644 --- a/src/Main.cc +++ b/src/Main.cc @@ -31,4 +31,4 @@ Balau::AtExit::AtExit(int priority) : m_priority(priority) { *ptr = this; } -Balau::Main * Balau::Main::application = NULL; +Balau::Main * Balau::Main::s_application = NULL; diff --git a/src/Task.cc b/src/Task.cc index 9bdad0c..73a2329 100644 --- a/src/Task.cc +++ b/src/Task.cc @@ -4,50 +4,72 @@ #include "Printer.h" #include "Local.h" +static Balau::LocalTmpl<Balau::Task> localTask; + Balau::Task::Task() { size_t size = stackSize(); - stack = malloc(size); - coro_create(&ctx, coroutine, this, stack, size); - taskMan = TaskMan::getTaskMan(); - taskMan->registerTask(this); - tls = tlsManager->createTLS(); - status = STARTING; + m_stack = malloc(size); + coro_create(&m_ctx, coroutine, this, m_stack, size); + + m_taskMan = TaskMan::getTaskMan(); + m_taskMan->registerTask(this); + + m_tls = g_tlsManager->createTLS(); + void * oldTLS = g_tlsManager->getTLS(); + g_tlsManager->setTLS(m_tls); + localTask.set(this); + g_tlsManager->setTLS(oldTLS); + + m_status = STARTING; } Balau::Task::~Task() { - free(stack); - free(tls); + free(m_stack); + free(m_tls); } void Balau::Task::coroutine(void * arg) { Task * task = reinterpret_cast<Task *>(arg); Assert(task); try { - task->status = RUNNING; + task->m_status = RUNNING; task->Do(); - task->status = STOPPED; + task->m_status = STOPPED; } catch (GeneralException & e) { Printer::log(M_WARNING, "Task %s caused an exception: `%s' - stopping.", task->getName(), e.getMsg()); - task->status = FAULTED; + task->m_status = FAULTED; } catch (...) { Printer::log(M_WARNING, "Task %s caused an unknown exception - stopping.", task->getName()); - task->status = FAULTED; + task->m_status = FAULTED; } - coro_transfer(&task->ctx, &task->taskMan->returnContext); + coro_transfer(&task->m_ctx, &task->m_taskMan->m_returnContext); } void Balau::Task::switchTo() { - status = RUNNING; - void * oldTLS = tlsManager->getTLS(); - tlsManager->setTLS(tls); - coro_transfer(&taskMan->returnContext, &ctx); - tlsManager->setTLS(oldTLS); - if (status == RUNNING) - status = IDLE; + m_status = RUNNING; + void * oldTLS = g_tlsManager->getTLS(); + g_tlsManager->setTLS(m_tls); + coro_transfer(&m_taskMan->m_returnContext, &m_ctx); + g_tlsManager->setTLS(oldTLS); + if (m_status == RUNNING) + m_status = IDLE; } void Balau::Task::suspend() { - coro_transfer(&ctx, &taskMan->returnContext); + coro_transfer(&m_ctx, &m_taskMan->m_returnContext); +} + +Balau::Task * Balau::Task::getCurrentTask() { + return localTask.get(); +} + +void Balau::Task::waitFor(Balau::Events::BaseEvent * e) { + e->registerOwner(this); + // probably have to register the event in the Task manager +} + +Balau::Events::TaskEvent::TaskEvent(Task * taskWaited) : m_taskWaited(taskWaited) { + m_taskWaited->m_waitedBy.push_back(this); } diff --git a/src/TaskMan.cc b/src/TaskMan.cc index bbaf35c..6730f13 100644 --- a/src/TaskMan.cc +++ b/src/TaskMan.cc @@ -6,8 +6,8 @@ static Balau::DefaultTmpl<Balau::TaskMan> defaultTaskMan(50); static Balau::LocalTmpl<Balau::TaskMan> localTaskMan; -Balau::TaskMan::TaskMan() : stopped(false) { - coro_create(&returnContext, 0, 0, 0, 0); +Balau::TaskMan::TaskMan() : m_stopped(false) { + coro_create(&m_returnContext, 0, 0, 0, 0); if (!localTaskMan.getGlobal()) localTaskMan.setGlobal(this); } @@ -21,21 +21,13 @@ Balau::TaskMan::~TaskMan() { void Balau::TaskMan::mainLoop() { // We need at least one round before bailing :) do { - taskList::iterator i; + taskList_t::iterator iL; + taskHash_t::iterator iH; Task * t; - // lock pending - // Adding tasks that were added, maybe from other threads - for (i = pendingAdd.begin(); i != pendingAdd.end(); i++) { - Assert(tasks.find(*i) == tasks.end()); - tasks.insert(*i); - } - pendingAdd.clear(); - // unlock pending - // checking "STARTING" tasks, and running them once - for (i = tasks.begin(); i != tasks.end(); i++) { - t = *i; + for (iH = m_tasks.begin(); iH != m_tasks.end(); iH++) { + t = *iH; if (t->getStatus() == Task::STARTING) { t->switchTo(); } @@ -43,25 +35,53 @@ void Balau::TaskMan::mainLoop() { // That's probably where we poll for events - // checking "STOPPED" tasks, and destroying them + // lock pending + // Adding tasks that were added, maybe from other threads + for (iL = m_pendingAdd.begin(); iL != m_pendingAdd.end(); iL++) { + t = *iL; + Assert(m_tasks.find(t) == m_tasks.end()); + m_tasks.insert(t); + } + m_pendingAdd.clear(); + // unlock pending + + // Dealing with stopped and faulted tasks. + // First by signalling the waiters. + for (iH = m_tasks.begin(); iH != m_tasks.end(); iH++) { + t = *iH; + if (((t->getStatus() == Task::STOPPED) || (t->getStatus() == Task::FAULTED)) && + (t->m_waitedBy.size() != 0)) { + Task::waitedByList_t::iterator i; + while ((i = t->m_waitedBy.begin()) != t->m_waitedBy.end()) { + Events::TaskEvent * e = *i; + e->doSignal(); + e->taskWaiting()->switchTo(); + t->m_waitedBy.erase(i); + } + } + } + + // Then, by destroying them. bool didDelete; do { didDelete = false; - for (i = tasks.begin(); i != tasks.end(); i++) { - t = *i; - if ((t->getStatus() == Task::STOPPED) || (t->getStatus() == Task::FAULTED)) { + for (iH = m_tasks.begin(); iH != m_tasks.end(); iH++) { + t = *iH; + if (((t->getStatus() == Task::STOPPED) || (t->getStatus() == Task::FAULTED)) && + (t->m_waitedBy.size() == 0)) { delete t; - tasks.erase(i); + m_tasks.erase(iH); didDelete = true; break; } } } while (didDelete); - } while (!stopped && tasks.size() != 0); + + } while (!m_stopped && m_tasks.size() != 0); } void Balau::TaskMan::registerTask(Balau::Task * t) { // lock pending - pendingAdd.insert(t); + m_pendingAdd.push_back(t); // unlock pending } diff --git a/tests/test-Sanity.cc b/tests/test-Sanity.cc index 1495f76..52a85cc 100644 --- a/tests/test-Sanity.cc +++ b/tests/test-Sanity.cc @@ -4,11 +4,10 @@ BALAU_STARTUP; using namespace Balau; -int Application::startup() throw (Balau::GeneralException) { +void MainTask::Do() { Printer::log(M_STATUS, "Test::Sanity running."); Assert(sizeof(off_t) == 8); Printer::log(M_STATUS, "Test::Sanity passed."); - return 0; } diff --git a/tests/test-String.cc b/tests/test-String.cc index 36c87e4..8e6eacb 100644 --- a/tests/test-String.cc +++ b/tests/test-String.cc @@ -5,7 +5,7 @@ BALAU_STARTUP; using namespace Balau; -int Application::startup() throw (Balau::GeneralException) { +void MainTask::Do() { Printer::log(M_STATUS, "Test::String running."); String x = "foobar"; @@ -46,5 +46,4 @@ int Application::startup() throw (Balau::GeneralException) { Assert(((unsigned char) y[0]) == 0xe9); Printer::log(M_STATUS, "Test::String passed."); - return 0; } diff --git a/tests/test-Tasks.cc b/tests/test-Tasks.cc index 2eedb47..dcc692c 100644 --- a/tests/test-Tasks.cc +++ b/tests/test-Tasks.cc @@ -11,23 +11,25 @@ class CustomPrinter : public Printer { static CustomPrinter * customPrinter = NULL; -class MainTask : public Task { +class TestTask : public Task { public: virtual const char * getName() { return "MainTask"; } private: virtual void Do() { + Printer::log(M_STATUS, "xyz"); customPrinter->setLocal(); Printer::enable(M_ALL); - Printer::log(M_DEBUG, "In MainTask::Do()"); + Printer::log(M_DEBUG, "In TestTask::Do()"); } }; -int Application::startup() throw (Balau::GeneralException) { +void MainTask::Do() { customPrinter = new CustomPrinter(); Printer::log(M_STATUS, "Test::Tasks running."); - Task * mainTask = new MainTask(); - TaskMan::getTaskMan()->mainLoop(); + Task * testTask = new TestTask(); + Events::TaskEvent e(testTask); + waitFor(&e); + suspend(); Printer::log(M_STATUS, "Test::Tasks passed."); Printer::log(M_DEBUG, "You shouldn't see that message."); - return 0; } |