diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Socket.cc | 400 |
1 files changed, 400 insertions, 0 deletions
diff --git a/src/Socket.cc b/src/Socket.cc new file mode 100644 index 0000000..762aae3 --- /dev/null +++ b/src/Socket.cc @@ -0,0 +1,400 @@ +#include <arpa/inet.h> +#include <sys/types.h> +#include <sys/socket.h> +#include <netdb.h> +#include <unistd.h> +#include <fcntl.h> +#include <errno.h> +#include "Socket.h" +#include "Threads.h" +#include "Printer.h" +#include "Main.h" + +void Balau::Socket::SocketEvent::gotOwner(Task * task) { + Printer::elog(E_SOCKET, "Arming SocketEvent at %p", this); + if (!m_task) { + Printer::elog(E_SOCKET, "...with a new task (%p)", task); + } else if (task == m_task) { + m_evt.start(); + return; + } else { + Printer::elog(E_SOCKET, "...with a new task (%p -> %p); stopping first", m_task, task); + m_evt.stop(); + } + m_task = task; + m_evt.set(task->getLoop()); + m_evt.start(); +} + +struct DNSRequest { + const char * name; + const char * service; + struct addrinfo * res; + struct addrinfo * hints; + Balau::Events::Async * evt; + int error; +}; + +#if 0 +// TODO: use getaddrinfo_a, if available. +#else +class ResolverThread : public Balau::Thread, public Balau::AtStart { + public: + ResolverThread() : AtStart(8) { } + virtual ~ResolverThread(); + void pushRequest(DNSRequest * req) { m_queue.push(req); } + private: + virtual void * proc(); + virtual void doStart(); + Balau::Queue<DNSRequest *> m_queue; +}; + +void ResolverThread::doStart() { + threadStart(); +} + +ResolverThread::~ResolverThread() { + DNSRequest req; + memset(&req, 0, sizeof(req)); + pushRequest(&req); +} + +void * ResolverThread::proc() { + DNSRequest * req; + DNSRequest stop; + memset(&stop, 0, sizeof(stop)); + while (true) { + req = m_queue.pop(); + if (memcmp(&stop, req, sizeof(stop)) == 0) + break; + Balau::Printer::elog(Balau::E_SOCKET, "Resolver thread got a request for `%s'", req->name); + req->error = getaddrinfo(req->name, req->service, req->hints, &req->res); + Balau::Printer::elog(Balau::E_SOCKET, "Resolver thread got an answer; sending signal"); + req->evt->trigger(); + } + return NULL; +} +#endif + +static ResolverThread resolverThread; + +static DNSRequest resolveName(const char * name, const char * service = NULL, struct addrinfo * hints = NULL) { + Balau::Events::Async evt; + DNSRequest req; + memset(&req, 0, sizeof(req)); + + req.name = name; + req.service = service; + req.hints = hints; + req.evt = &evt; + Balau::Printer::elog(Balau::E_SOCKET, "Sending a request to the resolver thread"); + Balau::Task::prepare(&evt); + resolverThread.pushRequest(&req); + Balau::Task::yield(&evt); + + return req; +} + +Balau::Socket::Socket() throw (GeneralException) : m_fd(socket(AF_INET6, SOCK_STREAM, 0)), m_connected(false), m_connecting(false), m_listening(false) { + m_name = "Socket(unconnected)"; + Assert(m_fd >= 0); + m_evtR = new SocketEvent(m_fd, EV_READ); + m_evtW = new SocketEvent(m_fd, EV_WRITE); + fcntl(m_fd, F_SETFL, O_NONBLOCK); + memset(&m_localAddr, 0, sizeof(m_localAddr)); + memset(&m_remoteAddr, 0, sizeof(m_remoteAddr)); + Printer::elog(E_SOCKET, "Creating a socket at %p", this); +} + +Balau::Socket::Socket(int fd) : m_fd(fd), m_connected(true), m_connecting(false), m_listening(false) { + socklen_t len; + + len = sizeof(m_localAddr); + getsockname(m_fd, (sockaddr *) &m_localAddr, &len); + + len = sizeof(m_remoteAddr); + getpeername(m_fd, (sockaddr *) &m_remoteAddr, &len); + + char prtLocal[INET6_ADDRSTRLEN], prtRemote[INET6_ADDRSTRLEN]; + const char * rLocal, * rRemote; + + len = sizeof(m_localAddr); + rLocal = inet_ntop(AF_INET6, &m_localAddr.sin6_addr, prtLocal, len); + rRemote = inet_ntop(AF_INET6, &m_remoteAddr.sin6_addr, prtRemote, len); + + Assert(rLocal); + Assert(rRemote); + + m_evtR = new SocketEvent(m_fd, EV_READ); + m_evtW = new SocketEvent(m_fd, EV_WRITE); + fcntl(m_fd, F_SETFL, O_NONBLOCK); + + m_name.set("Socket(Connected - [%s]:%i <- [%s]:%i)", rLocal, htons(m_localAddr.sin6_port), rRemote, htons(m_remoteAddr.sin6_port)); + Printer::elog(E_SOCKET, "Created a new socket from listener at %p; %s", this, m_name.to_charp()); +} + +void Balau::Socket::close() throw (GeneralException) { +#ifdef _WIN32 + closesocket(m_fd); +#else + ::close(m_fd); +#endif + Printer::elog(E_SOCKET, "Closing socket at %p", this); + m_connected = false; + m_connecting = false; + m_listening = false; + m_fd = -1; + delete m_evtR; + delete m_evtW; + m_evtR = m_evtW = NULL; +} + +bool Balau::Socket::isClosed() { return m_fd < 0; } +bool Balau::Socket::isEOF() { return isClosed(); } +bool Balau::Socket::canRead() { return true; } +bool Balau::Socket::canWrite() { return true; } +const char * Balau::Socket::getName() { return m_name.to_charp(); } + +bool Balau::Socket::setLocal(const char * hostname, int port) { + Assert(m_localAddr.sin6_family == 0); + + if (hostname && hostname[0]) { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET6; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + hints.ai_flags = AI_ADDRCONFIG | AI_V4MAPPED; + + DNSRequest req = resolveName(hostname, NULL, &hints); + struct addrinfo * res = req.res; + if (req.error != 0) { + freeaddrinfo(res); + return false; + } + if (!res) { + freeaddrinfo(res); + return false; + } + Assert(res->ai_family == AF_INET6); + Assert(res->ai_protocol == IPPROTO_TCP); + Assert(res->ai_addrlen == sizeof(sockaddr_in6)); + memcpy(&m_localAddr.sin6_addr, &((sockaddr_in6 *) res->ai_addr)->sin6_addr, sizeof(struct in6_addr)); + freeaddrinfo(res); + } else { + m_localAddr.sin6_addr = in6addr_any; + } + + if (port) + m_localAddr.sin6_port = htons(port); + + m_localAddr.sin6_family = AF_INET6; +#ifndef _WIN32 + int enable = 1; + setsockopt(m_fd, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)); +#endif + return bind(m_fd, (struct sockaddr *) &m_localAddr, sizeof(m_localAddr)) == 0; +} + +bool Balau::Socket::connect(const char * hostname, int port) { + Assert(!m_listening); + Assert(!m_connected); + Assert(hostname); + + if (!m_connecting) { + Printer::elog(E_SOCKET, "Resolving %s", hostname); + Assert(m_remoteAddr.sin6_family == 0); + + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET6; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + hints.ai_flags = AI_ADDRCONFIG | AI_V4MAPPED; + + DNSRequest req = resolveName(hostname, NULL, &hints); + struct addrinfo * res = req.res; + if (req.error != 0) { + freeaddrinfo(res); + return false; + } + if (!res) { + freeaddrinfo(res); + return false; + } + Printer::elog(E_SOCKET, "Got a resolution answer"); + Assert(res->ai_family == AF_INET6); + Assert(res->ai_protocol == IPPROTO_TCP); + Assert(res->ai_addrlen == sizeof(sockaddr_in6)); + memcpy(&m_remoteAddr.sin6_addr, &((sockaddr_in6 *) res->ai_addr)->sin6_addr, sizeof(struct in6_addr)); + + m_remoteAddr.sin6_port = htons(port); + m_remoteAddr.sin6_family = AF_INET6; + + m_connecting = true; + + freeaddrinfo(res); + } else { + // if we end up there, it means our yield earlier thrown a EAgain exception. + Assert(m_evtR->gotSignal()); + } + + int spins = 0; + + do { + Printer::elog(E_SOCKET, "Connecting now..."); + if (::connect(m_fd, (sockaddr *) &m_remoteAddr, sizeof(m_remoteAddr)) == 0) { + m_connected = true; + m_connecting = false; + + socklen_t len; + + len = sizeof(m_localAddr); + getsockname(m_fd, (sockaddr *) &m_localAddr, &len); + + len = sizeof(m_remoteAddr); + getpeername(m_fd, (sockaddr *) &m_remoteAddr, &len); + + char prtLocal[INET6_ADDRSTRLEN], prtRemote[INET6_ADDRSTRLEN]; + const char * rLocal, * rRemote; + + len = sizeof(m_localAddr); + rLocal = inet_ntop(AF_INET6, &m_localAddr.sin6_addr, prtLocal, len); + rRemote = inet_ntop(AF_INET6, &m_remoteAddr.sin6_addr, prtRemote, len); + + Assert(rLocal); + Assert(rRemote); + + m_name.set("Socket(Connected - [%s]:%i -> [%s]:%i)", rLocal, htons(m_localAddr.sin6_port), rRemote, htons(m_remoteAddr.sin6_port)); + Printer::elog(E_SOCKET, "Connected; %s", m_name.to_charp()); + + m_evtW->stop(); + return true; + } + +#ifdef _WIN32 + if (WSAGetLastError() != WSAEWOULDBLOCK) { +#else + if (errno != EINPROGRESS) { +#endif + Printer::elog(E_SOCKET, "Connect() failed with the following error code: %i (%s)", errno, strerror(errno)); + return false; + } else { + Assert(spins == 0); + } + + Task::yield(m_evtW, true); + // if we're still here, it means the parent task doesn't want to be thrown an exception + Assert(m_evtW->gotSignal()); + + } while (spins++ < 2); + + return false; +} + +bool Balau::Socket::listen() { + Assert(!m_listening); + Assert(!m_connecting); + Assert(!m_connected); + + if (::listen(m_fd, 16) == 0) { + m_listening = true; + + socklen_t len; + + len = sizeof(m_localAddr); + getsockname(m_fd, (sockaddr *) &m_localAddr, &len); + + char prtLocal[INET6_ADDRSTRLEN]; + const char * rLocal; + + len = sizeof(m_localAddr); + rLocal = inet_ntop(AF_INET6, &m_localAddr.sin6_addr, prtLocal, len); + + Assert(rLocal); + + m_name.set("Socket(Listener - [%s]:%i)", rLocal, htons(m_localAddr.sin6_port)); + } + + return m_listening; +} + +Balau::IO<Balau::Socket> Balau::Socket::accept() throw (GeneralException) { + Assert(m_listening); + Assert(m_fd >= 0); + + while(true) { + sockaddr_in6 remoteAddr; + socklen_t len; + int s = ::accept(m_fd, (sockaddr *) &remoteAddr, &len); + + if (s < 0) { + if ((errno == EAGAIN) || (errno == EINTR) || (errno == EWOULDBLOCK)) { + Task::yield(m_evtR, true); + } else { + throw GeneralException(String("Unexpected error accepting a connection: #") + errno + "(" + strerror(errno) + ")"); + } + } else { + Printer::elog(E_SOCKET, "Listener at %p got a new connection", this); + m_evtR->stop(); + return IO<Socket>(new Socket(s)); + } + } +} + +ssize_t Balau::Socket::read(void * buf, size_t count) throw (GeneralException) { + if (count == 0) + return 0; + + Assert(m_connected); + Assert(m_fd >= 0); + + int spins = 0; + + do { + ssize_t r = ::recv(m_fd, (char *) buf, count, 0); + + if (r >= 0) { + if (r == 0) + close(); + return r; + } + + if ((errno == EAGAIN) || (errno == EINTR) || (errno == EWOULDBLOCK)) { + Task::yield(m_evtR, true); + } else { + m_evtR->stop(); + return r; + } + } while (spins++ < 2); + + return -1; +} + +ssize_t Balau::Socket::write(const void * buf, size_t count) throw (GeneralException) { + if (count == 0) + return 0; + + Assert(m_connected); + Assert(m_fd >= 0); + + int spins = 0; + + do { + ssize_t r = ::send(m_fd, (const char *) buf, count, 0); + + Assert(r != 0); + + if (r > 0) + return r; + + if ((errno == EAGAIN) || (errno == EINTR) || (errno == EWOULDBLOCK)) { + Task::yield(m_evtW, true); + } else { + m_evtW->stop(); + return r; + } + } while (spins++ < 2); + + return -1; +} |