summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSimon Rettberg2015-04-28 15:54:45 +0200
committerSimon Rettberg2015-04-28 15:54:45 +0200
commitd611cc597822049b1bd091b6bf2f136e07ae53cf (patch)
tree6a31137cda1d6727123e668268d891d429b53c9d
parent"Support" feature query (done by sssd) (diff)
downloadldadp-d611cc597822049b1bd091b6bf2f136e07ae53cf.tar.gz
ldadp-d611cc597822049b1bd091b6bf2f136e07ae53cf.tar.xz
ldadp-d611cc597822049b1bd091b6bf2f136e07ae53cf.zip
SSL support when talking to ADS
-rw-r--r--config/config.example13
-rw-r--r--ldadp.c48
-rw-r--r--openssl.c61
-rw-r--r--openssl.h8
-rw-r--r--server.c348
-rw-r--r--server.h4
-rw-r--r--types.h31
7 files changed, 387 insertions, 126 deletions
diff --git a/config/config.example b/config/config.example
index 6ad38a6..574d328 100644
--- a/config/config.example
+++ b/config/config.example
@@ -1,3 +1,4 @@
+# Configure an ADS we proxy. hostname will be the section name
[dc0.example.com]
# bind DN towards this AD if client doesn't specify one
binddn=CN=blabla,OU=Foo,DC=public,DC=ads,DC=example,DC=com
@@ -7,7 +8,15 @@ bindpw=geheim
base=DC=public,DC=ads,DC=example,DC=com
# optional: template for home directory mount point to pass to client. use %s as the users account name. only used if AD doesn't supply the homeDirectory attribute (or it doesn't contain a UNC path)
home=\\windows-server\users\%s
-# For using SSL between client and proxy, uncomment these. For plaintext, remove or comment out
+# Set this to use SSL when talking to the ADS. SSL is not enabled by default, so make sure your ADS has it.
+fingerprint=76:EC:9D:18:99:0D:8F:E1:99:D2:07:09:48:DF:82:4F:28:47:32:14
+# Optinally set remote port. Default is 3268 for plain connection, 636 for SSL connection.
+port=6666
+
+# Configure the proxy)
+[local]
+# Local TCP port to listen on
+port=1234
+# For using SSL between client and proxy, set these. For plaintext, remove or comment out
cert=/my/cert.pem
privkey=/my/privatekey.pem
-
diff --git a/ldadp.c b/ldadp.c
index 584da8a..8d320a8 100644
--- a/ldadp.c
+++ b/ldadp.c
@@ -84,7 +84,7 @@ static void listen_callback(void *data, int haveIn, int haveOut, int doCleanup)
printf("Accepted connection.\n");
SSL *ssl = NULL;
if (listen->sslContext != NULL) {
- ssl = ssl_startAccept(sock, listen->sslContext);
+ ssl = ssl_new(sock, listen->sslContext);
if (ssl == NULL) {
close(sock);
return;
@@ -106,26 +106,32 @@ static void listen_callback(void *data, int haveIn, int haveOut, int doCleanup)
static int loadConfig_handler(void *stuff, const char *section, const char *key, const char *value)
{
- if (strcmp(key, "binddn") == 0) {
- server_setBind(section, value);
- }
- if (strcmp(key, "bindpw") == 0) {
- server_setPassword(section, value);
- }
- if (strcmp(key, "base") == 0) {
- server_setBase(section, value);
- }
- if (strcmp(key, "home") == 0 && *value != '\0') {
- server_setHomeTemplate(section, value);
- }
- if (strcmp(key, "port") == 0) {
- localPort = atoi(value);
- }
- if (strcmp(key, "cert") == 0) {
- certFile = strdup(value);
- }
- if (strcmp(key, "privkey") == 0) {
- keyFile = strdup(value);
+ if (strcmp(section, "local") == 0) {
+ if (strcmp(key, "port") == 0) {
+ localPort = atoi(value);
+ } else if (strcmp(key, "cert") == 0) {
+ certFile = strdup(value);
+ } else if (strcmp(key, "privkey") == 0) {
+ keyFile = strdup(value);
+ } else {
+ printf("Unknown local config option '%s'\n", key);
+ }
+ } else {
+ if (strcmp(key, "binddn") == 0) {
+ server_setBind(section, value);
+ } else if (strcmp(key, "bindpw") == 0) {
+ server_setPassword(section, value);
+ } else if (strcmp(key, "base") == 0) {
+ server_setBase(section, value);
+ } else if (strcmp(key, "home") == 0 && *value != '\0') {
+ server_setHomeTemplate(section, value);
+ } else if (strcmp(key, "fingerprint") == 0 && *value != '\0') {
+ server_setFingerprint(section, value);
+ } else if (strcmp(key, "port") == 0) {
+ server_setPort(section, value);
+ } else {
+ printf("Unknown ADS config option '%s' for server '%s'\n", key, section);
+ }
}
return 1;
}
diff --git a/openssl.c b/openssl.c
index 32c7bca..c8e4142 100644
--- a/openssl.c
+++ b/openssl.c
@@ -2,6 +2,7 @@
#include "helper.h"
static BOOL initDone = FALSE;
+static const EVP_MD *sha1 = NULL;
void ssl_printErrors(char *bailMsg)
{
@@ -19,6 +20,8 @@ BOOL ssl_init()
SSL_load_error_strings();
SSL_library_init();
OpenSSL_add_all_algorithms();
+ sha1 = EVP_get_digestbyname("sha1");
+ if (sha1 == NULL) ssl_printErrors("Could not load SHA-1 digest\n");
return TRUE;
}
@@ -29,13 +32,26 @@ SSL_CTX* ssl_newServerCtx(char *certfile, char *keyfile)
SSL_CTX *ctx = SSL_CTX_new(m);
if (ctx == NULL) ssl_printErrors("newServerCtx: ctx is NULL");
SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2);
+ SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv3);
SSL_CTX_use_certificate_file(ctx, certfile, SSL_FILETYPE_PEM);
SSL_CTX_use_PrivateKey_file(ctx, keyfile, SSL_FILETYPE_PEM);
if (!SSL_CTX_check_private_key(ctx)) ssl_printErrors("Could not load cert/private key");
+ SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
return ctx;
}
-SSL *ssl_startAccept(int clientFd, SSL_CTX *ctx)
+SSL_CTX* ssl_newClientCtx()
+{
+ const SSL_METHOD *m = SSLv23_client_method();
+ if (m == NULL) ssl_printErrors("newClientCtx: method is NULL");
+ SSL_CTX *ctx = SSL_CTX_new(m);
+ if (ctx == NULL) ssl_printErrors("newClientCtx: ctx is NULL");
+ SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2);
+ SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE);
+ return ctx;
+}
+
+SSL *ssl_new(int clientFd, SSL_CTX *ctx)
{
SSL *ssl = SSL_new(ctx);
if (ssl == NULL) {
@@ -47,7 +63,6 @@ SSL *ssl_startAccept(int clientFd, SSL_CTX *ctx)
SSL_free(ssl);
return NULL;
}
- SSL_set_mode(ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
return ssl;
}
@@ -66,3 +81,45 @@ BOOL ssl_acceptClient(epoll_client_t *client)
return FALSE;
}
+BOOL ssl_connectServer(epoll_server_t *server)
+{
+ if (server->sslConnected) return TRUE;
+ int ret = SSL_connect(server->ssl);
+ if (ret == 1) {
+ if (!ssl_checkCertificateHash(server)) {
+ printf("Warning: Fingerprint of %s doesn't match value given in config, refusing to talk to server!\n", server->serverData->addr);
+ return FALSE;
+ }
+ server->sslConnected = TRUE;
+ return TRUE;
+ }
+ if (ret < 0) {
+ int err = SSL_get_error(server->ssl, ret);
+ if (SSL_BLOCKED(err)) return TRUE;
+ }
+ return FALSE;
+}
+
+BOOL ssl_checkCertificateHash(epoll_server_t *server)
+{
+ if (server->ssl == NULL) {
+ printf("Bug: Asked to check certificate of non-SSL connection\n");
+ return FALSE;
+ }
+ for (int i = 0; i < FINGERPRINTLEN; ++i) {
+ if (server->serverData->fingerprint[i] != 0) {
+ unsigned char md[EVP_MAX_MD_SIZE];
+ unsigned int n = 20;
+ X509 *cert = SSL_get_peer_certificate(server->ssl);
+ if (cert == NULL) {
+ printf("Warning: Server %s has no certificate!\n", server->serverData->addr);
+ return FALSE;
+ }
+ X509_free(cert);
+ X509_digest(cert, sha1, md, &n);
+ return n == 20 && memcmp(md, server->serverData->fingerprint, n) == 0;
+ }
+ }
+ return TRUE;
+}
+
diff --git a/openssl.h b/openssl.h
index a564b97..a37c58e 100644
--- a/openssl.h
+++ b/openssl.h
@@ -13,9 +13,15 @@ BOOL ssl_init();
SSL_CTX* ssl_newServerCtx(char *certfile, char *keyfile);
-SSL *ssl_startAccept(int clientFd, SSL_CTX *ctx);
+SSL_CTX* ssl_newClientCtx();
+
+SSL *ssl_new(int clientFd, SSL_CTX *ctx);
BOOL ssl_acceptClient(epoll_client_t *client);
+BOOL ssl_connectServer(epoll_server_t *server);
+
+BOOL ssl_checkCertificateHash(epoll_server_t *server);
+
#endif
diff --git a/server.c b/server.c
index 39f8dce..5ec6148 100644
--- a/server.c
+++ b/server.c
@@ -3,6 +3,7 @@
#include "helper.h"
#include "epoll.h"
#include "tmpbuffer.h"
+#include "openssl.h"
#include <time.h>
#include <stdio.h>
#include <stdlib.h>
@@ -12,6 +13,7 @@
#include <socket.h>
#define AD_PORT 3268
+#define AD_PORT_SSL 636
#define MSGID_BIND 1
#define MAX_SERVERS 10
@@ -22,9 +24,12 @@ static void server_init();
static server_t *server_create(const char *server);
static void server_free(epoll_server_t *server);
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup);
-static void server_flush(epoll_server_t * const server);
+static void server_haveIn(epoll_server_t *server);
+static void server_haveOut(epoll_server_t * const server);
static BOOL server_ensureConnected(server_t *server);
-static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len);
+static BOOL server_ensureSendBuffer(epoll_server_t * const s, const size_t len);
+static int server_connectInternal(server_t *server);
+static BOOL server_connectSsl(epoll_server_t *server);
// Generate a message ID for request to AD
static inline uint32_t msgId()
@@ -36,6 +41,18 @@ static inline uint32_t msgId()
// Setting up server(s)
+void server_setPort(const char *server, const char *portStr)
+{
+ server_t *entry = server_create(server);
+ if (entry == NULL) return;
+ int port = atoi(portStr);
+ if (port < 1 || port > 65535) {
+ printf("Warning: Invalid port '%s' for '%s'\n", portStr, server);
+ return;
+ }
+ entry->port = (uint16_t)port;
+}
+
void server_setBind(const char *server, const char *bind)
{
server_t *entry = server_create(server);
@@ -80,6 +97,47 @@ void server_setHomeTemplate(const char *server, const char *hometemplate)
if (count > 5) printf("WARNING: Too many '%%' in Home Template for %s. Don't forget to replace literal '%%' with '%%%%'\n", server);
}
+void server_setFingerprint(const char *server, const char *fingerprint)
+{
+ server_t *entry = server_create(server);
+ if (entry == NULL || entry->sslContext != NULL) return;
+ int chars = 0, val = -1;
+ while (*fingerprint != '\0' && chars / 2 < FINGERPRINTLEN) {
+ if (*fingerprint == ':' || *fingerprint == ' ') {
+ fingerprint++;
+ continue;
+ }
+ val = -1;
+ if (*fingerprint >= '0' && *fingerprint <= '9') {
+ val = *fingerprint - '0';
+ } else if (*fingerprint >= 'a' && *fingerprint <= 'f') {
+ val = *fingerprint - 'a' + 10;
+ } else if (*fingerprint >= 'A' && *fingerprint <= 'F') {
+ val = *fingerprint - 'A' + 10;
+ } else {
+ break;
+ }
+ if (chars % 2 == 0) {
+ entry->fingerprint[chars / 2] |= val << 4;
+ } else {
+ entry->fingerprint[chars / 2] |= val;
+ }
+ fingerprint++;
+ chars++;
+ }
+ if (chars / 2 != FINGERPRINTLEN || val == -1) {
+ printf("Warning: Fingerprint for %s is invalid (adsha1 should be a SHA-1 hash of the cert in hex representation.)\n", server);
+ return;
+ }
+ printf("Using fingerprint ");
+ for (int i = 0; i < FINGERPRINTLEN - 1; ++i) {
+ printf("%02x:", (int)entry->fingerprint[i]);
+ }
+ printf("%02x for %s\n", (int)entry->fingerprint[FINGERPRINTLEN-1], server);
+ ssl_init();
+ entry->sslContext = ssl_newClientCtx();
+}
+
BOOL server_initServers()
{
int i;
@@ -134,27 +192,10 @@ uint32_t server_tryUserBind(server_t *server, struct string *binddn, struct stri
con->dynamic = TRUE;
printf("Connecting to AD '%s' for %.*ss bind...\n", server->addr, (int)binddn->l, binddn->s);
con->sbPos = con->sbFill = 0;
- int sock;
- if (server->lastLookup + 300 < time(NULL)) {
- sock = helper_connect4(server->addr, AD_PORT, server->ip);
- if (sock == -1) {
- printf("[ADB] Could not resolve/connect to AD server %s\n", server->addr);
- server_free(con);
- return 0;
- }
- } else {
- sock = socket_tcp4b();
- if (sock == -1) {
- printf("[ADB] Could not allocate socket for connection to AD\n");
- server_free(con);
- return 0;
- }
- if (socket_connect4(sock, server->ip, AD_PORT) == -1) {
- printf("[ADB] Could not connect to cached IP of %s\n", server->addr);
- close(sock);
- server_free(con);
- return 0;
- }
+ int sock = server_connectInternal(server);
+ if (sock == -1) {
+ server_free(con);
+ return 0;
}
printf("[ADB] Connected, binding....\n");
helper_nonblock(sock);
@@ -165,6 +206,11 @@ uint32_t server_tryUserBind(server_t *server, struct string *binddn, struct stri
server_free(con);
return 0;
}
+ // SSL
+ if (!server_connectSsl(con)) {
+ server_free(con);
+ return 0;
+ }
// Now bind - TODO: SASL (DIGEST-MD5?)
const uint32_t id = msgId();
const size_t bodyLen = fmt_ldapbindrequeststring(NULL, 3, binddn, password);
@@ -211,8 +257,14 @@ static server_t *server_create(const char *server)
static void server_free(epoll_server_t *server)
{
server->bound = FALSE;
- if (server->fd != -1) close(server->fd);
- server->fd = -1;
+ if (server->ssl != NULL) {
+ SSL_free(server->ssl);
+ server->ssl = NULL;
+ }
+ if (server->fd != -1) {
+ close(server->fd);
+ server->fd = -1;
+ }
server->sbPos = server->sbFill = 0;
if (server->dynamic) {
printf("Freeing Bind-AD-Connection\n");
@@ -224,61 +276,97 @@ static void server_free(epoll_server_t *server)
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup)
{
epoll_server_t *server = (epoll_server_t *)data;
- if (doCleanup) {
+ if (doCleanup || server->kill) {
server_free(server);
return;
}
- if (haveIn) {
- for (;;) {
- if (server->rbPos >= MAXMSGLEN) {
- printf("[AD->Proxy] Read buffer overflow. Disconnecting.\n");
- server_free(server);
- return;
- }
- const size_t buflen = MAXMSGLEN - server->rbPos;
- const ssize_t ret = read(server->fd, server->readBuffer + server->rbPos, buflen);
+ if (server->ssl == NULL) {
+ // Plain connection
+ if (haveIn) server_haveIn(server);
+ if (haveOut) server_haveOut(server);
+ if (server->kill) server_free(server);
+ return;
+ }
+ // SSL
+ if (!server->sslConnected) {
+ // Still SSL-Connecting
+ if (!ssl_connectServer(server)) {
+ printf("SSL Server connect failed!\n");
+ server_free(server);
+ return;
+ }
+ if (!server->sslConnected) return;
+ }
+ // Since we don't know if the incoming data is just wrapped application data or ssl protocol stuff, we always call both
+ server_haveIn(server);
+ server_haveOut(server);
+ if (server->kill) server_free(server);
+}
+
+static void server_haveIn(epoll_server_t *server)
+{
+ for (;;) {
+ if (server->rbPos >= MAXMSGLEN) {
+ printf("[AD->Proxy] Read buffer overflow. Disconnecting.\n");
+ server->kill = TRUE;
+ return;
+ }
+ const size_t buflen = MAXMSGLEN - server->rbPos;
+ ssize_t ret;
+ if (server->ssl == NULL) {
+ // Plain
+ ret = read(server->fd, server->readBuffer + server->rbPos, buflen);
printf("AD read %d (err %d)\n", (int)ret, errno);
if (ret < 0 && errno == EINTR) continue;
if (ret < 0 && errno == EAGAIN) break;
if (ret <= 0) {
- printf("AD gone while reading.\n");
- server_free(server);
+ printf("AD Server gone while reading.\n");
+ server->kill = TRUE;
return;
}
- server->rbPos += ret;
- // Request complete?
- for (;;) {
- size_t consumed, len;
- consumed = scan_asn1SEQUENCE(server->readBuffer, server->readBuffer + server->rbPos, &len);
- if (consumed == 0) break; // Length-Header not complete
- len += consumed;
- if (len > server->rbPos) break; // Body not complete
- printf("[AD] Received complete reply...\n");
- if (!proxy_fromServer(server, len)) {
- if (server->dynamic) {
- server_free(server);
- return;
- }
- printf("Error parsing reply from AD.\n");
- // Let's try to go on with the next message....
- }
- // Shift remaining buffer contents
- if (len == server->rbPos) {
- server->rbPos = 0;
- break;
+ } else {
+ // SSL
+ ret = SSL_read(server->ssl, server->readBuffer + server->rbPos, buflen);
+ if (ret <= 0) {
+ int err = SSL_get_error(server->ssl, ret);
+ if (SSL_BLOCKED(err)) break;
+ printf("AD Server gone while reading (%d, %d).\n", (int)ret, err);
+ server->kill = TRUE;
+ return;
+ }
+ }
+ server->rbPos += ret;
+ // Request complete?
+ for (;;) {
+ size_t consumed, len;
+ consumed = scan_asn1SEQUENCE(server->readBuffer, server->readBuffer + server->rbPos, &len);
+ if (consumed == 0) break; // Length-Header not complete
+ len += consumed;
+ if (len > server->rbPos) break; // Body not complete
+ printf("[AD] Received complete reply...\n");
+ if (!proxy_fromServer(server, len)) {
+ if (server->dynamic) {
+ server->kill = TRUE;
+ return;
}
- memmove(server->readBuffer, server->readBuffer + len, server->rbPos - len);
- server->rbPos -= len;
+ printf("Error parsing reply from AD.\n");
+ // Let's try to go on with the next message....
}
- if ((ssize_t)buflen > ret) break; // Read less than buffer len, epoll will fire again
+ // Shift remaining buffer contents
+ if (len == server->rbPos) {
+ server->rbPos = 0;
+ break;
+ }
+ memmove(server->readBuffer, server->readBuffer + len, server->rbPos - len);
+ server->rbPos -= len;
}
+ if ((ssize_t)buflen > ret) break; // Read less than buffer len, epoll will fire again
}
- if (haveOut) server_flush(server);
}
BOOL server_send(epoll_server_t *server, const char *buffer, size_t len, const BOOL cork)
{
- if (server->sbFill == 0 && !cork) {
+ if (server->ssl == NULL && server->sbFill == 0 && !cork) {
// Nothing in send buffer, fire away
const int ret = write(server->fd, buffer, len);
if (ret == 0 || (ret < 0 && errno != EINTR && errno != EAGAIN)) {
@@ -295,28 +383,54 @@ BOOL server_send(epoll_server_t *server, const char *buffer, size_t len, const B
}
}
// Buffer...
- server_ensureSendBuffer(server, len);
+ if (!server_ensureSendBuffer(server, len)) {
+ server->kill = TRUE;
+ return FALSE;
+ }
// Finally append to buffer
memcpy(server->sendBuffer + server->sbFill, buffer, len);
server->sbFill += len;
- if (!cork) server_flush(server);
+ if (!cork) server_haveOut(server);
return TRUE;
}
-static void server_flush(epoll_server_t * const server)
+static void server_haveOut(epoll_server_t * const server)
{
while (server->sbPos < server->sbFill) {
- const int tosend = server->sbFill - server->sbPos;
- const int ret = write(server->fd, server->sendBuffer + server->sbPos, tosend);
- if (ret < 0 && errno == EINTR) continue;
- if (ret < 0 && errno == EAGAIN) return;
- if (ret <= 0) {
- printf("Connection to AD Server failed while flushing (ret: %d, errno: %d)\n", ret, errno);
- return;
+ const ssize_t tosend = server->sbFill - server->sbPos;
+ ssize_t ret;
+ if (server->ssl == NULL) {
+ // Plain
+ ret = write(server->fd, server->sendBuffer + server->sbPos, tosend);
+ if (ret < 0 && errno == EINTR) continue;
+ if (ret < 0 && errno == EAGAIN) return;
+ if (ret <= 0) {
+ printf("Connection to AD Server failed while flushing (ret: %d, errno: %d)\n", (int)ret, errno);
+ return;
+ }
+ } else {
+ // SSL
+ ret = SSL_write(server->ssl, server->sendBuffer + server->sbPos, tosend);
+ if (ret <= 0) {
+ int err = SSL_get_error(server->ssl, ret);
+ if (SSL_BLOCKED(err)) {
+ server->writeBlocked = TRUE;
+ return; // Blocking
+ }
+ printf("SSL server gone while sending (%d)\n", err);
+ ERR_print_errors_fp(stdout);
+ server->kill = TRUE;
+ return; // Closed
+ }
}
server->lastActive = time(NULL);
server->sbPos += ret;
- if (ret != tosend) return;
+ if (server->ssl != NULL) {
+ memmove(server->sendBuffer, server->sendBuffer + server->sbPos, server->sbFill - server->sbPos);
+ server->sbFill -= server->sbPos;
+ server->sbPos = 0;
+ }
+ if (server->ssl == NULL && ret != tosend) return;
}
server->sbPos = server->sbFill = 0;
}
@@ -329,25 +443,8 @@ static BOOL server_ensureConnected(server_t *server)
con->bound = FALSE;
printf("Connecting to AD '%s'...\n", server->addr);
con->sbPos = con->sbFill = 0;
- int sock;
- if (server->lastLookup + 300 < time(NULL)) {
- sock = helper_connect4(server->addr, AD_PORT, server->ip);
- if (sock == -1) {
- printf("Could not resolve/connect to AD server %s\n", server->addr);
- return FALSE;
- }
- } else {
- sock = socket_tcp4b();
- if (sock == -1) {
- printf("Could not allocate socket for connection to AD\n");
- return FALSE;
- }
- if (socket_connect4(sock, server->ip, AD_PORT) == -1) {
- printf("Could not connect to cached IP of %s\n", server->addr);
- close(sock);
- return FALSE;
- }
- }
+ int sock = server_connectInternal(server);
+ if (sock == -1) return FALSE;
printf("Connected, binding....\n");
helper_nonblock(sock);
con->fd = sock;
@@ -358,6 +455,12 @@ static BOOL server_ensureConnected(server_t *server)
con->fd = -1;
return FALSE;
}
+ // SSL
+ if (!server_connectSsl(con)) {
+ close(con->fd);
+ con->fd = -1;
+ return FALSE;
+ }
// Now bind - TODO: SASL (DIGEST-MD5?)
const size_t bodyLen = fmt_ldapbindrequest(NULL, 3, server->bind, server->password);
const size_t headerLen = fmt_ldapmessage(NULL, MSGID_BIND, BindRequest, bodyLen);
@@ -375,18 +478,71 @@ static BOOL server_ensureConnected(server_t *server)
return TRUE;
}
-static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len)
+static BOOL server_ensureSendBuffer(epoll_server_t * const s, const size_t len)
{
- if (len > 1000000) bail("server_ensureSendBuffer: request too large!");
+ if (len > 1000000) {
+ printf("server_ensureSendBuffer: request too large!\n");
+ return FALSE;
+ }
if (s->sbLen - s->sbFill < len) {
+ if (s->writeBlocked) {
+ printf("SSL Write blocked and buffer to small (%d)\n", (int)s->sbLen);
+ return FALSE;
+ }
if (s->sbPos != 0) {
memmove(s->sendBuffer, s->sendBuffer + s->sbPos, s->sbFill - s->sbPos);
s->sbFill -= s->sbPos;
s->sbPos = 0;
}
if (s->sbLen - s->sbFill < len) {
- helper_realloc(&s->sendBuffer, &s->sbLen, s->sbLen + len + 1000, "server_ensureSendBuffer");
+ if (helper_realloc(&s->sendBuffer, &s->sbLen, s->sbLen + len + (s->ssl == NULL ? 1000 : 6000), "server_ensureSendBuffer") == -1) {
+ return FALSE;
+ }
}
}
+ return TRUE;
+}
+
+static int server_connectInternal(server_t *server)
+{
+ int sock;
+ const uint16_t port = server->port != 0 ? server->port : (server->sslContext == NULL ? AD_PORT : AD_PORT_SSL);
+ if (server->lastLookup + 300 < time(NULL)) {
+ sock = helper_connect4(server->addr, port, server->ip);
+ if (sock == -1) {
+ printf("Could not resolve/connect to AD server %s\n", server->addr);
+ return -1;
+ }
+ } else {
+ sock = socket_tcp4b();
+ if (sock == -1) {
+ printf("Could not allocate socket for connection to AD\n");
+ return -1;
+ }
+ if (socket_connect4(sock, server->ip, port) == -1) {
+ printf("Could not connect to cached IP of %s\n", server->addr);
+ server->lastLookup = 0;
+ close(sock);
+ return -1;
+ }
+ }
+ return sock;
+}
+
+static BOOL server_connectSsl(epoll_server_t *server)
+{
+ if (server->serverData->sslContext == NULL) return TRUE;
+ server->ssl = ssl_new(server->fd, server->serverData->sslContext);
+ if (server->ssl == NULL) {
+ printf("Could not get SSL client from context\n");
+ return FALSE;
+ }
+ if (!ssl_connectServer(server)) {
+ printf("SSL connect failed.\n");
+ SSL_free(server->ssl);
+ server->ssl = NULL;
+ return FALSE;
+ }
+ return TRUE;
}
diff --git a/server.h b/server.h
index d2d84ef..6c4d889 100644
--- a/server.h
+++ b/server.h
@@ -6,6 +6,8 @@
struct string;
struct SearchRequest;
+void server_setPort(const char *server, const char *portStr);
+
void server_setBind(const char *server, const char *bind);
void server_setPassword(const char *server, const char *password);
@@ -14,6 +16,8 @@ void server_setBase(const char *server, const char *base);
void server_setHomeTemplate(const char *server, const char *hometemplate);
+void server_setFingerprint(const char *server, const char *fingerprint);
+
BOOL server_initServers();
BOOL server_send(epoll_server_t *server, const char *buffer, size_t len, const BOOL cork);
diff --git a/types.h b/types.h
index 71cd36d..373b52a 100644
--- a/types.h
+++ b/types.h
@@ -12,21 +12,28 @@
#define BASELEN 250
#define SIDLEN 28
#define MOUNTLEN 100
+#define FINGERPRINTLEN 20
#define REQLEN 4000
#define MAXMSGLEN 100000
#define BOOL uint8_t
-#define TRUE 1
-#define FALSE 0
+#define TRUE (1)
+#define FALSE (0)
typedef struct _server_t_ server_t;
+/**
+ * General epoll struct, to be implemented by every epoll struct.
+ */
typedef struct {
void (*callback)(void *data, int haveIn, int haveOut, int doCleanup);
int fd;
} epoll_item_t;
+/**
+ * epoll struct for listening sockets.
+ */
typedef struct {
void (*callback)(void *data, int haveIn, int haveOut, int doCleanup);
int fd;
@@ -34,6 +41,9 @@ typedef struct {
SSL_CTX *sslContext; // Listening for SSL connections, NULL otherwise
} epoll_listen_t;
+/**
+ * epoll struct for a client we're serving.
+ */
typedef struct {
void (*callback)(void *data, int haveIn, int haveOut, int doCleanup);
int fd;
@@ -51,22 +61,32 @@ typedef struct {
char readBuffer[REQLEN]; // Static, queries > 4000 bytes simply not supported
} epoll_client_t;
+/**
+ * epoll struct for a connection to AD.
+ */
typedef struct {
void (*callback)(void *data, int haveIn, int haveOut, int doCleanup);
int fd;
+ //
// Send buffer (me to server)
size_t sbPos, sbFill, sbLen;
+ SSL *ssl; // NULL if not encrypted
char *sendBuffer; // Dynamically allocated, might or might not get huge
// Recv buffer (server's response)
size_t rbPos;
char readBuffer[MAXMSGLEN];
- BOOL bound;
+ BOOL bound; // Already bound to server?
BOOL dynamic;
- //unsigned long messageId; // ID of message currently being received
+ BOOL sslConnected;
+ BOOL kill; // Should the connection be killed?
+ BOOL writeBlocked; // An SSL_write returned WANT_*, so we must not reallocate the current send buffer
time_t lastActive;
server_t *serverData;
} epoll_server_t;
+/**
+ * Configuration data for an ADS we're proxying.
+ */
struct _server_t_ {
size_t baseLen;
char ip[4];
@@ -77,6 +97,9 @@ struct _server_t_ {
char base[BASELEN];
char sid[SIDLEN];
char homeTemplate[MOUNTLEN];
+ unsigned char fingerprint[FINGERPRINTLEN];
+ uint16_t port;
+ SSL_CTX *sslContext;
epoll_server_t con;
};