diff options
-rw-r--r-- | includes/Main.h | 6 | ||||
-rw-r--r-- | includes/Socket.h | 3 | ||||
-rw-r--r-- | includes/Task.h | 5 | ||||
-rw-r--r-- | includes/TaskMan.h | 30 | ||||
-rw-r--r-- | includes/Threads.h | 7 | ||||
-rw-r--r-- | src/Handle.cc | 2 | ||||
-rw-r--r-- | src/HttpServer.cc | 2 | ||||
-rw-r--r-- | src/Task.cc | 45 | ||||
-rw-r--r-- | src/TaskMan.cc | 124 | ||||
-rw-r--r-- | tests/test-Sockets.cc | 4 | ||||
-rw-r--r-- | tests/test-Tasks.cc | 2 |
11 files changed, 177 insertions, 53 deletions
diff --git a/includes/Main.h b/includes/Main.h index 3f81137..5fa64c1 100644 --- a/includes/Main.h +++ b/includes/Main.h @@ -45,7 +45,7 @@ namespace Balau { class MainTask : public Task { public: MainTask() : m_stopTaskManOnExit(true) { } - virtual ~MainTask() { if (m_stopTaskManOnExit) TaskMan::getTaskMan()->stop(); } + virtual ~MainTask() { if (m_stopTaskManOnExit) TaskMan::stop(); } virtual const char * getName() { return "Main Task"; } virtual void Do(); void stopTaskManOnExit(bool v) { m_stopTaskManOnExit = v; } @@ -77,8 +77,8 @@ class Main { try { m_status = RUNNING; - MainTask * mainTask = new MainTask(); - TaskMan::getTaskMan()->mainLoop(); + MainTask * mainTask = createTask(new MainTask()); + TaskMan::getDefaultTaskMan()->mainLoop(); m_status = STOPPING; } catch (Exit e) { diff --git a/includes/Socket.h b/includes/Socket.h index 164090c..a19b1e2 100644 --- a/includes/Socket.h +++ b/includes/Socket.h @@ -8,6 +8,7 @@ #include <netdb.h> #endif #include <Handle.h> +#include <TaskMan.h> #include <Task.h> #include <Printer.h> @@ -78,7 +79,7 @@ class Listener : public ListenerBase { public: Listener(int port, const char * local = "", void * opaque = NULL) : ListenerBase(port, local, opaque) { } protected: - virtual void factory(IO<Socket> & io, void * opaque) { new Worker(io, opaque); } + virtual void factory(IO<Socket> & io, void * opaque) { createTask(new Worker(io, opaque)); } virtual void setName() { m_name = String(ClassName(this).c_str()) + " - " + m_listener->getName(); } }; diff --git a/includes/Task.h b/includes/Task.h index 156add2..22c522e 100644 --- a/includes/Task.h +++ b/includes/Task.h @@ -128,10 +128,13 @@ class Task { m_okayToEAgain = enable; return oldValue; } + TaskMan * getMyTaskMan() { return m_taskMan; } private: size_t stackSize() { return 128 * 1024; } + void setup(TaskMan * taskMan); void switchTo(); - static void CALLBACK coroutine(void *); + static void CALLBACK coroutineTrampoline(void *); + void coroutine(); void * m_stack; #ifndef _WIN32 coro_context m_ctx; diff --git a/includes/TaskMan.h b/includes/TaskMan.h index b4645fe..afc2b8a 100644 --- a/includes/TaskMan.h +++ b/includes/TaskMan.h @@ -8,40 +8,54 @@ #include <ext/hash_set> #include <vector> #include <Threads.h> +#include <Exceptions.h> namespace gnu = __gnu_cxx; namespace Balau { class Task; +class TaskScheduler; + +namespace Events { + +class Async; + +}; class TaskMan { public: TaskMan(); ~TaskMan(); void mainLoop(); - void stop() { m_stopped = true; } - static TaskMan * getTaskMan(); + static TaskMan * getDefaultTaskMan(); struct ev_loop * getLoop() { return m_loop; } void signalTask(Task * t); - + static void stop(); + void stopMe() { m_stopped = true; } private: - void registerTask(Task * t); + static void registerTask(Task * t); + void addToPending(Task * t); #ifndef _WIN32 coro_context m_returnContext; #else void * m_fiber; #endif friend class Task; + friend class TaskScheduler; + template<class T> + friend T * createTask(T * t); struct taskHasher { size_t operator()(const Task * t) const { return reinterpret_cast<uintptr_t>(t); } }; typedef gnu::hash_set<Task *, taskHasher> taskHash_t; - typedef std::vector<Task *> taskList_t; taskHash_t m_tasks, m_signaledTasks; - taskList_t m_pendingAdd; - Lock m_pendingLock; - volatile bool m_stopped; + Queue<Task *> m_pendingAdd; + bool m_stopped; struct ev_loop * m_loop; bool m_allowedToSignal; + ev::async m_evt; }; +template<class T> +T * createTask(T * t) { TaskMan::registerTask(t); Assert(dynamic_cast<Task *>(t)); return t; } + }; diff --git a/includes/Threads.h b/includes/Threads.h index 2347a84..a1f270e 100644 --- a/includes/Threads.h +++ b/includes/Threads.h @@ -57,6 +57,13 @@ class Queue { m_lock.leave(); return t; } + int size() { + int r; + m_lock.enter(); + r = m_queue.size(); + m_lock.leave(); + return r; + } private: std::queue<T> m_queue; Lock m_lock; diff --git a/src/Handle.cc b/src/Handle.cc index 839c58a..b8dd45e 100644 --- a/src/Handle.cc +++ b/src/Handle.cc @@ -43,7 +43,7 @@ void eioInterface::readyCB(ev::async & w, int revents) { void eioInterface::doStart() { Balau::Printer::elog(Balau::E_HANDLE, "Starting the eio interface"); - Balau::TaskMan * taskMan = Balau::TaskMan::getTaskMan(); + Balau::TaskMan * taskMan = Balau::TaskMan::getDefaultTaskMan(); Assert(taskMan); struct ev_loop * loop = taskMan->getLoop(); diff --git a/src/HttpServer.cc b/src/HttpServer.cc index 608b8a8..b027824 100644 --- a/src/HttpServer.cc +++ b/src/HttpServer.cc @@ -354,7 +354,7 @@ typedef Balau::Listener<Balau::HttpWorker> HttpListener; void Balau::HttpServer::start() { Assert(!m_started); - m_listenerPtr = new HttpListener(m_port, m_local.to_charp()); + m_listenerPtr = createTask(new HttpListener(m_port, m_local.to_charp())); m_started = true; } diff --git a/src/Task.cc b/src/Task.cc index 3772dd5..3a6f8ad 100644 --- a/src/Task.cc +++ b/src/Task.cc @@ -7,28 +7,29 @@ static Balau::LocalTmpl<Balau::Task> localTask; Balau::Task::Task() { + m_status = STARTING; + m_okayToEAgain = false; + + Printer::elog(E_TASK, "Created a Task at %p"); +} + +void Balau::Task::setup(TaskMan * taskMan) { size_t size = stackSize(); #ifndef _WIN32 m_stack = malloc(size); - coro_create(&m_ctx, coroutine, this, m_stack, size); + coro_create(&m_ctx, coroutineTrampoline, this, m_stack, size); #else m_stack = NULL; - m_fiber = CreateFiber(size, coroutine, this); + m_fiber = CreateFiber(size, coroutineTrampoline, this); #endif - m_taskMan = TaskMan::getTaskMan(); - m_taskMan->registerTask(this); + m_taskMan = taskMan; m_tls = g_tlsManager->createTLS(); void * oldTLS = g_tlsManager->getTLS(); g_tlsManager->setTLS(m_tls); localTask.set(this); g_tlsManager->setTLS(oldTLS); - - m_status = STARTING; - m_okayToEAgain = false; - - Printer::elog(E_TASK, "Created a Task at %p"); } Balau::Task::~Task() { @@ -37,27 +38,31 @@ Balau::Task::~Task() { free(m_tls); } -void Balau::Task::coroutine(void * arg) { +void Balau::Task::coroutineTrampoline(void * arg) { Task * task = reinterpret_cast<Task *>(arg); Assert(task); - Assert(task->m_status == STARTING); + task->coroutine(); +} + +void Balau::Task::coroutine() { + Assert(m_status == STARTING); try { - task->m_status = RUNNING; - task->Do(); - task->m_status = STOPPED; + m_status = RUNNING; + Do(); + m_status = STOPPED; } catch (GeneralException & e) { - Printer::log(M_WARNING, "Task %s at %p caused an exception: `%s' - stopping.", task->getName(), task, e.getMsg()); - task->m_status = FAULTED; + Printer::log(M_WARNING, "Task %s at %p caused an exception: `%s' - stopping.", getName(), this, e.getMsg()); + m_status = FAULTED; } catch (...) { - Printer::log(M_WARNING, "Task %s at %p caused an unknown exception - stopping.", task->getName(), task); - task->m_status = FAULTED; + Printer::log(M_WARNING, "Task %s at %p caused an unknown exception - stopping.", getName(), this); + m_status = FAULTED; } #ifndef _WIN32 - coro_transfer(&task->m_ctx, &task->m_taskMan->m_returnContext); + coro_transfer(&m_ctx, &m_taskMan->m_returnContext); #else - SwitchToFiber(task->m_taskMan->m_fiber); + SwitchToFiber(m_taskMan->m_fiber); #endif } diff --git a/src/TaskMan.cc b/src/TaskMan.cc index 0fc4668..8c99e52 100644 --- a/src/TaskMan.cc +++ b/src/TaskMan.cc @@ -3,9 +3,94 @@ #include "Main.h" #include "Local.h" +class Stopper : public Balau::Task { + virtual void Do(); + virtual const char * getName(); +}; + +void Stopper::Do() { + getMyTaskMan()->stopMe(); +} + +const char * Stopper::getName() { + return "Stopper"; +} + static Balau::DefaultTmpl<Balau::TaskMan> defaultTaskMan(50); static Balau::LocalTmpl<Balau::TaskMan> localTaskMan; +namespace Balau { + +class TaskScheduler : public Thread, public AtStart, public AtExit { + public: + TaskScheduler() : AtStart(100), m_stopping(false) { } + void registerTask(Task * t); + virtual void * proc(); + virtual void doStart(); + virtual void doExit(); + void registerTaskMan(TaskMan * t); + void unregisterTaskMan(TaskMan * t); + void stopAll(); + private: + Queue<Task *> m_queue; + volatile bool m_stopping; +}; + +}; + +static Balau::TaskScheduler s_scheduler; + +void Balau::TaskScheduler::registerTask(Task * t) { + Printer::elog(E_TASK, "TaskScheduler::registerTask with t = %p", t); + m_queue.push(t); +} + +void Balau::TaskScheduler::registerTaskMan(TaskMan * t) { + // meh. We need a round-robin queue system. +} + +void Balau::TaskScheduler::unregisterTaskMan(TaskMan * t) { + // and here, we need to remove that taskman from the round robin queue. +} + +void Balau::TaskScheduler::stopAll() { + m_stopping = true; + // and finally, we need to crawl the whole list and stop all of them. + TaskMan * tm = localTaskMan.getGlobal(); + tm->addToPending(new Stopper()); +} + +void * Balau::TaskScheduler::proc() { + while (true) { + Printer::elog(E_TASK, "TaskScheduler waiting for a task to pop"); + Task * t = m_queue.pop(); + if (!t) + break; + if (dynamic_cast<Stopper *>(t) || m_stopping) + break; + // pick up a task manager here... for now let's take the global one. + // but we need some sort of round robin across all of the threads, as described above. + TaskMan * tm = localTaskMan.getGlobal(); + Printer::elog(E_TASK, "TaskScheduler popped task %s at %p; adding to TaskMan %p", t->getName(), t, tm); + tm->addToPending(t); + tm->m_evt.send(); + } + Printer::elog(E_TASK, "TaskScheduler stopping."); + return NULL; +} + +void Balau::TaskScheduler::doStart() { + threadStart(); +} + +void Balau::TaskScheduler::doExit() { + Task * s = NULL; + m_queue.push(s); + join(); +} + +void asyncDummy(ev::async & w, int revents) { } + Balau::TaskMan::TaskMan() : m_stopped(false), m_allowedToSignal(false) { #ifndef _WIN32 coro_create(&m_returnContext, 0, 0, 0, 0); @@ -13,12 +98,17 @@ Balau::TaskMan::TaskMan() : m_stopped(false), m_allowedToSignal(false) { m_fiber = ConvertThreadToFiber(NULL); Assert(m_fiber); #endif - if (!localTaskMan.getGlobal()) { + TaskMan * global = localTaskMan.getGlobal(); + if (!global) { localTaskMan.setGlobal(this); m_loop = ev_default_loop(EVFLAG_AUTO); } else { m_loop = ev_loop_new(EVFLAG_AUTO); } + m_evt.set(m_loop); + m_evt.set<asyncDummy>(); + m_evt.start(); + s_scheduler.registerTaskMan(this); } #ifdef _WIN32 @@ -35,17 +125,18 @@ class WinSocketStartup : public Balau::AtStart { static WinSocketStartup wsa; #endif -Balau::TaskMan * Balau::TaskMan::getTaskMan() { return localTaskMan.get(); } +Balau::TaskMan * Balau::TaskMan::getDefaultTaskMan() { return localTaskMan.get(); } Balau::TaskMan::~TaskMan() { Assert(localTaskMan.getGlobal() != this); + s_scheduler.unregisterTaskMan(this); + // probably way more work to do here in order to clean up tasks from that thread ev_loop_destroy(m_loop); } void Balau::TaskMan::mainLoop() { // We need at least one round before bailing :) do { - taskList_t::iterator iL; taskHash_t::iterator iH; Task * t; bool noWait = false; @@ -65,15 +156,13 @@ void Balau::TaskMan::mainLoop() { if (m_tasks.size() == 0) noWait = true; - m_pendingLock.enter(); if (m_pendingAdd.size() != 0) noWait = true; - m_pendingLock.leave(); // libev's event "loop". We always runs it once though. m_allowedToSignal = true; Printer::elog(E_TASK, "Going to libev main loop"); - ev_run(m_loop, noWait ? EVRUN_NOWAIT : EVRUN_ONCE); + ev_run(m_loop, noWait || m_stopped ? EVRUN_NOWAIT : EVRUN_ONCE); Printer::elog(E_TASK, "Getting out of libev main loop"); // let's check what task got stopped, and signal them @@ -99,15 +188,13 @@ void Balau::TaskMan::mainLoop() { } m_signaledTasks.clear(); - m_pendingLock.enter(); // Adding tasks that were added, maybe from other threads - for (iL = m_pendingAdd.begin(); iL != m_pendingAdd.end(); iL++) { - t = *iL; + while ((m_pendingAdd.size() != 0) || (m_tasks.size() == 0) && !m_stopped) { + t = m_pendingAdd.pop(); Assert(m_tasks.find(t) == m_tasks.end()); + t->setup(this); m_tasks.insert(t); } - m_pendingAdd.clear(); - m_pendingLock.leave(); // Finally, let's destroy tasks that no longer are necessary. bool didDelete; @@ -125,13 +212,16 @@ void Balau::TaskMan::mainLoop() { } } while (didDelete); - } while (!m_stopped && m_tasks.size() != 0); + } while (!m_stopped); + Printer::elog(E_TASK, "TaskManager stopping."); } void Balau::TaskMan::registerTask(Balau::Task * t) { - m_pendingLock.enter(); - m_pendingAdd.push_back(t); - m_pendingLock.leave(); + s_scheduler.registerTask(t); +} + +void Balau::TaskMan::addToPending(Balau::Task * t) { + m_pendingAdd.push(t); } void Balau::TaskMan::signalTask(Task * t) { @@ -139,3 +229,7 @@ void Balau::TaskMan::signalTask(Task * t) { Assert(m_allowedToSignal); m_signaledTasks.insert(t); } + +void Balau::TaskMan::stop() { + s_scheduler.stopAll(); +} diff --git a/tests/test-Sockets.cc b/tests/test-Sockets.cc index 0afedb1..87d557b 100644 --- a/tests/test-Sockets.cc +++ b/tests/test-Sockets.cc @@ -64,8 +64,8 @@ void MainTask::Do() { Printer::enable(M_ALL); Printer::log(M_STATUS, "Test::Sockets running."); - Events::TaskEvent evtSvr(listener = new Listener<Worker>(1234)); - Events::TaskEvent evtCln(new Client); + Events::TaskEvent evtSvr(listener = Balau::createTask(new Listener<Worker>(1234))); + Events::TaskEvent evtCln(Balau::createTask(new Client)); Printer::log(M_STATUS, "Created %s", listener->getName()); waitFor(&evtSvr); waitFor(&evtCln); diff --git a/tests/test-Tasks.cc b/tests/test-Tasks.cc index 2760e52..8677fb7 100644 --- a/tests/test-Tasks.cc +++ b/tests/test-Tasks.cc @@ -32,7 +32,7 @@ void MainTask::Do() { customPrinter = new CustomPrinter(); Printer::log(M_STATUS, "Test::Tasks running."); - Task * testTask = new TestTask(); + Task * testTask = Balau::createTask(new TestTask()); Events::TaskEvent taskEvt(testTask); waitFor(&taskEvt); Assert(!taskEvt.gotSignal()); |