#include "server.h"
#include "proxy.h"
#include "helper.h"
#include "epoll.h"
#include "tmpbuffer.h"
#include <time.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <socket.h>
#define AD_PORT 3268
#define MSGID_BIND 1
#define MAX_SERVERS 10
static server_t *servers = NULL;
static int serverCount = 0;
static void server_init();
static server_t *server_create(const char *server);
static void server_free(epoll_server_t *server);
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup);
static void server_flush(epoll_server_t * const server);
static BOOL server_ensureConnected(server_t *server);
static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len);
// Generate a message ID for request to AD
static inline uint32_t msgId()
{
static uint32_t id = 1336;
if (++id < 2) id = 2;
return id;
}
// Setting up server(s)
void server_setBind(const char *server, const char *bind)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->bind, BINDLEN, "%s", bind) >= BINDLEN) printf("Warning: BindDN for %s is too long.\n", server);
}
void server_setPassword(const char *server, const char *password)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->password, PWLEN, "%s", password) >= PWLEN) printf("Warning: BindPW for %s is too long.\n", server);
}
void server_setBase(const char *server, const char *base)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->base, BASELEN, "%s", base) >= BASELEN) printf("Warning: SearchBase for %s is too long.\n", server);
entry->baseLen = normalize_dn(entry->base, entry->base, min(strlen(entry->base), BASELEN - 1));
entry->base[entry->baseLen] = '\0';
}
void server_setHomeTemplate(const char *server, const char *hometemplate)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->homeTemplate, MOUNTLEN, "%s", hometemplate) >= MOUNTLEN) printf("Warning: Home Template for %s is too long.\n", server);
// TODO: Better template system. Using a format string is too lazy
BOOL b = FALSE;
char *s = entry->homeTemplate;
int count = 0;
while (*s) {
if (b) {
if (*s != '%') count++;
b = FALSE;
} else if (*s == '%') b = TRUE;
if (count > 5) *s = '_';
if (*s == '\\') *s = '/';
s++;
}
if (count > 5) printf("WARNING: Too many '%%' in Home Template for %s. Don't forget to replace literal '%%' with '%%%%'\n", server);
}
void server_initServers()
{
int i;
printf("%d servers configured.\n", serverCount);
for (i = 0; i < serverCount; ++i) {
printf("%s:\n Bind: %s\n Base: %s\n", servers[i].addr, servers[i].bind, servers[i].base);
server_ensureConnected(&servers[i]);
}
}
// What the proxy calls
server_t *server_getFromBase(struct string *in)
{
int i;
char buffer[TMPLEN];
const size_t searchLen = normalize_dn(buffer, in->s, min(in->l, TMPLEN - 1));
buffer[searchLen] = '\0';
// Now buffer contains the normalized wanted bind/domain/whatev. Try to find a match in the server list
for (i = 0; i < serverCount; ++i) {
if (searchLen < servers[i].baseLen) continue;
if (strcmp(servers[i].base, buffer + (searchLen - servers[i].baseLen)) == 0) {
return &servers[i];
}
}
return NULL;
}
uint32_t server_searchRequest(server_t *server, struct SearchRequest *req)
{
if (!server_ensureConnected(server)) return 0;
const uint32_t msgid = msgId();
const size_t bodyLen = fmt_ldapsearchrequest(NULL, req);
const size_t headerLen = fmt_ldapmessage(NULL, msgid, SearchRequest, bodyLen);
char buffer[bodyLen + 50];
char *bufoff = buffer + 50;
fmt_ldapsearchrequest(bufoff, req);
fmt_ldapmessage(bufoff - headerLen, msgid, SearchRequest, bodyLen);
epoll_server_t * const s = &server->con;
server_send(s, bufoff - headerLen, headerLen + bodyLen, FALSE);
return msgid;
}
uint32_t server_tryUserBind(server_t *server, struct string *binddn, struct string *password)
{
epoll_server_t *con = calloc(1, sizeof(epoll_server_t));
con->serverData = server;
con->fd = -1;
con->bound = FALSE;
con->dynamic = TRUE;
printf("Connecting to AD '%s' for %.*ss bind...\n", server->addr, (int)binddn->l, binddn->s);
con->sbPos = con->sbFill = 0;
int sock;
if (server->lastLookup + 300 < time(NULL)) {
sock = helper_connect4(server->addr, AD_PORT, server->ip);
if (sock == -1) {
printf("[ADB] Could not resolve/connect to AD server %s\n", server->addr);
server_free(con);
return 0;
}
} else {
sock = socket_tcp4b();
if (sock == -1) {
printf("[ADB] Could not allocate socket for connection to AD\n");
server_free(con);
return 0;
}
if (socket_connect4(sock, server->ip, AD_PORT) == -1) {
printf("[ADB] Could not connect to cached IP of %s\n", server->addr);
close(sock);
server_free(con);
return 0;
}
}
printf("[ADB] Connected, binding....\n");
helper_nonblock(sock);
con->fd = sock;
con->callback = &server_callback;
if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) {
printf("[ADB] epoll_add failed for ad server %s\n", server->addr);
server_free(con);
return 0;
}
// Now bind - TODO: SASL (DIGEST-MD5?)
const uint32_t id = msgId();
const size_t bodyLen = fmt_ldapbindrequeststring(NULL, 3, binddn, password);
const size_t headerLen = fmt_ldapmessage(NULL, id, BindRequest, bodyLen);
char buffer[bodyLen + 50];
char *bufoff = buffer + 50;
if (headerLen >= 50) {
printf("[ADB] bind too long for %s\n", server->addr);
server_free(con);
return 0;
}
fmt_ldapbindrequeststring(bufoff, 3, binddn, password);
fmt_ldapmessage(bufoff - headerLen, id, BindRequest, bodyLen);
server_send(con, bufoff - headerLen, bodyLen + headerLen, FALSE);
return id;
}
//
// Private stuff
static void server_init()
{
if (servers != NULL) return;
servers = calloc(MAX_SERVERS, sizeof(server_t));
}
static server_t *server_create(const char *server)
{
int i;
server_init();
for (i = 0; i < serverCount; ++i) {
if (strcmp(servers[i].addr, server) == 0) return &servers[i];
}
if (serverCount >= MAX_SERVERS) {
printf("Cannot add server %s: Too many servers.\n", server);
return NULL;
}
snprintf(servers[serverCount].addr, ADDRLEN, "%s", server);
servers[serverCount].con.fd = -1;
servers[serverCount].con.serverData = &servers[serverCount];
return &servers[serverCount++];
}
static void server_free(epoll_server_t *server)
{
server->bound = FALSE;
if (server->fd != -1) close(server->fd);
server->fd = -1;
server->sbPos = server->sbFill = 0;
if (server->dynamic) {
printf("Freeing Bind-AD-Connection\n");
free(server->sendBuffer);
free(server);
}
}
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup)
{
epoll_server_t *server = (epoll_server_t *)data;
if (doCleanup) {
server_free(server);
return;
}
if (haveIn) {
for (;;) {
if (server->rbPos >= MAXMSGLEN) {
printf("[AD->Proxy] Read buffer overflow. Disconnecting.\n");
server_free(server);
return;
}
const size_t buflen = MAXMSGLEN - server->rbPos;
const ssize_t ret = read(server->fd, server->readBuffer + server->rbPos, buflen);
printf("AD read %d (err %d)\n", (int)ret, errno);
if (ret < 0 && errno == EINTR) continue;
if (ret < 0 && errno == EAGAIN) break;
if (ret <= 0) {
printf("AD gone while reading.\n");
server_free(server);
return;
}
server->rbPos += ret;
// Request complete?
for (;;) {
size_t consumed, len;
consumed = scan_asn1SEQUENCE(server->readBuffer, server->readBuffer + server->rbPos, &len);
if (consumed == 0) break; // Length-Header not complete
len += consumed;
if (len > server->rbPos) break; // Body not complete
printf("[AD] Received complete reply...\n");
if (proxy_fromServer(server, len) == -1) {
if (server->dynamic) {
server_free(server);
return;
}
printf("Error parsing reply from AD.\n");
// Let's try to go on with the next message....
}
// Shift remaining buffer contents
if (len == server->rbPos) {
server->rbPos = 0;
break;
}
memmove(server->readBuffer, server->readBuffer + len, server->rbPos - len);
server->rbPos -= len;
}
if ((ssize_t)buflen > ret) break; // Read less than buffer len, epoll will fire again
}
}
if (haveOut) server_flush(server);
}
int server_send(epoll_server_t *server, const char *buffer, size_t len, const BOOL cork)
{
if (server->sbFill == 0 && !cork) {
// Nothing in send buffer, fire away
const int ret = write(server->fd, buffer, len);
if (ret == 0 || (ret < 0 && errno != EINTR && errno != EAGAIN)) {
printf("Server gone when trying to send.\n");
return -1;
}
server->lastActive = time(NULL);
if (ret == (int)len) return 0;
// Couldn't send everything, continue with buffering logic below
if (ret > 0) {
printf("[AD] Partial send (%d of %d)\n", ret, (int)len);
buffer += ret;
len -= (size_t)ret;
}
}
// Buffer...
server_ensureSendBuffer(server, len);
// Finally append to buffer
memcpy(server->sendBuffer + server->sbFill, buffer, len);
server->sbFill += len;
if (!cork) server_flush(server);
return 0;
}
static void server_flush(epoll_server_t * const server)
{
while (server->sbPos < server->sbFill) {
const int tosend = server->sbFill - server->sbPos;
const int ret = write(server->fd, server->sendBuffer + server->sbPos, tosend);
if (ret < 0 && errno == EINTR) continue;
if (ret < 0 && errno == EAGAIN) return;
if (ret <= 0) {
printf("Connection to AD Server failed while flushing (ret: %d, errno: %d)\n", ret, errno);
return;
}
server->lastActive = time(NULL);
server->sbPos += ret;
if (ret != tosend) return;
}
server->sbPos = server->sbFill = 0;
}
static BOOL server_ensureConnected(server_t *server)
{
epoll_server_t * const con = &server->con;
if (con->fd != -1 && con->lastActive + 120 > time(NULL)) return TRUE;
if (con->fd != -1) close(con->fd);
con->bound = FALSE;
printf("Connecting to AD '%s'...\n", server->addr);
con->sbPos = con->sbFill = 0;
int sock;
if (server->lastLookup + 300 < time(NULL)) {
sock = helper_connect4(server->addr, AD_PORT, server->ip);
if (sock == -1) {
printf("Could not resolve/connect to AD server %s\n", server->addr);
return FALSE;
}
} else {
sock = socket_tcp4b();
if (sock == -1) {
printf("Could not allocate socket for connection to AD\n");
return FALSE;
}
if (socket_connect4(sock, server->ip, AD_PORT) == -1) {
printf("Could not connect to cached IP of %s\n", server->addr);
close(sock);
return FALSE;
}
}
printf("Connected, binding....\n");
helper_nonblock(sock);
con->fd = sock;
con->callback = &server_callback;
if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) {
printf("epoll_add failed for ad server %s\n", server->addr);
close(con->fd);
con->fd = -1;
return FALSE;
}
// Now bind - TODO: SASL (DIGEST-MD5?)
const size_t bodyLen = fmt_ldapbindrequest(NULL, 3, server->bind, server->password);
const size_t headerLen = fmt_ldapmessage(NULL, MSGID_BIND, BindRequest, bodyLen);
char buffer[bodyLen + 50];
char *bufoff = buffer + 50;
if (headerLen >= 50) {
printf("[AD] bind too long for %s\n", server->addr);
close(con->fd);
con->fd = -1;
return FALSE;
}
fmt_ldapbindrequest(bufoff, 3, server->bind, server->password);
fmt_ldapmessage(bufoff - headerLen, MSGID_BIND, BindRequest, bodyLen);
server_send(con, bufoff - headerLen, bodyLen + headerLen, FALSE);
return TRUE;
}
static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len)
{
if (len > 1000000) bail("server_ensureSendBuffer: request too large!");
if (s->sbLen - s->sbFill < len) {
if (s->sbPos != 0) {
memmove(s->sendBuffer, s->sendBuffer + s->sbPos, s->sbFill - s->sbPos);
s->sbFill -= s->sbPos;
s->sbPos = 0;
}
if (s->sbLen - s->sbFill < len) {
helper_realloc(&s->sendBuffer, &s->sbLen, s->sbLen + len + 1000, "server_ensureSendBuffer");
}
}
}