diff options
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | includes/Socket.h | 90 | ||||
-rw-r--r-- | src/Socket.cc | 400 | ||||
-rw-r--r-- | tests/test-Sockets.cc | 86 |
4 files changed, 578 insertions, 0 deletions
@@ -117,6 +117,7 @@ Printer.cc \ \ Handle.cc \ Input.cc \ +Socket.cc \ \ Task.cc \ TaskMan.cc \ @@ -152,6 +153,7 @@ test-String.cc \ test-Tasks.cc \ test-Threads.cc \ test-Handles.cc \ +test-Sockets.cc \ LIB = libBalau.a diff --git a/includes/Socket.h b/includes/Socket.h new file mode 100644 index 0000000..1f20fbb --- /dev/null +++ b/includes/Socket.h @@ -0,0 +1,90 @@ +#pragma once + +#include <netdb.h> +#include <Handle.h> +#include <Task.h> +#include <Printer.h> + +namespace Balau { + +class Socket : public Handle { + public: + + Socket() throw (GeneralException); + virtual void close() throw (GeneralException); + virtual ssize_t read(void * buf, size_t count) throw (GeneralException); + virtual ssize_t write(const void * buf, size_t count) throw (GeneralException); + virtual bool isClosed(); + virtual bool isEOF(); + virtual bool canRead(); + virtual bool canWrite(); + virtual const char * getName(); + + bool setLocal(const char * hostname = NULL, int port = 0); + bool connect(const char * hostname, int port); + IO<Socket> accept() throw (GeneralException); + bool listen(); + private: + Socket(int fd); + class SocketEvent : public Events::BaseEvent { + public: + SocketEvent(int fd, int evt = EV_READ | EV_WRITE) : m_task(NULL) { Printer::elog(E_SOCKET, "Got a new SocketEvent at %p", this); m_evt.set<SocketEvent, &SocketEvent::evt_cb>(this); m_evt.set(fd, evt); } + virtual ~SocketEvent() { m_evt.stop(); } + void stop() { reset(); m_evt.stop(); } + private: + void evt_cb(ev::io & w, int revents) { Printer::elog(E_SOCKET, "Got a libev callback on a SocketEvent at %p", this); doSignal(); } + virtual void gotOwner(Task * task); + + ev::io m_evt; + Task * m_task; + }; + + int m_fd; + String m_name; + bool m_connected; + bool m_connecting; + bool m_listening; + sockaddr_in6 m_localAddr, m_remoteAddr; + SocketEvent * m_evtR, * m_evtW; +}; + +template<class Worker> +class Listener : public Task { + public: + Listener(int port, const char * local = NULL) : m_stop(false) { + m_listener.setLocal(local, port); + m_listener.listen(); + m_name = String(ClassName(this).c_str()) + " - " + m_listener.getName(); + Printer::elog(E_SOCKET, "Created a listener task at %p", this); + } + virtual void Do() { + waitFor(&m_evt); + setOkayToEAgain(true); + while (!m_stop) { + IO<Socket> io; + try { + io = m_listener.accept(); + } + catch (EAgain) { + Printer::elog(E_SOCKET, "Listener task at %p (%s) got an EAgain - stop = %s", this, ClassName(this).c_str(), m_stop ? "true" : "false"); + if (!m_stop) + yield(); + continue; + } + new Worker(io); + } + } + void stop() { + Printer::elog(E_SOCKET, "Listener task at %p (%s) is asked to stop.", this, ClassName(this).c_str()); + m_stop = true; + m_evt.trigger(); + } + virtual const char * getName() { return m_name.to_charp(); } + private: + Socket m_listener; + Events::Async m_evt; + volatile bool m_stop; + String m_name; +}; + +}; diff --git a/src/Socket.cc b/src/Socket.cc new file mode 100644 index 0000000..762aae3 --- /dev/null +++ b/src/Socket.cc @@ -0,0 +1,400 @@ +#include <arpa/inet.h> +#include <sys/types.h> +#include <sys/socket.h> +#include <netdb.h> +#include <unistd.h> +#include <fcntl.h> +#include <errno.h> +#include "Socket.h" +#include "Threads.h" +#include "Printer.h" +#include "Main.h" + +void Balau::Socket::SocketEvent::gotOwner(Task * task) { + Printer::elog(E_SOCKET, "Arming SocketEvent at %p", this); + if (!m_task) { + Printer::elog(E_SOCKET, "...with a new task (%p)", task); + } else if (task == m_task) { + m_evt.start(); + return; + } else { + Printer::elog(E_SOCKET, "...with a new task (%p -> %p); stopping first", m_task, task); + m_evt.stop(); + } + m_task = task; + m_evt.set(task->getLoop()); + m_evt.start(); +} + +struct DNSRequest { + const char * name; + const char * service; + struct addrinfo * res; + struct addrinfo * hints; + Balau::Events::Async * evt; + int error; +}; + +#if 0 +// TODO: use getaddrinfo_a, if available. +#else +class ResolverThread : public Balau::Thread, public Balau::AtStart { + public: + ResolverThread() : AtStart(8) { } + virtual ~ResolverThread(); + void pushRequest(DNSRequest * req) { m_queue.push(req); } + private: + virtual void * proc(); + virtual void doStart(); + Balau::Queue<DNSRequest *> m_queue; +}; + +void ResolverThread::doStart() { + threadStart(); +} + +ResolverThread::~ResolverThread() { + DNSRequest req; + memset(&req, 0, sizeof(req)); + pushRequest(&req); +} + +void * ResolverThread::proc() { + DNSRequest * req; + DNSRequest stop; + memset(&stop, 0, sizeof(stop)); + while (true) { + req = m_queue.pop(); + if (memcmp(&stop, req, sizeof(stop)) == 0) + break; + Balau::Printer::elog(Balau::E_SOCKET, "Resolver thread got a request for `%s'", req->name); + req->error = getaddrinfo(req->name, req->service, req->hints, &req->res); + Balau::Printer::elog(Balau::E_SOCKET, "Resolver thread got an answer; sending signal"); + req->evt->trigger(); + } + return NULL; +} +#endif + +static ResolverThread resolverThread; + +static DNSRequest resolveName(const char * name, const char * service = NULL, struct addrinfo * hints = NULL) { + Balau::Events::Async evt; + DNSRequest req; + memset(&req, 0, sizeof(req)); + + req.name = name; + req.service = service; + req.hints = hints; + req.evt = &evt; + Balau::Printer::elog(Balau::E_SOCKET, "Sending a request to the resolver thread"); + Balau::Task::prepare(&evt); + resolverThread.pushRequest(&req); + Balau::Task::yield(&evt); + + return req; +} + +Balau::Socket::Socket() throw (GeneralException) : m_fd(socket(AF_INET6, SOCK_STREAM, 0)), m_connected(false), m_connecting(false), m_listening(false) { + m_name = "Socket(unconnected)"; + Assert(m_fd >= 0); + m_evtR = new SocketEvent(m_fd, EV_READ); + m_evtW = new SocketEvent(m_fd, EV_WRITE); + fcntl(m_fd, F_SETFL, O_NONBLOCK); + memset(&m_localAddr, 0, sizeof(m_localAddr)); + memset(&m_remoteAddr, 0, sizeof(m_remoteAddr)); + Printer::elog(E_SOCKET, "Creating a socket at %p", this); +} + +Balau::Socket::Socket(int fd) : m_fd(fd), m_connected(true), m_connecting(false), m_listening(false) { + socklen_t len; + + len = sizeof(m_localAddr); + getsockname(m_fd, (sockaddr *) &m_localAddr, &len); + + len = sizeof(m_remoteAddr); + getpeername(m_fd, (sockaddr *) &m_remoteAddr, &len); + + char prtLocal[INET6_ADDRSTRLEN], prtRemote[INET6_ADDRSTRLEN]; + const char * rLocal, * rRemote; + + len = sizeof(m_localAddr); + rLocal = inet_ntop(AF_INET6, &m_localAddr.sin6_addr, prtLocal, len); + rRemote = inet_ntop(AF_INET6, &m_remoteAddr.sin6_addr, prtRemote, len); + + Assert(rLocal); + Assert(rRemote); + + m_evtR = new SocketEvent(m_fd, EV_READ); + m_evtW = new SocketEvent(m_fd, EV_WRITE); + fcntl(m_fd, F_SETFL, O_NONBLOCK); + + m_name.set("Socket(Connected - [%s]:%i <- [%s]:%i)", rLocal, htons(m_localAddr.sin6_port), rRemote, htons(m_remoteAddr.sin6_port)); + Printer::elog(E_SOCKET, "Created a new socket from listener at %p; %s", this, m_name.to_charp()); +} + +void Balau::Socket::close() throw (GeneralException) { +#ifdef _WIN32 + closesocket(m_fd); +#else + ::close(m_fd); +#endif + Printer::elog(E_SOCKET, "Closing socket at %p", this); + m_connected = false; + m_connecting = false; + m_listening = false; + m_fd = -1; + delete m_evtR; + delete m_evtW; + m_evtR = m_evtW = NULL; +} + +bool Balau::Socket::isClosed() { return m_fd < 0; } +bool Balau::Socket::isEOF() { return isClosed(); } +bool Balau::Socket::canRead() { return true; } +bool Balau::Socket::canWrite() { return true; } +const char * Balau::Socket::getName() { return m_name.to_charp(); } + +bool Balau::Socket::setLocal(const char * hostname, int port) { + Assert(m_localAddr.sin6_family == 0); + + if (hostname && hostname[0]) { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET6; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + hints.ai_flags = AI_ADDRCONFIG | AI_V4MAPPED; + + DNSRequest req = resolveName(hostname, NULL, &hints); + struct addrinfo * res = req.res; + if (req.error != 0) { + freeaddrinfo(res); + return false; + } + if (!res) { + freeaddrinfo(res); + return false; + } + Assert(res->ai_family == AF_INET6); + Assert(res->ai_protocol == IPPROTO_TCP); + Assert(res->ai_addrlen == sizeof(sockaddr_in6)); + memcpy(&m_localAddr.sin6_addr, &((sockaddr_in6 *) res->ai_addr)->sin6_addr, sizeof(struct in6_addr)); + freeaddrinfo(res); + } else { + m_localAddr.sin6_addr = in6addr_any; + } + + if (port) + m_localAddr.sin6_port = htons(port); + + m_localAddr.sin6_family = AF_INET6; +#ifndef _WIN32 + int enable = 1; + setsockopt(m_fd, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)); +#endif + return bind(m_fd, (struct sockaddr *) &m_localAddr, sizeof(m_localAddr)) == 0; +} + +bool Balau::Socket::connect(const char * hostname, int port) { + Assert(!m_listening); + Assert(!m_connected); + Assert(hostname); + + if (!m_connecting) { + Printer::elog(E_SOCKET, "Resolving %s", hostname); + Assert(m_remoteAddr.sin6_family == 0); + + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET6; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + hints.ai_flags = AI_ADDRCONFIG | AI_V4MAPPED; + + DNSRequest req = resolveName(hostname, NULL, &hints); + struct addrinfo * res = req.res; + if (req.error != 0) { + freeaddrinfo(res); + return false; + } + if (!res) { + freeaddrinfo(res); + return false; + } + Printer::elog(E_SOCKET, "Got a resolution answer"); + Assert(res->ai_family == AF_INET6); + Assert(res->ai_protocol == IPPROTO_TCP); + Assert(res->ai_addrlen == sizeof(sockaddr_in6)); + memcpy(&m_remoteAddr.sin6_addr, &((sockaddr_in6 *) res->ai_addr)->sin6_addr, sizeof(struct in6_addr)); + + m_remoteAddr.sin6_port = htons(port); + m_remoteAddr.sin6_family = AF_INET6; + + m_connecting = true; + + freeaddrinfo(res); + } else { + // if we end up there, it means our yield earlier thrown a EAgain exception. + Assert(m_evtR->gotSignal()); + } + + int spins = 0; + + do { + Printer::elog(E_SOCKET, "Connecting now..."); + if (::connect(m_fd, (sockaddr *) &m_remoteAddr, sizeof(m_remoteAddr)) == 0) { + m_connected = true; + m_connecting = false; + + socklen_t len; + + len = sizeof(m_localAddr); + getsockname(m_fd, (sockaddr *) &m_localAddr, &len); + + len = sizeof(m_remoteAddr); + getpeername(m_fd, (sockaddr *) &m_remoteAddr, &len); + + char prtLocal[INET6_ADDRSTRLEN], prtRemote[INET6_ADDRSTRLEN]; + const char * rLocal, * rRemote; + + len = sizeof(m_localAddr); + rLocal = inet_ntop(AF_INET6, &m_localAddr.sin6_addr, prtLocal, len); + rRemote = inet_ntop(AF_INET6, &m_remoteAddr.sin6_addr, prtRemote, len); + + Assert(rLocal); + Assert(rRemote); + + m_name.set("Socket(Connected - [%s]:%i -> [%s]:%i)", rLocal, htons(m_localAddr.sin6_port), rRemote, htons(m_remoteAddr.sin6_port)); + Printer::elog(E_SOCKET, "Connected; %s", m_name.to_charp()); + + m_evtW->stop(); + return true; + } + +#ifdef _WIN32 + if (WSAGetLastError() != WSAEWOULDBLOCK) { +#else + if (errno != EINPROGRESS) { +#endif + Printer::elog(E_SOCKET, "Connect() failed with the following error code: %i (%s)", errno, strerror(errno)); + return false; + } else { + Assert(spins == 0); + } + + Task::yield(m_evtW, true); + // if we're still here, it means the parent task doesn't want to be thrown an exception + Assert(m_evtW->gotSignal()); + + } while (spins++ < 2); + + return false; +} + +bool Balau::Socket::listen() { + Assert(!m_listening); + Assert(!m_connecting); + Assert(!m_connected); + + if (::listen(m_fd, 16) == 0) { + m_listening = true; + + socklen_t len; + + len = sizeof(m_localAddr); + getsockname(m_fd, (sockaddr *) &m_localAddr, &len); + + char prtLocal[INET6_ADDRSTRLEN]; + const char * rLocal; + + len = sizeof(m_localAddr); + rLocal = inet_ntop(AF_INET6, &m_localAddr.sin6_addr, prtLocal, len); + + Assert(rLocal); + + m_name.set("Socket(Listener - [%s]:%i)", rLocal, htons(m_localAddr.sin6_port)); + } + + return m_listening; +} + +Balau::IO<Balau::Socket> Balau::Socket::accept() throw (GeneralException) { + Assert(m_listening); + Assert(m_fd >= 0); + + while(true) { + sockaddr_in6 remoteAddr; + socklen_t len; + int s = ::accept(m_fd, (sockaddr *) &remoteAddr, &len); + + if (s < 0) { + if ((errno == EAGAIN) || (errno == EINTR) || (errno == EWOULDBLOCK)) { + Task::yield(m_evtR, true); + } else { + throw GeneralException(String("Unexpected error accepting a connection: #") + errno + "(" + strerror(errno) + ")"); + } + } else { + Printer::elog(E_SOCKET, "Listener at %p got a new connection", this); + m_evtR->stop(); + return IO<Socket>(new Socket(s)); + } + } +} + +ssize_t Balau::Socket::read(void * buf, size_t count) throw (GeneralException) { + if (count == 0) + return 0; + + Assert(m_connected); + Assert(m_fd >= 0); + + int spins = 0; + + do { + ssize_t r = ::recv(m_fd, (char *) buf, count, 0); + + if (r >= 0) { + if (r == 0) + close(); + return r; + } + + if ((errno == EAGAIN) || (errno == EINTR) || (errno == EWOULDBLOCK)) { + Task::yield(m_evtR, true); + } else { + m_evtR->stop(); + return r; + } + } while (spins++ < 2); + + return -1; +} + +ssize_t Balau::Socket::write(const void * buf, size_t count) throw (GeneralException) { + if (count == 0) + return 0; + + Assert(m_connected); + Assert(m_fd >= 0); + + int spins = 0; + + do { + ssize_t r = ::send(m_fd, (const char *) buf, count, 0); + + Assert(r != 0); + + if (r > 0) + return r; + + if ((errno == EAGAIN) || (errno == EINTR) || (errno == EWOULDBLOCK)) { + Task::yield(m_evtW, true); + } else { + m_evtW->stop(); + return r; + } + } while (spins++ < 2); + + return -1; +} diff --git a/tests/test-Sockets.cc b/tests/test-Sockets.cc new file mode 100644 index 0000000..fd6eaa0 --- /dev/null +++ b/tests/test-Sockets.cc @@ -0,0 +1,86 @@ +#include <Main.h> +#include <Socket.h> + +BALAU_STARTUP; + +using namespace Balau; + +class Worker : public Task { + public: + Worker(IO<Socket> io); + virtual const char * getName(); + virtual void Do(); + IO<Socket> m_io; + String m_name; +}; + +Worker::Worker(IO<Socket> io) : m_io(io) { + m_name = m_io->getName(); + Printer::log(M_STATUS, "Got connection: %s", m_name.to_charp()); +} + +const char * Worker::getName() { + return m_name.to_charp(); +} + +void Worker::Do() { + char x, y; + + int r; + r = m_io->read(&x, 1); + Assert(x == 'x'); + Assert(r == 1); + y = 'y'; + r = m_io->write(&y, 1); + Assert(r == 1); +} + +Listener<Worker> * listener; + +class Client : public Task { + public: + virtual const char * getName() { return "Test client"; } + virtual void Do() { + Events::Timeout evt(0.1); + waitFor(&evt); + yield(); + + char x, y; + IO<Socket> s(new Socket()); + bool c = s->connect("localhost", 1234); + Assert(c); + x = 'x'; + int r; + r = s->write(&x, 1); + Assert(r == 1); + r = s->read(&y, 1); + Assert(y == 'y'); + Assert(r == 1); + listener->stop(); + } +}; + +void MainTask::Do() { + Printer::enable(M_ALL); + Printer::log(M_STATUS, "Test::Sockets running."); + + Events::TaskEvent evtSvr(listener = new Listener<Worker>(1234)); + Events::TaskEvent evtCln(new Client); + Printer::log(M_STATUS, "Created %s", listener->getName()); + waitFor(&evtSvr); + waitFor(&evtCln); + bool svrDone = false, clnDone = false; + while (!svrDone || !clnDone) { + yield(); + if (evtSvr.gotSignal()) { + evtSvr.ack(); + svrDone = true; + } + if (evtCln.gotSignal()) { + evtCln.ack(); + clnDone = true; + } + } + + Printer::log(M_STATUS, "Test::Sockets passed."); +} |