summaryrefslogblamecommitdiffstats
path: root/server.c
blob: 5ec6148a2c036ed9a09cb52e2f12be8c2d52ff27 (plain) (tree)
1
2
3
4
5
6




                      
                    







                   
                    
                       

                    





                                                   
                                                
                                                                                

                                                          
                                                     


                                                                                










                                          











                                                                                 






















                                                                                                                               




















                                                                                                                                                    








































                                                                                                                                              
                         



                                                        
                                                                                                           

                                                         
         
                    



                       
                                               




                                                                                     
                                                                                                               
                                           

                                                                                              
                                           

                 
                    

 
                                                                          








                                                                                      
                                                



                                                                       








                                                                                                     



                                                  









                                                                                  




                                      
















                                                                                    





















                                                                             
                                                                    





                                               







                                      
                                           




                                                       




                                                                               
                                        


                                    



































                                                                                                                               



                                                                         

                                                                          

                                       























                                                                                                                   
                                 

                                                                               
                         






                                                                                                   
                 
                                                                                                     
         

 
                                                                                         
 
                                                                  



                                                                                 
                                     

                                                
                                                 







                                                                                



                                                    


                                                                 
                                          
                    

 
                                                         

                                                
























                                                                                                                                


                                                





                                                                                                                        



                                           
                                                    
 



                                                                             
                                                           
                                     

                                                  









                                                                                





                                      
                                              
                                                                                            
                                                                                         













                                                                              
                                                                               
 



                                                                        
                                         



                                                                                              





                                                                                               


                                                                                                                                                          

                 











































                                                                                                                      

 
#include "server.h"
#include "proxy.h"
#include "helper.h"
#include "epoll.h"
#include "tmpbuffer.h"
#include "openssl.h"
#include <time.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <socket.h>

#define AD_PORT 3268
#define AD_PORT_SSL 636
#define MSGID_BIND 1

#define MAX_SERVERS 10
static server_t *servers = NULL;
static int serverCount = 0;

static void server_init();
static server_t *server_create(const char *server);
static void server_free(epoll_server_t *server);
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup);
static void server_haveIn(epoll_server_t *server);
static void server_haveOut(epoll_server_t * const server);
static BOOL server_ensureConnected(server_t *server);
static BOOL server_ensureSendBuffer(epoll_server_t * const s, const size_t len);
static int server_connectInternal(server_t *server);
static BOOL server_connectSsl(epoll_server_t *server);

// Generate a message ID for request to AD
static inline uint32_t msgId()
{
	static uint32_t id = 1336;
	if (++id < 2) id = 2;
	return id;
}

// Setting up server(s)

void server_setPort(const char *server, const char *portStr)
{
	server_t *entry = server_create(server);
	if (entry == NULL) return;
	int port = atoi(portStr);
	if (port < 1 || port > 65535) {
		printf("Warning: Invalid port '%s' for '%s'\n", portStr, server);
		return;
	}
	entry->port = (uint16_t)port;
}

void server_setBind(const char *server, const char *bind)
{
	server_t *entry = server_create(server);
	if (entry == NULL) return;
	if (snprintf(entry->bind, BINDLEN, "%s", bind) >= BINDLEN) printf("Warning: BindDN for %s is too long.\n", server);
}

void server_setPassword(const char *server, const char *password)
{
	server_t *entry = server_create(server);
	if (entry == NULL) return;
	if (snprintf(entry->password, PWLEN, "%s", password) >= PWLEN) printf("Warning: BindPW for %s is too long.\n", server);
}

void server_setBase(const char *server, const char *base)
{
	server_t *entry = server_create(server);
	if (entry == NULL) return;
	if (snprintf(entry->base, BASELEN, "%s", base) >= BASELEN) printf("Warning: SearchBase for %s is too long.\n", server);
	entry->baseLen = normalize_dn(entry->base, entry->base, min(strlen(entry->base), BASELEN - 1));
	entry->base[entry->baseLen] = '\0';
}

