summaryrefslogblamecommitdiffstats
path: root/openssl.c
blob: bc49eb88be5bc06469edb0cd9c832cc1f8f29b91 (plain) (tree)
1
2
3
4
5
6
7
8
9
10

                    


                           
 



                                              
                             
                                 
 

                                                                       















                                                        

                                                                           








                                                                       
                                                                    


                                                                                                
                                                                                                    


                   
                                               





                                                                       
                                                                                                      
                                                                                    






                                                                                                       
         



                                        










                                         

















                                                          





                                                        
                                                                                                                            




                                            
                       

                                                          


                                              
                                                              
                 









                                                                                  


                                                           
                                                                                           






                                                                                       
                                                                                                                       



                                                              

                                                                                                                          




                                       



                                                              
                                                        
                                        


                                                                                              
                        


                    























                                                                      

                                                 
 

                                                                                 









                                                                                               
                                                                                                           









                                                                                                     
         









                                                                                                               
#include "openssl.h"
#include "helper.h"
#include <string.h>
#include <openssl/conf.h>
#include <openssl/x509v3.h>

#if OPENSSL_VERSION_NUMBER < 0x10100000L
#define ASN1_STRING_get0_data ASN1_STRING_data
#endif

static BOOL initDone = FALSE;
static const EVP_MD *sha1 = NULL;

static BOOL spc_verify_cert_hostname(X509 *cert, const char *hostname);

void ssl_printErrors(char *bailMsg)
{
	unsigned long err;
	while ((err = ERR_get_error())) {
		char *msg = ERR_error_string(err, NULL);
		printf("OpenSSL: %s\n", msg);
	}
	if (bailMsg != NULL) bail(bailMsg);
}

BOOL ssl_init()
{
	if (initDone) return TRUE;
	SSL_load_error_strings();
	SSL_library_init();
	OpenSSL_add_all_algorithms();
	sha1 = EVP_get_digestbyname("sha1");
	if (sha1 == NULL) ssl_printErrors("Could not load SHA-1 digest\n");
	return TRUE;
}

SSL_CTX* ssl_newServerCtx(char *certfile, char *keyfile)
{
	const SSL_METHOD *m = SSLv23_server_method();
	if (m == NULL) ssl_printErrors("newServerCtx: method is NULL");
	SSL_CTX *ctx = SSL_CTX_new(m);
	if (ctx == NULL) ssl_printErrors("newServerCtx: ctx is NULL");
	SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
	SSL_CTX_use_certificate_file(ctx, certfile, SSL_FILETYPE_PEM);
	SSL_CTX_use_PrivateKey_file(ctx, keyfile, SSL_FILETYPE_PEM);
	if (!SSL_CTX_check_private_key(ctx)) ssl_printErrors("Could not load cert/private key");
	SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); // SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER
	return ctx;
}

SSL_CTX* ssl_newClientCtx(const char *cabundle)
{
	const SSL_METHOD *m = SSLv23_client_method();
	if (m == NULL) ssl_printErrors("newClientCtx: method is NULL");
	SSL_CTX *ctx = SSL_CTX_new(m);
	if (ctx == NULL) ssl_printErrors("newClientCtx: ctx is NULL");
	SSL_CTX_set_options(ctx, SSL_OP_NO_SSLv2);
	SSL_CTX_set_mode(ctx, SSL_MODE_ENABLE_PARTIAL_WRITE); // | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER
	if (cabundle != NULL && cabundle[0] != '\0' && strcmp(cabundle, "*") != 0) {
		if (SSL_CTX_load_verify_locations(ctx, cabundle, NULL) == 0) {
			ssl_printErrors("Loading trusted certs failed");
			exit(1);
		}
		SSL_CTX_set_default_verify_paths(ctx);
		printf("Loaded ca-bundle '%s'\n", cabundle);
		//SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL); <- do this manually after SSL_connect
	}
	return ctx;
}

SSL *ssl_new(int clientFd, SSL_CTX *ctx)
{
	SSL *ssl = SSL_new(ctx);
	if (ssl == NULL) {
		ssl_printErrors(NULL);
		return NULL;
	}
	if (!SSL_set_fd(ssl, clientFd)) {
		ssl_printErrors(NULL);
		SSL_free(ssl);
		return NULL;
	}
	return ssl;
}

BOOL ssl_acceptClient(epoll_client_t *client)
{
	if (client->sslAccepted) return TRUE;
	int ret = SSL_accept(client->ssl);
	if (ret == 1) {
		client->sslAccepted = TRUE;
		return TRUE;
	}
	if (ret < 0) {
		int err = SSL_get_error(client->ssl, ret);
		if (SSL_BLOCKED(err)) return TRUE;
	}
	return FALSE;
}

