summaryrefslogblamecommitdiffstats
path: root/server.c
blob: 39f8dceb6132d455523238c9a9a33ea213f5540f (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13












                      


                    





                                                   
                                                

                                                                                
                                                     


































                                                                                                                               




















                                                                                                                                                    
                         



                                                        
                                                                                                           

                                                         
         
                    



                       
                                               




                                                                                     
                                                                                                               
                                           

                                                                                              
                                           

                 
                    

 
                                                                          








                                                                                      
                                                



                                                                       
























































                                                                                                     





















                                                                             
                                                                    








                                                




                                                       


































                                                                                                                           
                                                                     



                                                                    
                                                                                 
                                                                                       














                                                                                                             
                                                                                         





                                                                                 
                                     

                                                
                                                 












                                                                                
                    



















                                                                                                                   
                                                    
 



                                                                             
                                                           





























                                                                                            
                                              
                                                                                            
                                                                                         




























                                                                                                                    
#include "server.h"
#include "proxy.h"
#include "helper.h"
#include "epoll.h"
#include "tmpbuffer.h"
#include <time.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <socket.h>

#define AD_PORT 3268
#define MSGID_BIND 1

#define MAX_SERVERS 10
static server_t *servers = NULL;
static int serverCount = 0;

static void server_init();
static server_t *server_create(const char *server);
static void server_free(epoll_server_t *server);
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup);
static void server_flush(epoll_server_t * const server);
static BOOL server_ensureConnected(server_t *server);
static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len);

// Generate a message ID for request to AD
static inline uint32_t msgId()
{
	static uint32_t id = 1336;
	if (++id < 2) id = 2;
	return id;
}

// Setting up server(s)

void server_setBind(const char *server, const char *bind)
{
	server_t *entry = server_create(server);
	if (entry == NULL) return;
	if (snprintf(entry->bind, BINDLEN, "%s", bind) >= BINDLEN) printf("Warning: BindDN for %s is too long.\n", server);
}

void server_setPassword(const char *server, const char *password)
{
	server_t *entry = server_create(server);
	if (entry == NULL) return;
	if (snprintf(entry->password, PWLEN, "%s", password) >= PWLEN) printf("Warning: BindPW for %s is too long.\n", server);
}

void server_setBase(const char *server, const char *base)
{
	server_t *entry = server_create(server);
	if (entry == NULL) return;
	if (snprintf(entry->base, BASELEN, "%s", base) >= BASELEN) printf("Warning: SearchBase for %s is too long.\n", server);
	entry->baseLen = normalize_dn(entry->base, entry->base, min(strlen(entry->base), BASELEN - 1));
	entry->base[entry->baseLen] = '\0';
}

void server_setHomeTemplate(const char *server, const char *hometemplate)
{
	server_t *entry = server_create(server);
	if (entry == NULL) return;
	if (snprintf(entry->homeTemplate, MOUNTLEN, "%s", hometemplate) >= MOUNTLEN) printf("Warning: Home Template for %s is too long.\n", server);
	// TODO: Better template system. Using a format string is too lazy
	BOOL b = FALSE;
	char *s = entry->homeTemplate;
	int count = 0;
	while (*s) {
		if (b) {
			if (*s != '%') count++;
			b = FALSE;
		} else if (*s == '%') b = TRUE;
		if (count > 5) *s = '_';
		if (*s == '\\') *s = '/';
		s++;
	}
	if (count > 5) printf("WARNING: Too many '%%' in Home Template for %s. Don't forget to replace literal '%%' with '%%%%'\n", server);
}

BOOL server_initServers()
{
	int i;
	printf("%d servers configured.\n", serverCount);
	for (i = 0; i < serverCount; ++i) {
		printf("%s:\n  Bind: %s\n  Base: %s\n", servers[i].addr, servers[i].bind, servers[i].base);
		if (!server_ensureConnected(&servers[i]))
			return FALSE;
	}
	return TRUE;
}

// What the proxy calls

server_t *server_getFromBase(struct string *in)
{
	int i;
	char buffer[TMPLEN];
	const size_t searchLen = normalize_dn(buffer, in->s, min(in->l, TMPLEN - 1));
	buffer[searchLen] = '\0';
	// Now buffer contains the normalized wanted bind/domain/whatev. Try to find a match in the server list
	for (i = 0; i < serverCount; ++i) {
		if (searchLen < servers[i].baseLen) continue;
		if (strcmp(servers[i].base, buffer + (searchLen - servers[i].baseLen)) == 0) {
			return &servers[i];
		}
	}
	return NULL;
}

uint32_t server_searchRequest(server_t *server, struct SearchRequest *req)
{
	if (!server_ensureConnected(server)) return 0;
	const uint32_t msgid = msgId();
	const size_t bodyLen = fmt_ldapsearchrequest(NULL, req);
	const size_t headerLen = fmt_ldapmessage(NULL, msgid, SearchRequest, bodyLen);
	char buffer[bodyLen + 50];
	char *bufoff = buffer + 50;
	fmt_ldapsearchrequest(bufoff, req);
	fmt_ldapmessage(bufoff - headerLen, msgid, SearchRequest, bodyLen);
	epoll_server_t * const s = &server->con;
	server_send(s, bufoff - headerLen, headerLen + bodyLen, FALSE);
	return msgid;
}

