summaryrefslogtreecommitdiff
path: root/includes/BWebSocket.h
blob: 571405302d1c4ea95138b03abdab34052f620a79 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#pragma once

#include <Task.h>
#include <StacklessTask.h>
#include <BStream.h>
#include <HttpServer.h>

namespace Balau {

class WebSocketActionBase;

class WebSocketFrame {
  public:
      WebSocketFrame(const String & str, uint8_t opcode = 1, bool mask = false) : WebSocketFrame((uint8_t *) str.to_charp(), str.strlen(), opcode, mask) { }
      WebSocketFrame(size_t len, uint8_t opcode = 1, bool mask = false) : WebSocketFrame(NULL, len, opcode, mask) { }
      WebSocketFrame(const uint8_t * data, size_t len, uint8_t opcode = 1, bool mask = false);
      ~WebSocketFrame() { free(m_data); }
    uint8_t & operator[](size_t idx);
    uint8_t * getPtr() { return m_data + m_headerSize; }
    void send(IO<Handle> socket);
  private:
    uint8_t * m_data = NULL;
    size_t m_len = 0;
    size_t m_headerSize = 0;
    uint32_t m_mask = 'BLAH';
    size_t m_bytesSent = 0;
};

class WebSocketWorker : public StacklessTask {
  public:
    virtual bool parse(Http::Request & req) { return true; }
    void sendFrame(WebSocketFrame * frame) { m_sendQueue.push(frame); }
    void enforceServer(void) throw (GeneralException);
    void enforceClient(void) throw (GeneralException);
  protected:
      WebSocketWorker(IO<Handle> socket, const String & url) : m_socket(new BStream(socket)) { m_name = String("WebSocket:") + url + ":" + m_socket->getName(); }
      ~WebSocketWorker();
    void disconnect() { m_socket->close(); }
    virtual void receiveMessage(const uint8_t * msg, size_t len, bool binary) = 0;
    virtual void Do();
private:
    void processMessage();
    void processPing();
    void processPong();
    virtual const char * getName() const { return m_name.to_charp(); }
    String m_name;
    IO<BStream> m_socket;
    enum {
        READ_H,
        READ_PLB,
        READ_PLL,
        READ_MK,
        READ_PL,
    } m_status = READ_H;
    enum { MAX_WEBSOCKET_LIMIT = 4 * 1024 * 1024 };
    uint8_t * m_payload = NULL;
    WebSocketFrame * m_sending = NULL;
    TQueue<WebSocketFrame> m_sendQueue;
    uint64_t m_payloadLen;
    uint64_t m_totalLen;
    uint64_t m_remainingBytes;
    uint32_t m_mask;
    uint8_t m_opcode;

    uint8_t * m_payloadCTRL = NULL;
    uint64_t m_payloadLenCTRL;
    uint64_t m_totalLenCTRL;
    uint64_t m_remainingBytesCTRL;
    uint32_t m_maskCTRL;
    uint8_t m_opcodeCTRL;

    bool m_hasMask;
    bool m_hasMaskCTRL;

    bool m_fin;
    bool m_firstFragment = true;
    bool m_enforceServer = false;
    bool m_enforceClient = false;
    bool m_inCTRL = false;
    enum {
        OPCODE_CONT  =  0,
        OPCODE_TEXT  =  1,
        OPCODE_BIN   =  2,
        OPCODE_CLOSE =  8,
        OPCODE_PING  =  9,
        OPCODE_PONG  = 10,
    };
    friend class WebSocketActionBase;
};

class WebSocketServerBase : public HttpServer::Action {
  protected:
      WebSocketServerBase(const Regex & regex) : Action(regex) { }
    virtual WebSocketWorker * spawnWorker(IO<Handle> socket, const String & url) = 0;
  private:
    void sendError(IO<Handle> out, const char * serverName);
    bool Do(HttpServer * server, Http::Request & req, HttpServer::Action::ActionMatch & match, IO<Handle> out) throw (GeneralException);
};

template<class T>
class WebSocketServer : public WebSocketServerBase {
  protected:
      WebSocketServer(const Regex & regex) : WebSocketServerBase(regex) { }
    virtual WebSocketWorker * spawnWorker(IO<Handle> socket, const String & url) { return new T(socket, url); }
};

};