summaryrefslogtreecommitdiffstats
path: root/server.c
diff options
context:
space:
mode:
Diffstat (limited to 'server.c')
-rw-r--r--server.c107
1 files changed, 94 insertions, 13 deletions
diff --git a/server.c b/server.c
index 54fd154..9bf9e58 100644
--- a/server.c
+++ b/server.c
@@ -33,6 +33,7 @@ static BOOL server_ensureConnected(server_t *server);
static BOOL server_ensureSendBuffer(epoll_server_t * const s, const size_t len);
static int server_connectInternal(server_t *server);
static BOOL server_connectSsl(epoll_server_t *server);
+static BOOL server_handleStartTlsResponse(epoll_server_t *server, const size_t len);
// Generate a message ID for request to AD
static inline uint32_t msgId()
@@ -93,6 +94,14 @@ void server_setGenUidNumber(const char *server, const char *enabledStr)
plog(DEBUG_VERBOSE, "Using UID mapping for %s: %d", server, (int)entry->genUidNumber);
}
+void server_setUseStartTls(const char *server, const char *enabledStr)
+{
+ server_t *entry = server_create(server);
+ if (entry == NULL) return;
+ entry->useStartTls = parseBool(enabledStr);
+ plog(DEBUG_VERBOSE, "STARTTLS for %s: %d", server, (int)entry->useStartTls);
+}
+
static void strtolower(char *str)
{
while (*str != '\0') {
@@ -167,12 +176,14 @@ void server_setCaBundle(const char *server, const char *file)
server_t *entry = server_create(server);
if (entry == NULL) return;
if (file == NULL || *file == '\0') return;
- int fh = open(file, O_RDONLY);
- if (fh == -1) {
- printf("Error: cabundle '%s' not readable.\n", file);
- exit(1);
+ if (strcmp(file, "*") != 0) {
+ int fh = open(file, O_RDONLY);
+ if (fh == -1) {
+ printf("Error: cabundle '%s' not readable.\n", file);
+ exit(1);
+ }
+ close(fh);
}
- close(fh);
if (snprintf(entry->cabundle, MAXPATH, "%s", file) >= MAXPATH) printf("Warning: CaBundle for %s is too long.\n", server);
ssl_init();
}
@@ -201,7 +212,7 @@ void server_setHomeTemplate(const char *server, const char *hometemplate)
void server_setHomeAttribute(const char *server, const char *homeattribute)
{
server_t *entry = server_create(server);
- if (entry == NULL || entry->sslContext != NULL) return;
+ if (entry == NULL) return;
free((void*)entry->map.homemount.s);
char *tmp = strdup(homeattribute);
strtolower(tmp);
@@ -238,7 +249,7 @@ void server_setFingerprint(const char *server, const char *fingerprint)
chars++;
}
if (chars / 2 != FINGERPRINTLEN || val == -1) {
- printf("Warning: Fingerprint for %s is invalid (adsha1 should be a SHA-1 hash of the cert in hex representation.)\n", server);
+ printf("Warning: Fingerprint for %s is invalid (should be a SHA-1 hash of the cert in hex representation.)\n", server);
return;
}
printf("Using fingerprint ");
@@ -284,6 +295,10 @@ BOOL server_initServers()
if (server->cabundle[0] != '\0' || memcmp(server->fingerprint, "\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) != 0) {
// Have cabundle or fingerprint - use SSL to talk to server
server->sslContext = ssl_newClientCtx(server->cabundle);
+ if (strcmp(server->cabundle, "*") == 0) {
+ server->cabundle[0] = '\0';
+ }
+ printf("Using SSL for Uplink\n");
}
printf("%s:\n Bind: %s\n Base: %s\n", server->addr, server->bind, server->base);
printf("Plain LDAP-LDAP: %d\n", (int)server->plainLdap);
@@ -363,6 +378,7 @@ uint32_t server_tryUserBind(server_t *server, struct string *binddn, struct stri
con->fd = -1;
con->bound = FALSE;
con->dynamic = TRUE;
+ con->startTlsId = 0;
con->sbPos = con->sbFill = 0;
int sock = server_connectInternal(server);
if (sock == -1) {
@@ -444,6 +460,7 @@ void server_free(epoll_server_t *server)
proxy_removeServer(server);
server->bound = FALSE;
server->kill = FALSE;
+ server->startTlsId = 0;
if (server->ssl != NULL) {
SSL_free(server->ssl);
server->ssl = NULL;
@@ -530,7 +547,13 @@ static void server_haveIn(epoll_server_t *server)
len += consumed;
if (len > server->rbPos) break; // Body not complete
//printf("[AD] Received complete reply (need %d, have %d)...\n", (int)len, (int)server->rbPos);
- if (!proxy_fromServer(server, len)) {
+ // Complete STARTTLS
+ if (server->ssl == NULL && server->startTlsId != 0) {
+ if (!server_handleStartTlsResponse(server, len)) {
+ server->kill = TRUE;
+ return;
+ }
+ } else if (!proxy_fromServer(server, len)) {
if (server->dynamic) {
server->kill = TRUE;
return;
@@ -588,6 +611,7 @@ static BOOL server_haveOut(epoll_server_t * const server)
{
if (server->kill) return FALSE;
if (server->ssl != NULL && !server->sslConnected) return TRUE;
+ if (server->ssl == NULL && server->startTlsId != 0) return TRUE; // We asked the server for TLS, don't send anything until we get a reply
// Bind not sent/acknowledged yet - send bind if pending, otherwise do nothing
if (!server->bound) {
if (server->bindLen == 0) return TRUE;
@@ -745,17 +769,17 @@ static int server_connectInternal(server_t *server)
if (server->lastLookup + 300 < time(NULL)) {
sock = helper_connect4(server->addr, port, server->ip);
if (sock == -1) {
- printf("[Proxy] Could not resolve hostname or connect to AD server %s (errno=%d)\n", server->addr, errno);
+ printf("[Server] Could not resolve hostname or connect to AD server %s (errno=%d)\n", server->addr, errno);
return -1;
}
} else {
sock = helper_newSocket();
if (sock == -1) {
- printf("[Proxy] Could not allocate socket for connection to AD server %s (errno=%d)\n", server->addr, errno);
+ printf("[Server] Could not allocate socket for connection to AD server %s (errno=%d)\n", server->addr, errno);
return -1;
}
if (socket_connect4(sock, server->ip, port) == -1) {
- printf("[Proxy] Could not connect to cached IP (%s) of %s (errno=%d)\n", server->ip, server->addr, errno);
+ printf("[Server] Could not connect to cached IP (%s) of %s (errno=%d)\n", server->ip, server->addr, errno);
server->lastLookup = 0;
close(sock);
return -1;
@@ -768,13 +792,29 @@ static BOOL server_connectSsl(epoll_server_t *server)
{
if (server->serverData->sslContext == NULL) return TRUE;
server->sslConnected = FALSE;
+ if (server->serverData->useStartTls && server->startTlsId == 0) {
+ static const char oid[] = "1.3.6.1.4.1.1466.20037";
+ static const size_t oidlen = sizeof(oid) - 1;
+ char buf[200];
+ size_t len;
+ uint32_t mid = msgId();
+ len = fmt_asn1string(NULL, CONTEXT_SPECIFIC, 0, 0, oid, oidlen);
+ len = fmt_ldapmessage(buf, mid, ExtendedRequest, len);
+ len += fmt_asn1string(buf + len, CONTEXT_SPECIFIC, 0, 0, oid, oidlen);
+ if (!server_send(server, buf, len, FALSE)) {
+ printf("[Server] Sending STARTTLS request failed\n");
+ return FALSE;
+ }
+ server->startTlsId = mid;
+ return TRUE;
+ }
server->ssl = ssl_new(server->fd, server->serverData->sslContext);
if (server->ssl == NULL) {
- printf("[Proxy] Could not get SSL client from context\n");
+ printf("[Server] Could not get SSL client from context\n");
return FALSE;
}
if (!ssl_connectServer(server)) {
- printf("[Proxy] SSL connect to AD server %s failed.\n", server->serverData->addr);
+ printf("[Server] SSL connect to AD server %s failed.\n", server->serverData->addr);
SSL_free(server->ssl);
server->ssl = NULL;
return FALSE;
@@ -782,3 +822,44 @@ static BOOL server_connectSsl(epoll_server_t *server)
return TRUE;
}
+static BOOL server_handleStartTlsResponse(epoll_server_t *server, const size_t maxLen)
+{
+ unsigned long messageId, op;
+ size_t len;
+ const size_t res = scan_ldapmessage(server->readBuffer, server->readBuffer + maxLen, &messageId, &op, &len);
+ if (res == 0) {
+ printf("[Server] Error parsing STARTLS reply\n");
+ return FALSE;
+ }
+ if (messageId != server->startTlsId) {
+ printf("[Server] STARTTLS reply doesn't have expected message id\n");
+ return FALSE;
+ }
+ if (op != ExtendedResponse) {
+ printf("[Server] STARTTLS reply doesn't have op ExtendedResponse\n");
+ return FALSE;
+ }
+ enum asn1_tagclass tc;
+ enum asn1_tagtype tt;
+ unsigned long tag;
+ long value;
+ len = scan_asn1int(server->readBuffer + res, server->readBuffer + maxLen, &tc, &tt, &tag, &value);
+ if (len == 0) {
+ printf("[Server] Could not read resultCode int from STARTTLS reply\n");
+ return FALSE;
+ }
+ if (tag != ENUMERATED && tag != INTEGER) {
+ printf("[Server] STARTTLS reply doesn't contain resultCode\n");
+ return FALSE;
+ }
+ if (value != 0) {
+ printf("[Server] STARTTLS reply has resultCode != 0\n");
+ return FALSE;
+ }
+ if (!server_connectSsl(server)) {
+ printf("[Server] ...after successful STARTTLS reply\n");
+ return FALSE;
+ }
+ return TRUE;
+}
+