void server_setHomeTemplate(const char *server, const char *hometemplate)
{
	server_t *entry = server_create(server);
	if (entry == NULL) return;
	if (snprintf(entry->homeTemplate, MOUNTLEN, "%s", hometemplate) >= MOUNTLEN) printf("Warning: Home Template for %s is too long.\n", server);
	// TODO: Better template system. Using a format string is too lazy
	BOOL b = FALSE;
	char *s = entry->homeTemplate;
	int count = 0;
	while (*s) {
		if (b) {
			if (*s != '%') count++;
			b = FALSE;
		} else if (*s == '%') b = TRUE;
		if (count > 5) *s = '_';
		if (*s == '\\') *s = '/';
		s++;
	}
	if (count > 5) printf("WARNING: Too many '%%' in Home Template for %s. Don't forget to replace literal '%%' with '%%%%'\n", server);
}

void server_setFingerprint(const char *server, const char *fingerprint)
{
	server_t *entry = server_create(server);
	if (entry == NULL || entry->sslContext != NULL) return;
	int chars = 0, val = -1;
	while (*fingerprint != '\0' && chars / 2 < FINGERPRINTLEN) {
		if (*fingerprint == ':' || *fingerprint == ' ') {
			fingerprint++;
			continue;
		}
		val = -1;
		if (*fingerprint >= '0' && *fingerprint <= '9') {
			val = *fingerprint - '0';
		} else if (*fingerprint >= 'a' && *fingerprint <= 'f') {
			val = *fingerprint - 'a' + 10;
		} else if (*fingerprint >= 'A' && *fingerprint <= 'F') {
			val = *fingerprint - 'A' + 10;
		} else {
			break;
		}
		if (chars % 2 == 0) {
			entry->fingerprint[chars / 2] |= val << 4;
		} else {
			entry->fingerprint[chars / 2] |= val;
		}
		fingerprint++;
		chars++;
	}
	if (chars / 2 != FINGERPRINTLEN || val == -1) {
		printf("Warning: Fingerprint for %s is invalid (adsha1 should be a SHA-1 hash of the cert in hex representation.)\n", server);
		return;
	}
	printf("Using fingerprint ");
	for (int i = 0; i < FINGERPRINTLEN - 1; ++i) {
		printf("%02x:", (int)entry->fingerprint[i]);
	}
	printf("%02x for %s\n", (int)entry->fingerprint[FINGERPRINTLEN-1], server);
	ssl_init();
	entry->sslContext = ssl_newClientCtx();
}

BOOL server_initServers()
{
	int i;
	printf("%d servers configured.\n", serverCount);
	for (i = 0; i < serverCount; ++i) {
		printf("%s:\n  Bind: %s\n  Base: %s\n", servers[i].addr, servers[i].bind, servers[i].base);
		if (!server_ensureConnected(&servers[i]))
			return FALSE;
	}
	return TRUE;
}

// What the proxy calls

server_t *server_getFromBase(struct string *in)
{
	int i;
	char buffer[TMPLEN];
	const size_t searchLen = normalize_dn(buffer, in->s, min(in->l, TMPLEN - 1));
	buffer[searchLen] = '\0';
	// Now buffer contains the normalized wanted bind/domain/whatev. Try to find a match in the server list
	for (i = 0; i < serverCount; ++i) {
		if (searchLen < servers[i].baseLen) continue;
		if (strcmp(servers[i].base, buffer + (searchLen - servers[i].baseLen)) == 0) {
			return &servers[i];
		}
	}
	return NULL;
}

uint32_t server_searchRequest(server_t *server, struct SearchRequest *req)
{
	if (!server_ensureConnected(server)) return 0;
	const uint32_t msgid = msgId();
	const size_t bodyLen = fmt_ldapsearchrequest(NULL, req);
	const size_t headerLen = fmt_ldapmessage(NULL, msgid, SearchRequest, bodyLen);
	char buffer[bodyLen + 50];
	char *bufoff = buffer + 50;
	fmt_ldapsearchrequest(bufoff, req);
	fmt_ldapmessage(bufoff - headerLen, msgid, SearchRequest, bodyLen);
	epoll_server_t * const s = &server->con;
	server_send(s, bufoff - headerLen, headerLen + bodyLen, FALSE);
	return msgid;
}

