#include "server.h"
#include "proxy.h"
#include "helper.h"
#include "epoll.h"
#include "tmpbuffer.h"
#include "openssl.h"
#include <time.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <socket.h>
#include <fcntl.h>
#include <ctype.h>
#define AD_PORT 3268
#define AD_PORT_SSL 636
#define MSGID_BIND 1
#define MAX_SERVERS 10
static server_t *servers = NULL;
static int serverCount = 0;
static BOOL connectionInitDone = FALSE;
static void server_init();
static server_t *server_create(const char *server);
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup);
static void server_haveIn(epoll_server_t *server);
static BOOL server_haveOut(epoll_server_t * const server);
static BOOL server_ensureConnected(server_t *server);
static BOOL server_ensureSendBuffer(epoll_server_t * const s, const size_t len);
static int server_connectInternal(server_t *server);
static BOOL server_connectSsl(epoll_server_t *server);
// Generate a message ID for request to AD
static inline uint32_t msgId()
{
static uint32_t id = 1336;
if (++id < 2) id = 2;
return id;
}
// Setting up server(s)
void server_setPort(const char *server, const char *portStr)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
int port = atoi(portStr);
if (port < 1 || port > 65535) {
printf("Warning: Invalid port '%s' for '%s'\n", portStr, server);
return;
}
entry->port = (uint16_t)port;
}
void server_setPlainLdap(const char *server, const char *enabledStr)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
entry->plainLdap = atoi(enabledStr) != 0 || strcmp(enabledStr, "true") == 0
|| strcmp(enabledStr, "True") == 0 || strcmp(enabledStr, "TRUE") == 0;
}
void server_setBind(const char *server, const char *bind)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->bind, BINDLEN, "%s", bind) >= BINDLEN) printf("Warning: BindDN for %s is too long.\n", server);
}
void server_setPassword(const char *server, const char *password)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->password, PWLEN, "%s", password) >= PWLEN) printf("Warning: BindPW for %s is too long.\n", server);
}
void server_setBase(const char *server, const char *base)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->base, BASELEN, "%s", base) >= BASELEN) printf("Warning: SearchBase for %s is too long.\n", server);
entry->baseLen = normalize_dn(entry->base, entry->base, min(strlen(entry->base), BASELEN - 1));
entry->base[entry->baseLen] = '\0';
}
void server_setCaBundle(const char *server, const char *file)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (file == NULL || *file == '\0') return;
int fh = open(file, O_RDONLY);
if (fh == -1) {
printf("Error: cabundle '%s' not readable.\n", file);
exit(1);
}
close(fh);
if (snprintf(entry->cabundle, MAXPATH, "%s", file) >= MAXPATH) printf("Warning: CaBundle for %s is too long.\n", server);
ssl_init();
}
void server_setHomeTemplate(const char *server, const char *hometemplate)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->homeTemplate, MOUNTLEN, "%s", hometemplate) >= MOUNTLEN) printf("Warning: Home Template for %s is too long.\n", server);
// TODO: Better template system. Using a format string is too lazy
BOOL b = FALSE;
char *s = entry->homeTemplate;
int count = 0;
while (*s) {
if (b) {
if (*s != '%') count++;
b = FALSE;
} else if (*s == '%') b = TRUE;
if (count > 5) *s = '_';
if (*s == '\\') *s = '/';
s++;
}
if (count > 5) printf("WARNING: Too many '%%' in Home Template for %s. Don't forget to replace literal '%%' with '%%%%'\n", server);
}
void server_setHomeAttribute(const char *server, const char *homeattribute)
{
server_t *entry = server_create(server);
if (entry == NULL || entry->sslContext != NULL) return;
free((void*)entry->homeAttr.s);
free((void*)entry->homeAttrLower.s);
entry->homeAttr.l = strlen(homeattribute);
entry->homeAttrLower.l = entry->homeAttr.l;
entry->homeAttr.s = strdup(homeattribute);
char *tmp = strdup(homeattribute);
for (size_t i = 0; i < entry->homeAttrLower.l; ++i) {
tmp[i] = tolower(tmp[i]);
}
entry->homeAttrLower.s = tmp;
}
void server_setFingerprint(const char *server, const char *fingerprint)
{
server_t *entry = server_create(server);
if (entry == NULL || entry->sslContext != NULL) return;
int chars = 0, val = -1;
while (*fingerprint != '\0' && chars / 2 < FINGERPRINTLEN) {
if (*fingerprint == ':' || *fingerprint == ' ') {
fingerprint++;
continue;
}
val = -1;
if (*fingerprint >= '0' && *fingerprint <= '9') {
val = *fingerprint - '0';
} else if (*fingerprint >= 'a' && *fingerprint <= 'f') {
val = *fingerprint - 'a' + 10;
} else if (*fingerprint >= 'A' && *fingerprint <= 'F') {
val = *fingerprint - 'A' + 10;
} else {
break;
}
if (chars % 2 == 0) {
entry->fingerprint[chars / 2] |= val << 4;
} else {
entry->fingerprint[chars / 2] |= val;
}
fingerprint++;
chars++;
}
if (chars / 2 != FINGERPRINTLEN || val == -1) {
printf("Warning: Fingerprint for %s is invalid (adsha1 should be a SHA-1 hash of the cert in hex representation.)\n", server);
return;
}
printf("Using fingerprint ");
for (int i = 0; i < FINGERPRINTLEN - 1; ++i) {
printf("%02x:", (int)entry->fingerprint[i]);
}
printf("%02x for %s\n", (int)entry->fingerprint[FINGERPRINTLEN-1], server);
ssl_init();
}
BOOL server_initServers()
{
int i;
printf("%d servers configured.\n", serverCount);
for (i = 0; i < serverCount; ++i) {
if (servers[i].cabundle[0] != '\0' || memcmp(servers[i].fingerprint, "\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) != 0) {
servers[i].sslContext = ssl_newClientCtx(servers[i].cabundle);
}
printf("%s:\n Bind: %s\n Base: %s\n", servers[i].addr, servers[i].bind, servers[i].base);
printf("Plain LDAP-LDAP: %d\n", (int)servers[i].plainLdap);
if (!server_ensureConnected(&servers[i]))
return FALSE;
}
connectionInitDone = TRUE;
return TRUE;
}
// What the proxy calls
server_t *server_getFromBase(struct string *in)
{
int i;
char buffer[TMPLEN];
const size_t searchLen = normalize_dn(buffer, in->s, min(in->l, TMPLEN - 1));
buffer[searchLen] = '\0';
// Now buffer contains the normalized wanted bind/domain/whatev. Try to find a match in the server list
for (i = 0; i < serverCount; ++i) {
if (searchLen < servers[i].baseLen) continue;
if (strcmp(servers[i].base, buffer + (searchLen - servers[i].baseLen)) == 0) {
return &servers[i];
}
}
return NULL;
}
uint32_t server_searchRequest(server_t *server, struct SearchRequest *req)
{
epoll_server_t * const s = &server->con;
if (!server_ensureConnected(server)) return 0;
printf("fd: %d, Kill: %d, Bound: %d, idle: %d\n", s->fd, (int)s->kill, (int)s->bound, (int)(time(NULL) - s->lastActive));
if (s->fd != -1 && !s->kill) {
uint32_t msgId = server_searchRequestOnConnection(s, req);
if (msgId != 0) return msgId;
}
if (!server_ensureConnected(server)) return 0;
return server_searchRequestOnConnection(s, req);
}
uint32_t server_searchRequestOnConnection(epoll_server_t *server, struct SearchRequest *req)
{
const uint32_t msgid = msgId();
const size_t bodyLen = fmt_ldapsearchrequest(NULL, req);
const size_t headerLen = fmt_ldapmessage(NULL, msgid, SearchRequest, bodyLen);
char buffer[bodyLen + 50];
char *bufoff = buffer + 50;
fmt_ldapsearchrequest(bufoff, req);
fmt_ldapmessage(bufoff - headerLen, msgid, SearchRequest, bodyLen);
if (!server_send(server, bufoff - headerLen, headerLen + bodyLen, FALSE)) return 0;
return msgid;
}
uint32_t server_tryUserBind(server_t *server, struct string *binddn, struct string *password, epoll_server_t **newcon)
{
epoll_server_t *con = calloc(1, sizeof(epoll_server_t));
con->serverData = server;
con->fd = -1;
con->bound = FALSE;
con->dynamic = TRUE;
con->sbPos = con->sbFill = 0;
int sock = server_connectInternal(server);
if (sock == -1) {
printf("[Proxy] Could not connect to AD for user bind.\n");
server_free(con);
return 0;
}
helper_nonblock(sock);
con->fd = sock;
con->callback = &server_callback;
if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) {
printf("[Proxy] epoll_add failed for AD server %s on user bind\n", server->addr);
server_free(con);
return 0;
}
// SSL
if (!server_connectSsl(con)) {
printf("[Proxy] SSL handshake failed for AD server %s on user bind\n", server->addr);
server_free(con);
return 0;
}
// Now bind - TODO: SASL (DIGEST-MD5?)
const uint32_t id = msgId();
const size_t bodyLen = fmt_ldapbindrequeststring(NULL, 3, binddn, password);
const size_t headerLen = fmt_ldapmessage(NULL, id, BindRequest, bodyLen);
char buffer[bodyLen + 50];
char *bufoff = buffer + 50;
if (headerLen >= 50) {
printf("[Proxy] User bind too long for %s\n", server->addr);
server_free(con);
return 0;
}
fmt_ldapbindrequeststring(bufoff, 3, binddn, password);
fmt_ldapmessage(bufoff - headerLen, id, BindRequest, bodyLen);
con->bindLen = (int)(bodyLen + headerLen);
if (con->bindLen < 0 || con->bindLen > BINDLEN) {
printf("[Server] Error: bind too long");
con->bindLen = BINDLEN;
}
memcpy(con->bindBuffer, bufoff - headerLen, con->bindLen);
if (!server_haveOut(con)) {
printf("[Server] Could not send user bindrequest to server %s\n", server->addr);
server_free(con);
return 0;
}
*newcon = con;
return id;
}
//
// Private stuff
static void server_init()
{
if (servers != NULL) return;
servers = calloc(MAX_SERVERS, sizeof(server_t));
}
static server_t *server_create(const char *server)
{
int i;
server_init();
for (i = 0; i < serverCount; ++i) {
if (strcmp(servers[i].addr, server) == 0) return &servers[i];
}
if (serverCount >= MAX_SERVERS) {
printf("Cannot add server %s: Too many servers.\n", server);
return NULL;
}
snprintf(servers[serverCount].addr, ADDRLEN, "%s", server);
servers[serverCount].con.fd = -1;
servers[serverCount].con.serverData = &servers[serverCount];
return &servers[serverCount++];
}
void server_free(epoll_server_t *server)
{
proxy_removeServer(server);
server->bound = FALSE;
server->kill = FALSE;
if (server->ssl != NULL) {
SSL_free(server->ssl);
server->ssl = NULL;
}
if (server->fd != -1) {
close(server->fd);
server->fd = -1;
}
server->sbPos = server->sbFill = 0;
if (server->dynamic) {
printf("[Server] Freeing Bind-AD-Connection\n");
free(server->sendBuffer);
free(server);
} else {
printf("[Server] Closed shared anonymous connection\n");
}
}
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup)
{
epoll_server_t *server = (epoll_server_t *)data;
if (server->ssl == NULL) {
// Plain connection
if (haveIn) server_haveIn(server);
if (haveOut) server_haveOut(server);
} else {
// SSL
if (!server->sslConnected) {
// Still SSL-Connecting
if (!ssl_connectServer(server) || doCleanup || server->kill) {
printf("[Proxy] SSL handshake for AD server %s failed.\n", server->serverData->addr);
server_free(server);
return;
}
if (!server->sslConnected) return;
}
// Since we don't know if the incoming data is just wrapped application data or ssl protocol stuff, we always call both
server_haveIn(server);
server_haveOut(server);
}
if (doCleanup || server->kill) {
server_free(server);
}
}
static void server_haveIn(epoll_server_t *server)
{
if (server->ssl != NULL && !server->sslConnected) return;
for (;;) {
if (server->rbPos >= MAXMSGLEN) {
printf("[Proxy] Buffer overflow while reading from AD server. Disconnecting.\n");
server->kill = TRUE;
return;
}
const size_t buflen = MAXMSGLEN - server->rbPos;
ssize_t ret;
if (server->ssl == NULL) {
// Plain
ret = read(server->fd, server->readBuffer + server->rbPos, buflen);
if (ret < 0 && errno == EINTR) continue;
if (ret < 0 && errno == EAGAIN) break;
//if (ret < 0) printf("[Proxy] AD Server %s gone while reading (ret=%d, errno=%d).\n", server->serverData->addr, (int)ret, errno);
if (ret <= 0) {
server->kill = TRUE;
return;
}
} else {
// SSL
ret = SSL_read(server->ssl, server->readBuffer + server->rbPos, buflen);
if (ret <= 0) {
int err = SSL_get_error(server->ssl, ret);
if (SSL_BLOCKED(err)) break;
//if (err != 0) printf("[Proxy] AD Server %s gone while reading (ret=%d, err=%d).\n", server->serverData->addr, (int)ret, err);
server->kill = TRUE;
return;
}
}
server->rbPos += ret;
// Request complete?
for (;;) {
size_t consumed, len;
consumed = scan_asn1SEQUENCE(server->readBuffer, server->readBuffer + server->rbPos, &len);
if (consumed == 0) break; // Length-Header not complete
len += consumed;
if (len > server->rbPos) break; // Body not complete
//printf("[AD] Received complete reply (need %d, have %d)...\n", (int)len, (int)server->rbPos);
if (!proxy_fromServer(server, len)) {
if (server->dynamic) {
server->kill = TRUE;
return;
}
printf("[Proxy] Error parsing message from AD.\n");
// Let's try to go on with the next message....
}
// Shift remaining buffer contents
if (len == server->rbPos) {
server->rbPos = 0;
break;
}
memmove(server->readBuffer, server->readBuffer + len, server->rbPos - len);
server->rbPos -= len;
}
if (server->ssl == NULL && (ssize_t)buflen > ret) break; // Read less than buffer len, epoll will fire again
}
}
BOOL server_send(epoll_server_t *server, const char *buffer, size_t len, const BOOL cork)
{
if (server->kill) return FALSE;
if (server->ssl == NULL && server->sbFill == 0 && !cork) {
// Nothing in send buffer, fire away
const int ret = write(server->fd, buffer, len);
if (ret == 0 || (ret < 0 && errno != EINTR && errno != EAGAIN)) {
printf("[Proxy] AD Server %s gone when trying to send.\n", server->serverData->addr);
server->kill = TRUE;
return FALSE;
}
server->lastActive = time(NULL);
if (ret == (int)len) return TRUE;
// Couldn't send everything, continue with buffering logic below
if (ret > 0) {
buffer += ret;
len -= (size_t)ret;
}
}
// Buffer...
if (!server_ensureSendBuffer(server, len)) {
server->kill = TRUE;
return FALSE;
}
// Finally append to buffer
memcpy(server->sendBuffer + server->sbFill, buffer, len);
server->sbFill += len;
if (!cork) {
server->lastActive = time(NULL);
return server_haveOut(server);
}
return TRUE;
}
static BOOL server_haveOut(epoll_server_t * const server)
{
if (server->kill) return FALSE;
if (server->ssl != NULL && !server->sslConnected) return TRUE;
// Bind not sent/acknowledged yet - send bind if pending, otherwise do nothing
if (!server->bound) {
if (server->bindLen == 0) return TRUE;
int ret;
if (server->ssl == NULL) {
ret = write(server->fd, server->bindBuffer, server->bindLen);
} else {
ret = SSL_write(server->ssl, server->bindBuffer, server->bindLen);
}
if (ret <= 0) {
printf("[Server] Flushing bind to LDAP/AD failed...\n");
server->kill = TRUE;
return FALSE;
}
if (ret < server->bindLen) {
memmove(server->bindBuffer, server->bindBuffer + ret, server->bindLen - ret);
}
server->bindLen -= ret;
return TRUE;
}
// Only flush the regular send buffer (containing searches) if we know the bind succeeded
while (server->sbPos < server->sbFill) {
const ssize_t tosend = server->sbFill - server->sbPos;
ssize_t ret;
if (server->ssl == NULL) {
// Plain
ret = write(server->fd, server->sendBuffer + server->sbPos, tosend);
if (ret < 0 && errno == EINTR) continue;
if (ret < 0 && errno == EAGAIN) return TRUE;
if (ret <= 0) {
printf("[Proxy] AD Server %s gone while flushing send buffer (ret=%d, errno=%d)\n", server->serverData->addr, (int)ret, errno);
server->kill = TRUE;
return FALSE;
}
} else {
// SSL
ret = SSL_write(server->ssl, server->sendBuffer + server->sbPos, tosend);
if (ret <= 0) {
int err = SSL_get_error(server->ssl, ret);
if (SSL_BLOCKED(err)) {
server->writeBlocked = TRUE;
return TRUE; // Blocking
} else if (err == SSL_ERROR_SSL) {
ssl_printErrors(NULL);
}
printf("[Proxy] AD Server %s gone while flushing send buffer (ret=%d, err=%d)\n", server->serverData->addr, (int)ret, err);
ERR_print_errors_fp(stdout);
server->kill = TRUE;
return FALSE; // Closed
}
server->writeBlocked = FALSE;
}
server->sbPos += ret;
if (server->ssl != NULL) {
memmove(server->sendBuffer, server->sendBuffer + server->sbPos, server->sbFill - server->sbPos);
server->sbFill -= server->sbPos;
server->sbPos = 0;
}
if (server->ssl == NULL && ret != tosend) return TRUE;
}
server->sbPos = server->sbFill = 0;
return TRUE;
}
static BOOL server_ensureConnected(server_t *server)
{
epoll_server_t * const con = &server->con;
if (con->fd != -1 && con->lastActive + 120 > time(NULL)) return TRUE;
if (con->fd != -1) {
server_free(con);
}
int sock = server_connectInternal(server);
if (sock == -1) {
printf("[Server] Creating socket for shared connection failed.\n");
return FALSE;
}
if (connectionInitDone) {
helper_nonblock(sock);
}
con->fd = sock;
con->callback = &server_callback;
if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) {
printf("[Proxy] epoll_add failed for ad server %s\n", server->addr);
close(con->fd);
con->fd = -1;
return FALSE;
}
// SSL
if (!server_connectSsl(con)) {
printf("[Proxy] SSL handshake failed for shared connection of %s\n", server->addr);
close(con->fd);
con->fd = -1;
return FALSE;
}
if (!connectionInitDone) {
helper_nonblock(sock);
}
// Now bind - TODO: SASL (DIGEST-MD5?)
const size_t bodyLen = fmt_ldapbindrequest(NULL, 3, server->bind, server->password);
const size_t headerLen = fmt_ldapmessage(NULL, MSGID_BIND, BindRequest, bodyLen);
char buffer[bodyLen + 50];
char *bufoff = buffer + 50;
if (headerLen >= 50) {
printf("[Proxy] bind too long for %s\n", server->addr);
close(con->fd);
con->fd = -1;
return FALSE;
}
con->kill = FALSE;
fmt_ldapbindrequest(bufoff, 3, server->bind, server->password);
fmt_ldapmessage(bufoff - headerLen, MSGID_BIND, BindRequest, bodyLen);
con->bindLen = (int)(bodyLen + headerLen);
if (con->bindLen < 0 || con->bindLen > BINDLEN) {
printf("[Server] Error: bind too long");
con->bindLen = BINDLEN;
}
memcpy(con->bindBuffer, bufoff - headerLen, con->bindLen);
if (!server_haveOut(con)) {
printf("[Server] Sending bindrequest for shared connection failed for server %s\n", server->addr);
return FALSE;
}
con->lastActive = time(NULL);
return TRUE;
}
static BOOL server_ensureSendBuffer(epoll_server_t * const s, const size_t len)
{
if (len > 1000000) {
printf("server_ensureSendBuffer: request too large!\n");
return FALSE;
}
if (s->sbLen - s->sbFill < len) {
if (s->writeBlocked) {
printf("[Proxy] SSL write to AD server blocked and buffer to small (%d bytes)\n", (int)s->sbLen);
return FALSE;
}
if (s->sbPos != 0) {
memmove(s->sendBuffer, s->sendBuffer + s->sbPos, s->sbFill - s->sbPos);
s->sbFill -= s->sbPos;
s->sbPos = 0;
}
if (s->sbLen - s->sbFill < len) {
if (helper_realloc(&s->sendBuffer, &s->sbLen, s->sbLen + len + (s->ssl == NULL ? 1000 : 6000), "server_ensureSendBuffer") == -1) {
return FALSE;
}
}
}
return TRUE;
}
static int server_connectInternal(server_t *server)
{
int sock;
const uint16_t port = server->port != 0 ? server->port : (server->sslContext == NULL ? AD_PORT : AD_PORT_SSL);
if (server->lastLookup + 300 < time(NULL)) {
sock = helper_connect4(server->addr, port, server->ip);
if (sock == -1) {
printf("[Proxy] Could not resolve hostname or connect to AD server %s\n", server->addr);
return -1;
}
} else {
sock = socket_tcp4b();
if (sock == -1) {
printf("[Proxy] Could not allocate socket for connection to AD server %s\n", server->addr);
return -1;
}
if (socket_connect4(sock, server->ip, port) == -1) {
printf("[Proxy] Could not connect to cached IP (%s) of %s\n", server->ip, server->addr);
server->lastLookup = 0;
close(sock);
return -1;
}
}
return sock;
}
static BOOL server_connectSsl(epoll_server_t *server)
{
if (server->serverData->sslContext == NULL) return TRUE;
server->sslConnected = FALSE;
server->ssl = ssl_new(server->fd, server->serverData->sslContext);
if (server->ssl == NULL) {
printf("[Proxy] Could not get SSL client from context\n");
return FALSE;
}
if (!ssl_connectServer(server)) {
printf("[Proxy] SSL connect to AD server %s failed.\n", server->serverData->addr);
SSL_free(server->ssl);
server->ssl = NULL;
return FALSE;
}
return TRUE;
}