From d611cc597822049b1bd091b6bf2f136e07ae53cf Mon Sep 17 00:00:00 2001 From: Simon Rettberg Date: Tue, 28 Apr 2015 15:54:45 +0200 Subject: SSL support when talking to ADS --- server.c | 348 +++++++++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 252 insertions(+), 96 deletions(-) (limited to 'server.c') diff --git a/server.c b/server.c index 39f8dce..5ec6148 100644 --- a/server.c +++ b/server.c @@ -3,6 +3,7 @@ #include "helper.h" #include "epoll.h" #include "tmpbuffer.h" +#include "openssl.h" #include #include #include @@ -12,6 +13,7 @@ #include #define AD_PORT 3268 +#define AD_PORT_SSL 636 #define MSGID_BIND 1 #define MAX_SERVERS 10 @@ -22,9 +24,12 @@ static void server_init(); static server_t *server_create(const char *server); static void server_free(epoll_server_t *server); static void server_callback(void *data, int haveIn, int haveOut, int doCleanup); -static void server_flush(epoll_server_t * const server); +static void server_haveIn(epoll_server_t *server); +static void server_haveOut(epoll_server_t * const server); static BOOL server_ensureConnected(server_t *server); -static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len); +static BOOL server_ensureSendBuffer(epoll_server_t * const s, const size_t len); +static int server_connectInternal(server_t *server); +static BOOL server_connectSsl(epoll_server_t *server); // Generate a message ID for request to AD static inline uint32_t msgId() @@ -36,6 +41,18 @@ static inline uint32_t msgId() // Setting up server(s) +void server_setPort(const char *server, const char *portStr) +{ + server_t *entry = server_create(server); + if (entry == NULL) return; + int port = atoi(portStr); + if (port < 1 || port > 65535) { + printf("Warning: Invalid port '%s' for '%s'\n", portStr, server); + return; + } + entry->port = (uint16_t)port; +} + void server_setBind(const char *server, const char *bind) { server_t *entry = server_create(server); @@ -80,6 +97,47 @@ void server_setHomeTemplate(const char *server, const char *hometemplate) if (count > 5) printf("WARNING: Too many '%%' in Home Template for %s. Don't forget to replace literal '%%' with '%%%%'\n", server); } +void server_setFingerprint(const char *server, const char *fingerprint) +{ + server_t *entry = server_create(server); + if (entry == NULL || entry->sslContext != NULL) return; + int chars = 0, val = -1; + while (*fingerprint != '\0' && chars / 2 < FINGERPRINTLEN) { + if (*fingerprint == ':' || *fingerprint == ' ') { + fingerprint++; + continue; + } + val = -1; + if (*fingerprint >= '0' && *fingerprint <= '9') { + val = *fingerprint - '0'; + } else if (*fingerprint >= 'a' && *fingerprint <= 'f') { + val = *fingerprint - 'a' + 10; + } else if (*fingerprint >= 'A' && *fingerprint <= 'F') { + val = *fingerprint - 'A' + 10; + } else { + break; + } + if (chars % 2 == 0) { + entry->fingerprint[chars / 2] |= val << 4; + } else { + entry->fingerprint[chars / 2] |= val; + } + fingerprint++; + chars++; + } + if (chars / 2 != FINGERPRINTLEN || val == -1) { + printf("Warning: Fingerprint for %s is invalid (adsha1 should be a SHA-1 hash of the cert in hex representation.)\n", server); + return; + } + printf("Using fingerprint "); + for (int i = 0; i < FINGERPRINTLEN - 1; ++i) { + printf("%02x:", (int)entry->fingerprint[i]); + } + printf("%02x for %s\n", (int)entry->fingerprint[FINGERPRINTLEN-1], server); + ssl_init(); + entry->sslContext = ssl_newClientCtx(); +} + BOOL server_initServers() { int i; @@ -134,27 +192,10 @@ uint32_t server_tryUserBind(server_t *server, struct string *binddn, struct stri con->dynamic = TRUE; printf("Connecting to AD '%s' for %.*ss bind...\n", server->addr, (int)binddn->l, binddn->s); con->sbPos = con->sbFill = 0; - int sock; - if (server->lastLookup + 300 < time(NULL)) { - sock = helper_connect4(server->addr, AD_PORT, server->ip); - if (sock == -1) { - printf("[ADB] Could not resolve/connect to AD server %s\n", server->addr); - server_free(con); - return 0; - } - } else { - sock = socket_tcp4b(); - if (sock == -1) { - printf("[ADB] Could not allocate socket for connection to AD\n"); - server_free(con); - return 0; - } - if (socket_connect4(sock, server->ip, AD_PORT) == -1) { - printf("[ADB] Could not connect to cached IP of %s\n", server->addr); - close(sock); - server_free(con); - return 0; - } + int sock = server_connectInternal(server); + if (sock == -1) { + server_free(con); + return 0; } printf("[ADB] Connected, binding....\n"); helper_nonblock(sock); @@ -165,6 +206,11 @@ uint32_t server_tryUserBind(server_t *server, struct string *binddn, struct stri server_free(con); return 0; } + // SSL + if (!server_connectSsl(con)) { + server_free(con); + return 0; + } // Now bind - TODO: SASL (DIGEST-MD5?) const uint32_t id = msgId(); const size_t bodyLen = fmt_ldapbindrequeststring(NULL, 3, binddn, password); @@ -211,8 +257,14 @@ static server_t *server_create(const char *server) static void server_free(epoll_server_t *server) { server->bound = FALSE; - if (server->fd != -1) close(server->fd); - server->fd = -1; + if (server->ssl != NULL) { + SSL_free(server->ssl); + server->ssl = NULL; + } + if (server->fd != -1) { + close(server->fd); + server->fd = -1; + } server->sbPos = server->sbFill = 0; if (server->dynamic) { printf("Freeing Bind-AD-Connection\n"); @@ -224,61 +276,97 @@ static void server_free(epoll_server_t *server) static void server_callback(void *data, int haveIn, int haveOut, int doCleanup) { epoll_server_t *server = (epoll_server_t *)data; - if (doCleanup) { + if (doCleanup || server->kill) { server_free(server); return; } - if (haveIn) { - for (;;) { - if (server->rbPos >= MAXMSGLEN) { - printf("[AD->Proxy] Read buffer overflow. Disconnecting.\n"); - server_free(server); - return; - } - const size_t buflen = MAXMSGLEN - server->rbPos; - const ssize_t ret = read(server->fd, server->readBuffer + server->rbPos, buflen); + if (server->ssl == NULL) { + // Plain connection + if (haveIn) server_haveIn(server); + if (haveOut) server_haveOut(server); + if (server->kill) server_free(server); + return; + } + // SSL + if (!server->sslConnected) { + // Still SSL-Connecting + if (!ssl_connectServer(server)) { + printf("SSL Server connect failed!\n"); + server_free(server); + return; + } + if (!server->sslConnected) return; + } + // Since we don't know if the incoming data is just wrapped application data or ssl protocol stuff, we always call both + server_haveIn(server); + server_haveOut(server); + if (server->kill) server_free(server); +} + +static void server_haveIn(epoll_server_t *server) +{ + for (;;) { + if (server->rbPos >= MAXMSGLEN) { + printf("[AD->Proxy] Read buffer overflow. Disconnecting.\n"); + server->kill = TRUE; + return; + } + const size_t buflen = MAXMSGLEN - server->rbPos; + ssize_t ret; + if (server->ssl == NULL) { + // Plain + ret = read(server->fd, server->readBuffer + server->rbPos, buflen); printf("AD read %d (err %d)\n", (int)ret, errno); if (ret < 0 && errno == EINTR) continue; if (ret < 0 && errno == EAGAIN) break; if (ret <= 0) { - printf("AD gone while reading.\n"); - server_free(server); + printf("AD Server gone while reading.\n"); + server->kill = TRUE; return; } - server->rbPos += ret; - // Request complete? - for (;;) { - size_t consumed, len; - consumed = scan_asn1SEQUENCE(server->readBuffer, server->readBuffer + server->rbPos, &len); - if (consumed == 0) break; // Length-Header not complete - len += consumed; - if (len > server->rbPos) break; // Body not complete - printf("[AD] Received complete reply...\n"); - if (!proxy_fromServer(server, len)) { - if (server->dynamic) { - server_free(server); - return; - } - printf("Error parsing reply from AD.\n"); - // Let's try to go on with the next message.... - } - // Shift remaining buffer contents - if (len == server->rbPos) { - server->rbPos = 0; - break; + } else { + // SSL + ret = SSL_read(server->ssl, server->readBuffer + server->rbPos, buflen); + if (ret <= 0) { + int err = SSL_get_error(server->ssl, ret); + if (SSL_BLOCKED(err)) break; + printf("AD Server gone while reading (%d, %d).\n", (int)ret, err); + server->kill = TRUE; + return; + } + } + server->rbPos += ret; + // Request complete? + for (;;) { + size_t consumed, len; + consumed = scan_asn1SEQUENCE(server->readBuffer, server->readBuffer + server->rbPos, &len); + if (consumed == 0) break; // Length-Header not complete + len += consumed; + if (len > server->rbPos) break; // Body not complete + printf("[AD] Received complete reply...\n"); + if (!proxy_fromServer(server, len)) { + if (server->dynamic) { + server->kill = TRUE; + return; } - memmove(server->readBuffer, server->readBuffer + len, server->rbPos - len); - server->rbPos -= len; + printf("Error parsing reply from AD.\n"); + // Let's try to go on with the next message.... } - if ((ssize_t)buflen > ret) break; // Read less than buffer len, epoll will fire again + // Shift remaining buffer contents + if (len == server->rbPos) { + server->rbPos = 0; + break; + } + memmove(server->readBuffer, server->readBuffer + len, server->rbPos - len); + server->rbPos -= len; } + if ((ssize_t)buflen > ret) break; // Read less than buffer len, epoll will fire again } - if (haveOut) server_flush(server); } BOOL server_send(epoll_server_t *server, const char *buffer, size_t len, const BOOL cork) { - if (server->sbFill == 0 && !cork) { + if (server->ssl == NULL && server->sbFill == 0 && !cork) { // Nothing in send buffer, fire away const int ret = write(server->fd, buffer, len); if (ret == 0 || (ret < 0 && errno != EINTR && errno != EAGAIN)) { @@ -295,28 +383,54 @@ BOOL server_send(epoll_server_t *server, const char *buffer, size_t len, const B } } // Buffer... - server_ensureSendBuffer(server, len); + if (!server_ensureSendBuffer(server, len)) { + server->kill = TRUE; + return FALSE; + } // Finally append to buffer memcpy(server->sendBuffer + server->sbFill, buffer, len); server->sbFill += len; - if (!cork) server_flush(server); + if (!cork) server_haveOut(server); return TRUE; } -static void server_flush(epoll_server_t * const server) +static void server_haveOut(epoll_server_t * const server) { while (server->sbPos < server->sbFill) { - const int tosend = server->sbFill - server->sbPos; - const int ret = write(server->fd, server->sendBuffer + server->sbPos, tosend); - if (ret < 0 && errno == EINTR) continue; - if (ret < 0 && errno == EAGAIN) return; - if (ret <= 0) { - printf("Connection to AD Server failed while flushing (ret: %d, errno: %d)\n", ret, errno); - return; + const ssize_t tosend = server->sbFill - server->sbPos; + ssize_t ret; + if (server->ssl == NULL) { + // Plain + ret = write(server->fd, server->sendBuffer + server->sbPos, tosend); + if (ret < 0 && errno == EINTR) continue; + if (ret < 0 && errno == EAGAIN) return; + if (ret <= 0) { + printf("Connection to AD Server failed while flushing (ret: %d, errno: %d)\n", (int)ret, errno); + return; + } + } else { + // SSL + ret = SSL_write(server->ssl, server->sendBuffer + server->sbPos, tosend); + if (ret <= 0) { + int err = SSL_get_error(server->ssl, ret); + if (SSL_BLOCKED(err)) { + server->writeBlocked = TRUE; + return; // Blocking + } + printf("SSL server gone while sending (%d)\n", err); + ERR_print_errors_fp(stdout); + server->kill = TRUE; + return; // Closed + } } server->lastActive = time(NULL); server->sbPos += ret; - if (ret != tosend) return; + if (server->ssl != NULL) { + memmove(server->sendBuffer, server->sendBuffer + server->sbPos, server->sbFill - server->sbPos); + server->sbFill -= server->sbPos; + server->sbPos = 0; + } + if (server->ssl == NULL && ret != tosend) return; } server->sbPos = server->sbFill = 0; } @@ -329,25 +443,8 @@ static BOOL server_ensureConnected(server_t *server) con->bound = FALSE; printf("Connecting to AD '%s'...\n", server->addr); con->sbPos = con->sbFill = 0; - int sock; - if (server->lastLookup + 300 < time(NULL)) { - sock = helper_connect4(server->addr, AD_PORT, server->ip); - if (sock == -1) { - printf("Could not resolve/connect to AD server %s\n", server->addr); - return FALSE; - } - } else { - sock = socket_tcp4b(); - if (sock == -1) { - printf("Could not allocate socket for connection to AD\n"); - return FALSE; - } - if (socket_connect4(sock, server->ip, AD_PORT) == -1) { - printf("Could not connect to cached IP of %s\n", server->addr); - close(sock); - return FALSE; - } - } + int sock = server_connectInternal(server); + if (sock == -1) return FALSE; printf("Connected, binding....\n"); helper_nonblock(sock); con->fd = sock; @@ -358,6 +455,12 @@ static BOOL server_ensureConnected(server_t *server) con->fd = -1; return FALSE; } + // SSL + if (!server_connectSsl(con)) { + close(con->fd); + con->fd = -1; + return FALSE; + } // Now bind - TODO: SASL (DIGEST-MD5?) const size_t bodyLen = fmt_ldapbindrequest(NULL, 3, server->bind, server->password); const size_t headerLen = fmt_ldapmessage(NULL, MSGID_BIND, BindRequest, bodyLen); @@ -375,18 +478,71 @@ static BOOL server_ensureConnected(server_t *server) return TRUE; } -static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len) +static BOOL server_ensureSendBuffer(epoll_server_t * const s, const size_t len) { - if (len > 1000000) bail("server_ensureSendBuffer: request too large!"); + if (len > 1000000) { + printf("server_ensureSendBuffer: request too large!\n"); + return FALSE; + } if (s->sbLen - s->sbFill < len) { + if (s->writeBlocked) { + printf("SSL Write blocked and buffer to small (%d)\n", (int)s->sbLen); + return FALSE; + } if (s->sbPos != 0) { memmove(s->sendBuffer, s->sendBuffer + s->sbPos, s->sbFill - s->sbPos); s->sbFill -= s->sbPos; s->sbPos = 0; } if (s->sbLen - s->sbFill < len) { - helper_realloc(&s->sendBuffer, &s->sbLen, s->sbLen + len + 1000, "server_ensureSendBuffer"); + if (helper_realloc(&s->sendBuffer, &s->sbLen, s->sbLen + len + (s->ssl == NULL ? 1000 : 6000), "server_ensureSendBuffer") == -1) { + return FALSE; + } } } + return TRUE; +} + +static int server_connectInternal(server_t *server) +{ + int sock; + const uint16_t port = server->port != 0 ? server->port : (server->sslContext == NULL ? AD_PORT : AD_PORT_SSL); + if (server->lastLookup + 300 < time(NULL)) { + sock = helper_connect4(server->addr, port, server->ip); + if (sock == -1) { + printf("Could not resolve/connect to AD server %s\n", server->addr); + return -1; + } + } else { + sock = socket_tcp4b(); + if (sock == -1) { + printf("Could not allocate socket for connection to AD\n"); + return -1; + } + if (socket_connect4(sock, server->ip, port) == -1) { + printf("Could not connect to cached IP of %s\n", server->addr); + server->lastLookup = 0; + close(sock); + return -1; + } + } + return sock; +} + +static BOOL server_connectSsl(epoll_server_t *server) +{ + if (server->serverData->sslContext == NULL) return TRUE; + server->ssl = ssl_new(server->fd, server->serverData->sslContext); + if (server->ssl == NULL) { + printf("Could not get SSL client from context\n"); + return FALSE; + } + if (!ssl_connectServer(server)) { + printf("SSL connect failed.\n"); + SSL_free(server->ssl); + server->ssl = NULL; + return FALSE; + } + return TRUE; } -- cgit v1.2.3-55-g7522