uint32_t server_tryUserBind(server_t *server, struct string *binddn, struct string *password)
{
	epoll_server_t *con = calloc(1, sizeof(epoll_server_t));
	con->serverData = server;
	con->fd = -1;
	con->bound = FALSE;
	con->dynamic = TRUE;
	printf("Connecting to AD '%s' for %.*ss bind...\n", server->addr, (int)binddn->l, binddn->s);
	con->sbPos = con->sbFill = 0;
	int sock = server_connectInternal(server);
	if (sock == -1) {
		server_free(con);
		return 0;
	}
	printf("[ADB] Connected, binding....\n");
	helper_nonblock(sock);
	con->fd = sock;
	con->callback = &server_callback;
	if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) {
		printf("[ADB] epoll_add failed for ad server %s\n", server->addr);
		server_free(con);
		return 0;
	}
	// SSL
	if (!server_connectSsl(con)) {
		server_free(con);
		return 0;
	}
	// Now bind - TODO: SASL (DIGEST-MD5?)
	const uint32_t id = msgId();
	const size_t bodyLen = fmt_ldapbindrequeststring(NULL, 3, binddn, password);
	const size_t headerLen = fmt_ldapmessage(NULL, id, BindRequest, bodyLen);
	char buffer[bodyLen + 50];
	char *bufoff = buffer + 50;
	if (headerLen >= 50) {
		printf("[ADB] bind too long for %s\n", server->addr);
		server_free(con);
		return 0;
	}
	fmt_ldapbindrequeststring(bufoff, 3, binddn, password);
	fmt_ldapmessage(bufoff - headerLen, id, BindRequest, bodyLen);
	server_send(con, bufoff - headerLen, bodyLen + headerLen, FALSE);
	return id;
}

//
// Private stuff

static void server_init()
{
	if (servers != NULL) return;
	servers = calloc(MAX_SERVERS, sizeof(server_t));
}

static server_t *server_create(const char *server)
{
	int i;
	server_init();
	for (i = 0; i < serverCount; ++i) {
		if (strcmp(servers[i].addr, server) == 0) return &servers[i];
	}
	if (serverCount >= MAX_SERVERS) {
		printf("Cannot add server %s: Too many servers.\n", server);
		return NULL;
	}
	snprintf(servers[serverCount].addr, ADDRLEN, "%s", server);
	servers[serverCount].con.fd = -1;
	servers[serverCount].con.serverData = &servers[serverCount];
	return &servers[serverCount++];
}

static void server_free(epoll_server_t *server)
{
	server->bound = FALSE;
	if (server->ssl != NULL) {
		SSL_free(server->ssl);
		server->ssl = NULL;
	}
	if (server->fd != -1) {
		close(server->fd);
		server->fd = -1;
	}
	server->sbPos = server->sbFill = 0;
	if (server->dynamic) {
		printf("Freeing Bind-AD-Connection\n");
		free(server->sendBuffer);
		free(server);
	}
}

static void server_callback(void *data, int haveIn, int haveOut, int doCleanup)
{
	epoll_server_t *server = (epoll_server_t *)data;
	if (doCleanup || server->kill) {
		server_free(server);
		return;
	}
	if (server->ssl == NULL) {
		// Plain connection
		if (haveIn) server_haveIn(server);
		if (haveOut) server_haveOut(server);
		if (server->kill) server_free(server);
		return;
	}
	// SSL
	if (!server->sslConnected) {
		// Still SSL-Connecting
		if (!ssl_connectServer(server)) {
			printf("SSL Server connect failed!\n");
			server_free(server);
			return;
		}
		if (!server->sslConnected) return;
	}
	// Since we don't know if the incoming data is just wrapped application data or ssl protocol stuff, we always call both
	server_haveIn(server);
	server_haveOut(server);
	if (server->kill) server_free(server);
}

