summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--includes/Task.h13
-rw-r--r--src/Task.cc41
-rw-r--r--tests/test-Tasks.cc15
3 files changed, 60 insertions, 9 deletions
diff --git a/includes/Task.h b/includes/Task.h
index 8a504c7..c4d9a46 100644
--- a/includes/Task.h
+++ b/includes/Task.h
@@ -18,8 +18,9 @@ class BaseEvent {
BaseEvent() : m_signal(false), m_task(NULL) { }
bool gotSignal() { return m_signal; }
void doSignal();
+ void reset() { Assert(m_task != NULL); m_signal = false; gotOwner(m_task); }
Task * taskWaiting() { Assert(m_task); return m_task; }
- void registerOwner(Task * task) { Assert(m_task == NULL); m_task = task; gotOwner(task); }
+ void registerOwner(Task * task) { if (m_task == task) return; Assert(m_task == NULL); m_task = task; gotOwner(task); }
protected:
virtual void gotOwner(Task * task) { }
private:
@@ -31,6 +32,7 @@ class Timeout : public BaseEvent {
public:
Timeout(ev_tstamp tstamp);
void evt_cb(ev::timer & w, int revents);
+ void set(ev_tstamp tstamp);
private:
virtual void gotOwner(Task * task);
ev::timer m_evt;
@@ -69,12 +71,14 @@ class Task {
virtual const char * getName() = 0;
Status getStatus() { return m_status; }
static Task * getCurrentTask();
- static void yield(Events::BaseEvent * evt) { Task * t = getCurrentTask(); t->waitFor(evt); t->yield(); }
+ static void yield(Events::BaseEvent * evt) { Task * t = getCurrentTask(); t->waitFor(evt, true); t->yield(true); }
TaskMan * getTaskMan() { return m_taskMan; }
+ struct ev_loop * getLoop();
protected:
- void yield();
+ void yield(bool override = false);
virtual void Do() = 0;
- void waitFor(Events::BaseEvent * event);
+ void waitFor(Events::BaseEvent * event, bool override = false);
+ void setPreemptible(bool enable);
private:
size_t stackSize() { return 128 * 1024; }
void switchTo();
@@ -88,6 +92,7 @@ class Task {
friend class Events::TaskEvent;
typedef std::vector<Events::TaskEvent *> waitedByList_t;
waitedByList_t m_waitedBy;
+ struct ev_loop * m_loop;
};
};
diff --git a/src/Task.cc b/src/Task.cc
index a47bd8b..938b3de 100644
--- a/src/Task.cc
+++ b/src/Task.cc
@@ -21,9 +21,12 @@ Balau::Task::Task() {
g_tlsManager->setTLS(oldTLS);
m_status = STARTING;
+ m_loop = NULL;
}
Balau::Task::~Task() {
+ if (m_loop)
+ ev_loop_destroy(m_loop);
free(m_stack);
free(m_tls);
}
@@ -57,16 +60,40 @@ void Balau::Task::switchTo() {
m_status = IDLE;
}
-void Balau::Task::yield() {
- coro_transfer(&m_ctx, &m_taskMan->m_returnContext);
+void Balau::Task::yield(bool override) {
+ if (m_loop && override) {
+ ev_run(m_loop, 0);
+ } else {
+ coro_transfer(&m_ctx, &m_taskMan->m_returnContext);
+ }
}
Balau::Task * Balau::Task::getCurrentTask() {
return localTask.get();
}
-void Balau::Task::waitFor(Balau::Events::BaseEvent * e) {
+void Balau::Task::waitFor(Balau::Events::BaseEvent * e, bool override) {
+ struct ev_loop * loop = m_loop;
+ if (!override)
+ m_loop = NULL;
e->registerOwner(this);
+ m_loop = loop;
+}
+
+void Balau::Task::setPreemptible(bool enable) {
+ if (!m_loop && !enable) {
+ m_loop = ev_loop_new(EVFLAG_AUTO);
+ } else if (m_loop) {
+ ev_loop_destroy(m_loop);
+ m_loop = NULL;
+ }
+}
+
+struct ev_loop * Balau::Task::getLoop() {
+ if (m_loop)
+ return m_loop;
+ else
+ return getTaskMan()->getLoop();
}
void Balau::Events::BaseEvent::doSignal() {
@@ -79,12 +106,16 @@ Balau::Events::TaskEvent::TaskEvent(Task * taskWaited) : m_taskWaited(taskWaited
}
Balau::Events::Timeout::Timeout(ev_tstamp tstamp) {
+ set(tstamp);
+}
+
+void Balau::Events::Timeout::set(ev_tstamp tstamp) {
m_evt.set<Timeout, &Timeout::evt_cb>(this);
m_evt.set(tstamp);
}
void Balau::Events::Timeout::gotOwner(Task * task) {
- m_evt.set(task->getTaskMan()->getLoop());
+ m_evt.set(task->getLoop());
m_evt.start();
}
@@ -93,5 +124,5 @@ void Balau::Events::Timeout::evt_cb(ev::timer & w, int revents) {
}
void Balau::Events::Custom::gotOwner(Task * task) {
- m_loop = task->getTaskMan()->getLoop();
+ m_loop = task->getLoop();
}
diff --git a/tests/test-Tasks.cc b/tests/test-Tasks.cc
index 1faa194..bdafa8d 100644
--- a/tests/test-Tasks.cc
+++ b/tests/test-Tasks.cc
@@ -22,6 +22,12 @@ class TestTask : public Task {
}
};
+static void yieldingFunction() {
+ Events::Timeout timeout(0.2);
+ Task::yield(&timeout);
+ Assert(timeout.gotSignal());
+}
+
void MainTask::Do() {
customPrinter = new CustomPrinter();
Printer::log(M_STATUS, "Test::Tasks running.");
@@ -39,6 +45,15 @@ void MainTask::Do() {
yield();
Assert(timeout.gotSignal());
+ timeout.set(0.1);
+ timeout.reset();
+ setPreemptible(false);
+ yieldingFunction();
+ Assert(!timeout.gotSignal());
+ waitFor(&timeout);
+ yield();
+ Assert(timeout.gotSignal());
+
Printer::log(M_STATUS, "Test::Tasks passed.");
Printer::log(M_DEBUG, "You shouldn't see that message.");
}