summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile2
-rw-r--r--includes/Socket.h90
-rw-r--r--src/Socket.cc400
-rw-r--r--tests/test-Sockets.cc86
4 files changed, 578 insertions, 0 deletions
diff --git a/Makefile b/Makefile
index be4afec..498f331 100644
--- a/Makefile
+++ b/Makefile
@@ -117,6 +117,7 @@ Printer.cc \
\
Handle.cc \
Input.cc \
+Socket.cc \
\
Task.cc \
TaskMan.cc \
@@ -152,6 +153,7 @@ test-String.cc \
test-Tasks.cc \
test-Threads.cc \
test-Handles.cc \
+test-Sockets.cc \
LIB = libBalau.a
diff --git a/includes/Socket.h b/includes/Socket.h
new file mode 100644
index 0000000..1f20fbb
--- /dev/null
+++ b/includes/Socket.h
@@ -0,0 +1,90 @@
+#pragma once
+
+#include <netdb.h>
+#include <Handle.h>
+#include <Task.h>
+#include <Printer.h>
+
+namespace Balau {
+
+class Socket : public Handle {
+ public:
+
+ Socket() throw (GeneralException);
+ virtual void close() throw (GeneralException);
+ virtual ssize_t read(void * buf, size_t count) throw (GeneralException);
+ virtual ssize_t write(const void * buf, size_t count) throw (GeneralException);
+ virtual bool isClosed();
+ virtual bool isEOF();
+ virtual bool canRead();
+ virtual bool canWrite();
+ virtual const char * getName();
+
+ bool setLocal(const char * hostname = NULL, int port = 0);
+ bool connect(const char * hostname, int port);
+ IO<Socket> accept() throw (GeneralException);
+ bool listen();
+ private:
+ Socket(int fd);
+ class SocketEvent : public Events::BaseEvent {
+ public:
+ SocketEvent(int fd, int evt = EV_READ | EV_WRITE) : m_task(NULL) { Printer::elog(E_SOCKET, "Got a new SocketEvent at %p", this); m_evt.set<SocketEvent, &SocketEvent::evt_cb>(this); m_evt.set(fd, evt); }
+ virtual ~SocketEvent() { m_evt.stop(); }
+ void stop() { reset(); m_evt.stop(); }
+ private:
+ void evt_cb(ev::io & w, int revents) { Printer::elog(E_SOCKET, "Got a libev callback on a SocketEvent at %p", this); doSignal(); }
+ virtual void gotOwner(Task * task);
+
+ ev::io m_evt;
+ Task * m_task;
+ };
+
+ int m_fd;
+ String m_name;
+ bool m_connected;
+ bool m_connecting;
+ bool m_listening;
+ sockaddr_in6 m_localAddr, m_remoteAddr;
+ SocketEvent * m_evtR, * m_evtW;
+};
+
+template<class Worker>
+class Listener : public Task {
+ public:
+ Listener(int port, const char * local = NULL) : m_stop(false) {
+ m_listener.setLocal(local, port);
+ m_listener.listen();
+ m_name = String(ClassName(this).c_str()) + " - " + m_listener.getName();
+ Printer::elog(E_SOCKET, "Created a listener task at %p", this);
+ }
+ virtual void Do() {
+ waitFor(&m_evt);
+ setOkayToEAgain(true);
+ while (!m_stop) {
+ IO<Socket> io;
+ try {
+ io = m_listener.accept();
+ }
+ catch (EAgain) {
+ Printer::elog(E_SOCKET, "Listener task at %p (%s) got an EAgain - stop = %s", this, ClassName(this).c_str(), m_stop ? "true" : "false");
+ if (!m_stop)
+ yield();
+ continue;
+ }
+ new Worker(io);
+ }
+ }
+ void stop() {
+ Printer::elog(E_SOCKET, "Listener task at %p (%s) is asked to stop.", this, ClassName(this).c_str());
+ m_stop = true;
+ m_evt.trigger();
+ }
+ virtual const char * getName() { return m_name.to_charp(); }
+ private:
+ Socket m_listener;
+ Events::Async m_evt;
+ volatile bool m_stop;
+ String m_name;
+};
+
+};
diff --git a/src/Socket.cc b/src/Socket.cc
new file mode 100644
index 0000000..762aae3
--- /dev/null
+++ b/src/Socket.cc
@@ -0,0 +1,400 @@
+#include <arpa/inet.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <fcntl.h>
+#include <errno.h>
+#include "Socket.h"
+#include "Threads.h"
+#include "Printer.h"
+#include "Main.h"
+
+void Balau::Socket::SocketEvent::gotOwner(Task * task) {
+ Printer::elog(E_SOCKET, "Arming SocketEvent at %p", this);
+ if (!m_task) {
+ Printer::elog(E_SOCKET, "...with a new task (%p)", task);
+ } else if (task == m_task) {
+ m_evt.start();
+ return;
+ } else {
+ Printer::elog(E_SOCKET, "...with a new task (%p -> %p); stopping first", m_task, task);
+ m_evt.stop();
+ }
+ m_task = task;
+ m_evt.set(task->getLoop());
+ m_evt.start();
+}
+
+struct DNSRequest {
+ const char * name;
+ const char * service;
+ struct addrinfo * res;
+ struct addrinfo * hints;
+ Balau::Events::Async * evt;
+ int error;
+};
+
+#if 0
+// TODO: use getaddrinfo_a, if available.
+#else
+class ResolverThread : public Balau::Thread, public Balau::AtStart {
+ public:
+ ResolverThread() : AtStart(8) { }
+ virtual ~ResolverThread();
+ void pushRequest(DNSRequest * req) { m_queue.push(req); }
+ private:
+ virtual void * proc();
+ virtual void doStart();
+ Balau::Queue<DNSRequest *> m_queue;
+};
+
+void ResolverThread::doStart() {
+ threadStart();
+}
+
+ResolverThread::~ResolverThread() {
+ DNSRequest req;
+ memset(&req, 0, sizeof(req));
+ pushRequest(&req);
+}
+
+void * ResolverThread::proc() {
+ DNSRequest * req;
+ DNSRequest stop;
+ memset(&stop, 0, sizeof(stop));
+ while (true) {
+ req = m_queue.pop();
+ if (memcmp(&stop, req, sizeof(stop)) == 0)
+ break;
+ Balau::Printer::elog(Balau::E_SOCKET, "Resolver thread got a request for `%s'", req->name);
+ req->error = getaddrinfo(req->name, req->service, req->hints, &req->res);
+ Balau::Printer::elog(Balau::E_SOCKET, "Resolver thread got an answer; sending signal");
+ req->evt->trigger();
+ }
+ return NULL;
+}
+#endif
+
+static ResolverThread resolverThread;
+
+static DNSRequest resolveName(const char * name, const char * service = NULL, struct addrinfo * hints = NULL) {
+ Balau::Events::Async evt;
+ DNSRequest req;
+ memset(&req, 0, sizeof(req));
+
+ req.name = name;
+ req.service = service;
+ req.hints = hints;
+ req.evt = &evt;
+ Balau::Printer::elog(Balau::E_SOCKET, "Sending a request to the resolver thread");
+ Balau::Task::prepare(&evt);
+ resolverThread.pushRequest(&req);
+ Balau::Task::yield(&evt);
+
+ return req;
+}
+
+Balau::Socket::Socket() throw (GeneralException) : m_fd(socket(AF_INET6, SOCK_STREAM, 0)), m_connected(false), m_connecting(false), m_listening(false) {
+ m_name = "Socket(unconnected)";
+ Assert(m_fd >= 0);
+ m_evtR = new SocketEvent(m_fd, EV_READ);
+ m_evtW = new SocketEvent(m_fd, EV_WRITE);
+ fcntl(m_fd, F_SETFL, O_NONBLOCK);
+ memset(&m_localAddr, 0, sizeof(m_localAddr));
+ memset(&m_remoteAddr, 0, sizeof(m_remoteAddr));
+ Printer::elog(E_SOCKET, "Creating a socket at %p", this);
+}
+
+Balau::Socket::Socket(int fd) : m_fd(fd), m_connected(true), m_connecting(false), m_listening(false) {
+ socklen_t len;
+
+ len = sizeof(m_localAddr);
+ getsockname(m_fd, (sockaddr *) &m_localAddr, &len);
+
+ len = sizeof(m_remoteAddr);
+ getpeername(m_fd, (sockaddr *) &m_remoteAddr, &len);
+
+ char prtLocal[INET6_ADDRSTRLEN], prtRemote[INET6_ADDRSTRLEN];
+ const char * rLocal, * rRemote;
+
+ len = sizeof(m_localAddr);
+ rLocal = inet_ntop(AF_INET6, &m_localAddr.sin6_addr, prtLocal, len);
+ rRemote = inet_ntop(AF_INET6, &m_remoteAddr.sin6_addr, prtRemote, len);
+
+ Assert(rLocal);
+ Assert(rRemote);
+
+ m_evtR = new SocketEvent(m_fd, EV_READ);
+ m_evtW = new SocketEvent(m_fd, EV_WRITE);
+ fcntl(m_fd, F_SETFL, O_NONBLOCK);
+
+ m_name.set("Socket(Connected - [%s]:%i <- [%s]:%i)", rLocal, htons(m_localAddr.sin6_port), rRemote, htons(m_remoteAddr.sin6_port));
+ Printer::elog(E_SOCKET, "Created a new socket from listener at %p; %s", this, m_name.to_charp());
+}
+
+void Balau::Socket::close() throw (GeneralException) {
+#ifdef _WIN32
+ closesocket(m_fd);
+#else
+ ::close(m_fd);
+#endif
+ Printer::elog(E_SOCKET, "Closing socket at %p", this);
+ m_connected = false;
+ m_connecting = false;
+ m_listening = false;
+ m_fd = -1;
+ delete m_evtR;
+ delete m_evtW;
+ m_evtR = m_evtW = NULL;
+}
+
+bool Balau::Socket::isClosed() { return m_fd < 0; }
+bool Balau::Socket::isEOF() { return isClosed(); }
+bool Balau::Socket::canRead() { return true; }
+bool Balau::Socket::canWrite() { return true; }
+const char * Balau::Socket::getName() { return m_name.to_charp(); }
+
+bool Balau::Socket::setLocal(const char * hostname, int port) {
+ Assert(m_localAddr.sin6_family == 0);
+
+ if (hostname && hostname[0]) {
+ struct addrinfo hints;
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_INET6;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = IPPROTO_TCP;
+ hints.ai_flags = AI_ADDRCONFIG | AI_V4MAPPED;
+
+ DNSRequest req = resolveName(hostname, NULL, &hints);
+ struct addrinfo * res = req.res;
+ if (req.error != 0) {
+ freeaddrinfo(res);
+ return false;
+ }
+ if (!res) {
+ freeaddrinfo(res);
+ return false;
+ }
+ Assert(res->ai_family == AF_INET6);
+ Assert(res->ai_protocol == IPPROTO_TCP);
+ Assert(res->ai_addrlen == sizeof(sockaddr_in6));
+ memcpy(&m_localAddr.sin6_addr, &((sockaddr_in6 *) res->ai_addr)->sin6_addr, sizeof(struct in6_addr));
+ freeaddrinfo(res);
+ } else {
+ m_localAddr.sin6_addr = in6addr_any;
+ }
+
+ if (port)
+ m_localAddr.sin6_port = htons(port);
+
+ m_localAddr.sin6_family = AF_INET6;
+#ifndef _WIN32
+ int enable = 1;
+ setsockopt(m_fd, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable));
+#endif
+ return bind(m_fd, (struct sockaddr *) &m_localAddr, sizeof(m_localAddr)) == 0;
+}
+
+bool Balau::Socket::connect(const char * hostname, int port) {
+ Assert(!m_listening);
+ Assert(!m_connected);
+ Assert(hostname);
+
+ if (!m_connecting) {
+ Printer::elog(E_SOCKET, "Resolving %s", hostname);
+ Assert(m_remoteAddr.sin6_family == 0);
+
+ struct addrinfo hints;
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_INET6;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = IPPROTO_TCP;
+ hints.ai_flags = AI_ADDRCONFIG | AI_V4MAPPED;
+
+ DNSRequest req = resolveName(hostname, NULL, &hints);
+ struct addrinfo * res = req.res;
+ if (req.error != 0) {
+ freeaddrinfo(res);
+ return false;
+ }
+ if (!res) {
+ freeaddrinfo(res);
+ return false;
+ }
+ Printer::elog(E_SOCKET, "Got a resolution answer");
+ Assert(res->ai_family == AF_INET6);
+ Assert(res->ai_protocol == IPPROTO_TCP);
+ Assert(res->ai_addrlen == sizeof(sockaddr_in6));
+ memcpy(&m_remoteAddr.sin6_addr, &((sockaddr_in6 *) res->ai_addr)->sin6_addr, sizeof(struct in6_addr));
+
+ m_remoteAddr.sin6_port = htons(port);
+ m_remoteAddr.sin6_family = AF_INET6;
+
+ m_connecting = true;
+
+ freeaddrinfo(res);
+ } else {
+ // if we end up there, it means our yield earlier thrown a EAgain exception.
+ Assert(m_evtR->gotSignal());
+ }
+
+ int spins = 0;
+
+ do {
+ Printer::elog(E_SOCKET, "Connecting now...");
+ if (::connect(m_fd, (sockaddr *) &m_remoteAddr, sizeof(m_remoteAddr)) == 0) {
+ m_connected = true;
+ m_connecting = false;
+
+ socklen_t len;
+
+ len = sizeof(m_localAddr);
+ getsockname(m_fd, (sockaddr *) &m_localAddr, &len);
+
+ len = sizeof(m_remoteAddr);
+ getpeername(m_fd, (sockaddr *) &m_remoteAddr, &len);
+
+ char prtLocal[INET6_ADDRSTRLEN], prtRemote[INET6_ADDRSTRLEN];
+ const char * rLocal, * rRemote;
+
+ len = sizeof(m_localAddr);
+ rLocal = inet_ntop(AF_INET6, &m_localAddr.sin6_addr, prtLocal, len);
+ rRemote = inet_ntop(AF_INET6, &m_remoteAddr.sin6_addr, prtRemote, len);
+
+ Assert(rLocal);
+ Assert(rRemote);
+
+ m_name.set("Socket(Connected - [%s]:%i -> [%s]:%i)", rLocal, htons(m_localAddr.sin6_port), rRemote, htons(m_remoteAddr.sin6_port));
+ Printer::elog(E_SOCKET, "Connected; %s", m_name.to_charp());
+
+ m_evtW->stop();
+ return true;
+ }
+
+#ifdef _WIN32
+ if (WSAGetLastError() != WSAEWOULDBLOCK) {
+#else
+ if (errno != EINPROGRESS) {
+#endif
+ Printer::elog(E_SOCKET, "Connect() failed with the following error code: %i (%s)", errno, strerror(errno));
+ return false;
+ } else {
+ Assert(spins == 0);
+ }
+
+ Task::yield(m_evtW, true);
+ // if we're still here, it means the parent task doesn't want to be thrown an exception
+ Assert(m_evtW->gotSignal());
+
+ } while (spins++ < 2);
+
+ return false;
+}
+
+bool Balau::Socket::listen() {
+ Assert(!m_listening);
+ Assert(!m_connecting);
+ Assert(!m_connected);
+
+ if (::listen(m_fd, 16) == 0) {
+ m_listening = true;
+
+ socklen_t len;
+
+ len = sizeof(m_localAddr);
+ getsockname(m_fd, (sockaddr *) &m_localAddr, &len);
+
+ char prtLocal[INET6_ADDRSTRLEN];
+ const char * rLocal;
+
+ len = sizeof(m_localAddr);
+ rLocal = inet_ntop(AF_INET6, &m_localAddr.sin6_addr, prtLocal, len);
+
+ Assert(rLocal);
+
+ m_name.set("Socket(Listener - [%s]:%i)", rLocal, htons(m_localAddr.sin6_port));
+ }
+
+ return m_listening;
+}
+
+Balau::IO<Balau::Socket> Balau::Socket::accept() throw (GeneralException) {
+ Assert(m_listening);
+ Assert(m_fd >= 0);
+
+ while(true) {
+ sockaddr_in6 remoteAddr;
+ socklen_t len;
+ int s = ::accept(m_fd, (sockaddr *) &remoteAddr, &len);
+
+ if (s < 0) {
+ if ((errno == EAGAIN) || (errno == EINTR) || (errno == EWOULDBLOCK)) {
+ Task::yield(m_evtR, true);
+ } else {
+ throw GeneralException(String("Unexpected error accepting a connection: #") + errno + "(" + strerror(errno) + ")");
+ }
+ } else {
+ Printer::elog(E_SOCKET, "Listener at %p got a new connection", this);
+ m_evtR->stop();
+ return IO<Socket>(new Socket(s));
+ }
+ }
+}
+
+ssize_t Balau::Socket::read(void * buf, size_t count) throw (GeneralException) {
+ if (count == 0)
+ return 0;
+
+ Assert(m_connected);
+ Assert(m_fd >= 0);
+
+ int spins = 0;
+
+ do {
+ ssize_t r = ::recv(m_fd, (char *) buf, count, 0);
+
+ if (r >= 0) {
+ if (r == 0)
+ close();
+ return r;
+ }
+
+ if ((errno == EAGAIN) || (errno == EINTR) || (errno == EWOULDBLOCK)) {
+ Task::yield(m_evtR, true);
+ } else {
+ m_evtR->stop();
+ return r;
+ }
+ } while (spins++ < 2);
+
+ return -1;
+}
+
+ssize_t Balau::Socket::write(const void * buf, size_t count) throw (GeneralException) {
+ if (count == 0)
+ return 0;
+
+ Assert(m_connected);
+ Assert(m_fd >= 0);
+
+ int spins = 0;
+
+ do {
+ ssize_t r = ::send(m_fd, (const char *) buf, count, 0);
+
+ Assert(r != 0);
+
+ if (r > 0)
+ return r;
+
+ if ((errno == EAGAIN) || (errno == EINTR) || (errno == EWOULDBLOCK)) {
+ Task::yield(m_evtW, true);
+ } else {
+ m_evtW->stop();
+ return r;
+ }
+ } while (spins++ < 2);
+
+ return -1;
+}
diff --git a/tests/test-Sockets.cc b/tests/test-Sockets.cc
new file mode 100644
index 0000000..fd6eaa0
--- /dev/null
+++ b/tests/test-Sockets.cc
@@ -0,0 +1,86 @@
+#include <Main.h>
+#include <Socket.h>
+
+BALAU_STARTUP;
+
+using namespace Balau;
+
+class Worker : public Task {
+ public:
+ Worker(IO<Socket> io);
+ virtual const char * getName();
+ virtual void Do();
+ IO<Socket> m_io;
+ String m_name;
+};
+
+Worker::Worker(IO<Socket> io) : m_io(io) {
+ m_name = m_io->getName();
+ Printer::log(M_STATUS, "Got connection: %s", m_name.to_charp());
+}
+
+const char * Worker::getName() {
+ return m_name.to_charp();
+}
+
+void Worker::Do() {
+ char x, y;
+
+ int r;
+ r = m_io->read(&x, 1);
+ Assert(x == 'x');
+ Assert(r == 1);
+ y = 'y';
+ r = m_io->write(&y, 1);
+ Assert(r == 1);
+}
+
+Listener<Worker> * listener;
+
+class Client : public Task {
+ public:
+ virtual const char * getName() { return "Test client"; }
+ virtual void Do() {
+ Events::Timeout evt(0.1);
+ waitFor(&evt);
+ yield();
+
+ char x, y;
+ IO<Socket> s(new Socket());
+ bool c = s->connect("localhost", 1234);
+ Assert(c);
+ x = 'x';
+ int r;
+ r = s->write(&x, 1);
+ Assert(r == 1);
+ r = s->read(&y, 1);
+ Assert(y == 'y');
+ Assert(r == 1);
+ listener->stop();
+ }
+};
+
+void MainTask::Do() {
+ Printer::enable(M_ALL);
+ Printer::log(M_STATUS, "Test::Sockets running.");
+
+ Events::TaskEvent evtSvr(listener = new Listener<Worker>(1234));
+ Events::TaskEvent evtCln(new Client);
+ Printer::log(M_STATUS, "Created %s", listener->getName());
+ waitFor(&evtSvr);
+ waitFor(&evtCln);
+ bool svrDone = false, clnDone = false;
+ while (!svrDone || !clnDone) {
+ yield();
+ if (evtSvr.gotSignal()) {
+ evtSvr.ack();
+ svrDone = true;
+ }
+ if (evtCln.gotSignal()) {
+ evtCln.ack();
+ clnDone = true;
+ }
+ }
+
+ Printer::log(M_STATUS, "Test::Sockets passed.");
+}