static void server_haveIn(epoll_server_t *server)
{
	for (;;) {
		if (server->rbPos >= MAXMSGLEN) {
			printf("[AD->Proxy] Read buffer overflow. Disconnecting.\n");
			server->kill = TRUE;
			return;
		}
		const size_t buflen = MAXMSGLEN - server->rbPos;
		ssize_t ret;
		if (server->ssl == NULL) {
			// Plain
			ret = read(server->fd, server->readBuffer + server->rbPos, buflen);
			printf("AD read %d (err %d)\n", (int)ret, errno);
			if (ret < 0 && errno == EINTR) continue;
			if (ret < 0 && errno == EAGAIN) break;
			if (ret <= 0) {
				printf("AD Server gone while reading.\n");
				server->kill = TRUE;
				return;
			}
		} else {
			// SSL
			ret = SSL_read(server->ssl, server->readBuffer + server->rbPos, buflen);
			if (ret <= 0) {
				int err = SSL_get_error(server->ssl, ret);
				if (SSL_BLOCKED(err)) break;
				printf("AD Server gone while reading (%d, %d).\n", (int)ret, err);
				server->kill = TRUE;
				return;
			}
		}
		server->rbPos += ret;
		// Request complete?
		for (;;) {
			size_t consumed, len;
			consumed = scan_asn1SEQUENCE(server->readBuffer, server->readBuffer + server->rbPos, &len);
			if (consumed == 0) break; // Length-Header not complete
			len += consumed;
			if (len > server->rbPos) break; // Body not complete
			printf("[AD] Received complete reply...\n");
			if (!proxy_fromServer(server, len)) {
				if (server->dynamic) {
					server->kill = TRUE;
					return;
				}
				printf("Error parsing reply from AD.\n");
				// Let's try to go on with the next message....
			}
			// Shift remaining buffer contents
			if (len == server->rbPos) {
				server->rbPos = 0;
				break;
			}
			memmove(server->readBuffer, server->readBuffer + len, server->rbPos - len);
			server->rbPos -= len;
		}
		if ((ssize_t)buflen > ret) break; // Read less than buffer len, epoll will fire again
	}
}

BOOL server_send(epoll_server_t *server, const char *buffer, size_t len, const BOOL cork)
{
	if (server->ssl == NULL && server->sbFill == 0 && !cork) {
		// Nothing in send buffer, fire away
		const int ret = write(server->fd, buffer, len);
		if (ret == 0 || (ret < 0 && errno != EINTR && errno != EAGAIN)) {
			printf("Server gone when trying to send.\n");
			return FALSE;
		}
		server->lastActive = time(NULL);
		if (ret == (int)len) return TRUE;
		// Couldn't send everything, continue with buffering logic below
		if (ret > 0) {
			printf("[AD] Partial send (%d of %d)\n", ret, (int)len);
			buffer += ret;
			len -= (size_t)ret;
		}
	}
	// Buffer...
	if (!server_ensureSendBuffer(server, len)) {
		server->kill = TRUE;
		return FALSE;
	}
	// Finally append to buffer
	memcpy(server->sendBuffer + server->sbFill, buffer, len);
	server->sbFill += len;
	if (!cork) server_haveOut(server);
	return TRUE;
}

static void server_haveOut(epoll_server_t * const server)
{
	while (server->sbPos < server->sbFill) {
		const ssize_t tosend = server->sbFill - server->sbPos;
		ssize_t ret;
		if (server->ssl == NULL) {
			// Plain
			ret = write(server->fd, server->sendBuffer + server->sbPos, tosend);
			if (ret < 0 && errno == EINTR) continue;
			if (ret < 0 && errno == EAGAIN) return;
			if (ret <= 0) {
				printf("Connection to AD Server failed while flushing (ret: %d, errno: %d)\n", (int)ret, errno);
				return;
			}
		} else {
			// SSL
			ret = SSL_write(server->ssl, server->sendBuffer + server->sbPos, tosend);
			if (ret <= 0) {
				int err = SSL_get_error(server->ssl, ret);
				if (SSL_BLOCKED(err)) {
					server->writeBlocked = TRUE;
					return; // Blocking
				}
				printf("SSL server gone while sending (%d)\n", err);
				ERR_print_errors_fp(stdout);
				server->kill = TRUE;
				return; // Closed
			}
		}
		server->lastActive = time(NULL);
		server->sbPos += ret;
		if (server->ssl != NULL) {
			memmove(server->sendBuffer, server->sendBuffer + server->sbPos, server->sbFill - server->sbPos);
			server->sbFill -= server->sbPos;
			server->sbPos = 0;
		}
		if (server->ssl == NULL && ret != tosend) return;
	}
	server->sbPos = server->sbFill = 0;
}

