diff options
-rw-r--r-- | includes/Task.h | 13 | ||||
-rw-r--r-- | src/Task.cc | 41 | ||||
-rw-r--r-- | tests/test-Tasks.cc | 15 |
3 files changed, 60 insertions, 9 deletions
diff --git a/includes/Task.h b/includes/Task.h index 8a504c7..c4d9a46 100644 --- a/includes/Task.h +++ b/includes/Task.h @@ -18,8 +18,9 @@ class BaseEvent { BaseEvent() : m_signal(false), m_task(NULL) { } bool gotSignal() { return m_signal; } void doSignal(); + void reset() { Assert(m_task != NULL); m_signal = false; gotOwner(m_task); } Task * taskWaiting() { Assert(m_task); return m_task; } - void registerOwner(Task * task) { Assert(m_task == NULL); m_task = task; gotOwner(task); } + void registerOwner(Task * task) { if (m_task == task) return; Assert(m_task == NULL); m_task = task; gotOwner(task); } protected: virtual void gotOwner(Task * task) { } private: @@ -31,6 +32,7 @@ class Timeout : public BaseEvent { public: Timeout(ev_tstamp tstamp); void evt_cb(ev::timer & w, int revents); + void set(ev_tstamp tstamp); private: virtual void gotOwner(Task * task); ev::timer m_evt; @@ -69,12 +71,14 @@ class Task { virtual const char * getName() = 0; Status getStatus() { return m_status; } static Task * getCurrentTask(); - static void yield(Events::BaseEvent * evt) { Task * t = getCurrentTask(); t->waitFor(evt); t->yield(); } + static void yield(Events::BaseEvent * evt) { Task * t = getCurrentTask(); t->waitFor(evt, true); t->yield(true); } TaskMan * getTaskMan() { return m_taskMan; } + struct ev_loop * getLoop(); protected: - void yield(); + void yield(bool override = false); virtual void Do() = 0; - void waitFor(Events::BaseEvent * event); + void waitFor(Events::BaseEvent * event, bool override = false); + void setPreemptible(bool enable); private: size_t stackSize() { return 128 * 1024; } void switchTo(); @@ -88,6 +92,7 @@ class Task { friend class Events::TaskEvent; typedef std::vector<Events::TaskEvent *> waitedByList_t; waitedByList_t m_waitedBy; + struct ev_loop * m_loop; }; }; diff --git a/src/Task.cc b/src/Task.cc index a47bd8b..938b3de 100644 --- a/src/Task.cc +++ b/src/Task.cc @@ -21,9 +21,12 @@ Balau::Task::Task() { g_tlsManager->setTLS(oldTLS); m_status = STARTING; + m_loop = NULL; } Balau::Task::~Task() { + if (m_loop) + ev_loop_destroy(m_loop); free(m_stack); free(m_tls); } @@ -57,16 +60,40 @@ void Balau::Task::switchTo() { m_status = IDLE; } -void Balau::Task::yield() { - coro_transfer(&m_ctx, &m_taskMan->m_returnContext); +void Balau::Task::yield(bool override) { + if (m_loop && override) { + ev_run(m_loop, 0); + } else { + coro_transfer(&m_ctx, &m_taskMan->m_returnContext); + } } Balau::Task * Balau::Task::getCurrentTask() { return localTask.get(); } -void Balau::Task::waitFor(Balau::Events::BaseEvent * e) { +void Balau::Task::waitFor(Balau::Events::BaseEvent * e, bool override) { + struct ev_loop * loop = m_loop; + if (!override) + m_loop = NULL; e->registerOwner(this); + m_loop = loop; +} + +void Balau::Task::setPreemptible(bool enable) { + if (!m_loop && !enable) { + m_loop = ev_loop_new(EVFLAG_AUTO); + } else if (m_loop) { + ev_loop_destroy(m_loop); + m_loop = NULL; + } +} + +struct ev_loop * Balau::Task::getLoop() { + if (m_loop) + return m_loop; + else + return getTaskMan()->getLoop(); } void Balau::Events::BaseEvent::doSignal() { @@ -79,12 +106,16 @@ Balau::Events::TaskEvent::TaskEvent(Task * taskWaited) : m_taskWaited(taskWaited } Balau::Events::Timeout::Timeout(ev_tstamp tstamp) { + set(tstamp); +} + +void Balau::Events::Timeout::set(ev_tstamp tstamp) { m_evt.set<Timeout, &Timeout::evt_cb>(this); m_evt.set(tstamp); } void Balau::Events::Timeout::gotOwner(Task * task) { - m_evt.set(task->getTaskMan()->getLoop()); + m_evt.set(task->getLoop()); m_evt.start(); } @@ -93,5 +124,5 @@ void Balau::Events::Timeout::evt_cb(ev::timer & w, int revents) { } void Balau::Events::Custom::gotOwner(Task * task) { - m_loop = task->getTaskMan()->getLoop(); + m_loop = task->getLoop(); } diff --git a/tests/test-Tasks.cc b/tests/test-Tasks.cc index 1faa194..bdafa8d 100644 --- a/tests/test-Tasks.cc +++ b/tests/test-Tasks.cc @@ -22,6 +22,12 @@ class TestTask : public Task { } }; +static void yieldingFunction() { + Events::Timeout timeout(0.2); + Task::yield(&timeout); + Assert(timeout.gotSignal()); +} + void MainTask::Do() { customPrinter = new CustomPrinter(); Printer::log(M_STATUS, "Test::Tasks running."); @@ -39,6 +45,15 @@ void MainTask::Do() { yield(); Assert(timeout.gotSignal()); + timeout.set(0.1); + timeout.reset(); + setPreemptible(false); + yieldingFunction(); + Assert(!timeout.gotSignal()); + waitFor(&timeout); + yield(); + Assert(timeout.gotSignal()); + Printer::log(M_STATUS, "Test::Tasks passed."); Printer::log(M_DEBUG, "You shouldn't see that message."); } |