diff options
| author | Nicolas "Pixel" Noble <pixel@nobis-crew.org> | 2014-08-07 15:51:33 -0700 | 
|---|---|---|
| committer | Nicolas "Pixel" Noble <pixel@nobis-crew.org> | 2014-08-07 15:51:33 -0700 | 
| commit | 2f93a2e442c251d0d9de5e828a66acd40086e28d (patch) | |
| tree | 987f6181c3ae5cdc402e7d400266b68c2fa55985 /src | |
| parent | d36ef7155563734d372d4bb950c6945ce7fb8b1a (diff) | |
Adding full c-ares support into Balau - untested.
Diffstat (limited to 'src')
| -rw-r--r-- | src/Socket.cc | 176 | ||||
| -rw-r--r-- | src/TaskMan.cc | 143 | 
2 files changed, 198 insertions, 121 deletions
diff --git a/src/Socket.cc b/src/Socket.cc index 7f91e7b..f27b85a 100644 --- a/src/Socket.cc +++ b/src/Socket.cc @@ -1,3 +1,4 @@ +#include <ares.h>  #ifndef _WIN32  #include <arpa/inet.h>  #include <sys/socket.h> @@ -160,51 +161,6 @@ static const char * inet_ntop(int af, const void * src, char * dst, socklen_t si  #endif -namespace Balau { - -struct DNSRequest { -    struct addrinfo * res; -    int error; -    Balau::Events::Custom evt; -}; - -}; - -namespace { - -class AsyncOpResolv : public Balau::AsyncOperation { -  public: -      AsyncOpResolv(const char * name, const char * service, struct addrinfo * hints, Balau::DNSRequest * request) -        : m_name(name ? ::strdup(name) : NULL) -        , m_service(service ? ::strdup(service) : NULL) -        , m_hints(*hints) -        , m_request(request) -        { } -      virtual ~AsyncOpResolv() { free(m_name); free(m_service); } -    virtual bool needsMainQueue() { return false; } -    virtual bool needsFinishWorker() { return true; } -    virtual void run() { -        m_request->error = getaddrinfo(m_name, m_service, &m_hints, &m_request->res); -    } -    virtual void done() { -        m_request->evt.doSignal(); -        delete this; -    } -  private: -    char * m_name; -    char * m_service; -    struct addrinfo m_hints; -    Balau::DNSRequest * m_request; -}; - -}; - -static Balau::DNSRequest * resolveName(const char * name, const char * service = NULL, struct addrinfo * hints = NULL) { -    Balau::DNSRequest * req = new Balau::DNSRequest(); -    Balau::createAsyncOp(new AsyncOpResolv(name, service, hints, req)); -    return req; -} -  Balau::Socket::Socket() throw (GeneralException) {  #ifdef _WIN32      int fd = _open_osfhandle(WSASocket(AF_INET6, SOCK_STREAM, 0, 0, 0, 0), 0); @@ -271,52 +227,75 @@ bool Balau::Socket::canRead() { return true; }  bool Balau::Socket::canWrite() { return true; }  const char * Balau::Socket::getName() { return m_name.to_charp(); } -bool Balau::Socket::resolved() { -    return m_req && m_req->evt.gotSignal(); +void Balau::Socket::resolve(const char * hostname) { +    if (!m_resolving) { +        m_resolving = 2; +        Task * t = Task::getCurrentTask(); +        auto callback = [&](int status, int timeouts, struct hostent * hostent, int family, ptrdiff_t srcOffset, void * destAddr, size_t sizeofDest, bool & failed) { +            if (status == ARES_SUCCESS) { +                IAssert(hostent->h_addrtype == family, "We asked for socket family %i, but got %i instead", family, hostent->h_addrtype); +                memcpy(destAddr, ((uint8_t *)hostent->h_addr_list[0]) + srcOffset, sizeofDest); +            } +            else { +                failed = true; +            } +            if (--m_resolving == 0) { +                m_resolveEvent.doSignal(); +                m_resolving = false; +                m_resolved = true; +            } +        }; + +        t->getTaskMan()->getHostByName(hostname, AF_INET, [&](int status, int timeouts, struct hostent * hostent) { callback(status, timeouts, hostent, AF_INET, offsetof(struct sockaddr_in, sin_addr), &m_resolvedAddr4, sizeof(m_resolvedAddr4), m_resolve4Failed); }); +        t->getTaskMan()->getHostByName(hostname, AF_INET6, [&](int status, int timeouts, struct hostent * hostent) { callback(status, timeouts, hostent, AF_INET6, offsetof(struct sockaddr_in6, sin6_addr), &m_resolvedAddr6, sizeof(m_resolvedAddr6), m_resolve6Failed); }); + +        Task::operationYield(&m_resolveEvent, Task::INTERRUPTIBLE); +    } +} + +void Balau::Socket::initAddr(sockaddr_in6 & out) { +    out.sin6_family = AF_INET6; +    out.sin6_port = 0; +    out.sin6_flowinfo = 0; +    out.sin6_addr = in6addr_any; +} + +void Balau::Socket::resolved(sockaddr_in6 & out) { +    if (!m_resolve6Failed) { +        memcpy(&out.sin6_addr, &m_resolvedAddr6, sizeof(struct in6_addr)); +    } +    else { +        memset(&out.sin6_addr, 0, sizeof(struct in6_addr)); +        // v4 mapped IPv6 address +        out.sin6_addr.s6_addr[10] = 0xff; +        out.sin6_addr.s6_addr[11] = 0xff; +        memcpy(out.sin6_addr.s6_addr + 12, &m_resolvedAddr4, sizeof(struct in_addr)); +    } +    m_resolving = false; +    m_resolved = false;  }  bool Balau::Socket::setLocal(const char * hostname, int port) {      AAssert(m_localAddr.sin6_family == 0, "Can't call setLocal twice"); -    if (hostname && hostname[0] && !m_req) { -        struct addrinfo hints; -        memset(&hints, 0, sizeof(hints)); -        hints.ai_family = AF_INET6; -        hints.ai_socktype = SOCK_STREAM; -        hints.ai_protocol = IPPROTO_TCP; -        hints.ai_flags = AI_V4MAPPED; +    if (hostname && hostname[0]) +        resolve(hostname); -        m_req = resolveName(hostname, NULL, &hints); -        Task::operationYield(&m_req->evt, Task::INTERRUPTIBLE); -    } +    initAddr(m_localAddr); -    if (m_req) { -        AAssert(m_req->evt.gotSignal(), "Please don't call setLocal after a EAgain without checking its resolution status first."); -        struct addrinfo * res = m_req->res; -        if (m_req->error != 0) { -            Printer::elog(E_SOCKET, "Got a resolution error for host %s: %s (%i)", hostname, gai_strerror(m_req->error), m_req->error); -            if (res) -                freeaddrinfo(res); -            delete m_req; -            m_req = NULL; +    if (m_resolving || m_resolved) { +        AAssert(m_resolved && !m_resolving, "Please don't call setLocal after a EAgain without checking its resolution status first."); +        if (m_resolve4Failed && m_resolve6Failed) { +            Printer::elog(E_SOCKET, "Got a resolution error for host %s", hostname); +            m_resolved = false;              return false;          } -        IAssert(res, "That really shouldn't happen..."); -        EAssert(res->ai_family == AF_INET6, "getaddrinfo returned a familiy which isn't AF_INET6; %i", res->ai_family); -        EAssert(res->ai_protocol == IPPROTO_TCP, "getaddrinfo returned a protocol which isn't IPPROTO_TCP; %i", res->ai_protocol); -        EAssert(res->ai_addrlen == sizeof(sockaddr_in6), "getaddrinfo returned an addrlen which isn't that of sizeof(sockaddr_in6); %i", res->ai_addrlen); -        memcpy(&m_localAddr.sin6_addr, &((sockaddr_in6 *) res->ai_addr)->sin6_addr, sizeof(struct in6_addr)); -        freeaddrinfo(res); -        delete m_req; -        m_req = NULL; -    } else { -        m_localAddr.sin6_addr = in6addr_any; +        resolved(m_localAddr);      }      if (port)          m_localAddr.sin6_port = htons(port); -    m_localAddr.sin6_family = AF_INET6;  #ifndef _WIN32      int enable = 1;      setsockopt(getFD(), SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)); @@ -334,47 +313,24 @@ bool Balau::Socket::connect(const char * hostname, int port) {      AAssert(hostname, "You can't call Socket::connect() without a hostname");      AAssert(!isClosed(), "You can't call Socket::connect() on a closed socket"); -    if (!m_connecting && !m_req) { -        Printer::elog(E_SOCKET, "Resolving %s", hostname); -        IAssert(m_remoteAddr.sin6_family == 0, "That shouldn't happen...; family = %i", m_remoteAddr.sin6_family); - -        struct addrinfo hints; -        memset(&hints, 0, sizeof(hints)); -        hints.ai_family = AF_INET6; -        hints.ai_socktype = SOCK_STREAM; -        hints.ai_protocol = IPPROTO_TCP; -        hints.ai_flags = AI_V4MAPPED; +    if (!m_connecting && !m_resolving) +        resolve(hostname); -        m_req = resolveName(hostname, NULL, &hints); -        Task::operationYield(&m_req->evt, Task::INTERRUPTIBLE); -    } - -    if (!m_connecting && m_req) { -        AAssert(m_req->evt.gotSignal(), "Please don't call connect after a EAgain without checking its resolution status first."); -        struct addrinfo * res = m_req->res; -        if (m_req->error != 0) { -            Printer::elog(E_SOCKET, "Got a resolution error for host %s: %s (%i)", hostname, gai_strerror(m_req->error), m_req->error); -            if (res) -                freeaddrinfo(res); -            delete m_req; -            m_req = NULL; +    if (!m_connecting && (m_resolving || m_resolved)) { +        AAssert(m_resolved && !m_resolving, "Please don't call connect after a EAgain without checking its resolution status first."); +        if (m_resolve4Failed && m_resolve6Failed) { +            Printer::elog(E_SOCKET, "Got a resolution error for host %s", hostname); +            m_resolved = false;              return false;          } -        IAssert(res, "That really shouldn't happen...");          Printer::elog(E_SOCKET, "Got a resolution answer"); -        EAssert(res->ai_family == AF_INET6, "getaddrinfo returned a familiy which isn't AF_INET6; %i", res->ai_family); -        EAssert(res->ai_protocol == IPPROTO_TCP, "getaddrinfo returned a protocol which isn't IPPROTO_TCP; %i", res->ai_protocol); -        EAssert(res->ai_addrlen == sizeof(sockaddr_in6), "getaddrinfo returned an addrlen which isn't that of sizeof(sockaddr_in6); %i", res->ai_addrlen); -        memcpy(&m_remoteAddr.sin6_addr, &((sockaddr_in6 *) res->ai_addr)->sin6_addr, sizeof(struct in6_addr)); +        initAddr(m_remoteAddr); +        resolved(m_remoteAddr);          m_remoteAddr.sin6_port = htons(port); -        m_remoteAddr.sin6_family = AF_INET6;          m_connecting = true; - -        freeaddrinfo(res); -        delete m_req; -        m_req = NULL; +        m_resolved = false;      } else {          // if we end up there, it means our yield earlier threw an EAgain exception.          AAssert(gotR(), "Please don't call connect after a EAgain without checking its signal first."); diff --git a/src/TaskMan.cc b/src/TaskMan.cc index fce7ffe..5e73ba1 100644 --- a/src/TaskMan.cc +++ b/src/TaskMan.cc @@ -1,10 +1,3 @@ -#ifdef _WIN32 -#include <windows.h> -#include <io.h> -#endif - -#undef ERROR -  #include "Async.h"  #include "TaskMan.h"  #include "Task.h" @@ -12,8 +5,16 @@  #include "Local.h"  #include "CurlTask.h" +#include <ares.h>  #include <curl/curl.h> +#ifdef _WIN32 +#include <windows.h> +#include <io.h> +#endif + +#undef ERROR +  static Balau::AsyncManager s_async;  static CURLSH * s_curlShared = NULL; @@ -39,9 +40,9 @@ class Stopper : public Balau::Task {      int m_code;  }; -class CurlSharedManager : public Balau::AtStart, Balau::AtExit { +class CurlAndCaresSharedManager : public Balau::AtStart, Balau::AtExit {    public: -      CurlSharedManager() : AtStart(0), AtExit(0) { } +      CurlAndCaresSharedManager() : AtStart(0), AtExit(0) { }      struct SharedLocks {          Balau::RWLock share, cookie, dns, ssl_session;      }; @@ -83,8 +84,12 @@ class CurlSharedManager : public Balau::AtStart, Balau::AtExit {          curl_share_setopt(s_curlShared, CURLSHOPT_USERDATA, &locks);          curl_share_setopt(s_curlShared, CURLSHOPT_LOCKFUNC, lock_function);          curl_share_setopt(s_curlShared, CURLSHOPT_UNLOCKFUNC, unlock_function); + +        ares_library_init(ARES_LIB_INIT_ALL);      }      void doExit() { +        ares_library_cleanup(); +          curl_share_cleanup(s_curlShared);          curl_global_cleanup();      } @@ -93,7 +98,7 @@ class CurlSharedManager : public Balau::AtStart, Balau::AtExit {  };  static AsyncStarter s_asyncStarter; -static CurlSharedManager s_curlSharedmManager; +static CurlAndCaresSharedManager s_curlSharedmManager;  void Stopper::Do() {      getTaskMan()->stopMe(m_code); @@ -247,6 +252,21 @@ Balau::TaskMan::TaskMan() {      m_curlTimer.set(m_loop);      m_curlTimer.set<TaskMan, &TaskMan::curlMultiTimerEventCallback>(this); + +    m_aresTimer.set(m_loop); +    m_aresTimer.set<TaskMan, &TaskMan::aresTimerEventCallback>(this); + +    ares_options aresOptions; + +    aresOptions.sock_state_cb = aresSocketCallbackStatic; +    aresOptions.sock_state_cb_data = this; + +    ares_init_options(&m_aresChannel, &aresOptions, ARES_OPT_SOCK_STATE_CB); + +    for (int i = 0; i < ARES_MAX_SOCKETS; i++) { +        m_aresSockets[i] = ARES_SOCKET_BAD; +        m_aresSocketEvents[i] = NULL; +    }  }  #ifdef _WIN32 @@ -258,7 +278,7 @@ inline static int toSocket(int fd) { return fd; }  #endif  int Balau::TaskMan::curlSocketCallbackStatic(CURL * easy, curl_socket_t s, int what, void * userp, void * socketp) { -    TaskMan * taskMan = (TaskMan *)userp; +    TaskMan * taskMan = (TaskMan *) userp;      return taskMan->curlSocketCallback(easy, s, what, socketp);  } @@ -331,6 +351,95 @@ void Balau::TaskMan::curlMultiTimerEventCallback(ev::timer & w, int revents) {      curl_multi_socket_action(m_curlMulti, CURL_SOCKET_TIMEOUT, 0, &m_curlStillRunning);  } +void Balau::TaskMan::aresSocketCallbackStatic(void * data, curl_socket_t s, int read, int write) { +    TaskMan * taskMan = (TaskMan *) data; +    return taskMan->aresSocketCallback(s, read, write); +} + +void Balau::TaskMan::aresSocketCallback(curl_socket_t s, int read, int write) { +    int fd = fromSocket(s); +    int i; +    int freeSlot = ARES_MAX_SOCKETS; + +    int what = CURL_POLL_NONE; + +    for (i = 0; i < ARES_MAX_SOCKETS; i++) { +        if (m_aresSockets[i] == s) +            break; +        if (m_aresSockets[i] == ARES_SOCKET_BAD) +            freeSlot = i; +    } + +    if (i == ARES_MAX_SOCKETS) +        i = freeSlot; + +    IAssert(i != ARES_MAX_SOCKETS, "ares socket error - please increase ARES_MAX_SOCKETS"); + +    if (!read && !write) { +        what = CURL_POLL_REMOVE; +    } else if (read && !write) { +        what = CURL_POLL_IN; +    } else if (!read && write) { +        what = CURL_POLL_OUT; +    } else if (read && write) { +        what = CURL_POLL_INOUT; +    } + +    struct timeval tv; +    bool hasTimer = ares_timeout(m_aresChannel, NULL, &tv); + +    m_aresTimer.stop(); +    if (hasTimer) { +        m_aresTimer.set((ev_tstamp)(tv.tv_sec * 1000 + tv.tv_usec / 1000 + 1)); +        m_aresTimer.start(); +    } + +    ev::io * evt = m_aresSocketEvents[i]; +    if (!evt) { +        if (what == CURL_POLL_REMOVE) +            return; +        evt = new ev::io; +        evt->set<TaskMan, &TaskMan::aresSocketEventCallback>(this); +        evt->set(m_loop); +        m_aresSocketEvents[i] = evt; +        m_aresSockets[i] = s; +    } + +    switch (what) { +    case CURL_POLL_IN: +        evt->stop(); +        evt->set(fd, ev::READ); +        evt->start(); +        break; +    case CURL_POLL_OUT: +        evt->stop(); +        evt->set(fd, ev::WRITE); +        evt->start(); +        break; +    case CURL_POLL_INOUT: +        evt->stop(); +        evt->set(fd, ev::READ | ev::WRITE); +        evt->start(); +        break; +    case CURL_POLL_REMOVE: +        evt->stop(); +        delete evt; +        m_aresSocketEvents[i] = NULL; +        m_aresSockets[i] = ARES_SOCKET_BAD; +    } + +    return; +} + +void Balau::TaskMan::aresSocketEventCallback(ev::io & w, int revents) { +    ares_socket_t s = toSocket(w.fd); +    ares_process_fd(m_aresChannel, revents & (ev::READ | ev::ERROR) ? s : ARES_SOCKET_BAD, revents & (ev::WRITE | ev::ERROR) ? s : ARES_SOCKET_BAD); +} + +void Balau::TaskMan::aresTimerEventCallback(ev::timer & w, int revents) { +    ares_process(m_aresChannel, NULL, NULL); +} +  #ifdef _WIN32  namespace { @@ -362,6 +471,7 @@ Balau::TaskMan::~TaskMan() {      m_evt.stop();      ev_loop_destroy(m_loop);      curl_multi_cleanup(m_curlMulti); +    ares_destroy(m_aresChannel);  }  void * Balau::TaskMan::getStack() { @@ -553,6 +663,17 @@ void Balau::TaskMan::unregisterCurlHandle(Balau::CurlTask * curlTask) {      curl_multi_remove_handle(m_curlMulti, curlTask->m_curlHandle);  } +void Balau::TaskMan::getHostByName(const Balau::String & name, int family, AresHostCallback callback) { +    AresHostCallback * dup = new AresHostCallback(callback); +    ares_gethostbyname(m_aresChannel, name.to_charp(), family, aresHostCallback, dup); +} + +void Balau::TaskMan::aresHostCallback(void * arg, int status, int timeouts, struct hostent * hostent) { +    AresHostCallback * callback = (AresHostCallback *) arg; +    (*callback)(status, timeouts, hostent); +    delete callback; +} +  void Balau::TaskMan::iRegisterTask(Balau::Task * t, Balau::Task * stick, Events::TaskEvent * event) {      if (stick) {          IAssert(!event, "inconsistent");  | 