uint32_t server_tryUserBind(server_t *server, struct string *binddn, struct string *password)
{
	epoll_server_t *con = calloc(1, sizeof(epoll_server_t));
	con->serverData = server;
	con->fd = -1;
	con->bound = FALSE;
	con->dynamic = TRUE;
	printf("Connecting to AD '%s' for %.*ss bind...\n", server->addr, (int)binddn->l, binddn->s);
	con->sbPos = con->sbFill = 0;
	int sock;
	if (server->lastLookup + 300 < time(NULL)) {
		sock = helper_connect4(server->addr, AD_PORT, server->ip);
		if (sock == -1) {
			printf("[ADB] Could not resolve/connect to AD server %s\n", server->addr);
			server_free(con);
			return 0;
		}
	} else {
		sock = socket_tcp4b();
		if (sock == -1) {
			printf("[ADB] Could not allocate socket for connection to AD\n");
			server_free(con);
			return 0;
		}
		if (socket_connect4(sock, server->ip, AD_PORT) == -1) {
			printf("[ADB] Could not connect to cached IP of %s\n", server->addr);
			close(sock);
			server_free(con);
			return 0;
		}
	}
	printf("[ADB] Connected, binding....\n");
	helper_nonblock(sock);
	con->fd = sock;
	con->callback = &server_callback;
	if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) {
		printf("[ADB] epoll_add failed for ad server %s\n", server->addr);
		server_free(con);
		return 0;
	}
	// Now bind - TODO: SASL (DIGEST-MD5?)
	const uint32_t id = msgId();
	const size_t bodyLen = fmt_ldapbindrequeststring(NULL, 3, binddn, password);
	const size_t headerLen = fmt_ldapmessage(NULL, id, BindRequest, bodyLen);
	char buffer[bodyLen + 50];
	char *bufoff = buffer + 50;
	if (headerLen >= 50) {
		printf("[ADB] bind too long for %s\n", server->addr);
		server_free(con);
		return 0;
	}
	fmt_ldapbindrequeststring(bufoff, 3, binddn, password);
	fmt_ldapmessage(bufoff - headerLen, id, BindRequest, bodyLen);
	server_send(con, bufoff - headerLen, bodyLen + headerLen, FALSE);
	return id;
}

//
// Private stuff

static void server_init()
{
	if (servers != NULL) return;
	servers = calloc(MAX_SERVERS, sizeof(server_t));
}

static server_t *server_create(const char *server)
{
	int i;
	server_init();
	for (i = 0; i < serverCount; ++i) {
		if (strcmp(servers[i].addr, server) == 0) return &servers[i];
	}
	if (serverCount >= MAX_SERVERS) {
		printf("Cannot add server %s: Too many servers.\n", server);
		return NULL;
	}
	snprintf(servers[serverCount].addr, ADDRLEN, "%s", server);
	servers[serverCount].con.fd = -1;
	servers[serverCount].con.serverData = &servers[serverCount];
	return &servers[serverCount++];
}

static void server_free(epoll_server_t *server)
{
	server->bound = FALSE;
	if (server->fd != -1) close(server->fd);
	server->fd = -1;
	server->sbPos = server->sbFill = 0;
	if (server->dynamic) {
		printf("Freeing Bind-AD-Connection\n");
		free(server->sendBuffer);
		free(server);
	}
}

static void server_callback(void *data, int haveIn, int haveOut, int doCleanup)
{
	epoll_server_t *server = (epoll_server_t *)data;
	if (doCleanup) {
		server_free(server);
		return;
	}
	if (haveIn) {
		for (;;) {
			if (server->rbPos >= MAXMSGLEN) {
				printf("[AD->Proxy] Read buffer overflow. Disconnecting.\n");
				server_free(server);
				return;
			}
			const size_t buflen = MAXMSGLEN - server->rbPos;
			const ssize_t ret = read(server->fd, server->readBuffer + server->rbPos, buflen);
			printf("AD read %d (err %d)\n", (int)ret, errno);
			if (ret < 0 && errno == EINTR) continue;
			if (ret < 0 && errno == EAGAIN) break;
			if (ret <= 0) {
				printf("AD gone while reading.\n");
				server_free(server);
				return;
			}
			server->rbPos += ret;
			// Request complete?
			for (;;) {
				size_t consumed, len;
				consumed = scan_asn1SEQUENCE(server->readBuffer, server->readBuffer + server->rbPos, &len);
				if (consumed == 0) break; // Length-Header not complete
				len += consumed;
				if (len > server->rbPos) break; // Body not complete
				printf("[AD] Received complete reply...\n");
				if (!proxy_fromServer(server, len)) {
					if (server->dynamic) {
						server_free(server);
						return;
					}
					printf("Error parsing reply from AD.\n");
					// Let's try to go on with the next message....
				}
				// Shift remaining buffer contents
				if (len == server->rbPos) {
					server->rbPos = 0;
					break;
				}
				memmove(server->readBuffer, server->readBuffer + len, server->rbPos - len);
				server->rbPos -= len;
			}
			if ((ssize_t)buflen > ret) break; // Read less than buffer len, epoll will fire again
		}
	}
	if (haveOut) server_flush(server);
}

