diff options
-rw-r--r-- | includes/BString.h | 2 | ||||
-rw-r--r-- | includes/BWebSocket.h | 62 | ||||
-rw-r--r-- | includes/Base64.h | 19 | ||||
-rw-r--r-- | includes/Http.h | 1 | ||||
-rw-r--r-- | includes/HttpServer.h | 2 | ||||
-rw-r--r-- | includes/SHA1.h | 24 | ||||
-rw-r--r-- | includes/Selectable.h | 5 | ||||
-rw-r--r-- | includes/Task.h | 7 | ||||
-rw-r--r-- | src/BWebSocket.cc | 166 | ||||
-rw-r--r-- | src/Base64.cc | 121 | ||||
-rw-r--r-- | src/HttpServer.cc | 29 | ||||
-rw-r--r-- | src/SHA1.cc | 147 | ||||
-rw-r--r-- | src/Selectable.cc | 7 | ||||
-rw-r--r-- | src/Socket.cc | 1 |
14 files changed, 582 insertions, 11 deletions
diff --git a/includes/BString.h b/includes/BString.h index a4178a4..e783935 100644 --- a/includes/BString.h +++ b/includes/BString.h @@ -117,6 +117,8 @@ class String : private std::string { const char & operator[](ssize_t i) const { if (i < 0) i = strlen() + i; return at(i); } char & operator[](ssize_t i) { if (i < 0) i = strlen() + i; return at(i); } + + void reserve(size_t s) { std::string::reserve(s); } }; }; diff --git a/includes/BWebSocket.h b/includes/BWebSocket.h new file mode 100644 index 0000000..41e5a7d --- /dev/null +++ b/includes/BWebSocket.h @@ -0,0 +1,62 @@ +#pragma once + +#include <Task.h> +#include <StacklessTask.h> +#include <BStream.h> +#include <HttpServer.h> + +namespace Balau { + +class WebSocketActionBase; + +class WebSocketWorker : public StacklessTask { + public: + virtual bool parse(Http::Request & req) { return true; } + protected: + WebSocketWorker(IO<Handle> socket, const String & url) : m_socket(new BStream(socket)) { m_name = String("WebSocket:") + url + "/" + m_socket->getName(); } + ~WebSocketWorker() { free(m_payload); } + private: + void processMessage(); + const char * getName() const { return m_name.to_charp(); } + void Do(); + String m_name; + IO<BStream> m_socket; + enum { + READ_H, + READ_PLB, + READ_PLL, + READ_MK, + READ_PL, + } m_status = READ_H; + enum { MAX_WEBSOCKET_LIMIT = 4 * 1024 * 1024 }; + uint8_t * m_payload = NULL; + uint64_t m_payloadLen; + uint64_t m_totalLen; + uint64_t m_remainingBytes; + uint32_t m_mask; + uint8_t m_opcode; + bool m_hasMask; + bool m_fin; + bool m_firstFragment = true; + bool m_enforceServer = false; + bool m_enforceClient = false; + friend class WebSocketActionBase; +}; + +class WebSocketServerBase : public HttpServer::Action { + protected: + WebSocketServerBase(const Regex & regex) : Action(regex) { } + virtual WebSocketWorker * spawnWorker(IO<Handle> socket, const String & url) = 0; + private: + void sendError(IO<Handle> out, const char * serverName); + bool Do(HttpServer * server, Http::Request & req, HttpServer::Action::ActionMatch & match, IO<Handle> out) throw (GeneralException); +}; + +template<class T> +class WebSocketServer : public WebSocketServerBase { + protected: + WebSocketServer(const Regex & regex) : WebSocketServerBase(regex) { } + virtual WebSocketWorker * spawnWorker(IO<Handle> socket, const String & url) { return new T(socket, url); } +}; + +}; diff --git a/includes/Base64.h b/includes/Base64.h new file mode 100644 index 0000000..d039c73 --- /dev/null +++ b/includes/Base64.h @@ -0,0 +1,19 @@ +#pragma once + +#include <Exceptions.h> + +namespace Balau { + +class Base64 { +public: + static String encode(const uint8_t * data, int len); + static int decode(const String & str_in, uint8_t * data_out); + static const double ratio; + +private: + static void encode_block(unsigned char in_tab[3], int len, char out[4]); + static int stri(char); + static int decode_block(char s1, char s2, char s3, char s4, unsigned char * out_tab); +}; + +}; diff --git a/includes/Http.h b/includes/Http.h index 7c351a6..fae2ad9 100644 --- a/includes/Http.h +++ b/includes/Http.h @@ -27,6 +27,7 @@ struct Request { StringMap headers; FileList files; bool persistent; + bool upgrade; String version; }; diff --git a/includes/HttpServer.h b/includes/HttpServer.h index d044e92..683390c 100644 --- a/includes/HttpServer.h +++ b/includes/HttpServer.h @@ -24,6 +24,7 @@ class HttpServer { Response(HttpServer * server, Http::Request req, IO<Handle> out) : m_server(server), m_req(req), m_out(out), m_buffer(new Buffer()), m_responseCode(200), m_type("text/html; charset=UTF-8"), m_flushed(false) { } void SetResponseCode(int code) { m_responseCode = code; } void SetContentType(const String & type) { m_type = type; } + void setNoSize() { m_noSize = true; } IO<Buffer> get() { return m_buffer; } IO<Buffer> operator->() { return m_buffer; } void Flush(); @@ -39,6 +40,7 @@ class HttpServer { String m_type; std::list<String> m_extraHeaders; bool m_flushed; + bool m_noSize = false; Response(const Response &) = delete; Response & operator=(const Response &) = delete; diff --git a/includes/SHA1.h b/includes/SHA1.h new file mode 100644 index 0000000..f70f6b7 --- /dev/null +++ b/includes/SHA1.h @@ -0,0 +1,24 @@ +#pragma once + +#include <Exceptions.h> + +namespace Balau { + +class SHA1 { + public: + SHA1() { reset(); } + void reset(); + void update(const uint8_t* data, const size_t len); + void final(uint8_t * digest); + + enum { DIGEST_SIZE = 20 }; + + private: + void transform(uint32_t state[5], const uint8_t buffer[64]); + + uint32_t m_state[5]; + uint32_t m_count[2]; + uint8_t m_buffer[64]; +}; + +}; diff --git a/includes/Selectable.h b/includes/Selectable.h index b3b5b8c..8667415 100644 --- a/includes/Selectable.h +++ b/includes/Selectable.h @@ -24,14 +24,17 @@ class Selectable : public Handle { class SelectableEvent : public Events::BaseEvent { public: - SelectableEvent(int fd, int evt = ev::READ | ev::WRITE) : m_task(NULL) { Printer::elog(E_SELECT, "Got a new SelectableEvent at %p", this); m_evt.set<SelectableEvent, &SelectableEvent::evt_cb>(this); m_evt.set(fd, evt); } + SelectableEvent(int fd, int evt = ev::READ | ev::WRITE) : m_task(NULL), m_evtType(evt), m_fd(fd) { Printer::elog(E_SELECT, "Got a new SelectableEvent at %p", this); m_evt.set<SelectableEvent, &SelectableEvent::evt_cb>(this); m_evt.set(fd, evt); } virtual ~SelectableEvent() { Printer::elog(E_SELECT, "Destroying a SelectableEvent at %p", this); m_evt.stop(); } void stop() { Printer::elog(E_SELECT, "Stopping a SelectableEvent at %p", this); reset(); m_evt.stop(); } private: void evt_cb(ev::io & w, int revents) { Printer::elog(E_SELECT, "Got a libev callback on a SelectableEvent at %p", this); doSignal(); } virtual void gotOwner(Task * task); + virtual bool relaxed() { return true; } ev::io m_evt; + int m_evtType; + int m_fd; Task * m_task = NULL; }; diff --git a/includes/Task.h b/includes/Task.h index 1041d08..7521b9f 100644 --- a/includes/Task.h +++ b/includes/Task.h @@ -50,6 +50,10 @@ class BaseEvent { virtual ~BaseEvent() { if (m_cb) delete m_cb; } bool gotSignal() { return m_signal; } void doSignal(); + void resetMaybe() { + if (m_task) + reset(); + } void reset() { // could be potentially changed into a simple return AAssert(m_task != NULL, "Can't reset an event that doesn't have a task"); @@ -60,12 +64,13 @@ class BaseEvent { void registerOwner(Task * task) { if (m_task == task) return; - AAssert(m_task == NULL, "Can't register an event for another task"); + AAssert(m_task == NULL || relaxed(), "Can't register an event for another task"); m_task = task; gotOwner(task); } protected: virtual void gotOwner(Task * task) { } + virtual bool relaxed() { return false; } private: Callback * m_cb = NULL; bool m_signal = false; diff --git a/src/BWebSocket.cc b/src/BWebSocket.cc new file mode 100644 index 0000000..0f94728 --- /dev/null +++ b/src/BWebSocket.cc @@ -0,0 +1,166 @@ +#include "BWebSocket.h" +#include "SHA1.h" +#include "Base64.h" +#include "TaskMan.h" + +#define rotate(value) (((value) << 8) | ((value) >> 24)) + +void Balau::WebSocketWorker::Do() { + uint8_t c; + + try { + while (!m_socket->isClosed()) { + switch (m_state) { + case READ_H: + c = m_socket->readU8().get(); + m_fin = c & 0x80; + if ((c >> 4) & 7) goto error; + c &= 15; + if (!m_firstFragment && c) goto error; + if (m_firstFragment) + m_opcode = c; + m_state = READ_PLB; + case READ_PLB: + c = m_socket->readU8().get(); + m_hasMask = c & 0x80; + if (m_enforceServer && !m_hasMask) + goto error; + if (m_enforceClient && m_hasMask) + goto error; + m_payloadLen = c & 0x7f; + m_state = READ_PLL; + if (m_payloadLen == 126) { + m_payloadLen = 0; + m_remainingBytes = 2; + } + else if (m_payloadLen == 127) { + m_payloadLen = 0; + m_remainingBytes = 8; + } + else { + m_remainingBytes = 0; + } + case READ_PLL: + while (m_remainingBytes) { + c = m_socket->readU8().get(); + m_payloadLen <<= 8; + m_payloadLen += c; + m_remainingBytes--; + } + m_state = READ_MK; + if (m_firstFragment) + m_totalLen = m_payloadLen; + else + m_totalLen += m_payloadLen; + if (m_hasMask) m_remainingBytes = 4; + case READ_MK: + while (m_remainingBytes) { + c = m_socket->readU8().get(); + m_mask <<= 8; + m_mask += c; + m_remainingBytes--; + } + m_state = READ_PL; + m_remainingBytes = m_payloadLen; + if (m_totalLen >= MAX_WEBSOCKET_LIMIT) + goto error; + m_payload = (uint8_t *)realloc(m_payload, m_totalLen); + case READ_PL: + while (m_remainingBytes) { + int r = m_socket->read(m_payload + m_totalLen - m_remainingBytes, m_remainingBytes); + if (r < 0) + goto error; + m_remainingBytes -= r; + } + + m_firstFragment = m_fin; + + if (m_fin) { + if (m_hasMask) { + for (int i = 0; i < m_totalLen; i++) { + m_payload[i] ^= m_mask >> 24; + m_mask = rotate(m_mask); + } + } + processMessage(); + } + } + } + + error: + m_socket->close(); + } + catch (Balau::EAgain & e) { + taskSwitch(); + } +} + +void Balau::WebSocketWorker::processMessage() { + +} + +void Balau::WebSocketServerBase::sendError(IO<Handle> out, const char * serverName) { + const char * status = Http::getStatusMsg(400); + String errorMsg; + errorMsg.set( +"HTTP/1.0 400 %s\r\n" +"Content-Type: text/plain; charset=UTF-8\r\n" +"Connection: close\r\n" +"Server: %s\r\n" +"\r\n" +"400 - %s", + status, serverName, status); + out->writeString(errorMsg); +} + +static const Balau::String magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +bool Balau::WebSocketServerBase::Do(HttpServer * server, Http::Request & req, HttpServer::Action::ActionMatch & match, IO<Handle> out) throw (GeneralException) { + WebSocketWorker * worker = NULL; + + if (!req.upgrade) + goto error; + + if (req.headers["Upgrade"] != "websocket") + goto error; + + if (req.headers["Sec-WebSocket-Key"] == "") + goto error; + + worker = spawnWorker(out, req.uri); + if (!worker->parse(req)) + goto error; + + TaskMan::registerTask(worker); + { + HttpServer::Response response(server, req, out); + + String & key = req.headers["Sec-WebSocket-Key"]; + uint8_t * toHash = (uint8_t *)alloca(key.strlen() + magic.strlen()); + memcpy(toHash, key.to_charp(), key.strlen()); + memcpy(toHash + key.strlen(), magic.to_charp(), magic.strlen()); + + SHA1 h; + uint8_t digest[20]; + h.update(toHash, key.strlen() + magic.strlen()); + h.final(digest); + + String accept = Base64::encode(digest, 20); + + response.SetResponseCode(101); + response.AddHeader("Upgrade: websocket"); + response.AddHeader("Connection: Upgrade"); + response.AddHeader("Sec-WebSocket-Accept", accept); + response.AddHeader("Sec-WebSocket-Version: 13"); + response.SetContentType(""); + response.Flush(); + } + + return false; + +error: + if (worker) + TaskMan::registerTask(worker); + sendError(out, server->getServerName().to_charp()); + return false; +} diff --git a/src/Base64.cc b/src/Base64.cc new file mode 100644 index 0000000..3b638c6 --- /dev/null +++ b/src/Base64.cc @@ -0,0 +1,121 @@ +#include <functional> +#include "Base64.h" + +static char cb64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +static char lookup[] = { +// x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 xA xB xC xD xE xF + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 0x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 1x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, // 2x + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, // 3x + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // 4x + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1, // 5x + -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, // 6x + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, // 7x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 8x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 9x + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Ax + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Bx + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Cx + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Dx + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Ex + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Fx +}; + +const double Balau::Base64::ratio = 4 / 3; + +void Balau::Base64::encode_block(unsigned char in_tab[3], int len, char out[5]) { + out[0] = cb64[in_tab[0] >> 2]; + out[1] = cb64[((in_tab[0] & 3) << 4) | ((in_tab[1] & 240) >> 4)]; + out[2] = len > 1 ? cb64[((in_tab[1] & 15) << 2) | ((in_tab[2] & 192) >> 6)] : '='; + out[3] = len > 2 ? cb64[in_tab[2] & 63] : '='; + out[4] = 0; +} + +Balau::String Balau::Base64::encode(const uint8_t * data, int stream_size) { + String encoded; + encoded.reserve(stream_size * ratio + 1); + unsigned char in_tab[3]; + int len, i, s_pos; + + s_pos = 0; + + while (stream_size > 0) { + in_tab[0] = 0; + in_tab[1] = 0; + in_tab[2] = 0; + + len = stream_size >= 3 ? 3 : stream_size; + + for (i = 0; i < len; i++) { + in_tab[i] = data[s_pos + i]; + } + + char block[5]; + encode_block(in_tab, len, block); + + encoded += block; + + s_pos += 3; + stream_size -= 3; + } + + return encoded; +} + +int Balau::Base64::stri(char x) { + return lookup[(unsigned char) x]; +} + +int Balau::Base64::decode_block(char s1, char s2, char s3, char s4, unsigned char * out_tab) { + int len, sb1, sb2, sb3, sb4; + + len = s3 == '=' ? 1 : s4 == '=' ? 2 : 3; + s3 = (s3 == '=') || (s3 == 0) ? 'A' : s3; + s4 = (s4 == '=') || (s4 == 0) ? 'A' : s4; + + sb1 = stri(s1); + sb2 = stri(s2); + sb3 = stri(s3); + sb4 = stri(s4); + + out_tab[0] = (sb1 << 2) | (sb2 >> 4); + out_tab[1] = ((sb2 << 4) & 255) | (sb3 >> 2); + out_tab[2] = ((sb3 << 6) & 240) | sb4; + + return len; +} + +int Balau::Base64::decode(const String & str_in, uint8_t * data_out) { + int s_len = str_in.strlen(), len = 0, i, t_len, idx; + char s1, s2, s3, s4; + unsigned char t_out[3]; + unsigned char * out = (unsigned char *) malloc(s_len * 3 / 4 + 4); + unsigned char * p = out; + std::function<char()> readNext = [&]() { + char r = '='; + + if (idx >= s_len) + return r; + + do { + r = str_in[idx++]; + } while (r == '\r' || r == '\n' || r == ' ' || r == '\t'); + + return r; + }; + + for (idx = 0; idx < s_len;) { + s1 = readNext(); + s2 = readNext(); + s3 = readNext(); + s4 = readNext(); + t_len = decode_block(s1, s2, s3, s4, t_out); + + for (i = 0; i < t_len; i++) *(p++) = t_out[i]; + + len += t_len; + } + + return len; +} diff --git a/src/HttpServer.cc b/src/HttpServer.cc index 4a3e086..8eea5bb 100644 --- a/src/HttpServer.cc +++ b/src/HttpServer.cc @@ -18,6 +18,7 @@ class OutputCheck : public Balau::Handle { virtual bool isClosed() { return m_h->isClosed(); } virtual bool isEOF() { return m_h->isEOF(); } virtual bool canWrite() { return true; } + virtual bool canRead() { return m_h->canRead(); } virtual const char * getName() { return m_name.to_charp(); } virtual ssize_t write(const void * buf, size_t count) throw (Balau::GeneralException) { if (!count) @@ -25,6 +26,9 @@ class OutputCheck : public Balau::Handle { m_wrote = true; return m_h->write(buf, count); } + virtual ssize_t read(void * buf, size_t count) throw (Balau::GeneralException) { + return m_h->read(buf, count); + } bool wrote() { return m_wrote; } private: Balau::IO<Balau::Handle> m_h; @@ -131,7 +135,7 @@ const Balau::String SetDefaultTemplateTask::m_defaultErrorTemplate( }; -Balau::HttpWorker::HttpWorker(IO<Handle> io, void * _server) : m_socket(new WriteOnly(io)), m_strm(new BStream(io)) { +Balau::HttpWorker::HttpWorker(IO<Handle> io, void * _server) : m_socket(io), m_strm(new BStream(io)) { m_server = (HttpServer *) _server; m_name.set("HttpWorker(%s)", m_socket->getName()); // get stuff from server, such as port number, root document, base URL, default 400/404 actions, etc... @@ -258,6 +262,7 @@ bool Balau::HttpWorker::handleClient() { Http::StringMap variables; Http::FileList files; bool persistent = false; + bool upgrade = false; // read client's request do { @@ -439,6 +444,9 @@ bool Balau::HttpWorker::handleClient() { persistent = true; } else if (t == "TE") { Printer::elog(E_HTTPSERVER, "%s got the 'TE' connection marker (which is still unknown)", m_name.to_charp()); + } else if (t == "Upgrade") { + upgrade = true; + persistent = true; } else { Printer::elog(E_HTTPSERVER, "%s has an improper Connection HTTP header (%s)", m_name.to_charp(), t.to_charp()); send400(); @@ -551,6 +559,7 @@ bool Balau::HttpWorker::handleClient() { auto f = m_server->findAction(uri.to_charp(), host.to_charp()); if (f.action) { + m_strm->detach(); IO<OutputCheck> out(new OutputCheck(m_socket)); Http::Request req; req.method = method; @@ -560,6 +569,7 @@ bool Balau::HttpWorker::handleClient() { req.headers = httpHeaders; req.files = files; req.persistent = persistent; + req.upgrade = upgrade; req.version = httpVersion; try { if (!f.action->Do(m_server, req, f.matches, out)) @@ -684,17 +694,20 @@ void Balau::HttpServer::Response::Flush() { headers->writeString(response); headers->writeString(" "); headers->writeString(Http::getStatusMsg(m_responseCode)); - headers->writeString("\r\nContent-Type: "); - headers->writeString(m_type); - headers->writeString("\r\nContent-Length: "); - String len(m_buffer->getSize()); - headers->writeString(len); + if (m_type != "") { + headers->writeString("\r\nContent-Type: "); + headers->writeString(m_type); + } + if (!m_noSize) { + headers->writeString("\r\nContent-Length: "); + String len(m_buffer->getSize()); + headers->writeString(len); + } headers->writeString("\r\nServer: "); headers->writeString(m_server->getServerName()); headers->writeString("\r\n"); - if ((m_req.version == "1.1") && !m_req.persistent) { + if ((m_req.version == "1.1") && !m_req.persistent) headers->writeString("Connection: close\r\n"); - } while (!m_extraHeaders.empty()) { String s = m_extraHeaders.front(); diff --git a/src/SHA1.cc b/src/SHA1.cc new file mode 100644 index 0000000..21bf018 --- /dev/null +++ b/src/SHA1.cc @@ -0,0 +1,147 @@ +// based of public domain implementation found here: http://svn.ghostscript.com/jbig2dec/trunk/sha1.c + +/* +Test Vectors (from FIPS PUB 180-1) +"abc" + A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D +"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq" + 84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1 +A million repetitions of "a" + 34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F +*/ + +#include "SHA1.h" + +#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits)))) + +/* blk0() and blk() perform the initial expand. */ +/* I got the idea of expanding during the round function from SSLeay */ +/* FIXME: can we do this in an endian-proof way? */ +#ifdef WORDS_BIGENDIAN +#define blk0(i) block->l[i] +#else +#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \ + |(rol(block->l[i],8)&0x00FF00FF)) +#endif +#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \ + ^block->l[(i+2)&15]^block->l[i&15],1)) + +/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */ +#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30); +#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30); +#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30); + +/* Hash a single 512-bit block. This is the core of the algorithm. */ +void Balau::SHA1::transform(uint32_t state[5], const uint8_t buffer[64]) +{ + uint32_t a, b, c, d, e; + typedef union { + uint8_t c[64]; + uint32_t l[16]; + } CHAR64LONG16; + CHAR64LONG16* block; + + block = (CHAR64LONG16*)buffer; + + /* Copy m_state[] to working vars */ + a = state[0]; + b = state[1]; + c = state[2]; + d = state[3]; + e = state[4]; + + /* 4 rounds of 20 operations each. Loop unrolled. */ + R0(a,b,c,d,e, 0); R0(e,a,b,c,d, 1); R0(d,e,a,b,c, 2); R0(c,d,e,a,b, 3); + R0(b,c,d,e,a, 4); R0(a,b,c,d,e, 5); R0(e,a,b,c,d, 6); R0(d,e,a,b,c, 7); + R0(c,d,e,a,b, 8); R0(b,c,d,e,a, 9); R0(a,b,c,d,e,10); R0(e,a,b,c,d,11); + R0(d,e,a,b,c,12); R0(c,d,e,a,b,13); R0(b,c,d,e,a,14); R0(a,b,c,d,e,15); + R1(e,a,b,c,d,16); R1(d,e,a,b,c,17); R1(c,d,e,a,b,18); R1(b,c,d,e,a,19); + R2(a,b,c,d,e,20); R2(e,a,b,c,d,21); R2(d,e,a,b,c,22); R2(c,d,e,a,b,23); + R2(b,c,d,e,a,24); R2(a,b,c,d,e,25); R2(e,a,b,c,d,26); R2(d,e,a,b,c,27); + R2(c,d,e,a,b,28); R2(b,c,d,e,a,29); R2(a,b,c,d,e,30); R2(e,a,b,c,d,31); + R2(d,e,a,b,c,32); R2(c,d,e,a,b,33); R2(b,c,d,e,a,34); R2(a,b,c,d,e,35); + R2(e,a,b,c,d,36); R2(d,e,a,b,c,37); R2(c,d,e,a,b,38); R2(b,c,d,e,a,39); + R3(a,b,c,d,e,40); R3(e,a,b,c,d,41); R3(d,e,a,b,c,42); R3(c,d,e,a,b,43); + R3(b,c,d,e,a,44); R3(a,b,c,d,e,45); R3(e,a,b,c,d,46); R3(d,e,a,b,c,47); + R3(c,d,e,a,b,48); R3(b,c,d,e,a,49); R3(a,b,c,d,e,50); R3(e,a,b,c,d,51); + R3(d,e,a,b,c,52); R3(c,d,e,a,b,53); R3(b,c,d,e,a,54); R3(a,b,c,d,e,55); + R3(e,a,b,c,d,56); R3(d,e,a,b,c,57); R3(c,d,e,a,b,58); R3(b,c,d,e,a,59); + R4(a,b,c,d,e,60); R4(e,a,b,c,d,61); R4(d,e,a,b,c,62); R4(c,d,e,a,b,63); + R4(b,c,d,e,a,64); R4(a,b,c,d,e,65); R4(e,a,b,c,d,66); R4(d,e,a,b,c,67); + R4(c,d,e,a,b,68); R4(b,c,d,e,a,69); R4(a,b,c,d,e,70); R4(e,a,b,c,d,71); + R4(d,e,a,b,c,72); R4(c,d,e,a,b,73); R4(b,c,d,e,a,74); R4(a,b,c,d,e,75); + R4(e,a,b,c,d,76); R4(d,e,a,b,c,77); R4(c,d,e,a,b,78); R4(b,c,d,e,a,79); + + /* Add the working vars back into context.state[] */ + state[0] += a; + state[1] += b; + state[2] += c; + state[3] += d; + state[4] += e; + + /* Wipe variables */ + a = b = c = d = e = 0; +} + +/* SHA1Init - Initialize new context */ +void Balau::SHA1::reset() +{ + /* SHA1 initialization constants */ + m_state[0] = 0x67452301; + m_state[1] = 0xEFCDAB89; + m_state[2] = 0x98BADCFE; + m_state[3] = 0x10325476; + m_state[4] = 0xC3D2E1F0; + m_count[0] = m_count[1] = 0; +} + +/* Run your data through this. */ +void Balau::SHA1::update(const uint8_t* data, const size_t len) +{ + size_t i, j; + + j = (m_count[0] >> 3) & 63; + if ((m_count[0] += len << 3) < (len << 3)) m_count[1]++; + m_count[1] += (len >> 29); + if ((j + len) > 63) { + memcpy(&m_buffer[j], data, (i = 64-j)); + transform(m_state, m_buffer); + for ( ; i + 63 < len; i += 64) { + transform(m_state, data + i); + } + j = 0; + } + else i = 0; + memcpy(&m_buffer[j], &data[i], len - i); + +} + +/* Add padding and return the message digest. */ +void Balau::SHA1::final(uint8_t * digest) +{ + uint32_t i; + uint8_t finalcount[8]; + + for (i = 0; i < 8; i++) { + finalcount[i] = (unsigned char)((m_count[(i >= 4 ? 0 : 1)] + >> ((3-(i & 3)) * 8) ) & 255); /* Endian independent */ + } + update((uint8_t *)"\200", 1); + while ((m_count[0] & 504) != 448) { + update((uint8_t *)"\0", 1); + } + update(finalcount, 8); /* Should cause a SHA1_Transform() */ + for (i = 0; i < DIGEST_SIZE; i++) { + digest[i] = (uint8_t) + ((m_state[i>>2] >> ((3-(i & 3)) * 8) ) & 255); + } + + /* Wipe variables */ + i = 0; + memset(m_buffer, 0, 64); + memset(m_state, 0, 20); + memset(m_count, 0, 8); + memset(finalcount, 0, 8); /* SWR */ +} diff --git a/src/Selectable.cc b/src/Selectable.cc index 798a448..4213d34 100644 --- a/src/Selectable.cc +++ b/src/Selectable.cc @@ -52,6 +52,8 @@ void Balau::Selectable::SelectableEvent::gotOwner(Task * task) { } else { Printer::elog(E_SELECT, "...with a new task (%p -> %p); stopping first", m_task, task); m_evt.stop(); + m_evt.set<SelectableEvent, &SelectableEvent::evt_cb>(this); + m_evt.set(m_fd, m_evtType); } m_task = task; m_evt.set(task->getLoop()); @@ -96,6 +98,7 @@ ssize_t Balau::Selectable::read(void * buf, size_t count) throw (GeneralExceptio ssize_t r = recv(getSocket(m_fd), (char *) buf, count, 0); if (r >= 0) { + m_evtR->resetMaybe(); if (r == 0) close(); return r; @@ -137,8 +140,10 @@ ssize_t Balau::Selectable::write(const void * buf, size_t count) throw (GeneralE EAssert(r != 0, "send() returned 0 (broken pipe ?)"); - if (r > 0) + if (r > 0) { + m_evtW->resetMaybe(); return r; + } #ifndef _WIN32 int err = errno; diff --git a/src/Socket.cc b/src/Socket.cc index 7673c45..66dae18 100644 --- a/src/Socket.cc +++ b/src/Socket.cc @@ -522,6 +522,7 @@ Balau::IO<Balau::Socket> Balau::Socket::accept() throw (GeneralException) { Task::operationYield(m_evtR, Task::INTERRUPTIBLE); } else { String msg = getErrorMessage(); + m_evtR->stop(); throw GeneralException(String("Unexpected error accepting a connection: #") + errno + "(" + msg + ")"); } } else { |