From bedd2e7ccb1595c23e159eaa952ae1b0b5a3d2ad Mon Sep 17 00:00:00 2001 From: Simon Rettberg Date: Sat, 15 Mar 2014 01:49:50 +0100 Subject: Lean and mean initial commit Not much functionality yet --- server.c | 353 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 353 insertions(+) create mode 100644 server.c (limited to 'server.c') diff --git a/server.c b/server.c new file mode 100644 index 0000000..b303e67 --- /dev/null +++ b/server.c @@ -0,0 +1,353 @@ +#include "server.h" +#include "proxy.h" +#include "helper.h" +#include "epoll.h" +#include "tmpbuffer.h" +#include +#include +#include +#include +#include +#include +#include + +#define ADDRLEN 40 +#define BINDLEN 200 +#define PWLEN 40 +#define BASELEN 100 +#define ALIASLEN 40 + +#define AD_PORT 3268 +#define MSGID_BIND 1 + +typedef struct { + size_t aliasLen; + size_t baseLen; + char ip[4]; + time_t lastLookup; + char addr[ADDRLEN]; + char bind[BINDLEN]; + char password[PWLEN]; + char base[BASELEN]; + char alias[ALIASLEN]; + epoll_server_t con; +} server_t; + +#define MAX_SERVERS 10 +static server_t *servers = NULL; +static int serverCount = 0; + +static void server_init(); +static server_t *server_create(const char *server); +static void server_callback(void *data, int haveIn, int haveOut, int doCleanup); +static void server_flush(epoll_server_t * const server); +static BOOL server_ensureConnected(const int index); +static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len); + +// Generate a message ID for request to AD +static inline uint32_t msgId() +{ + static uint32_t id = 1336; + if (++id < 2) id = 2; + return id; +} + +// Setting up server(s) + +void server_setBind(const char *server, const char *bind) +{ + server_t *entry = server_create(server); + if (entry == NULL) return; + if (snprintf(entry->bind, BINDLEN, "%s", bind) >= BINDLEN) printf("Warning: BindDN for %s is too long.\n", server); +} + +void server_setPassword(const char *server, const char *password) +{ + server_t *entry = server_create(server); + if (entry == NULL) return; + if (snprintf(entry->password, PWLEN, "%s", password) >= PWLEN) printf("Warning: BindPW for %s is too long.\n", server); +} + +void server_setBase(const char *server, const char *base) +{ + server_t *entry = server_create(server); + if (entry == NULL) return; + if (snprintf(entry->base, BASELEN, "%s", base) >= BASELEN) printf("Warning: SearchBase for %s is too long.\n", server); + entry->baseLen = normalize_dn(entry->base, entry->base, min(strlen(entry->base), BASELEN - 1)); + entry->base[entry->baseLen] = '\0'; +} + +void server_setAlias(const char *server, const char *alias) +{ + server_t *entry = server_create(server); + if (entry == NULL) return; + if (snprintf(entry->alias, ALIASLEN, "%s", alias) >= ALIASLEN) printf("Warning: SearchBase Alias for %s is too long.\n", server); + entry->aliasLen = normalize_dn(entry->alias, entry->alias, min(strlen(entry->alias), ALIASLEN - 1)); + entry->alias[entry->aliasLen] = '\0'; +} + +void server_initServers() +{ + int i; + printf("%d servers configured.\n", serverCount); + for (i = 0; i < serverCount; ++i) { + printf("%s:\n Bind: %s\n Base: %s\n Proxy Alias: %s\n", servers[i].addr, servers[i].bind, servers[i].base, servers[i].alias); + server_ensureConnected(i); + } +} + +// What the proxy calls + +int server_aliasToBase(struct string *in, struct string *out) +{ + int i; + char buffer[TMPLEN]; + const size_t searchLen = normalize_dn(buffer, in->s, min(in->l, TMPLEN - 1)); + buffer[searchLen] = '\0'; + // Now buffer contains the normalized wanted alias. Try to find a match in the server list + for (i = 0; i < serverCount; ++i) { + if (searchLen < servers[i].aliasLen) continue; + if (strcmp(servers[i].alias, buffer + (searchLen - servers[i].aliasLen)) == 0) { + // Found, handle + tmpbuffer_format(out, "%.*s%s", (int)(searchLen - servers[i].aliasLen), buffer, servers[i].base); + return i; + } + } + return -1; +} + +int server_baseToAlias(struct string *in, struct string *out) +{ + int i; + char buffer[TMPLEN]; + const size_t searchLen = normalize_dn(buffer, in->s, min(in->l, TMPLEN - 1)); + buffer[searchLen] = '\0'; + // Now buffer contains the normalized wanted base. Try to find a match in the server list + for (i = 0; i < serverCount; ++i) { + printf("Comparing %s (%s) to %s\n", buffer, buffer + (searchLen - servers[i].baseLen), servers[i].base); + if (searchLen < servers[i].baseLen) continue; + if (strcmp(servers[i].base, buffer + (searchLen - servers[i].baseLen)) == 0) { + // Found, handle + tmpbuffer_format(out, "%.*s%s", (int)(searchLen - servers[i].baseLen), buffer, servers[i].alias); + printf("Match, returning %s\n", out->s); + return i; + } + } + return -1; +} + +uint32_t server_searchRequest(int server, struct SearchRequest *req) +{ + if (!server_ensureConnected(server)) return 0; + const uint32_t msgid = msgId(); + const size_t bodyLen = fmt_ldapsearchrequest(NULL, req); + const size_t headerLen = fmt_ldapmessage(NULL, msgid, SearchRequest, bodyLen); + char buffer[bodyLen + 50]; + char *bufoff = buffer + 50; + fmt_ldapsearchrequest(bufoff, req); + fmt_ldapmessage(bufoff - headerLen, msgid, SearchRequest, bodyLen); + epoll_server_t * const s = &servers[server].con; + server_send(s, bufoff - headerLen, headerLen + bodyLen, FALSE); + return msgid; +} + +// +// Private stuff + +static void server_init() +{ + if (servers != NULL) return; + servers = calloc(MAX_SERVERS, sizeof(server_t)); +} + +static server_t *server_create(const char *server) +{ + int i; + server_init(); + for (i = 0; i < serverCount; ++i) { + if (strcmp(servers[i].addr, server) == 0) return &servers[i]; + } + if (serverCount >= MAX_SERVERS) { + printf("Cannot add server %s: Too many servers.\n", server); + return NULL; + } + snprintf(servers[serverCount].addr, ADDRLEN, "%s", server); + servers[serverCount].con.fd = -1; + return &servers[serverCount++]; +} + +static void server_free(epoll_server_t *server) +{ + server->bound = FALSE; + if (server->fd != -1) close(server->fd); + server->fd = -1; + server->sbPos = server->sbFill = 0; +} + +static void server_callback(void *data, int haveIn, int haveOut, int doCleanup) +{ + epoll_server_t *server = (epoll_server_t *)data; + if (doCleanup) { + server_free(server); + return; + } + if (haveIn) { + for (;;) { + if (server->rbPos >= MAXMSGLEN) { + printf("[AD->Proxy] Read buffer overflow. Disconnecting.\n"); + server_free(server); + return; + } + const size_t buflen = MAXMSGLEN - server->rbPos; + const ssize_t ret = read(server->fd, server->readBuffer + server->rbPos, buflen); + printf("AD read %d (err %d)\n", (int)ret, errno); + if (ret < 0 && errno == EINTR) continue; + if (ret < 0 && errno == EAGAIN) break; + if (ret <= 0) { + printf("AD gone while reading.\n"); + server_free(server); + return; + } + server->rbPos += ret; + // Request complete? + for (;;) { + size_t consumed, len; + consumed = scan_asn1SEQUENCE(server->readBuffer, server->readBuffer + server->rbPos, &len); + if (consumed == 0) break; // Length-Header not complete + len += consumed; + if (len > server->rbPos) break; // Body not complete + printf("[AD] Received complete reply...\n"); + if (proxy_fromServer(server, len) == -1) { + printf("Error parsing reply from AD.\n"); + server_free(server); + return; + } + // Shift remaining buffer contents + if (len == server->rbPos) { + server->rbPos = 0; + break; + } + memmove(server->readBuffer, server->readBuffer + len, server->rbPos - len); + server->rbPos -= len; + } + if ((ssize_t)buflen > ret) break; // Read less than buffer len, epoll will fire again + } + } + if (haveOut) server_flush(server); +} + +int server_send(epoll_server_t *server, const char *buffer, size_t len, const BOOL cork) +{ + if (server->sbFill == 0 && !cork) { + // Nothing in send buffer, fire away + const int ret = write(server->fd, buffer, len); + if (ret == 0 || (ret < 0 && errno != EINTR && errno != EAGAIN)) { + printf("Server gone when trying to send.\n"); + return -1; + } + server->lastActive = time(NULL); + if (ret == (int)len) return 0; + // Couldn't send everything, continue with buffering logic below + if (ret > 0) { + printf("[AD] Partial send (%d of %d)\n", ret, (int)len); + buffer += ret; + len -= (size_t)ret; + } + } + // Buffer... + server_ensureSendBuffer(server, len); + // Finally append to buffer + memcpy(server->sendBuffer + server->sbFill, buffer, len); + server->sbFill += len; + if (!cork) server_flush(server); + return 0; +} + +static void server_flush(epoll_server_t * const server) +{ + while (server->sbPos < server->sbFill) { + const int tosend = server->sbFill - server->sbPos; + const int ret = write(server->fd, server->sendBuffer + server->sbPos, tosend); + if (ret < 0 && errno == EINTR) continue; + if (ret < 0 && errno == EAGAIN) return; + if (ret <= 0) { + printf("Connection to AD Server failed while flushing (ret: %d, errno: %d)\n", ret, errno); + return; + } + server->lastActive = time(NULL); + server->sbPos += ret; + if (ret != tosend) return; + } + server->sbPos = server->sbFill = 0; +} + +static BOOL server_ensureConnected(const int index) +{ + server_t * const server = &servers[index]; + epoll_server_t * const con = &server->con; + if (con->fd != -1 && con->lastActive + 120 > time(NULL)) return TRUE; + if (con->fd != -1) close(con->fd); + con->bound = FALSE; + printf("Connecting to AD %s...\n", server->addr); + con->sbPos = con->sbFill = 0; + int sock; + if (server->lastLookup + 300 < time(NULL)) { + sock = helper_connect4(server->addr, AD_PORT, server->ip); + if (sock == -1) { + printf("Could not resolve/connect to AD server %s\n", server->addr); + return FALSE; + } + } else { + sock = socket_tcp4b(); + if (sock == -1) { + printf("Could not allocate socket for connection to AD\n"); + return FALSE; + } + if (socket_connect4(sock, server->ip, AD_PORT) == -1) { + printf("Could not connect to cached IP of %s\n", server->addr); + close(sock); + return FALSE; + } + } + printf("Connected, binding....\n"); + helper_nonblock(sock); + con->fd = sock; + con->callback = &server_callback; + if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) { + printf("epoll_add failed for ad server %s\n", server->addr); + close(con->fd); + con->fd = -1; + return FALSE; + } + // Now bind + const size_t bodyLen = fmt_ldapbindrequest(NULL, 3, server->bind, server->password); + const size_t headerLen = fmt_ldapmessage(NULL, MSGID_BIND, BindResponse, bodyLen); + char buffer[bodyLen + 50]; + char *bufoff = buffer + 50; + if (headerLen >= 50) { + printf("[AD] bind too long for %s\n", server->addr); + close(con->fd); + con->fd = -1; + return FALSE; + } + fmt_ldapbindrequest(bufoff, 3, server->bind, server->password); + fmt_ldapmessage(bufoff - headerLen, MSGID_BIND, BindRequest, bodyLen); + server_send(con, bufoff - headerLen, bodyLen + headerLen, FALSE); + return TRUE; +} + +static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len) +{ + if (len > 1000000) bail("server_ensureSendBuffer: request too large!"); + if (s->sbLen - s->sbFill < len) { + if (s->sbPos != 0) { + memmove(s->sendBuffer, s->sendBuffer + s->sbPos, s->sbFill - s->sbPos); + s->sbFill -= s->sbPos; + s->sbPos = 0; + } + if (s->sbLen - s->sbFill < len) { + helper_realloc(&s->sendBuffer, &s->sbLen, s->sbLen + len + 1000, "server_ensureSendBuffer"); + } + } +} + -- cgit v1.2.3-55-g7522