diff options
-rw-r--r-- | includes/BWebSocket.h | 11 | ||||
-rw-r--r-- | src/BWebSocket.cc | 163 |
2 files changed, 123 insertions, 51 deletions
diff --git a/includes/BWebSocket.h b/includes/BWebSocket.h index 0d0a56f..5714053 100644 --- a/includes/BWebSocket.h +++ b/includes/BWebSocket.h @@ -61,11 +61,22 @@ private: uint64_t m_remainingBytes; uint32_t m_mask; uint8_t m_opcode; + + uint8_t * m_payloadCTRL = NULL; + uint64_t m_payloadLenCTRL; + uint64_t m_totalLenCTRL; + uint64_t m_remainingBytesCTRL; + uint32_t m_maskCTRL; + uint8_t m_opcodeCTRL; + bool m_hasMask; + bool m_hasMaskCTRL; + bool m_fin; bool m_firstFragment = true; bool m_enforceServer = false; bool m_enforceClient = false; + bool m_inCTRL = false; enum { OPCODE_CONT = 0, OPCODE_TEXT = 1, diff --git a/src/BWebSocket.cc b/src/BWebSocket.cc index 3dded95..9a3b420 100644 --- a/src/BWebSocket.cc +++ b/src/BWebSocket.cc @@ -79,6 +79,7 @@ void Balau::WebSocketFrame::send(Balau::IO<Balau::Handle> socket) { Balau::WebSocketWorker::~WebSocketWorker() { free(m_payload); + free(m_payloadCTRL); delete m_sending; while (!m_sendQueue.isEmpty()) delete m_sendQueue.pop(); @@ -87,6 +88,37 @@ Balau::WebSocketWorker::~WebSocketWorker() { void Balau::WebSocketWorker::Do() { uint8_t c; + uint8_t ** payloadP; + uint64_t * payloadLenP; + uint64_t * totalLenP; + uint64_t * remainingBytesP; + uint32_t * maskP; + uint8_t * opcodeP; + bool * hasMaskP; + int fin; + + std::function<void()> switchPacketType = [&]() { + if (m_inCTRL) { + payloadP = &m_payloadCTRL; + payloadLenP = &m_payloadLenCTRL; + totalLenP = &m_totalLenCTRL; + remainingBytesP = &m_remainingBytesCTRL; + maskP = &m_maskCTRL; + opcodeP = &m_opcodeCTRL; + hasMaskP = &m_hasMaskCTRL; + } else { + payloadP = &m_payload; + payloadLenP = &m_payloadLen; + totalLenP = &m_totalLen; + remainingBytesP = &m_remainingBytes; + maskP = &m_mask; + opcodeP = &m_opcode; + hasMaskP = &m_hasMask; + } + }; + + switchPacketType(); + waitFor(m_sendQueue.getEvent()); m_sendQueue.getEvent()->resetMaybe(); @@ -100,7 +132,8 @@ void Balau::WebSocketWorker::Do() { m_sending = m_sendQueue.pop(); else break; - if (m_socket->isClosed()) return; + if (m_socket->isClosed()) + return; } delete m_sending; @@ -109,77 +142,105 @@ void Balau::WebSocketWorker::Do() { switch (m_state) { case READ_H: m_socket->read(&c, 1); - if (m_socket->isClosed()) return; - m_fin = c & 0x80; - if ((c >> 4) & 7) goto error; + if (m_socket->isClosed()) + return; + fin = c & 0x80; + if ((c >> 4) & 7) + goto error; c &= 15; - if (!m_firstFragment && c) goto error; - if (m_firstFragment) - m_opcode = c; + if (!m_firstFragment && c) + goto error; + if (m_firstFragment) { + bool wasInCtrl = m_inCTRL; + m_inCTRL = c & 8; + if (wasInCtrl != m_inCTRL) + switchPacketType(); + *opcodeP = c; + } else { + bool wasInCtrl = m_inCTRL; + m_inCTRL = false; + if (wasInCtrl != m_inCTRL) + switchPacketType(); + } + if (!m_inCTRL) + m_fin = fin; + else if (!fin) + goto error; m_state = READ_PLB; case READ_PLB: m_socket->read(&c, 1); - if (m_socket->isClosed()) return; - m_hasMask = c & 0x80; - if (m_enforceServer && !m_hasMask) goto error; - if (m_enforceClient && m_hasMask) goto error; - m_payloadLen = c & 0x7f; + if (m_socket->isClosed()) + return; + *hasMaskP = c & 0x80; + if (m_enforceServer && !*hasMaskP) + goto error; + if (m_enforceClient && *hasMaskP) + goto error; + *payloadLenP = c & 0x7f; m_state = READ_PLL; - if (m_payloadLen == 126) { - m_payloadLen = 0; - m_remainingBytes = 2; - } else if (m_payloadLen == 127) { - m_payloadLen = 0; - m_remainingBytes = 8; + if (*payloadLenP == 126) { + *payloadLenP = 0; + *remainingBytesP = 2; + } else if (*payloadLenP == 127) { + *payloadLenP = 0; + *remainingBytesP = 8; } else { - m_remainingBytes = 0; + *remainingBytesP = 0; } case READ_PLL: - while (m_remainingBytes) { + while (*remainingBytesP) { m_socket->read(&c, 1); - if (m_socket->isClosed()) return; - m_payloadLen <<= 8; - m_payloadLen += c; - m_remainingBytes--; + if (m_socket->isClosed()) + return; + *payloadLenP <<= 8; + *payloadLenP += c; + *remainingBytesP -= 1; } m_state = READ_MK; - if (m_firstFragment) - m_totalLen = m_payloadLen; + if (m_firstFragment || m_inCTRL) + *totalLenP = *payloadLenP; else - m_totalLen += m_payloadLen; - if (m_hasMask) m_remainingBytes = 4; + *totalLenP += *payloadLenP; + if (*hasMaskP) *remainingBytesP = 4; case READ_MK: - while (m_remainingBytes) { + while (*remainingBytesP) { m_socket->read(&c, 1); - if (m_socket->isClosed()) return; - m_mask <<= 8; - m_mask += c; - m_remainingBytes--; + if (m_socket->isClosed()) + return; + *maskP <<= 8; + *maskP += c; + *remainingBytesP -= 1; } m_state = READ_PL; - m_remainingBytes = m_payloadLen; - if (m_totalLen >= MAX_WEBSOCKET_LIMIT) + *remainingBytesP = *payloadLenP; + if (*totalLenP >= MAX_WEBSOCKET_LIMIT) goto error; - m_payload = (uint8_t *)realloc(m_payload, m_totalLen + (m_opcode == OPCODE_TEXT ? 1 : 0)); + *payloadP = (uint8_t *)realloc(*payloadP, *totalLenP + (*opcodeP == OPCODE_TEXT ? 1 : 0)); case READ_PL: - while (m_remainingBytes) { - int r = m_socket->read(m_payload + m_totalLen - m_remainingBytes, m_remainingBytes); - if (m_socket->isClosed()) return; - if (r < 0) goto error; - m_remainingBytes -= r; + while (*remainingBytesP) { + int r = m_socket->read(*payloadP + *totalLenP - *remainingBytesP, *remainingBytesP); + if (m_socket->isClosed()) + return; + if (r < 0) + goto error; + *remainingBytesP -= r; } - m_firstFragment = m_fin; - - if (m_fin) { - if (m_hasMask) { - for (int i = 0; i < m_totalLen; i++) { - m_payload[i] ^= m_mask >> 24; - m_mask = rotate(m_mask); + if (!m_inCTRL) + m_firstFragment = m_fin; + + if (m_fin || m_inCTRL) { + uint8_t * payload = *payloadP; + uint64_t totalLen = *totalLenP; + if (*hasMaskP) { + uint32_t mask = *maskP; + for (int i = 0; i < totalLen; i++) { + payload[i] ^= mask >> 24; + mask = rotate(mask); } } - if (m_opcode == OPCODE_TEXT) - m_payload[m_payloadLen] = 0; + if (*opcodeP == OPCODE_TEXT) + payload[totalLen] = 0; processMessage(); } @@ -213,7 +274,7 @@ void Balau::WebSocketWorker::processMessage() { } void Balau::WebSocketWorker::processPing() { - sendFrame(new WebSocketFrame(m_payload, m_payloadLen, OPCODE_PING, m_enforceClient)); + sendFrame(new WebSocketFrame(m_payloadCTRL, m_payloadLenCTRL, OPCODE_PING, m_enforceClient)); } void Balau::WebSocketWorker::processPong() { |