#include "server.h"
#include "proxy.h"
#include "helper.h"
#include "epoll.h"
#include "tmpbuffer.h"
#include <time.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <socket.h>
#define ADDRLEN 40
#define BINDLEN 200
#define PWLEN 40
#define BASELEN 100
#define ALIASLEN 40
#define AD_PORT 3268
#define MSGID_BIND 1
typedef struct {
size_t aliasLen;
size_t baseLen;
char ip[4];
time_t lastLookup;
char addr[ADDRLEN];
char bind[BINDLEN];
char password[PWLEN];
char base[BASELEN];
char alias[ALIASLEN];
epoll_server_t con;
} server_t;
#define MAX_SERVERS 10
static server_t *servers = NULL;
static int serverCount = 0;
static void server_init();
static server_t *server_create(const char *server);
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup);
static void server_flush(epoll_server_t * const server);
static BOOL server_ensureConnected(const int index);
static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len);
// Generate a message ID for request to AD
static inline uint32_t msgId()
{
static uint32_t id = 1336;
if (++id < 2) id = 2;
return id;
}
// Setting up server(s)
void server_setBind(const char *server, const char *bind)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->bind, BINDLEN, "%s", bind) >= BINDLEN) printf("Warning: BindDN for %s is too long.\n", server);
}
void server_setPassword(const char *server, const char *password)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->password, PWLEN, "%s", password) >= PWLEN) printf("Warning: BindPW for %s is too long.\n", server);
}
void server_setBase(const char *server, const char *base)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->base, BASELEN, "%s", base) >= BASELEN) printf("Warning: SearchBase for %s is too long.\n", server);
entry->baseLen = normalize_dn(entry->base, entry->base, min(strlen(entry->base), BASELEN - 1));
entry->base[entry->baseLen] = '\0';
}
void server_setAlias(const char *server, const char *alias)
{
server_t *entry = server_create(server);
if (entry == NULL) return;
if (snprintf(entry->alias, ALIASLEN, "%s", alias) >= ALIASLEN) printf("Warning: SearchBase Alias for %s is too long.\n", server);
entry->aliasLen = normalize_dn(entry->alias, entry->alias, min(strlen(entry->alias), ALIASLEN - 1));
entry->alias[entry->aliasLen] = '\0';
}
void server_initServers()
{
int i;
printf("%d servers configured.\n", serverCount);
for (i = 0; i < serverCount; ++i) {
printf("%s:\n Bind: %s\n Base: %s\n Proxy Alias: %s\n", servers[i].addr, servers[i].bind, servers[i].base, servers[i].alias);
server_ensureConnected(i);
}
}
// What the proxy calls
int server_aliasToBase(struct string *in, struct string *out)
{
int i;
char buffer[TMPLEN];
const size_t searchLen = normalize_dn(buffer, in->s, min(in->l, TMPLEN - 1));
buffer[searchLen] = '\0';
// Now buffer contains the normalized wanted alias. Try to find a match in the server list
for (i = 0; i < serverCount; ++i) {
if (searchLen < servers[i].aliasLen) continue;
if (strcmp(servers[i].alias, buffer + (searchLen - servers[i].aliasLen)) == 0) {
// Found, handle
tmpbuffer_format(out, "%.*s%s", (int)(searchLen - servers[i].aliasLen), buffer, servers[i].base);
return i;
}
}
return -1;
}
int server_baseToAlias(struct string *in, struct string *out)
{
int i;
char buffer[TMPLEN];
const size_t searchLen = normalize_dn(buffer, in->s, min(in->l, TMPLEN - 1));
buffer[searchLen] = '\0';
// Now buffer contains the normalized wanted base. Try to find a match in the server list
for (i = 0; i < serverCount; ++i) {
printf("Comparing %s (%s) to %s\n", buffer, buffer + (searchLen - servers[i].baseLen), servers[i].base);
if (searchLen < servers[i].baseLen) continue;
if (strcmp(servers[i].base, buffer + (searchLen - servers[i].baseLen)) == 0) {
// Found, handle
tmpbuffer_format(out, "%.*s%s", (int)(searchLen - servers[i].baseLen), buffer, servers[i].alias);
printf("Match, returning %s\n", out->s);
return i;
}
}
return -1;
}
uint32_t server_searchRequest(int server, struct SearchRequest *req)
{
if (!server_ensureConnected(server)) return 0;
const uint32_t msgid = msgId();
const size_t bodyLen = fmt_ldapsearchrequest(NULL, req);
const size_t headerLen = fmt_ldapmessage(NULL, msgid, SearchRequest, bodyLen);
char buffer[bodyLen + 50];
char *bufoff = buffer + 50;
fmt_ldapsearchrequest(bufoff, req);
fmt_ldapmessage(bufoff - headerLen, msgid, SearchRequest, bodyLen);
epoll_server_t * const s = &servers[server].con;
server_send(s, bufoff - headerLen, headerLen + bodyLen, FALSE);
return msgid;
}
//
// Private stuff
static void server_init()
{
if (servers != NULL) return;
servers = calloc(MAX_SERVERS, sizeof(server_t));
}
static server_t *server_create(const char *server)
{
int i;
server_init();
for (i = 0; i < serverCount; ++i) {
if (strcmp(servers[i].addr, server) == 0) return &servers[i];
}
if (serverCount >= MAX_SERVERS) {
printf("Cannot add server %s: Too many servers.\n", server);
return NULL;
}
snprintf(servers[serverCount].addr, ADDRLEN, "%s", server);
servers[serverCount].con.fd = -1;
return &servers[serverCount++];
}
static void server_free(epoll_server_t *server)
{
server->bound = FALSE;
if (server->fd != -1) close(server->fd);
server->fd = -1;
server->sbPos = server->sbFill = 0;
}
static void server_callback(void *data, int haveIn, int haveOut, int doCleanup)
{
epoll_server_t *server = (epoll_server_t *)data;
if (doCleanup) {
server_free(server);
return;
}
if (haveIn) {
for (;;) {
if (server->rbPos >= MAXMSGLEN) {
printf("[AD->Proxy] Read buffer overflow. Disconnecting.\n");
server_free(server);
return;
}
const size_t buflen = MAXMSGLEN - server->rbPos;
const ssize_t ret = read(server->fd, server->readBuffer + server->rbPos, buflen);
printf("AD read %d (err %d)\n", (int)ret, errno);
if (ret < 0 && errno == EINTR) continue;
if (ret < 0 && errno == EAGAIN) break;
if (ret <= 0) {
printf("AD gone while reading.\n");
server_free(server);
return;
}
server->rbPos += ret;
// Request complete?
for (;;) {
size_t consumed, len;
consumed = scan_asn1SEQUENCE(server->readBuffer, server->readBuffer + server->rbPos, &len);
if (consumed == 0) break; // Length-Header not complete
len += consumed;
if (len > server->rbPos) break; // Body not complete
printf("[AD] Received complete reply...\n");
if (proxy_fromServer(server, len) == -1) {
printf("Error parsing reply from AD.\n");
server_free(server);
return;
}
// Shift remaining buffer contents
if (len == server->rbPos) {
server->rbPos = 0;
break;
}
memmove(server->readBuffer, server->readBuffer + len, server->rbPos - len);
server->rbPos -= len;
}
if ((ssize_t)buflen > ret) break; // Read less than buffer len, epoll will fire again
}
}
if (haveOut) server_flush(server);
}
int server_send(epoll_server_t *server, const char *buffer, size_t len, const BOOL cork)
{
if (server->sbFill == 0 && !cork) {
// Nothing in send buffer, fire away
const int ret = write(server->fd, buffer, len);
if (ret == 0 || (ret < 0 && errno != EINTR && errno != EAGAIN)) {
printf("Server gone when trying to send.\n");
return -1;
}
server->lastActive = time(NULL);
if (ret == (int)len) return 0;
// Couldn't send everything, continue with buffering logic below
if (ret > 0) {
printf("[AD] Partial send (%d of %d)\n", ret, (int)len);
buffer += ret;
len -= (size_t)ret;
}
}
// Buffer...
server_ensureSendBuffer(server, len);
// Finally append to buffer
memcpy(server->sendBuffer + server->sbFill, buffer, len);
server->sbFill += len;
if (!cork) server_flush(server);
return 0;
}
static void server_flush(epoll_server_t * const server)
{
while (server->sbPos < server->sbFill) {
const int tosend = server->sbFill - server->sbPos;
const int ret = write(server->fd, server->sendBuffer + server->sbPos, tosend);
if (ret < 0 && errno == EINTR) continue;
if (ret < 0 && errno == EAGAIN) return;
if (ret <= 0) {
printf("Connection to AD Server failed while flushing (ret: %d, errno: %d)\n", ret, errno);
return;
}
server->lastActive = time(NULL);
server->sbPos += ret;
if (ret != tosend) return;
}
server->sbPos = server->sbFill = 0;
}
static BOOL server_ensureConnected(const int index)
{
server_t * const server = &servers[index];
epoll_server_t * const con = &server->con;
if (con->fd != -1 && con->lastActive + 120 > time(NULL)) return TRUE;
if (con->fd != -1) close(con->fd);
con->bound = FALSE;
printf("Connecting to AD %s...\n", server->addr);
con->sbPos = con->sbFill = 0;
int sock;
if (server->lastLookup + 300 < time(NULL)) {
sock = helper_connect4(server->addr, AD_PORT, server->ip);
if (sock == -1) {
printf("Could not resolve/connect to AD server %s\n", server->addr);
return FALSE;
}
} else {
sock = socket_tcp4b();
if (sock == -1) {
printf("Could not allocate socket for connection to AD\n");
return FALSE;
}
if (socket_connect4(sock, server->ip, AD_PORT) == -1) {
printf("Could not connect to cached IP of %s\n", server->addr);
close(sock);
return FALSE;
}
}
printf("Connected, binding....\n");
helper_nonblock(sock);
con->fd = sock;
con->callback = &server_callback;
if (ePoll_add(EPOLLIN | EPOLLOUT | EPOLLET, (epoll_item_t*)con) == -1) {
printf("epoll_add failed for ad server %s\n", server->addr);
close(con->fd);
con->fd = -1;
return FALSE;
}
// Now bind
const size_t bodyLen = fmt_ldapbindrequest(NULL, 3, server->bind, server->password);
const size_t headerLen = fmt_ldapmessage(NULL, MSGID_BIND, BindResponse, bodyLen);
char buffer[bodyLen + 50];
char *bufoff = buffer + 50;
if (headerLen >= 50) {
printf("[AD] bind too long for %s\n", server->addr);
close(con->fd);
con->fd = -1;
return FALSE;
}
fmt_ldapbindrequest(bufoff, 3, server->bind, server->password);
fmt_ldapmessage(bufoff - headerLen, MSGID_BIND, BindRequest, bodyLen);
server_send(con, bufoff - headerLen, bodyLen + headerLen, FALSE);
return TRUE;
}
static void server_ensureSendBuffer(epoll_server_t * const s, const size_t len)
{
if (len > 1000000) bail("server_ensureSendBuffer: request too large!");
if (s->sbLen - s->sbFill < len) {
if (s->sbPos != 0) {
memmove(s->sendBuffer, s->sendBuffer + s->sbPos, s->sbFill - s->sbPos);
s->sbFill -= s->sbPos;
s->sbPos = 0;
}
if (s->sbLen - s->sbFill < len) {
helper_realloc(&s->sendBuffer, &s->sbLen, s->sbLen + len + 1000, "server_ensureSendBuffer");
}
}
}