summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--includes/Local.h6
-rw-r--r--includes/Main.h25
-rw-r--r--includes/Task.h44
-rw-r--r--includes/TaskMan.h17
-rw-r--r--src/Local.cc2
-rw-r--r--src/Main.cc2
-rw-r--r--src/Task.cc64
-rw-r--r--src/TaskMan.cc62
-rw-r--r--tests/test-Sanity.cc3
-rw-r--r--tests/test-String.cc3
-rw-r--r--tests/test-Tasks.cc14
11 files changed, 163 insertions, 79 deletions
diff --git a/includes/Local.h b/includes/Local.h
index e3d08a4..79862a6 100644
--- a/includes/Local.h
+++ b/includes/Local.h
@@ -11,7 +11,7 @@ class TLSManager {
void * createTLS();
};
-extern TLSManager * tlsManager;
+extern TLSManager * g_tlsManager;
class Local : public AtStart {
public:
@@ -27,8 +27,8 @@ class Local : public AtStart {
int getIndex() { return m_idx; }
private:
static void * create() { void * r = malloc(s_size * sizeof(void *)); return r; }
- static void * getTLS() { return tlsManager->getTLS(); }
- static void * setTLS(void * val) { return tlsManager->setTLS(val); }
+ static void * getTLS() { return g_tlsManager->getTLS(); }
+ static void * setTLS(void * val) { return g_tlsManager->setTLS(val); }
virtual void doStart();
int m_idx;
static int s_size;
diff --git a/includes/Main.h b/includes/Main.h
index eee7ff7..44764a3 100644
--- a/includes/Main.h
+++ b/includes/Main.h
@@ -37,9 +37,17 @@ class Exit : public GeneralException {
};
#include <Printer.h>
+#include <Task.h>
+#include <TaskMan.h>
namespace Balau {
+class MainTask : public Task {
+ public:
+ virtual const char * getName() { return "Main Task"; }
+ virtual void Do();
+};
+
class Main {
public:
enum Status {
@@ -49,11 +57,10 @@ class Main {
STOPPING,
STOPPED,
};
- Main() : m_status(UNKNOWN) { application = this; }
- virtual int startup() throw (GeneralException) = 0;
- static Status status() { return application->m_status; }
+ Main() : m_status(UNKNOWN) { Assert(s_application == 0); s_application = this; }
+ static Status status() { return s_application->m_status; }
int bootstrap(int _argc, char ** _argv) {
- int r;
+ int r = 0;
m_status = STARTING;
argc = _argc;
@@ -65,7 +72,8 @@ class Main {
try {
m_status = RUNNING;
- r = startup();
+ MainTask * mainTask = new MainTask();
+ TaskMan::getTaskMan()->mainLoop();
m_status = STOPPING;
}
catch (Exit e) {
@@ -96,7 +104,7 @@ class Main {
char ** enve;
private:
Status m_status;
- static Main * application;
+ static Main * s_application;
};
#define BALAU_STARTUP \
@@ -106,12 +114,11 @@ class Application : public Balau::Main { \
virtual int startup() throw (Balau::GeneralException); \
}; \
\
-static Application application; \
-\
extern "C" { \
int main(int argc, char ** argv) { \
setlocale(LC_ALL, ""); \
- return application.bootstrap(argc, argv); \
+ Balau::Main mainClass; \
+ return mainClass.bootstrap(argc, argv); \
} \
}
diff --git a/includes/Task.h b/includes/Task.h
index fb210a2..c2777fe 100644
--- a/includes/Task.h
+++ b/includes/Task.h
@@ -2,10 +2,37 @@
#include <stdlib.h>
#include <coro.h>
+#include <Exceptions.h>
+#include <vector>
namespace Balau {
class TaskMan;
+class Task;
+
+namespace Events {
+
+class BaseEvent {
+ public:
+ BaseEvent() : m_signal(false), m_task(NULL) { }
+ bool gotSignal() { return m_signal; }
+ void doSignal() { m_signal = true; }
+ Task * taskWaiting() { Assert(m_task); return m_task; }
+ void registerOwner(Task * task) { Assert(m_task == NULL); m_task = task; }
+ private:
+ bool m_signal;
+ Task * m_task;
+};
+
+class TaskEvent : public BaseEvent {
+ public:
+ TaskEvent(Task * taskWaited);
+ Task * taskWaited() { return m_taskWaited; }
+ private:
+ Task * m_taskWaited;
+};
+
+};
class Task {
public:
@@ -19,20 +46,25 @@ class Task {
Task();
virtual ~Task();
virtual const char * getName() = 0;
- Status getStatus() { return status; }
+ Status getStatus() { return m_status; }
+ static Task * getCurrentTask();
protected:
void suspend();
virtual void Do() = 0;
+ void waitFor(Events::BaseEvent * event);
private:
size_t stackSize() { return 128 * 1024; }
void switchTo();
static void coroutine(void *);
- void * stack;
- coro_context ctx;
- TaskMan * taskMan;
- Status status;
- void * tls;
+ void * m_stack;
+ coro_context m_ctx;
+ TaskMan * m_taskMan;
+ Status m_status;
+ void * m_tls;
friend class TaskMan;
+ friend class Events::TaskEvent;
+ typedef std::vector<Events::TaskEvent *> waitedByList_t;
+ waitedByList_t m_waitedBy;
};
};
diff --git a/includes/TaskMan.h b/includes/TaskMan.h
index ac95f71..585fb7f 100644
--- a/includes/TaskMan.h
+++ b/includes/TaskMan.h
@@ -1,8 +1,9 @@
#pragma once
+#include <stdint.h>
#include <coro.h>
#include <ext/hash_set>
-#include <stdint.h>
+#include <vector>
namespace gnu = __gnu_cxx;
@@ -15,18 +16,20 @@ class TaskMan {
TaskMan();
~TaskMan();
void mainLoop();
- void stop() { stopped = true; }
+ void stop() { m_stopped = true; }
static TaskMan * getTaskMan();
private:
void registerTask(Task * t);
void unregisterTask(Task * t);
- coro_context returnContext;
+ coro_context m_returnContext;
friend class Task;
- struct taskHash { size_t operator()(const Task * t) const { return reinterpret_cast<uintptr_t>(t); } };
- typedef gnu::hash_set<Task *, taskHash> taskList;
- taskList tasks, pendingAdd;
- volatile bool stopped;
+ struct taskHasher { size_t operator()(const Task * t) const { return reinterpret_cast<uintptr_t>(t); } };
+ typedef gnu::hash_set<Task *, taskHasher> taskHash_t;
+ typedef std::vector<Task *> taskList_t;
+ taskHash_t m_tasks;
+ taskList_t m_pendingAdd;
+ volatile bool m_stopped;
};
};
diff --git a/src/Local.cc b/src/Local.cc
index c354db9..1619557 100644
--- a/src/Local.cc
+++ b/src/Local.cc
@@ -18,7 +18,7 @@ void * Balau::TLSManager::createTLS() {
}
static Balau::TLSManager dummyTLSManager;
-Balau::TLSManager * Balau::tlsManager = &dummyTLSManager;
+Balau::TLSManager * Balau::g_tlsManager = &dummyTLSManager;
int Balau::Local::s_size = 0;
void ** Balau::Local::m_globals = 0;
diff --git a/src/Main.cc b/src/Main.cc
index 04d867f..eb5e589 100644
--- a/src/Main.cc
+++ b/src/Main.cc
@@ -31,4 +31,4 @@ Balau::AtExit::AtExit(int priority) : m_priority(priority) {
*ptr = this;
}
-Balau::Main * Balau::Main::application = NULL;
+Balau::Main * Balau::Main::s_application = NULL;
diff --git a/src/Task.cc b/src/Task.cc
index 9bdad0c..73a2329 100644
--- a/src/Task.cc
+++ b/src/Task.cc
@@ -4,50 +4,72 @@
#include "Printer.h"
#include "Local.h"
+static Balau::LocalTmpl<Balau::Task> localTask;
+
Balau::Task::Task() {
size_t size = stackSize();
- stack = malloc(size);
- coro_create(&ctx, coroutine, this, stack, size);
- taskMan = TaskMan::getTaskMan();
- taskMan->registerTask(this);
- tls = tlsManager->createTLS();
- status = STARTING;
+ m_stack = malloc(size);
+ coro_create(&m_ctx, coroutine, this, m_stack, size);
+
+ m_taskMan = TaskMan::getTaskMan();
+ m_taskMan->registerTask(this);
+
+ m_tls = g_tlsManager->createTLS();
+ void * oldTLS = g_tlsManager->getTLS();
+ g_tlsManager->setTLS(m_tls);
+ localTask.set(this);
+ g_tlsManager->setTLS(oldTLS);
+
+ m_status = STARTING;
}
Balau::Task::~Task() {
- free(stack);
- free(tls);
+ free(m_stack);
+ free(m_tls);
}
void Balau::Task::coroutine(void * arg) {
Task * task = reinterpret_cast<Task *>(arg);
Assert(task);
try {
- task->status = RUNNING;
+ task->m_status = RUNNING;
task->Do();
- task->status = STOPPED;
+ task->m_status = STOPPED;
}
catch (GeneralException & e) {
Printer::log(M_WARNING, "Task %s caused an exception: `%s' - stopping.", task->getName(), e.getMsg());
- task->status = FAULTED;
+ task->m_status = FAULTED;
}
catch (...) {
Printer::log(M_WARNING, "Task %s caused an unknown exception - stopping.", task->getName());
- task->status = FAULTED;
+ task->m_status = FAULTED;
}
- coro_transfer(&task->ctx, &task->taskMan->returnContext);
+ coro_transfer(&task->m_ctx, &task->m_taskMan->m_returnContext);
}
void Balau::Task::switchTo() {
- status = RUNNING;
- void * oldTLS = tlsManager->getTLS();
- tlsManager->setTLS(tls);
- coro_transfer(&taskMan->returnContext, &ctx);
- tlsManager->setTLS(oldTLS);
- if (status == RUNNING)
- status = IDLE;
+ m_status = RUNNING;
+ void * oldTLS = g_tlsManager->getTLS();
+ g_tlsManager->setTLS(m_tls);
+ coro_transfer(&m_taskMan->m_returnContext, &m_ctx);
+ g_tlsManager->setTLS(oldTLS);
+ if (m_status == RUNNING)
+ m_status = IDLE;
}
void Balau::Task::suspend() {
- coro_transfer(&ctx, &taskMan->returnContext);
+ coro_transfer(&m_ctx, &m_taskMan->m_returnContext);
+}
+
+Balau::Task * Balau::Task::getCurrentTask() {
+ return localTask.get();
+}
+
+void Balau::Task::waitFor(Balau::Events::BaseEvent * e) {
+ e->registerOwner(this);
+ // probably have to register the event in the Task manager
+}
+
+Balau::Events::TaskEvent::TaskEvent(Task * taskWaited) : m_taskWaited(taskWaited) {
+ m_taskWaited->m_waitedBy.push_back(this);
}
diff --git a/src/TaskMan.cc b/src/TaskMan.cc
index bbaf35c..6730f13 100644
--- a/src/TaskMan.cc
+++ b/src/TaskMan.cc
@@ -6,8 +6,8 @@
static Balau::DefaultTmpl<Balau::TaskMan> defaultTaskMan(50);
static Balau::LocalTmpl<Balau::TaskMan> localTaskMan;
-Balau::TaskMan::TaskMan() : stopped(false) {
- coro_create(&returnContext, 0, 0, 0, 0);
+Balau::TaskMan::TaskMan() : m_stopped(false) {
+ coro_create(&m_returnContext, 0, 0, 0, 0);
if (!localTaskMan.getGlobal())
localTaskMan.setGlobal(this);
}
@@ -21,21 +21,13 @@ Balau::TaskMan::~TaskMan() {
void Balau::TaskMan::mainLoop() {
// We need at least one round before bailing :)
do {
- taskList::iterator i;
+ taskList_t::iterator iL;
+ taskHash_t::iterator iH;
Task * t;
- // lock pending
- // Adding tasks that were added, maybe from other threads
- for (i = pendingAdd.begin(); i != pendingAdd.end(); i++) {
- Assert(tasks.find(*i) == tasks.end());
- tasks.insert(*i);
- }
- pendingAdd.clear();
- // unlock pending
-
// checking "STARTING" tasks, and running them once
- for (i = tasks.begin(); i != tasks.end(); i++) {
- t = *i;
+ for (iH = m_tasks.begin(); iH != m_tasks.end(); iH++) {
+ t = *iH;
if (t->getStatus() == Task::STARTING) {
t->switchTo();
}
@@ -43,25 +35,53 @@ void Balau::TaskMan::mainLoop() {
// That's probably where we poll for events
- // checking "STOPPED" tasks, and destroying them
+ // lock pending
+ // Adding tasks that were added, maybe from other threads
+ for (iL = m_pendingAdd.begin(); iL != m_pendingAdd.end(); iL++) {
+ t = *iL;
+ Assert(m_tasks.find(t) == m_tasks.end());
+ m_tasks.insert(t);
+ }
+ m_pendingAdd.clear();
+ // unlock pending
+
+ // Dealing with stopped and faulted tasks.
+ // First by signalling the waiters.
+ for (iH = m_tasks.begin(); iH != m_tasks.end(); iH++) {
+ t = *iH;
+ if (((t->getStatus() == Task::STOPPED) || (t->getStatus() == Task::FAULTED)) &&
+ (t->m_waitedBy.size() != 0)) {
+ Task::waitedByList_t::iterator i;
+ while ((i = t->m_waitedBy.begin()) != t->m_waitedBy.end()) {
+ Events::TaskEvent * e = *i;
+ e->doSignal();
+ e->taskWaiting()->switchTo();
+ t->m_waitedBy.erase(i);
+ }
+ }
+ }
+
+ // Then, by destroying them.
bool didDelete;
do {
didDelete = false;
- for (i = tasks.begin(); i != tasks.end(); i++) {
- t = *i;
- if ((t->getStatus() == Task::STOPPED) || (t->getStatus() == Task::FAULTED)) {
+ for (iH = m_tasks.begin(); iH != m_tasks.end(); iH++) {
+ t = *iH;
+ if (((t->getStatus() == Task::STOPPED) || (t->getStatus() == Task::FAULTED)) &&
+ (t->m_waitedBy.size() == 0)) {
delete t;
- tasks.erase(i);
+ m_tasks.erase(iH);
didDelete = true;
break;
}
}
} while (didDelete);
- } while (!stopped && tasks.size() != 0);
+
+ } while (!m_stopped && m_tasks.size() != 0);
}
void Balau::TaskMan::registerTask(Balau::Task * t) {
// lock pending
- pendingAdd.insert(t);
+ m_pendingAdd.push_back(t);
// unlock pending
}
diff --git a/tests/test-Sanity.cc b/tests/test-Sanity.cc
index 1495f76..52a85cc 100644
--- a/tests/test-Sanity.cc
+++ b/tests/test-Sanity.cc
@@ -4,11 +4,10 @@ BALAU_STARTUP;
using namespace Balau;
-int Application::startup() throw (Balau::GeneralException) {
+void MainTask::Do() {
Printer::log(M_STATUS, "Test::Sanity running.");
Assert(sizeof(off_t) == 8);
Printer::log(M_STATUS, "Test::Sanity passed.");
- return 0;
}
diff --git a/tests/test-String.cc b/tests/test-String.cc
index 36c87e4..8e6eacb 100644
--- a/tests/test-String.cc
+++ b/tests/test-String.cc
@@ -5,7 +5,7 @@ BALAU_STARTUP;
using namespace Balau;
-int Application::startup() throw (Balau::GeneralException) {
+void MainTask::Do() {
Printer::log(M_STATUS, "Test::String running.");
String x = "foobar";
@@ -46,5 +46,4 @@ int Application::startup() throw (Balau::GeneralException) {
Assert(((unsigned char) y[0]) == 0xe9);
Printer::log(M_STATUS, "Test::String passed.");
- return 0;
}
diff --git a/tests/test-Tasks.cc b/tests/test-Tasks.cc
index 2eedb47..dcc692c 100644
--- a/tests/test-Tasks.cc
+++ b/tests/test-Tasks.cc
@@ -11,23 +11,25 @@ class CustomPrinter : public Printer {
static CustomPrinter * customPrinter = NULL;
-class MainTask : public Task {
+class TestTask : public Task {
public:
virtual const char * getName() { return "MainTask"; }
private:
virtual void Do() {
+ Printer::log(M_STATUS, "xyz");
customPrinter->setLocal();
Printer::enable(M_ALL);
- Printer::log(M_DEBUG, "In MainTask::Do()");
+ Printer::log(M_DEBUG, "In TestTask::Do()");
}
};
-int Application::startup() throw (Balau::GeneralException) {
+void MainTask::Do() {
customPrinter = new CustomPrinter();
Printer::log(M_STATUS, "Test::Tasks running.");
- Task * mainTask = new MainTask();
- TaskMan::getTaskMan()->mainLoop();
+ Task * testTask = new TestTask();
+ Events::TaskEvent e(testTask);
+ waitFor(&e);
+ suspend();
Printer::log(M_STATUS, "Test::Tasks passed.");
Printer::log(M_DEBUG, "You shouldn't see that message.");
- return 0;
}