summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--includes/BWebSocket.h11
-rw-r--r--src/BWebSocket.cc163
2 files changed, 123 insertions, 51 deletions
diff --git a/includes/BWebSocket.h b/includes/BWebSocket.h
index 0d0a56f..5714053 100644
--- a/includes/BWebSocket.h
+++ b/includes/BWebSocket.h
@@ -61,11 +61,22 @@ private:
uint64_t m_remainingBytes;
uint32_t m_mask;
uint8_t m_opcode;
+
+ uint8_t * m_payloadCTRL = NULL;
+ uint64_t m_payloadLenCTRL;
+ uint64_t m_totalLenCTRL;
+ uint64_t m_remainingBytesCTRL;
+ uint32_t m_maskCTRL;
+ uint8_t m_opcodeCTRL;
+
bool m_hasMask;
+ bool m_hasMaskCTRL;
+
bool m_fin;
bool m_firstFragment = true;
bool m_enforceServer = false;
bool m_enforceClient = false;
+ bool m_inCTRL = false;
enum {
OPCODE_CONT = 0,
OPCODE_TEXT = 1,
diff --git a/src/BWebSocket.cc b/src/BWebSocket.cc
index 3dded95..9a3b420 100644
--- a/src/BWebSocket.cc
+++ b/src/BWebSocket.cc
@@ -79,6 +79,7 @@ void Balau::WebSocketFrame::send(Balau::IO<Balau::Handle> socket) {
Balau::WebSocketWorker::~WebSocketWorker() {
free(m_payload);
+ free(m_payloadCTRL);
delete m_sending;
while (!m_sendQueue.isEmpty())
delete m_sendQueue.pop();
@@ -87,6 +88,37 @@ Balau::WebSocketWorker::~WebSocketWorker() {
void Balau::WebSocketWorker::Do() {
uint8_t c;
+ uint8_t ** payloadP;
+ uint64_t * payloadLenP;
+ uint64_t * totalLenP;
+ uint64_t * remainingBytesP;
+ uint32_t * maskP;
+ uint8_t * opcodeP;
+ bool * hasMaskP;
+ int fin;
+
+ std::function<void()> switchPacketType = [&]() {
+ if (m_inCTRL) {
+ payloadP = &m_payloadCTRL;
+ payloadLenP = &m_payloadLenCTRL;
+ totalLenP = &m_totalLenCTRL;
+ remainingBytesP = &m_remainingBytesCTRL;
+ maskP = &m_maskCTRL;
+ opcodeP = &m_opcodeCTRL;
+ hasMaskP = &m_hasMaskCTRL;
+ } else {
+ payloadP = &m_payload;
+ payloadLenP = &m_payloadLen;
+ totalLenP = &m_totalLen;
+ remainingBytesP = &m_remainingBytes;
+ maskP = &m_mask;
+ opcodeP = &m_opcode;
+ hasMaskP = &m_hasMask;
+ }
+ };
+
+ switchPacketType();
+
waitFor(m_sendQueue.getEvent());
m_sendQueue.getEvent()->resetMaybe();
@@ -100,7 +132,8 @@ void Balau::WebSocketWorker::Do() {
m_sending = m_sendQueue.pop();
else
break;
- if (m_socket->isClosed()) return;
+ if (m_socket->isClosed())
+ return;
}
delete m_sending;
@@ -109,77 +142,105 @@ void Balau::WebSocketWorker::Do() {
switch (m_state) {
case READ_H:
m_socket->read(&c, 1);
- if (m_socket->isClosed()) return;
- m_fin = c & 0x80;
- if ((c >> 4) & 7) goto error;
+ if (m_socket->isClosed())
+ return;
+ fin = c & 0x80;
+ if ((c >> 4) & 7)
+ goto error;
c &= 15;
- if (!m_firstFragment && c) goto error;
- if (m_firstFragment)
- m_opcode = c;
+ if (!m_firstFragment && c)
+ goto error;
+ if (m_firstFragment) {
+ bool wasInCtrl = m_inCTRL;
+ m_inCTRL = c & 8;
+ if (wasInCtrl != m_inCTRL)
+ switchPacketType();
+ *opcodeP = c;
+ } else {
+ bool wasInCtrl = m_inCTRL;
+ m_inCTRL = false;
+ if (wasInCtrl != m_inCTRL)
+ switchPacketType();
+ }
+ if (!m_inCTRL)
+ m_fin = fin;
+ else if (!fin)
+ goto error;
m_state = READ_PLB;
case READ_PLB:
m_socket->read(&c, 1);
- if (m_socket->isClosed()) return;
- m_hasMask = c & 0x80;
- if (m_enforceServer && !m_hasMask) goto error;
- if (m_enforceClient && m_hasMask) goto error;
- m_payloadLen = c & 0x7f;
+ if (m_socket->isClosed())
+ return;
+ *hasMaskP = c & 0x80;
+ if (m_enforceServer && !*hasMaskP)
+ goto error;
+ if (m_enforceClient && *hasMaskP)
+ goto error;
+ *payloadLenP = c & 0x7f;
m_state = READ_PLL;
- if (m_payloadLen == 126) {
- m_payloadLen = 0;
- m_remainingBytes = 2;
- } else if (m_payloadLen == 127) {
- m_payloadLen = 0;
- m_remainingBytes = 8;
+ if (*payloadLenP == 126) {
+ *payloadLenP = 0;
+ *remainingBytesP = 2;
+ } else if (*payloadLenP == 127) {
+ *payloadLenP = 0;
+ *remainingBytesP = 8;
} else {
- m_remainingBytes = 0;
+ *remainingBytesP = 0;
}
case READ_PLL:
- while (m_remainingBytes) {
+ while (*remainingBytesP) {
m_socket->read(&c, 1);
- if (m_socket->isClosed()) return;
- m_payloadLen <<= 8;
- m_payloadLen += c;
- m_remainingBytes--;
+ if (m_socket->isClosed())
+ return;
+ *payloadLenP <<= 8;
+ *payloadLenP += c;
+ *remainingBytesP -= 1;
}
m_state = READ_MK;
- if (m_firstFragment)
- m_totalLen = m_payloadLen;
+ if (m_firstFragment || m_inCTRL)
+ *totalLenP = *payloadLenP;
else
- m_totalLen += m_payloadLen;
- if (m_hasMask) m_remainingBytes = 4;
+ *totalLenP += *payloadLenP;
+ if (*hasMaskP) *remainingBytesP = 4;
case READ_MK:
- while (m_remainingBytes) {
+ while (*remainingBytesP) {
m_socket->read(&c, 1);
- if (m_socket->isClosed()) return;
- m_mask <<= 8;
- m_mask += c;
- m_remainingBytes--;
+ if (m_socket->isClosed())
+ return;
+ *maskP <<= 8;
+ *maskP += c;
+ *remainingBytesP -= 1;
}
m_state = READ_PL;
- m_remainingBytes = m_payloadLen;
- if (m_totalLen >= MAX_WEBSOCKET_LIMIT)
+ *remainingBytesP = *payloadLenP;
+ if (*totalLenP >= MAX_WEBSOCKET_LIMIT)
goto error;
- m_payload = (uint8_t *)realloc(m_payload, m_totalLen + (m_opcode == OPCODE_TEXT ? 1 : 0));
+ *payloadP = (uint8_t *)realloc(*payloadP, *totalLenP + (*opcodeP == OPCODE_TEXT ? 1 : 0));
case READ_PL:
- while (m_remainingBytes) {
- int r = m_socket->read(m_payload + m_totalLen - m_remainingBytes, m_remainingBytes);
- if (m_socket->isClosed()) return;
- if (r < 0) goto error;
- m_remainingBytes -= r;
+ while (*remainingBytesP) {
+ int r = m_socket->read(*payloadP + *totalLenP - *remainingBytesP, *remainingBytesP);
+ if (m_socket->isClosed())
+ return;
+ if (r < 0)
+ goto error;
+ *remainingBytesP -= r;
}
- m_firstFragment = m_fin;
-
- if (m_fin) {
- if (m_hasMask) {
- for (int i = 0; i < m_totalLen; i++) {
- m_payload[i] ^= m_mask >> 24;
- m_mask = rotate(m_mask);
+ if (!m_inCTRL)
+ m_firstFragment = m_fin;
+
+ if (m_fin || m_inCTRL) {
+ uint8_t * payload = *payloadP;
+ uint64_t totalLen = *totalLenP;
+ if (*hasMaskP) {
+ uint32_t mask = *maskP;
+ for (int i = 0; i < totalLen; i++) {
+ payload[i] ^= mask >> 24;
+ mask = rotate(mask);
}
}
- if (m_opcode == OPCODE_TEXT)
- m_payload[m_payloadLen] = 0;
+ if (*opcodeP == OPCODE_TEXT)
+ payload[totalLen] = 0;
processMessage();
}
@@ -213,7 +274,7 @@ void Balau::WebSocketWorker::processMessage() {
}
void Balau::WebSocketWorker::processPing() {
- sendFrame(new WebSocketFrame(m_payload, m_payloadLen, OPCODE_PING, m_enforceClient));
+ sendFrame(new WebSocketFrame(m_payloadCTRL, m_payloadLenCTRL, OPCODE_PING, m_enforceClient));
}
void Balau::WebSocketWorker::processPong() {