static BOOL server_ensureConnected(server_t *server)
{
	epoll_server_t * const con = &server->con;
	if (con->fd != -1 && con->lastActive + 120 > time(NULL)) return TRUE;
	if (con->fd != -1) close(con->fd);
	con->bound = FALSE;
	printf("Connecting to AD '%s'...\n", server->addr);
	con->sbPos = con->sbFill = 0;
	int sock = server_connectInternal(server);
	if (sock == -1) return FALSE;
	printf("Connected, binding....\n");
	helper_nonblock(sock);
	con->fd = sock;
	con->callback = &server_callback;
	if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) {
		printf("epoll_add failed for ad server %s\n", server->addr);
		close(con->fd);
		con->fd = -1;
		return FALSE;
	}
	// SSL
	if (!server_connectSsl(con)) {
		close(con->fd);
		con->fd = -1;
		return FALSE;
	}
	// Now bind - TODO: SASL (DIGEST-MD5?)
	const size_t bodyLen = fmt_ldapbindrequest(NULL, 3, server->bind, server->password);
	const size_t headerLen = fmt_ldapmessage(NULL, MSGID_BIND, BindRequest, bodyLen);
	char buffer[bodyLen + 50];
	char *bufoff = buffer + 50;
	if (headerLen >= 50) {
		printf("[AD] bind too long for %s\n", server->addr);
		close(con->fd);
		con->fd = -1;
		return FALSE;
	}
	fmt_ldapbindrequest(bufoff, 3, server->bind, server->password);
	fmt_ldapmessage(bufoff - headerLen, MSGID_BIND, BindRequest, bodyLen);
	server_send(con, bufoff - headerLen, bodyLen + headerLen, FALSE);
	return TRUE;
}

static BOOL server_ensureSendBuffer(epoll_server_t * const s, const size_t len)
{
	if (len > 1000000) {
		printf("server_ensureSendBuffer: request too large!\n");
		return FALSE;
	}
	if (s->sbLen - s->sbFill < len) {
		if (s->writeBlocked) {
			printf("SSL Write blocked and buffer to small (%d)\n", (int)s->sbLen);
			return FALSE;
		}
		if (s->sbPos != 0) {
			memmove(s->sendBuffer, s->sendBuffer + s->sbPos, s->sbFill - s->sbPos);
			s->sbFill -= s->sbPos;
			s->sbPos = 0;
		}
		if (s->sbLen - s->sbFill < len) {
			if (helper_realloc(&s->sendBuffer, &s->sbLen, s->sbLen + len + (s->ssl == NULL ? 1000 : 6000), "server_ensureSendBuffer") == -1) {
				return FALSE;
			}
		}
	}
	return TRUE;
}

static int server_connectInternal(server_t *server)
{
	int sock;
	const uint16_t port = server->port != 0 ? server->port : (server->sslContext == NULL ? AD_PORT : AD_PORT_SSL);
	if (server->lastLookup + 300 < time(NULL)) {
		sock = helper_connect4(server->addr, port, server->ip);
		if (sock == -1) {
			printf("Could not resolve/connect to AD server %s\n", server->addr);
			return -1;
		}
	} else {
		sock = socket_tcp4b();
		if (sock == -1) {
			printf("Could not allocate socket for connection to AD\n");
			return -1;
		}
		if (socket_connect4(sock, server->ip, port) == -1) {
			printf("Could not connect to cached IP of %s\n", server->addr);
			server->lastLookup = 0;
			close(sock);
			return -1;
		}
	}
	return sock;
}

static BOOL server_connectSsl(epoll_server_t *server)
{
	if (server->serverData->sslContext == NULL) return TRUE;
	server->ssl = ssl_new(server->fd, server->serverData->sslContext);
	if (server->ssl == NULL) {
		printf("Could not get SSL client from context\n");
		return FALSE;
	}
	if (!ssl_connectServer(server)) {
		printf("SSL connect failed.\n");
		SSL_free(server->ssl);
		server->ssl = NULL;
		return FALSE;
	}
	return TRUE;
}