summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorNicolas "Pixel" Noble <pixel@nobis-crew.org>2013-12-21 18:32:27 -0800
committerNicolas "Pixel" Noble <pixel@nobis-crew.org>2013-12-21 18:32:27 -0800
commit9754372d5e4125bf5850d9cd3ae93d529efdef8d (patch)
treefc20e375256b95bbd13fecde0d85181100a198e4 /src
parent9697add8b75b96662c8d39477e58d5841c4b9cba (diff)
Preliminary WebSocket protocol support.
Diffstat (limited to 'src')
-rw-r--r--src/BWebSocket.cc166
-rw-r--r--src/Base64.cc121
-rw-r--r--src/HttpServer.cc29
-rw-r--r--src/SHA1.cc147
-rw-r--r--src/Selectable.cc7
-rw-r--r--src/Socket.cc1
6 files changed, 462 insertions, 9 deletions
diff --git a/src/BWebSocket.cc b/src/BWebSocket.cc
new file mode 100644
index 0000000..0f94728
--- /dev/null
+++ b/src/BWebSocket.cc
@@ -0,0 +1,166 @@
+#include "BWebSocket.h"
+#include "SHA1.h"
+#include "Base64.h"
+#include "TaskMan.h"
+
+#define rotate(value) (((value) << 8) | ((value) >> 24))
+
+void Balau::WebSocketWorker::Do() {
+ uint8_t c;
+
+ try {
+ while (!m_socket->isClosed()) {
+ switch (m_state) {
+ case READ_H:
+ c = m_socket->readU8().get();
+ m_fin = c & 0x80;
+ if ((c >> 4) & 7) goto error;
+ c &= 15;
+ if (!m_firstFragment && c) goto error;
+ if (m_firstFragment)
+ m_opcode = c;
+ m_state = READ_PLB;
+ case READ_PLB:
+ c = m_socket->readU8().get();
+ m_hasMask = c & 0x80;
+ if (m_enforceServer && !m_hasMask)
+ goto error;
+ if (m_enforceClient && m_hasMask)
+ goto error;
+ m_payloadLen = c & 0x7f;
+ m_state = READ_PLL;
+ if (m_payloadLen == 126) {
+ m_payloadLen = 0;
+ m_remainingBytes = 2;
+ }
+ else if (m_payloadLen == 127) {
+ m_payloadLen = 0;
+ m_remainingBytes = 8;
+ }
+ else {
+ m_remainingBytes = 0;
+ }
+ case READ_PLL:
+ while (m_remainingBytes) {
+ c = m_socket->readU8().get();
+ m_payloadLen <<= 8;
+ m_payloadLen += c;
+ m_remainingBytes--;
+ }
+ m_state = READ_MK;
+ if (m_firstFragment)
+ m_totalLen = m_payloadLen;
+ else
+ m_totalLen += m_payloadLen;
+ if (m_hasMask) m_remainingBytes = 4;
+ case READ_MK:
+ while (m_remainingBytes) {
+ c = m_socket->readU8().get();
+ m_mask <<= 8;
+ m_mask += c;
+ m_remainingBytes--;
+ }
+ m_state = READ_PL;
+ m_remainingBytes = m_payloadLen;
+ if (m_totalLen >= MAX_WEBSOCKET_LIMIT)
+ goto error;
+ m_payload = (uint8_t *)realloc(m_payload, m_totalLen);
+ case READ_PL:
+ while (m_remainingBytes) {
+ int r = m_socket->read(m_payload + m_totalLen - m_remainingBytes, m_remainingBytes);
+ if (r < 0)
+ goto error;
+ m_remainingBytes -= r;
+ }
+
+ m_firstFragment = m_fin;
+
+ if (m_fin) {
+ if (m_hasMask) {
+ for (int i = 0; i < m_totalLen; i++) {
+ m_payload[i] ^= m_mask >> 24;
+ m_mask = rotate(m_mask);
+ }
+ }
+ processMessage();
+ }
+ }
+ }
+
+ error:
+ m_socket->close();
+ }
+ catch (Balau::EAgain & e) {
+ taskSwitch();
+ }
+}
+
+void Balau::WebSocketWorker::processMessage() {
+
+}
+
+void Balau::WebSocketServerBase::sendError(IO<Handle> out, const char * serverName) {
+ const char * status = Http::getStatusMsg(400);
+ String errorMsg;
+ errorMsg.set(
+"HTTP/1.0 400 %s\r\n"
+"Content-Type: text/plain; charset=UTF-8\r\n"
+"Connection: close\r\n"
+"Server: %s\r\n"
+"\r\n"
+"400 - %s",
+ status, serverName, status);
+ out->writeString(errorMsg);
+}
+
+static const Balau::String magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+
+bool Balau::WebSocketServerBase::Do(HttpServer * server, Http::Request & req, HttpServer::Action::ActionMatch & match, IO<Handle> out) throw (GeneralException) {
+ WebSocketWorker * worker = NULL;
+
+ if (!req.upgrade)
+ goto error;
+
+ if (req.headers["Upgrade"] != "websocket")
+ goto error;
+
+ if (req.headers["Sec-WebSocket-Key"] == "")
+ goto error;
+
+ worker = spawnWorker(out, req.uri);
+ if (!worker->parse(req))
+ goto error;
+
+ TaskMan::registerTask(worker);
+ {
+ HttpServer::Response response(server, req, out);
+
+ String & key = req.headers["Sec-WebSocket-Key"];
+ uint8_t * toHash = (uint8_t *)alloca(key.strlen() + magic.strlen());
+ memcpy(toHash, key.to_charp(), key.strlen());
+ memcpy(toHash + key.strlen(), magic.to_charp(), magic.strlen());
+
+ SHA1 h;
+ uint8_t digest[20];
+ h.update(toHash, key.strlen() + magic.strlen());
+ h.final(digest);
+
+ String accept = Base64::encode(digest, 20);
+
+ response.SetResponseCode(101);
+ response.AddHeader("Upgrade: websocket");
+ response.AddHeader("Connection: Upgrade");
+ response.AddHeader("Sec-WebSocket-Accept", accept);
+ response.AddHeader("Sec-WebSocket-Version: 13");
+ response.SetContentType("");
+ response.Flush();
+ }
+
+ return false;
+
+error:
+ if (worker)
+ TaskMan::registerTask(worker);
+ sendError(out, server->getServerName().to_charp());
+ return false;
+}
diff --git a/src/Base64.cc b/src/Base64.cc
new file mode 100644
index 0000000..3b638c6
--- /dev/null
+++ b/src/Base64.cc
@@ -0,0 +1,121 @@
+#include <functional>
+#include "Base64.h"
+
+static char cb64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+static char lookup[] = {
+// x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 xA xB xC xD xE xF
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 0x
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 1x
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, // 2x
+ 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, // 3x
+ -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // 4x
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1, // 5x
+ -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, // 6x
+ 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, // 7x
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 8x
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 9x
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Ax
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Bx
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Cx
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Dx
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Ex
+ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Fx
+};
+
+const double Balau::Base64::ratio = 4 / 3;
+
+void Balau::Base64::encode_block(unsigned char in_tab[3], int len, char out[5]) {
+ out[0] = cb64[in_tab[0] >> 2];
+ out[1] = cb64[((in_tab[0] & 3) << 4) | ((in_tab[1] & 240) >> 4)];
+ out[2] = len > 1 ? cb64[((in_tab[1] & 15) << 2) | ((in_tab[2] & 192) >> 6)] : '=';
+ out[3] = len > 2 ? cb64[in_tab[2] & 63] : '=';
+ out[4] = 0;
+}
+
+Balau::String Balau::Base64::encode(const uint8_t * data, int stream_size) {
+ String encoded;
+ encoded.reserve(stream_size * ratio + 1);
+ unsigned char in_tab[3];
+ int len, i, s_pos;
+
+ s_pos = 0;
+
+ while (stream_size > 0) {
+ in_tab[0] = 0;
+ in_tab[1] = 0;
+ in_tab[2] = 0;
+
+ len = stream_size >= 3 ? 3 : stream_size;
+
+ for (i = 0; i < len; i++) {
+ in_tab[i] = data[s_pos + i];
+ }
+
+ char block[5];
+ encode_block(in_tab, len, block);
+
+ encoded += block;
+
+ s_pos += 3;
+ stream_size -= 3;
+ }
+
+ return encoded;
+}
+
+int Balau::Base64::stri(char x) {
+ return lookup[(unsigned char) x];
+}
+
+int Balau::Base64::decode_block(char s1, char s2, char s3, char s4, unsigned char * out_tab) {
+ int len, sb1, sb2, sb3, sb4;
+
+ len = s3 == '=' ? 1 : s4 == '=' ? 2 : 3;
+ s3 = (s3 == '=') || (s3 == 0) ? 'A' : s3;
+ s4 = (s4 == '=') || (s4 == 0) ? 'A' : s4;
+
+ sb1 = stri(s1);
+ sb2 = stri(s2);
+ sb3 = stri(s3);
+ sb4 = stri(s4);
+
+ out_tab[0] = (sb1 << 2) | (sb2 >> 4);
+ out_tab[1] = ((sb2 << 4) & 255) | (sb3 >> 2);
+ out_tab[2] = ((sb3 << 6) & 240) | sb4;
+
+ return len;
+}
+
+int Balau::Base64::decode(const String & str_in, uint8_t * data_out) {
+ int s_len = str_in.strlen(), len = 0, i, t_len, idx;
+ char s1, s2, s3, s4;
+ unsigned char t_out[3];
+ unsigned char * out = (unsigned char *) malloc(s_len * 3 / 4 + 4);
+ unsigned char * p = out;
+ std::function<char()> readNext = [&]() {
+ char r = '=';
+
+ if (idx >= s_len)
+ return r;
+
+ do {
+ r = str_in[idx++];
+ } while (r == '\r' || r == '\n' || r == ' ' || r == '\t');
+
+ return r;
+ };
+
+ for (idx = 0; idx < s_len;) {
+ s1 = readNext();
+ s2 = readNext();
+ s3 = readNext();
+ s4 = readNext();
+ t_len = decode_block(s1, s2, s3, s4, t_out);
+
+ for (i = 0; i < t_len; i++) *(p++) = t_out[i];
+
+ len += t_len;
+ }
+
+ return len;
+}
diff --git a/src/HttpServer.cc b/src/HttpServer.cc
index 4a3e086..8eea5bb 100644
--- a/src/HttpServer.cc
+++ b/src/HttpServer.cc
@@ -18,6 +18,7 @@ class OutputCheck : public Balau::Handle {
virtual bool isClosed() { return m_h->isClosed(); }
virtual bool isEOF() { return m_h->isEOF(); }
virtual bool canWrite() { return true; }
+ virtual bool canRead() { return m_h->canRead(); }
virtual const char * getName() { return m_name.to_charp(); }
virtual ssize_t write(const void * buf, size_t count) throw (Balau::GeneralException) {
if (!count)
@@ -25,6 +26,9 @@ class OutputCheck : public Balau::Handle {
m_wrote = true;
return m_h->write(buf, count);
}
+ virtual ssize_t read(void * buf, size_t count) throw (Balau::GeneralException) {
+ return m_h->read(buf, count);
+ }
bool wrote() { return m_wrote; }
private:
Balau::IO<Balau::Handle> m_h;
@@ -131,7 +135,7 @@ const Balau::String SetDefaultTemplateTask::m_defaultErrorTemplate(
};
-Balau::HttpWorker::HttpWorker(IO<Handle> io, void * _server) : m_socket(new WriteOnly(io)), m_strm(new BStream(io)) {
+Balau::HttpWorker::HttpWorker(IO<Handle> io, void * _server) : m_socket(io), m_strm(new BStream(io)) {
m_server = (HttpServer *) _server;
m_name.set("HttpWorker(%s)", m_socket->getName());
// get stuff from server, such as port number, root document, base URL, default 400/404 actions, etc...
@@ -258,6 +262,7 @@ bool Balau::HttpWorker::handleClient() {
Http::StringMap variables;
Http::FileList files;
bool persistent = false;
+ bool upgrade = false;
// read client's request
do {
@@ -439,6 +444,9 @@ bool Balau::HttpWorker::handleClient() {
persistent = true;
} else if (t == "TE") {
Printer::elog(E_HTTPSERVER, "%s got the 'TE' connection marker (which is still unknown)", m_name.to_charp());
+ } else if (t == "Upgrade") {
+ upgrade = true;
+ persistent = true;
} else {
Printer::elog(E_HTTPSERVER, "%s has an improper Connection HTTP header (%s)", m_name.to_charp(), t.to_charp());
send400();
@@ -551,6 +559,7 @@ bool Balau::HttpWorker::handleClient() {
auto f = m_server->findAction(uri.to_charp(), host.to_charp());
if (f.action) {
+ m_strm->detach();
IO<OutputCheck> out(new OutputCheck(m_socket));
Http::Request req;
req.method = method;
@@ -560,6 +569,7 @@ bool Balau::HttpWorker::handleClient() {
req.headers = httpHeaders;
req.files = files;
req.persistent = persistent;
+ req.upgrade = upgrade;
req.version = httpVersion;
try {
if (!f.action->Do(m_server, req, f.matches, out))
@@ -684,17 +694,20 @@ void Balau::HttpServer::Response::Flush() {
headers->writeString(response);
headers->writeString(" ");
headers->writeString(Http::getStatusMsg(m_responseCode));
- headers->writeString("\r\nContent-Type: ");
- headers->writeString(m_type);
- headers->writeString("\r\nContent-Length: ");
- String len(m_buffer->getSize());
- headers->writeString(len);
+ if (m_type != "") {
+ headers->writeString("\r\nContent-Type: ");
+ headers->writeString(m_type);
+ }
+ if (!m_noSize) {
+ headers->writeString("\r\nContent-Length: ");
+ String len(m_buffer->getSize());
+ headers->writeString(len);
+ }
headers->writeString("\r\nServer: ");
headers->writeString(m_server->getServerName());
headers->writeString("\r\n");
- if ((m_req.version == "1.1") && !m_req.persistent) {
+ if ((m_req.version == "1.1") && !m_req.persistent)
headers->writeString("Connection: close\r\n");
- }
while (!m_extraHeaders.empty()) {
String s = m_extraHeaders.front();
diff --git a/src/SHA1.cc b/src/SHA1.cc
new file mode 100644
index 0000000..21bf018
--- /dev/null
+++ b/src/SHA1.cc
@@ -0,0 +1,147 @@
+// based of public domain implementation found here: http://svn.ghostscript.com/jbig2dec/trunk/sha1.c
+
+/*
+Test Vectors (from FIPS PUB 180-1)
+"abc"
+ A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D
+"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"
+ 84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1
+A million repetitions of "a"
+ 34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F
+*/
+
+#include "SHA1.h"
+
+#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits))))
+
+/* blk0() and blk() perform the initial expand. */
+/* I got the idea of expanding during the round function from SSLeay */
+/* FIXME: can we do this in an endian-proof way? */
+#ifdef WORDS_BIGENDIAN
+#define blk0(i) block->l[i]
+#else
+#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \
+ |(rol(block->l[i],8)&0x00FF00FF))
+#endif
+#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \
+ ^block->l[(i+2)&15]^block->l[i&15],1))
+
+/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */
+#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30);
+#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30);
+#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30);
+#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30);
+#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30);
+
+/* Hash a single 512-bit block. This is the core of the algorithm. */
+void Balau::SHA1::transform(uint32_t state[5], const uint8_t buffer[64])
+{
+ uint32_t a, b, c, d, e;
+ typedef union {
+ uint8_t c[64];
+ uint32_t l[16];
+ } CHAR64LONG16;
+ CHAR64LONG16* block;
+
+ block = (CHAR64LONG16*)buffer;
+
+ /* Copy m_state[] to working vars */
+ a = state[0];
+ b = state[1];
+ c = state[2];
+ d = state[3];
+ e = state[4];
+
+ /* 4 rounds of 20 operations each. Loop unrolled. */
+ R0(a,b,c,d,e, 0); R0(e,a,b,c,d, 1); R0(d,e,a,b,c, 2); R0(c,d,e,a,b, 3);
+ R0(b,c,d,e,a, 4); R0(a,b,c,d,e, 5); R0(e,a,b,c,d, 6); R0(d,e,a,b,c, 7);
+ R0(c,d,e,a,b, 8); R0(b,c,d,e,a, 9); R0(a,b,c,d,e,10); R0(e,a,b,c,d,11);
+ R0(d,e,a,b,c,12); R0(c,d,e,a,b,13); R0(b,c,d,e,a,14); R0(a,b,c,d,e,15);
+ R1(e,a,b,c,d,16); R1(d,e,a,b,c,17); R1(c,d,e,a,b,18); R1(b,c,d,e,a,19);
+ R2(a,b,c,d,e,20); R2(e,a,b,c,d,21); R2(d,e,a,b,c,22); R2(c,d,e,a,b,23);
+ R2(b,c,d,e,a,24); R2(a,b,c,d,e,25); R2(e,a,b,c,d,26); R2(d,e,a,b,c,27);
+ R2(c,d,e,a,b,28); R2(b,c,d,e,a,29); R2(a,b,c,d,e,30); R2(e,a,b,c,d,31);
+ R2(d,e,a,b,c,32); R2(c,d,e,a,b,33); R2(b,c,d,e,a,34); R2(a,b,c,d,e,35);
+ R2(e,a,b,c,d,36); R2(d,e,a,b,c,37); R2(c,d,e,a,b,38); R2(b,c,d,e,a,39);
+ R3(a,b,c,d,e,40); R3(e,a,b,c,d,41); R3(d,e,a,b,c,42); R3(c,d,e,a,b,43);
+ R3(b,c,d,e,a,44); R3(a,b,c,d,e,45); R3(e,a,b,c,d,46); R3(d,e,a,b,c,47);
+ R3(c,d,e,a,b,48); R3(b,c,d,e,a,49); R3(a,b,c,d,e,50); R3(e,a,b,c,d,51);
+ R3(d,e,a,b,c,52); R3(c,d,e,a,b,53); R3(b,c,d,e,a,54); R3(a,b,c,d,e,55);
+ R3(e,a,b,c,d,56); R3(d,e,a,b,c,57); R3(c,d,e,a,b,58); R3(b,c,d,e,a,59);
+ R4(a,b,c,d,e,60); R4(e,a,b,c,d,61); R4(d,e,a,b,c,62); R4(c,d,e,a,b,63);
+ R4(b,c,d,e,a,64); R4(a,b,c,d,e,65); R4(e,a,b,c,d,66); R4(d,e,a,b,c,67);
+ R4(c,d,e,a,b,68); R4(b,c,d,e,a,69); R4(a,b,c,d,e,70); R4(e,a,b,c,d,71);
+ R4(d,e,a,b,c,72); R4(c,d,e,a,b,73); R4(b,c,d,e,a,74); R4(a,b,c,d,e,75);
+ R4(e,a,b,c,d,76); R4(d,e,a,b,c,77); R4(c,d,e,a,b,78); R4(b,c,d,e,a,79);
+
+ /* Add the working vars back into context.state[] */
+ state[0] += a;
+ state[1] += b;
+ state[2] += c;
+ state[3] += d;
+ state[4] += e;
+
+ /* Wipe variables */
+ a = b = c = d = e = 0;
+}
+
+/* SHA1Init - Initialize new context */
+void Balau::SHA1::reset()
+{
+ /* SHA1 initialization constants */
+ m_state[0] = 0x67452301;
+ m_state[1] = 0xEFCDAB89;
+ m_state[2] = 0x98BADCFE;
+ m_state[3] = 0x10325476;
+ m_state[4] = 0xC3D2E1F0;
+ m_count[0] = m_count[1] = 0;
+}
+
+/* Run your data through this. */
+void Balau::SHA1::update(const uint8_t* data, const size_t len)
+{
+ size_t i, j;
+
+ j = (m_count[0] >> 3) & 63;
+ if ((m_count[0] += len << 3) < (len << 3)) m_count[1]++;
+ m_count[1] += (len >> 29);
+ if ((j + len) > 63) {
+ memcpy(&m_buffer[j], data, (i = 64-j));
+ transform(m_state, m_buffer);
+ for ( ; i + 63 < len; i += 64) {
+ transform(m_state, data + i);
+ }
+ j = 0;
+ }
+ else i = 0;
+ memcpy(&m_buffer[j], &data[i], len - i);
+
+}
+
+/* Add padding and return the message digest. */
+void Balau::SHA1::final(uint8_t * digest)
+{
+ uint32_t i;
+ uint8_t finalcount[8];
+
+ for (i = 0; i < 8; i++) {
+ finalcount[i] = (unsigned char)((m_count[(i >= 4 ? 0 : 1)]
+ >> ((3-(i & 3)) * 8) ) & 255); /* Endian independent */
+ }
+ update((uint8_t *)"\200", 1);
+ while ((m_count[0] & 504) != 448) {
+ update((uint8_t *)"\0", 1);
+ }
+ update(finalcount, 8); /* Should cause a SHA1_Transform() */
+ for (i = 0; i < DIGEST_SIZE; i++) {
+ digest[i] = (uint8_t)
+ ((m_state[i>>2] >> ((3-(i & 3)) * 8) ) & 255);
+ }
+
+ /* Wipe variables */
+ i = 0;
+ memset(m_buffer, 0, 64);
+ memset(m_state, 0, 20);
+ memset(m_count, 0, 8);
+ memset(finalcount, 0, 8); /* SWR */
+}
diff --git a/src/Selectable.cc b/src/Selectable.cc
index 798a448..4213d34 100644
--- a/src/Selectable.cc
+++ b/src/Selectable.cc
@@ -52,6 +52,8 @@ void Balau::Selectable::SelectableEvent::gotOwner(Task * task) {
} else {
Printer::elog(E_SELECT, "...with a new task (%p -> %p); stopping first", m_task, task);
m_evt.stop();
+ m_evt.set<SelectableEvent, &SelectableEvent::evt_cb>(this);
+ m_evt.set(m_fd, m_evtType);
}
m_task = task;
m_evt.set(task->getLoop());
@@ -96,6 +98,7 @@ ssize_t Balau::Selectable::read(void * buf, size_t count) throw (GeneralExceptio
ssize_t r = recv(getSocket(m_fd), (char *) buf, count, 0);
if (r >= 0) {
+ m_evtR->resetMaybe();
if (r == 0)
close();
return r;
@@ -137,8 +140,10 @@ ssize_t Balau::Selectable::write(const void * buf, size_t count) throw (GeneralE
EAssert(r != 0, "send() returned 0 (broken pipe ?)");
- if (r > 0)
+ if (r > 0) {
+ m_evtW->resetMaybe();
return r;
+ }
#ifndef _WIN32
int err = errno;
diff --git a/src/Socket.cc b/src/Socket.cc
index 7673c45..66dae18 100644
--- a/src/Socket.cc
+++ b/src/Socket.cc
@@ -522,6 +522,7 @@ Balau::IO<Balau::Socket> Balau::Socket::accept() throw (GeneralException) {
Task::operationYield(m_evtR, Task::INTERRUPTIBLE);
} else {
String msg = getErrorMessage();
+ m_evtR->stop();
throw GeneralException(String("Unexpected error accepting a connection: #") + errno + "(" + msg + ")");
}
} else {