BOOL server_send(epoll_server_t *server, const char *buffer, size_t len, const BOOL cork)
{
	if (server->sbFill == 0 && !cork) {
		// Nothing in send buffer, fire away
		const int ret = write(server->fd, buffer, len);
		if (ret == 0 || (ret < 0 && errno != EINTR && errno != EAGAIN)) {
			printf("Server gone when trying to send.\n");
			return FALSE;
		}
		server->lastActive = time(NULL);
		if (ret == (int)len) return TRUE;
		// Couldn't send everything, continue with buffering logic below
		if (ret > 0) {
			printf("[AD] Partial send (%d of %d)\n", ret, (int)len);
			buffer += ret;
			len -= (size_t)ret;
		}
	}
	// Buffer...
	server_ensureSendBuffer(server, len);
	// Finally append to buffer
	memcpy(server->sendBuffer + server->sbFill, buffer, len);
	server->sbFill += len;
	if (!cork) server_flush(server);
	return TRUE;
}

static void server_flush(epoll_server_t * const server)
{
	while (server->sbPos < server->sbFill) {
		const int tosend = server->sbFill - server->sbPos;
		const int ret = write(server->fd, server->sendBuffer + server->sbPos, tosend);
		if (ret < 0 && errno == EINTR) continue;
		if (ret < 0 && errno == EAGAIN) return;
		if (ret <= 0) {
			printf("Connection to AD Server failed while flushing (ret: %d, errno: %d)\n", ret, errno);
			return;
		}
		server->lastActive = time(NULL);
		server->sbPos += ret;
		if (ret != tosend) return;
	}
	server->sbPos = server->sbFill = 0;
}

static BOOL server_ensureConnected(server_t *server)
{
	epoll_server_t * const con = &server->con;
	if (con->fd != -1 && con->lastActive + 120 > time(NULL)) return TRUE;
	if (con->fd != -1) close(con->fd);
	con->bound = FALSE;
	printf("Connecting to AD '%s'...\n", server->addr);
	con->sbPos = con->sbFill = 0;
	int sock;
	if (server->lastLookup + 300 < time(NULL)) {
		sock = helper_connect4(server->addr, AD_PORT, server->ip);
		if (sock == -1) {
			printf("Could not resolve/connect to AD server %s\n", server->addr);
			return FALSE;
		}
	} else {
		sock = socket_tcp4b();
		if (sock == -1) {
			printf("Could not allocate socket for connection to AD\n");
			return FALSE;
		}
		if (socket_connect4(sock, server->ip, AD_PORT) == -1) {
			printf("Could not connect to cached IP of %s\n", server->addr);
			close(sock);
			return FALSE;
		}
	}
	printf("Connected, binding....\n");
	helper_nonblock(sock);
	con->fd = sock;
	con->callback = &server_callback;
	if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) {
		printf("epoll_add failed for ad server %s\n", server->addr);
		close(con->fd);
		con->fd = -1;
		return FALSE;
	}
	// Now bind - TODO: SASL (DIGEST-MD5?)
	const size_t bodyLen = fmt_ldapbindrequest(NULL, 3, server->bind, server->password);
	const size_t headerLen = fmt_ldapmessage(NULL, MSGID_BIND, BindRequest, bodyLen);
	char buffer[bodyLen + 50];
	char *bufoff = buffer + 50;
	if (headerLen >= 50) {
		printf("[AD] bind too long for %s\n", server->addr);
		close(con->fd);
		con->fd = -1;
		return FALSE;
	}
	fmt_ldapbindrequest(bufoff, 3, server->bind, server->password);
	fmt_ldapmessage(bufoff - headerLen, MSGID_BIND, BindRequest, bodyLen);
	server_send(con, bufoff - headerLen, bodyLen + headerLen, FALSE);
	return TRUE;
}

static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len)
{
	if (len > 1000000) bail("server_ensureSendBuffer: request too large!");
	if (s->sbLen - s->sbFill < len) {
		if (s->sbPos != 0) {
			memmove(s->sendBuffer, s->sendBuffer + s->sbPos, s->sbFill - s->sbPos);
			s->sbFill -= s->sbPos;
			s->sbPos = 0;
		}
		if (s->sbLen - s->sbFill < len) {
			helper_realloc(&s->sendBuffer, &s->sbLen, s->sbLen + len + 1000, "server_ensureSendBuffer");
		}
	}
}