BOOL ssl_connectServer(epoll_server_t *server)
{
	if (server->sslConnected) return TRUE;
	int ret = SSL_connect(server->ssl);
	if (ret == 1) {
		if (!ssl_checkCertificateHash(server)) {
			printf("Warning: Certificate invalid, refusing to talk to server (%s)\n", server->serverData->addr);
			return FALSE;
		}
		server->sslConnected = TRUE;
		return TRUE;
	}
	if (ret <= 0) {
		int err = SSL_get_error(server->ssl, ret);
		if (SSL_BLOCKED(err)) return TRUE;
		if (err == SSL_ERROR_SSL) {
			ssl_printErrors(NULL);
		} else {
			printf("SSL connect error %d\n", err);
		}
	}
	return FALSE;
}

BOOL ssl_checkCertificateHash(epoll_server_t *server)
{
	if (server->ssl == NULL) {
		printf("Bug: Asked to check certificate of non-SSL connection\n");
		return FALSE;
	}
	// Get server cert
	X509 *cert = SSL_get_peer_certificate(server->ssl);
	if (cert == NULL) {
		printf("Error: Server %s has no certificate!\n", server->serverData->addr);
		return FALSE;
	}
	// Do we have a cabundle set?
	if (server->serverData->cabundle[0] != '\0') {
		BOOL hostOk = spc_verify_cert_hostname(cert, server->serverData->addr);
		X509_free(cert);
		if (!hostOk) {
			printf("Error: Server certificate's host name doesn't match '%s'\n", server->serverData->addr);
			return FALSE;
		}
		long res = SSL_get_verify_result(server->ssl);
		if(X509_V_OK != res) {
			printf("Error: Server %s's certificate cannot be verified with given cabundle %s (result: %ld)\n",
					server->serverData->addr, server->serverData->cabundle, res);
			return FALSE;
		}
		return TRUE;
	}
	// No cabundle, try fingerprint
	for (int i = 0; i < FINGERPRINTLEN; ++i) {
		if (server->serverData->fingerprint[i] != 0) {
			unsigned char md[EVP_MAX_MD_SIZE];
			unsigned int n = 20;
			X509_digest(cert, sha1, md, &n);
			X509_free(cert);
			return n == 20 && memcmp(md, server->serverData->fingerprint, n) == 0;
		}
	}
	X509_free(cert);
	return TRUE;
}

static BOOL wcmatch(const char *pattern, const char *string)
{
	if (pattern[0] != '*')
		return strcasecmp(string, pattern) == 0;
	if (pattern[1] != '.')
		return FALSE;
	if (strcasecmp(string, pattern + 2) == 0)
		return TRUE;
	// Match from back of string
	const size_t slen = strlen(string);
	const size_t plen = strlen(pattern + 1);
	if (slen < plen)
		return FALSE;
	return strcasecmp(string + (slen - plen), pattern + 1) == 0;
}

// Based on
// https://wiki.openssl.org/index.php/Hostname_validation
static BOOL spc_verify_cert_hostname(X509 *cert, const char *hostname)
{
	BOOL ok = FALSE;
	char name[256];
	X509_NAME *subj;
	int i;
	int san_names_nb = -1;
	STACK_OF(GENERAL_NAME) *san_names = NULL;

	// Try to extract the names within the SAN extension from the certificate
	san_names = X509_get_ext_d2i(cert, NID_subject_alt_name, NULL, NULL);
	if (san_names != NULL) {
		san_names_nb = sk_GENERAL_NAME_num(san_names);

		// Check each name within the extension
		for (i = 0; i < san_names_nb; i++) {
			const GENERAL_NAME *current_name = sk_GENERAL_NAME_value(san_names, i);
			if (current_name->type != GEN_DNS)
				continue;

			// Current name is a DNS name, let's check it
			const char *dns_name = (const char*)ASN1_STRING_get0_data(current_name->d.dNSName);
			// Make sure there isn't an embedded null character in the DNS name
			if ((size_t) ASN1_STRING_length(current_name->d.dNSName) != strlen(dns_name))
				break;
			// Compare expected hostname with the DNS name
			if (wcmatch(dns_name, hostname)) {
				ok = TRUE;
				break;
			}
		}
		sk_GENERAL_NAME_pop_free(san_names, GENERAL_NAME_free);
	}

	if (!ok && (subj = X509_get_subject_name(cert))) {
		const size_t len = (size_t)X509_NAME_get_text_by_NID(subj, NID_commonName, name, sizeof(name));
		if (len > 0 && strlen(name) == len && wcmatch(name, hostname)) {
			ok = TRUE;
		}
	}

	return ok;
}