From 2f93a2e442c251d0d9de5e828a66acd40086e28d Mon Sep 17 00:00:00 2001 From: "Nicolas \"Pixel\" Noble" Date: Thu, 7 Aug 2014 15:51:33 -0700 Subject: Adding full c-ares support into Balau - untested. --- src/Socket.cc | 176 ++++++++++++++++++++++----------------------------------- src/TaskMan.cc | 143 ++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 198 insertions(+), 121 deletions(-) (limited to 'src') diff --git a/src/Socket.cc b/src/Socket.cc index 7f91e7b..f27b85a 100644 --- a/src/Socket.cc +++ b/src/Socket.cc @@ -1,3 +1,4 @@ +#include #ifndef _WIN32 #include #include @@ -160,51 +161,6 @@ static const char * inet_ntop(int af, const void * src, char * dst, socklen_t si #endif -namespace Balau { - -struct DNSRequest { - struct addrinfo * res; - int error; - Balau::Events::Custom evt; -}; - -}; - -namespace { - -class AsyncOpResolv : public Balau::AsyncOperation { - public: - AsyncOpResolv(const char * name, const char * service, struct addrinfo * hints, Balau::DNSRequest * request) - : m_name(name ? ::strdup(name) : NULL) - , m_service(service ? ::strdup(service) : NULL) - , m_hints(*hints) - , m_request(request) - { } - virtual ~AsyncOpResolv() { free(m_name); free(m_service); } - virtual bool needsMainQueue() { return false; } - virtual bool needsFinishWorker() { return true; } - virtual void run() { - m_request->error = getaddrinfo(m_name, m_service, &m_hints, &m_request->res); - } - virtual void done() { - m_request->evt.doSignal(); - delete this; - } - private: - char * m_name; - char * m_service; - struct addrinfo m_hints; - Balau::DNSRequest * m_request; -}; - -}; - -static Balau::DNSRequest * resolveName(const char * name, const char * service = NULL, struct addrinfo * hints = NULL) { - Balau::DNSRequest * req = new Balau::DNSRequest(); - Balau::createAsyncOp(new AsyncOpResolv(name, service, hints, req)); - return req; -} - Balau::Socket::Socket() throw (GeneralException) { #ifdef _WIN32 int fd = _open_osfhandle(WSASocket(AF_INET6, SOCK_STREAM, 0, 0, 0, 0), 0); @@ -271,52 +227,75 @@ bool Balau::Socket::canRead() { return true; } bool Balau::Socket::canWrite() { return true; } const char * Balau::Socket::getName() { return m_name.to_charp(); } -bool Balau::Socket::resolved() { - return m_req && m_req->evt.gotSignal(); +void Balau::Socket::resolve(const char * hostname) { + if (!m_resolving) { + m_resolving = 2; + Task * t = Task::getCurrentTask(); + auto callback = [&](int status, int timeouts, struct hostent * hostent, int family, ptrdiff_t srcOffset, void * destAddr, size_t sizeofDest, bool & failed) { + if (status == ARES_SUCCESS) { + IAssert(hostent->h_addrtype == family, "We asked for socket family %i, but got %i instead", family, hostent->h_addrtype); + memcpy(destAddr, ((uint8_t *)hostent->h_addr_list[0]) + srcOffset, sizeofDest); + } + else { + failed = true; + } + if (--m_resolving == 0) { + m_resolveEvent.doSignal(); + m_resolving = false; + m_resolved = true; + } + }; + + t->getTaskMan()->getHostByName(hostname, AF_INET, [&](int status, int timeouts, struct hostent * hostent) { callback(status, timeouts, hostent, AF_INET, offsetof(struct sockaddr_in, sin_addr), &m_resolvedAddr4, sizeof(m_resolvedAddr4), m_resolve4Failed); }); + t->getTaskMan()->getHostByName(hostname, AF_INET6, [&](int status, int timeouts, struct hostent * hostent) { callback(status, timeouts, hostent, AF_INET6, offsetof(struct sockaddr_in6, sin6_addr), &m_resolvedAddr6, sizeof(m_resolvedAddr6), m_resolve6Failed); }); + + Task::operationYield(&m_resolveEvent, Task::INTERRUPTIBLE); + } +} + +void Balau::Socket::initAddr(sockaddr_in6 & out) { + out.sin6_family = AF_INET6; + out.sin6_port = 0; + out.sin6_flowinfo = 0; + out.sin6_addr = in6addr_any; +} + +void Balau::Socket::resolved(sockaddr_in6 & out) { + if (!m_resolve6Failed) { + memcpy(&out.sin6_addr, &m_resolvedAddr6, sizeof(struct in6_addr)); + } + else { + memset(&out.sin6_addr, 0, sizeof(struct in6_addr)); + // v4 mapped IPv6 address + out.sin6_addr.s6_addr[10] = 0xff; + out.sin6_addr.s6_addr[11] = 0xff; + memcpy(out.sin6_addr.s6_addr + 12, &m_resolvedAddr4, sizeof(struct in_addr)); + } + m_resolving = false; + m_resolved = false; } bool Balau::Socket::setLocal(const char * hostname, int port) { AAssert(m_localAddr.sin6_family == 0, "Can't call setLocal twice"); - if (hostname && hostname[0] && !m_req) { - struct addrinfo hints; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_INET6; - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = IPPROTO_TCP; - hints.ai_flags = AI_V4MAPPED; + if (hostname && hostname[0]) + resolve(hostname); - m_req = resolveName(hostname, NULL, &hints); - Task::operationYield(&m_req->evt, Task::INTERRUPTIBLE); - } + initAddr(m_localAddr); - if (m_req) { - AAssert(m_req->evt.gotSignal(), "Please don't call setLocal after a EAgain without checking its resolution status first."); - struct addrinfo * res = m_req->res; - if (m_req->error != 0) { - Printer::elog(E_SOCKET, "Got a resolution error for host %s: %s (%i)", hostname, gai_strerror(m_req->error), m_req->error); - if (res) - freeaddrinfo(res); - delete m_req; - m_req = NULL; + if (m_resolving || m_resolved) { + AAssert(m_resolved && !m_resolving, "Please don't call setLocal after a EAgain without checking its resolution status first."); + if (m_resolve4Failed && m_resolve6Failed) { + Printer::elog(E_SOCKET, "Got a resolution error for host %s", hostname); + m_resolved = false; return false; } - IAssert(res, "That really shouldn't happen..."); - EAssert(res->ai_family == AF_INET6, "getaddrinfo returned a familiy which isn't AF_INET6; %i", res->ai_family); - EAssert(res->ai_protocol == IPPROTO_TCP, "getaddrinfo returned a protocol which isn't IPPROTO_TCP; %i", res->ai_protocol); - EAssert(res->ai_addrlen == sizeof(sockaddr_in6), "getaddrinfo returned an addrlen which isn't that of sizeof(sockaddr_in6); %i", res->ai_addrlen); - memcpy(&m_localAddr.sin6_addr, &((sockaddr_in6 *) res->ai_addr)->sin6_addr, sizeof(struct in6_addr)); - freeaddrinfo(res); - delete m_req; - m_req = NULL; - } else { - m_localAddr.sin6_addr = in6addr_any; + resolved(m_localAddr); } if (port) m_localAddr.sin6_port = htons(port); - m_localAddr.sin6_family = AF_INET6; #ifndef _WIN32 int enable = 1; setsockopt(getFD(), SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)); @@ -334,47 +313,24 @@ bool Balau::Socket::connect(const char * hostname, int port) { AAssert(hostname, "You can't call Socket::connect() without a hostname"); AAssert(!isClosed(), "You can't call Socket::connect() on a closed socket"); - if (!m_connecting && !m_req) { - Printer::elog(E_SOCKET, "Resolving %s", hostname); - IAssert(m_remoteAddr.sin6_family == 0, "That shouldn't happen...; family = %i", m_remoteAddr.sin6_family); - - struct addrinfo hints; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_INET6; - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = IPPROTO_TCP; - hints.ai_flags = AI_V4MAPPED; + if (!m_connecting && !m_resolving) + resolve(hostname); - m_req = resolveName(hostname, NULL, &hints); - Task::operationYield(&m_req->evt, Task::INTERRUPTIBLE); - } - - if (!m_connecting && m_req) { - AAssert(m_req->evt.gotSignal(), "Please don't call connect after a EAgain without checking its resolution status first."); - struct addrinfo * res = m_req->res; - if (m_req->error != 0) { - Printer::elog(E_SOCKET, "Got a resolution error for host %s: %s (%i)", hostname, gai_strerror(m_req->error), m_req->error); - if (res) - freeaddrinfo(res); - delete m_req; - m_req = NULL; + if (!m_connecting && (m_resolving || m_resolved)) { + AAssert(m_resolved && !m_resolving, "Please don't call connect after a EAgain without checking its resolution status first."); + if (m_resolve4Failed && m_resolve6Failed) { + Printer::elog(E_SOCKET, "Got a resolution error for host %s", hostname); + m_resolved = false; return false; } - IAssert(res, "That really shouldn't happen..."); Printer::elog(E_SOCKET, "Got a resolution answer"); - EAssert(res->ai_family == AF_INET6, "getaddrinfo returned a familiy which isn't AF_INET6; %i", res->ai_family); - EAssert(res->ai_protocol == IPPROTO_TCP, "getaddrinfo returned a protocol which isn't IPPROTO_TCP; %i", res->ai_protocol); - EAssert(res->ai_addrlen == sizeof(sockaddr_in6), "getaddrinfo returned an addrlen which isn't that of sizeof(sockaddr_in6); %i", res->ai_addrlen); - memcpy(&m_remoteAddr.sin6_addr, &((sockaddr_in6 *) res->ai_addr)->sin6_addr, sizeof(struct in6_addr)); + initAddr(m_remoteAddr); + resolved(m_remoteAddr); m_remoteAddr.sin6_port = htons(port); - m_remoteAddr.sin6_family = AF_INET6; m_connecting = true; - - freeaddrinfo(res); - delete m_req; - m_req = NULL; + m_resolved = false; } else { // if we end up there, it means our yield earlier threw an EAgain exception. AAssert(gotR(), "Please don't call connect after a EAgain without checking its signal first."); diff --git a/src/TaskMan.cc b/src/TaskMan.cc index fce7ffe..5e73ba1 100644 --- a/src/TaskMan.cc +++ b/src/TaskMan.cc @@ -1,10 +1,3 @@ -#ifdef _WIN32 -#include -#include -#endif - -#undef ERROR - #include "Async.h" #include "TaskMan.h" #include "Task.h" @@ -12,8 +5,16 @@ #include "Local.h" #include "CurlTask.h" +#include #include +#ifdef _WIN32 +#include +#include +#endif + +#undef ERROR + static Balau::AsyncManager s_async; static CURLSH * s_curlShared = NULL; @@ -39,9 +40,9 @@ class Stopper : public Balau::Task { int m_code; }; -class CurlSharedManager : public Balau::AtStart, Balau::AtExit { +class CurlAndCaresSharedManager : public Balau::AtStart, Balau::AtExit { public: - CurlSharedManager() : AtStart(0), AtExit(0) { } + CurlAndCaresSharedManager() : AtStart(0), AtExit(0) { } struct SharedLocks { Balau::RWLock share, cookie, dns, ssl_session; }; @@ -83,8 +84,12 @@ class CurlSharedManager : public Balau::AtStart, Balau::AtExit { curl_share_setopt(s_curlShared, CURLSHOPT_USERDATA, &locks); curl_share_setopt(s_curlShared, CURLSHOPT_LOCKFUNC, lock_function); curl_share_setopt(s_curlShared, CURLSHOPT_UNLOCKFUNC, unlock_function); + + ares_library_init(ARES_LIB_INIT_ALL); } void doExit() { + ares_library_cleanup(); + curl_share_cleanup(s_curlShared); curl_global_cleanup(); } @@ -93,7 +98,7 @@ class CurlSharedManager : public Balau::AtStart, Balau::AtExit { }; static AsyncStarter s_asyncStarter; -static CurlSharedManager s_curlSharedmManager; +static CurlAndCaresSharedManager s_curlSharedmManager; void Stopper::Do() { getTaskMan()->stopMe(m_code); @@ -247,6 +252,21 @@ Balau::TaskMan::TaskMan() { m_curlTimer.set(m_loop); m_curlTimer.set(this); + + m_aresTimer.set(m_loop); + m_aresTimer.set(this); + + ares_options aresOptions; + + aresOptions.sock_state_cb = aresSocketCallbackStatic; + aresOptions.sock_state_cb_data = this; + + ares_init_options(&m_aresChannel, &aresOptions, ARES_OPT_SOCK_STATE_CB); + + for (int i = 0; i < ARES_MAX_SOCKETS; i++) { + m_aresSockets[i] = ARES_SOCKET_BAD; + m_aresSocketEvents[i] = NULL; + } } #ifdef _WIN32 @@ -258,7 +278,7 @@ inline static int toSocket(int fd) { return fd; } #endif int Balau::TaskMan::curlSocketCallbackStatic(CURL * easy, curl_socket_t s, int what, void * userp, void * socketp) { - TaskMan * taskMan = (TaskMan *)userp; + TaskMan * taskMan = (TaskMan *) userp; return taskMan->curlSocketCallback(easy, s, what, socketp); } @@ -331,6 +351,95 @@ void Balau::TaskMan::curlMultiTimerEventCallback(ev::timer & w, int revents) { curl_multi_socket_action(m_curlMulti, CURL_SOCKET_TIMEOUT, 0, &m_curlStillRunning); } +void Balau::TaskMan::aresSocketCallbackStatic(void * data, curl_socket_t s, int read, int write) { + TaskMan * taskMan = (TaskMan *) data; + return taskMan->aresSocketCallback(s, read, write); +} + +void Balau::TaskMan::aresSocketCallback(curl_socket_t s, int read, int write) { + int fd = fromSocket(s); + int i; + int freeSlot = ARES_MAX_SOCKETS; + + int what = CURL_POLL_NONE; + + for (i = 0; i < ARES_MAX_SOCKETS; i++) { + if (m_aresSockets[i] == s) + break; + if (m_aresSockets[i] == ARES_SOCKET_BAD) + freeSlot = i; + } + + if (i == ARES_MAX_SOCKETS) + i = freeSlot; + + IAssert(i != ARES_MAX_SOCKETS, "ares socket error - please increase ARES_MAX_SOCKETS"); + + if (!read && !write) { + what = CURL_POLL_REMOVE; + } else if (read && !write) { + what = CURL_POLL_IN; + } else if (!read && write) { + what = CURL_POLL_OUT; + } else if (read && write) { + what = CURL_POLL_INOUT; + } + + struct timeval tv; + bool hasTimer = ares_timeout(m_aresChannel, NULL, &tv); + + m_aresTimer.stop(); + if (hasTimer) { + m_aresTimer.set((ev_tstamp)(tv.tv_sec * 1000 + tv.tv_usec / 1000 + 1)); + m_aresTimer.start(); + } + + ev::io * evt = m_aresSocketEvents[i]; + if (!evt) { + if (what == CURL_POLL_REMOVE) + return; + evt = new ev::io; + evt->set(this); + evt->set(m_loop); + m_aresSocketEvents[i] = evt; + m_aresSockets[i] = s; + } + + switch (what) { + case CURL_POLL_IN: + evt->stop(); + evt->set(fd, ev::READ); + evt->start(); + break; + case CURL_POLL_OUT: + evt->stop(); + evt->set(fd, ev::WRITE); + evt->start(); + break; + case CURL_POLL_INOUT: + evt->stop(); + evt->set(fd, ev::READ | ev::WRITE); + evt->start(); + break; + case CURL_POLL_REMOVE: + evt->stop(); + delete evt; + m_aresSocketEvents[i] = NULL; + m_aresSockets[i] = ARES_SOCKET_BAD; + } + + return; +} + +void Balau::TaskMan::aresSocketEventCallback(ev::io & w, int revents) { + ares_socket_t s = toSocket(w.fd); + ares_process_fd(m_aresChannel, revents & (ev::READ | ev::ERROR) ? s : ARES_SOCKET_BAD, revents & (ev::WRITE | ev::ERROR) ? s : ARES_SOCKET_BAD); +} + +void Balau::TaskMan::aresTimerEventCallback(ev::timer & w, int revents) { + ares_process(m_aresChannel, NULL, NULL); +} + #ifdef _WIN32 namespace { @@ -362,6 +471,7 @@ Balau::TaskMan::~TaskMan() { m_evt.stop(); ev_loop_destroy(m_loop); curl_multi_cleanup(m_curlMulti); + ares_destroy(m_aresChannel); } void * Balau::TaskMan::getStack() { @@ -553,6 +663,17 @@ void Balau::TaskMan::unregisterCurlHandle(Balau::CurlTask * curlTask) { curl_multi_remove_handle(m_curlMulti, curlTask->m_curlHandle); } +void Balau::TaskMan::getHostByName(const Balau::String & name, int family, AresHostCallback callback) { + AresHostCallback * dup = new AresHostCallback(callback); + ares_gethostbyname(m_aresChannel, name.to_charp(), family, aresHostCallback, dup); +} + +void Balau::TaskMan::aresHostCallback(void * arg, int status, int timeouts, struct hostent * hostent) { + AresHostCallback * callback = (AresHostCallback *) arg; + (*callback)(status, timeouts, hostent); + delete callback; +} + void Balau::TaskMan::iRegisterTask(Balau::Task * t, Balau::Task * stick, Events::TaskEvent * event) { if (stick) { IAssert(!event, "inconsistent"); -- cgit v1.2.3