#include "server.h"
#include "proxy.h"
#include "helper.h"
#include "epoll.h"
#include "tmpbuffer.h"
#include "openssl.h"
#include <time.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <socket.h>
#define AD_PORT 3268
#define AD_PORT_SSL 636
#define MSGID_BIND 1
#define MAX_SERVERS 10
static server_t *servers = NULL;
static int serverCount = 0;
static void server_init();
static server_t *server_create(const char *server);
static void server_free(epoll_server_t *server);
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup);
static void server_haveIn(epoll_server_t *server);
static void server_haveOut(epoll_server_t * const server);
static BOOL server_ensureConnected(server_t *server);
static BOOL server_ensureSendBuffer(epoll_server_t * const s, const size_t len);
static int server_connectInternal(server_t *server);
static BOOL server_connectSsl(epoll_server_t *server);
// Generate a message ID for request to AD
static inline uint32_t msgId()
{
static uint32_t id = 1336;
if (++id < 2) id = 2;
return id;
}
// Setting up server(s)
void server_setPort(const char *server, const char *portStr)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
int port = atoi(portStr);
if (port < 1 || port > 65535) {
printf("Warning: Invalid port '%s' for '%s'\n", portStr, server);
return;
}
entry->port = (uint16_t)port;
}
void server_setBind(const char *server, const char *bind)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->bind, BINDLEN, "%s", bind) >= BINDLEN) printf("Warning: BindDN for %s is too long.\n", server);
}
void server_setPassword(const char *server, const char *password)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->password, PWLEN, "%s", password) >= PWLEN) printf("Warning: BindPW for %s is too long.\n", server);
}
void server_setBase(const char *server, const char *base)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->base, BASELEN, "%s", base) >= BASELEN) printf("Warning: SearchBase for %s is too long.\n", server);
entry->baseLen = normalize_dn(entry->base, entry->base, min(strlen(entry->base), BASELEN - 1));
entry->base[entry->baseLen] = '\0';
}
void server_setHomeTemplate(const char *server, const char *hometemplate)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->homeTemplate, MOUNTLEN, "%s", hometemplate) >= MOUNTLEN) printf("Warning: Home Template for %s is too long.\n", server);
// TODO: Better template system. Using a format string is too lazy
BOOL b = FALSE;
char *s = entry->homeTemplate;
int count = 0;
while (*s) {
if (b) {
if (*s != '%') count++;
b = FALSE;
} else if (*s == '%') b = TRUE;
if (count > 5) *s = '_';
if (*s == '\\') *s = '/';
s++;
}
if (count > 5) printf("WARNING: Too many '%%' in Home Template for %s. Don't forget to replace literal '%%' with '%%%%'\n", server);
}
void server_setFingerprint(const char *server, const char *fingerprint)
{
server_t *entry = server_create(server);
if (entry == NULL || entry->sslContext != NULL) return;
int chars = 0, val = -1;
while (*fingerprint != '\0' && chars / 2 < FINGERPRINTLEN) {
if (*fingerprint == ':' || *fingerprint == ' ') {
fingerprint++;
continue;
}
val = -1;
if (*fingerprint >= '0' && *fingerprint <= '9') {
val = *fingerprint - '0';
} else if (*fingerprint >= 'a' && *fingerprint <= 'f') {
val = *fingerprint - 'a' + 10;
} else if (*fingerprint >= 'A' && *fingerprint <= 'F') {
val = *fingerprint - 'A' + 10;
} else {
break;
}
if (chars % 2 == 0) {
entry->fingerprint[chars / 2] |= val << 4;
} else {
entry->fingerprint[chars / 2] |= val;
}
fingerprint++;
chars++;
}
if (chars / 2 != FINGERPRINTLEN || val == -1) {
printf("Warning: Fingerprint for %s is invalid (adsha1 should be a SHA-1 hash of the cert in hex representation.)\n", server);
return;
}
printf("Using fingerprint ");
for (int i = 0; i < FINGERPRINTLEN - 1; ++i) {
printf("%02x:", (int)entry->fingerprint[i]);
}
printf("%02x for %s\n", (int)entry->fingerprint[FINGERPRINTLEN-1], server);
ssl_init();
entry->sslContext = ssl_newClientCtx();
}
BOOL server_initServers()
{
int i;
printf("%d servers configured.\n", serverCount);
for (i = 0; i < serverCount; ++i) {
printf("%s:\n Bind: %s\n Base: %s\n", servers[i].addr, servers[i].bind, servers[i].base);
if (!server_ensureConnected(&servers[i]))
return FALSE;
}
return TRUE;
}
// What the proxy calls
server_t *server_getFromBase(struct string *in)
{
int i;
char buffer[TMPLEN];
const size_t searchLen = normalize_dn(buffer, in->s, min(in->l, TMPLEN - 1));
buffer[searchLen] = '\0';
// Now buffer contains the normalized wanted bind/domain/whatev. Try to find a match in the server list
for (i = 0; i < serverCount; ++i) {
if (searchLen < servers[i].baseLen) continue;
if (strcmp(servers[i].base, buffer + (searchLen - servers[i].baseLen)) == 0) {
return &servers[i];
}
}
return NULL;
}
uint32_t server_searchRequest(server_t *server, struct SearchRequest *req)
{
if (!server_ensureConnected(server)) return 0;
const uint32_t msgid = msgId();
const size_t bodyLen = fmt_ldapsearchrequest(NULL, req);
const size_t headerLen = fmt_ldapmessage(NULL, msgid, SearchRequest, bodyLen);
char buffer[bodyLen + 50];
char *bufoff = buffer + 50;
fmt_ldapsearchrequest(bufoff, req);
fmt_ldapmessage(bufoff - headerLen, msgid, SearchRequest, bodyLen);
epoll_server_t * const s = &server->con;
server_send(s, bufoff - headerLen, headerLen + bodyLen, FALSE);
return msgid;
}
uint32_t server_tryUserBind(server_t *server, struct string *binddn, struct string *password)
{
epoll_server_t *con = calloc(1, sizeof(epoll_server_t));
con->serverData = server;
con->fd = -1;
con->bound = FALSE;
con->dynamic = TRUE;
printf("Connecting to AD '%s' for %.*ss bind...\n", server->addr, (int)binddn->l, binddn->s);
con->sbPos = con->sbFill = 0;
int sock = server_connectInternal(server);
if (sock == -1) {
server_free(con);
return 0;
}
printf("[ADB] Connected, binding....\n");
helper_nonblock(sock);
con->fd = sock;
con->callback = &server_callback;
if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) {
printf("[ADB] epoll_add failed for ad server %s\n", server->addr);
server_free(con);
return 0;
}
// SSL
if (!server_connectSsl(con)) {
server_free(con);
return 0;
}
// Now bind - TODO: SASL (DIGEST-MD5?)
const uint32_t id = msgId();
const size_t bodyLen = fmt_ldapbindrequeststring(NULL, 3, binddn, password);
const size_t headerLen = fmt_ldapmessage(NULL, id, BindRequest, bodyLen);
char buffer[bodyLen + 50];
char *bufoff = buffer + 50;
if (headerLen >= 50) {
printf("[ADB] bind too long for %s\n", server->addr);
server_free(con);
return 0;
}
fmt_ldapbindrequeststring(bufoff, 3, binddn, password);
fmt_ldapmessage(bufoff - headerLen, id, BindRequest, bodyLen);
server_send(con, bufoff - headerLen, bodyLen + headerLen, FALSE);
return id;
}
//
// Private stuff
static void server_init()
{
if (servers != NULL) return;
servers = calloc(MAX_SERVERS, sizeof(server_t));
}
static server_t *server_create(const char *server)
{
int i;
server_init();
for (i = 0; i < serverCount; ++i) {
if (strcmp(servers[i].addr, server) == 0) return &servers[i];
}
if (serverCount >= MAX_SERVERS) {
printf("Cannot add server %s: Too many servers.\n", server);
return NULL;
}
snprintf(servers[serverCount].addr, ADDRLEN, "%s", server);
servers[serverCount].con.fd = -1;
servers[serverCount].con.serverData = &servers[serverCount];
return &servers[serverCount++];
}
static void server_free(epoll_server_t *server)
{
server->bound = FALSE;
if (server->ssl != NULL) {
SSL_free(server->ssl);
server->ssl = NULL;
}
if (server->fd != -1) {
close(server->fd);
server->fd = -1;
}
server->sbPos = server->sbFill = 0;
if (server->dynamic) {
printf("Freeing Bind-AD-Connection\n");
free(server->sendBuffer);
free(server);
}
}
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup)
{
epoll_server_t *server = (epoll_server_t *)data;
if (doCleanup || server->kill) {
server_free(server);
return;
}
if (server->ssl == NULL) {
// Plain connection
if (haveIn) server_haveIn(server);
if (haveOut) server_haveOut(server);
if (server->kill) server_free(server);
return;
}
// SSL
if (!server->sslConnected) {
// Still SSL-Connecting
if (!ssl_connectServer(server)) {
printf("SSL Server connect failed!\n");
server_free(server);
return;
}
if (!server->sslConnected) return;
}
// Since we don't know if the incoming data is just wrapped application data or ssl protocol stuff, we always call both
server_haveIn(server);
server_haveOut(server);
if (server->kill) server_free(server);
}
static void server_haveIn(epoll_server_t *server)
{
for (;;) {
if (server->rbPos >= MAXMSGLEN) {
printf("[AD->Proxy] Read buffer overflow. Disconnecting.\n");
server->kill = TRUE;
return;
}
const size_t buflen = MAXMSGLEN - server->rbPos;
ssize_t ret;
if (server->ssl == NULL) {
// Plain
ret = read(server->fd, server->readBuffer + server->rbPos, buflen);
printf("AD read %d (err %d)\n", (int)ret, errno);
if (ret < 0 && errno == EINTR) continue;
if (ret < 0 && errno == EAGAIN) break;
if (ret <= 0) {
printf("AD Server gone while reading.\n");
server->kill = TRUE;
return;
}
} else {
// SSL
ret = SSL_read(server->ssl, server->readBuffer + server->rbPos, buflen);
if (ret <= 0) {
int err = SSL_get_error(server->ssl, ret);
if (SSL_BLOCKED(err)) break;
printf("AD Server gone while reading (%d, %d).\n", (int)ret, err);
server->kill = TRUE;
return;
}
}
server->rbPos += ret;
// Request complete?
for (;;) {
size_t consumed, len;
consumed = scan_asn1SEQUENCE(server->readBuffer, server->readBuffer + server->rbPos, &len);
if (consumed == 0) break; // Length-Header not complete
len += consumed;
if (len > server->rbPos) break; // Body not complete
printf("[AD] Received complete reply...\n");
if (!proxy_fromServer(server, len)) {
if (server->dynamic) {
server->kill = TRUE;
return;
}
printf("Error parsing reply from AD.\n");
// Let's try to go on with the next message....
}
// Shift remaining buffer contents
if (len == server->rbPos) {
server->rbPos = 0;
break;
}
memmove(server->readBuffer, server->readBuffer + len, server->rbPos - len);
server->rbPos -= len;
}
if ((ssize_t)buflen > ret) break; // Read less than buffer len, epoll will fire again
}
}
BOOL server_send(epoll_server_t *server, const char *buffer, size_t len, const BOOL cork)
{
if (server->ssl == NULL && server->sbFill == 0 && !cork) {
// Nothing in send buffer, fire away
const int ret = write(server->fd, buffer, len);
if (ret == 0 || (ret < 0 && errno != EINTR && errno != EAGAIN)) {
printf("Server gone when trying to send.\n");
return FALSE;
}
server->lastActive = time(NULL);
if (ret == (int)len) return TRUE;
// Couldn't send everything, continue with buffering logic below
if (ret > 0) {
printf("[AD] Partial send (%d of %d)\n", ret, (int)len);
buffer += ret;
len -= (size_t)ret;
}
}
// Buffer...
if (!server_ensureSendBuffer(server, len)) {
server->kill = TRUE;
return FALSE;
}
// Finally append to buffer
memcpy(server->sendBuffer + server->sbFill, buffer, len);
server->sbFill += len;
if (!cork) server_haveOut(server);
return TRUE;
}
static void server_haveOut(epoll_server_t * const server)
{
while (server->sbPos < server->sbFill) {
const ssize_t tosend = server->sbFill - server->sbPos;
ssize_t ret;
if (server->ssl == NULL) {
// Plain
ret = write(server->fd, server->sendBuffer + server->sbPos, tosend);
if (ret < 0 && errno == EINTR) continue;
if (ret < 0 && errno == EAGAIN) return;
if (ret <= 0) {
printf("Connection to AD Server failed while flushing (ret: %d, errno: %d)\n", (int)ret, errno);
return;
}
} else {
// SSL
ret = SSL_write(server->ssl, server->sendBuffer + server->sbPos, tosend);
if (ret <= 0) {
int err = SSL_get_error(server->ssl, ret);
if (SSL_BLOCKED(err)) {
server->writeBlocked = TRUE;
return; // Blocking
}
printf("SSL server gone while sending (%d)\n", err);
ERR_print_errors_fp(stdout);
server->kill = TRUE;
return; // Closed
}
}
server->lastActive = time(NULL);
server->sbPos += ret;
if (server->ssl != NULL) {
memmove(server->sendBuffer, server->sendBuffer + server->sbPos, server->sbFill - server->sbPos);
server->sbFill -= server->sbPos;
server->sbPos = 0;
}
if (server->ssl == NULL && ret != tosend) return;
}
server->sbPos = server->sbFill = 0;
}
static BOOL server_ensureConnected(server_t *server)
{
epoll_server_t * const con = &server->con;
if (con->fd != -1 && con->lastActive + 120 > time(NULL)) return TRUE;
if (con->fd != -1) close(con->fd);
con->bound = FALSE;
printf("Connecting to AD '%s'...\n", server->addr);
con->sbPos = con->sbFill = 0;
int sock = server_connectInternal(server);
if (sock == -1) return FALSE;
printf("Connected, binding....\n");
helper_nonblock(sock);
con->fd = sock;
con->callback = &server_callback;
if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) {
printf("epoll_add failed for ad server %s\n", server->addr);
close(con->fd);
con->fd = -1;
return FALSE;
}
// SSL
if (!server_connectSsl(con)) {
close(con->fd);
con->fd = -1;
return FALSE;
}
// Now bind - TODO: SASL (DIGEST-MD5?)
const size_t bodyLen = fmt_ldapbindrequest(NULL, 3, server->bind, server->password);
const size_t headerLen = fmt_ldapmessage(NULL, MSGID_BIND, BindRequest, bodyLen);
char buffer[bodyLen + 50];
char *bufoff = buffer + 50;
if (headerLen >= 50) {
printf("[AD] bind too long for %s\n", server->addr);
close(con->fd);
con->fd = -1;
return FALSE;
}
fmt_ldapbindrequest(bufoff, 3, server->bind, server->password);
fmt_ldapmessage(bufoff - headerLen, MSGID_BIND, BindRequest, bodyLen);
server_send(con, bufoff - headerLen, bodyLen + headerLen, FALSE);
return TRUE;
}
static BOOL server_ensureSendBuffer(epoll_server_t * const s, const size_t len)
{
if (len > 1000000) {
printf("server_ensureSendBuffer: request too large!\n");
return FALSE;
}
if (s->sbLen - s->sbFill < len) {
if (s->writeBlocked) {
printf("SSL Write blocked and buffer to small (%d)\n", (int)s->sbLen);
return FALSE;
}
if (s->sbPos != 0) {
memmove(s->sendBuffer, s->sendBuffer + s->sbPos, s->sbFill - s->sbPos);
s->sbFill -= s->sbPos;
s->sbPos = 0;
}
if (s->sbLen - s->sbFill < len) {
if (helper_realloc(&s->sendBuffer, &s->sbLen, s->sbLen + len + (s->ssl == NULL ? 1000 : 6000), "server_ensureSendBuffer") == -1) {
return FALSE;
}
}
}
return TRUE;
}
static int server_connectInternal(server_t *server)
{
int sock;
const uint16_t port = server->port != 0 ? server->port : (server->sslContext == NULL ? AD_PORT : AD_PORT_SSL);
if (server->lastLookup + 300 < time(NULL)) {
sock = helper_connect4(server->addr, port, server->ip);
if (sock == -1) {
printf("Could not resolve/connect to AD server %s\n", server->addr);
return -1;
}
} else {
sock = socket_tcp4b();
if (sock == -1) {
printf("Could not allocate socket for connection to AD\n");
return -1;
}
if (socket_connect4(sock, server->ip, port) == -1) {
printf("Could not connect to cached IP of %s\n", server->addr);
server->lastLookup = 0;
close(sock);
return -1;
}
}
return sock;
}
static BOOL server_connectSsl(epoll_server_t *server)
{
if (server->serverData->sslContext == NULL) return TRUE;
server->ssl = ssl_new(server->fd, server->serverData->sslContext);
if (server->ssl == NULL) {
printf("Could not get SSL client from context\n");
return FALSE;
}
if (!ssl_connectServer(server)) {
printf("SSL connect failed.\n");
SSL_free(server->ssl);
server->ssl = NULL;
return FALSE;
}
return TRUE;
}