diff options
Diffstat (limited to 'src/crypto/rsa.c')
| -rw-r--r-- | src/crypto/rsa.c | 413 |
1 files changed, 219 insertions, 194 deletions
diff --git a/src/crypto/rsa.c b/src/crypto/rsa.c index 16c67d822..be055d881 100644 --- a/src/crypto/rsa.c +++ b/src/crypto/rsa.c @@ -22,6 +22,7 @@ */ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL ); +FILE_SECBOOT ( PERMITTED ); #include <stdint.h> #include <stdlib.h> @@ -47,6 +48,28 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL ); #define EINFO_EACCES_VERIFY \ __einfo_uniqify ( EINFO_EACCES, 0x01, "RSA signature incorrect" ) +/** An RSA context */ +struct rsa_context { + /** Allocated memory */ + void *dynamic; + /** Modulus */ + bigint_element_t *modulus0; + /** Modulus size */ + unsigned int size; + /** Modulus length */ + size_t max_len; + /** Exponent */ + bigint_element_t *exponent0; + /** Exponent size */ + unsigned int exponent_size; + /** Input buffer */ + bigint_element_t *input0; + /** Output buffer */ + bigint_element_t *output0; + /** Temporary working space for modular exponentiation */ + void *tmp; +}; + /** * Identify RSA prefix * @@ -69,10 +92,9 @@ rsa_find_prefix ( struct digest_algorithm *digest ) { * * @v context RSA context */ -static void rsa_free ( struct rsa_context *context ) { +static inline void rsa_free ( struct rsa_context *context ) { free ( context->dynamic ); - context->dynamic = NULL; } /** @@ -88,8 +110,7 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len, unsigned int size = bigint_required_size ( modulus_len ); unsigned int exponent_size = bigint_required_size ( exponent_len ); bigint_t ( size ) *modulus; - bigint_t ( exponent_size ) *exponent; - size_t tmp_len = bigint_mod_exp_tmp_len ( modulus, exponent ); + size_t tmp_len = bigint_mod_exp_tmp_len ( modulus ); struct { bigint_t ( size ) modulus; bigint_t ( exponent_size ) exponent; @@ -98,9 +119,6 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len, uint8_t tmp[tmp_len]; } __attribute__ (( packed )) *dynamic; - /* Free any existing dynamic storage */ - rsa_free ( context ); - /* Allocate dynamic storage */ dynamic = malloc ( sizeof ( *dynamic ) ); if ( ! dynamic ) @@ -121,34 +139,6 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len, } /** - * Parse RSA integer - * - * @v integer Integer to fill in - * @v raw ASN.1 cursor - * @ret rc Return status code - */ -static int rsa_parse_integer ( struct asn1_cursor *integer, - const struct asn1_cursor *raw ) { - - /* Enter integer */ - memcpy ( integer, raw, sizeof ( *integer ) ); - asn1_enter ( integer, ASN1_INTEGER ); - - /* Skip initial sign byte if applicable */ - if ( ( integer->len > 1 ) && - ( *( ( uint8_t * ) integer->data ) == 0x00 ) ) { - integer->data++; - integer->len--; - } - - /* Fail if cursor or integer are invalid */ - if ( ! integer->len ) - return -EINVAL; - - return 0; -} - -/** * Parse RSA modulus and exponent * * @v modulus Modulus to fill in @@ -159,7 +149,6 @@ static int rsa_parse_integer ( struct asn1_cursor *integer, static int rsa_parse_mod_exp ( struct asn1_cursor *modulus, struct asn1_cursor *exponent, const struct asn1_cursor *raw ) { - struct asn1_bit_string bits; struct asn1_cursor cursor; int is_private; int rc; @@ -178,8 +167,8 @@ static int rsa_parse_mod_exp ( struct asn1_cursor *modulus, asn1_skip_any ( &cursor ); /* Enter privateKey, if present */ - if ( asn1_check_algorithm ( &cursor, - &rsa_encryption_algorithm ) == 0 ) { + if ( asn1_check_algorithm ( &cursor, &rsa_encryption_algorithm, + NULL ) == 0 ) { /* Skip privateKeyAlgorithm */ asn1_skip_any ( &cursor ); @@ -203,17 +192,15 @@ static int rsa_parse_mod_exp ( struct asn1_cursor *modulus, asn1_skip ( &cursor, ASN1_SEQUENCE ); /* Enter subjectPublicKey */ - if ( ( rc = asn1_integral_bit_string ( &cursor, &bits ) ) != 0 ) - return rc; - cursor.data = bits.data; - cursor.len = bits.len; + asn1_enter_bits ( &cursor, NULL ); /* Enter RSAPublicKey */ asn1_enter ( &cursor, ASN1_SEQUENCE ); } /* Extract modulus */ - if ( ( rc = rsa_parse_integer ( modulus, &cursor ) ) != 0 ) + memcpy ( modulus, &cursor, sizeof ( *modulus ) ); + if ( ( rc = asn1_enter_unsigned ( modulus ) ) != 0 ) return rc; asn1_skip_any ( &cursor ); @@ -222,7 +209,8 @@ static int rsa_parse_mod_exp ( struct asn1_cursor *modulus, asn1_skip ( &cursor, ASN1_INTEGER ); /* Extract publicExponent/privateExponent */ - if ( ( rc = rsa_parse_integer ( exponent, &cursor ) ) != 0 ) + memcpy ( exponent, &cursor, sizeof ( *exponent ) ); + if ( ( rc = asn1_enter_unsigned ( exponent ) ) != 0 ) return rc; return 0; @@ -231,29 +219,23 @@ static int rsa_parse_mod_exp ( struct asn1_cursor *modulus, /** * Initialise RSA cipher * - * @v ctx RSA context + * @v context RSA context * @v key Key - * @v key_len Length of key * @ret rc Return status code */ -static int rsa_init ( void *ctx, const void *key, size_t key_len ) { - struct rsa_context *context = ctx; +static int rsa_init ( struct rsa_context *context, + const struct asn1_cursor *key ) { struct asn1_cursor modulus; struct asn1_cursor exponent; - struct asn1_cursor cursor; int rc; /* Initialise context */ memset ( context, 0, sizeof ( *context ) ); - /* Initialise cursor */ - cursor.data = key; - cursor.len = key_len; - /* Parse modulus and exponent */ - if ( ( rc = rsa_parse_mod_exp ( &modulus, &exponent, &cursor ) ) != 0 ){ + if ( ( rc = rsa_parse_mod_exp ( &modulus, &exponent, key ) ) != 0 ){ DBGC ( context, "RSA %p invalid modulus/exponent:\n", context ); - DBGC_HDA ( context, 0, cursor.data, cursor.len ); + DBGC_HDA ( context, 0, key->data, key->len ); goto err_parse; } @@ -281,18 +263,6 @@ static int rsa_init ( void *ctx, const void *key, size_t key_len ) { } /** - * Calculate RSA maximum output length - * - * @v ctx RSA context - * @ret max_len Maximum output length - */ -static size_t rsa_max_len ( void *ctx ) { - struct rsa_context *context = ctx; - - return context->max_len; -} - -/** * Perform RSA cipher operation * * @v context RSA context @@ -320,111 +290,158 @@ static void rsa_cipher ( struct rsa_context *context, /** * Encrypt using RSA * - * @v ctx RSA context + * @v key Key * @v plaintext Plaintext - * @v plaintext_len Length of plaintext * @v ciphertext Ciphertext * @ret ciphertext_len Length of ciphertext, or negative error */ -static int rsa_encrypt ( void *ctx, const void *plaintext, - size_t plaintext_len, void *ciphertext ) { - struct rsa_context *context = ctx; +static int rsa_encrypt ( const struct asn1_cursor *key, + const struct asn1_cursor *plaintext, + struct asn1_builder *ciphertext ) { + struct rsa_context context; void *temp; uint8_t *encoded; - size_t max_len = ( context->max_len - 11 ); - size_t random_nz_len = ( max_len - plaintext_len + 8 ); + size_t max_len; + size_t random_nz_len; int rc; + DBGC ( &context, "RSA %p encrypting:\n", &context ); + DBGC_HDA ( &context, 0, plaintext->data, plaintext->len ); + + /* Initialise context */ + if ( ( rc = rsa_init ( &context, key ) ) != 0 ) + goto err_init; + + /* Calculate lengths */ + max_len = ( context.max_len - 11 ); + random_nz_len = ( max_len - plaintext->len + 8 ); + /* Sanity check */ - if ( plaintext_len > max_len ) { - DBGC ( context, "RSA %p plaintext too long (%zd bytes, max " - "%zd)\n", context, plaintext_len, max_len ); - return -ERANGE; + if ( plaintext->len > max_len ) { + DBGC ( &context, "RSA %p plaintext too long (%zd bytes, max " + "%zd)\n", &context, plaintext->len, max_len ); + rc = -ERANGE; + goto err_sanity; } - DBGC ( context, "RSA %p encrypting:\n", context ); - DBGC_HDA ( context, 0, plaintext, plaintext_len ); /* Construct encoded message (using the big integer output * buffer as temporary storage) */ - temp = context->output0; + temp = context.output0; encoded = temp; encoded[0] = 0x00; encoded[1] = 0x02; if ( ( rc = get_random_nz ( &encoded[2], random_nz_len ) ) != 0 ) { - DBGC ( context, "RSA %p could not generate random data: %s\n", - context, strerror ( rc ) ); - return rc; + DBGC ( &context, "RSA %p could not generate random data: %s\n", + &context, strerror ( rc ) ); + goto err_random; } encoded[ 2 + random_nz_len ] = 0x00; - memcpy ( &encoded[ context->max_len - plaintext_len ], - plaintext, plaintext_len ); + memcpy ( &encoded[ context.max_len - plaintext->len ], + plaintext->data, plaintext->len ); + + /* Create space for ciphertext */ + if ( ( rc = asn1_grow ( ciphertext, context.max_len ) ) != 0 ) + goto err_grow; /* Encipher the encoded message */ - rsa_cipher ( context, encoded, ciphertext ); - DBGC ( context, "RSA %p encrypted:\n", context ); - DBGC_HDA ( context, 0, ciphertext, context->max_len ); + rsa_cipher ( &context, encoded, ciphertext->data ); + DBGC ( &context, "RSA %p encrypted:\n", &context ); + DBGC_HDA ( &context, 0, ciphertext->data, context.max_len ); + + /* Free context */ + rsa_free ( &context ); - return context->max_len; + return 0; + + err_grow: + err_random: + err_sanity: + rsa_free ( &context ); + err_init: + return rc; } /** * Decrypt using RSA * - * @v ctx RSA context + * @v key Key * @v ciphertext Ciphertext - * @v ciphertext_len Ciphertext length * @v plaintext Plaintext - * @ret plaintext_len Plaintext length, or negative error + * @ret rc Return status code */ -static int rsa_decrypt ( void *ctx, const void *ciphertext, - size_t ciphertext_len, void *plaintext ) { - struct rsa_context *context = ctx; +static int rsa_decrypt ( const struct asn1_cursor *key, + const struct asn1_cursor *ciphertext, + struct asn1_builder *plaintext ) { + struct rsa_context context; void *temp; uint8_t *encoded; uint8_t *end; uint8_t *zero; uint8_t *start; - size_t plaintext_len; + size_t len; + int rc; + + DBGC ( &context, "RSA %p decrypting:\n", &context ); + DBGC_HDA ( &context, 0, ciphertext->data, ciphertext->len ); + + /* Initialise context */ + if ( ( rc = rsa_init ( &context, key ) ) != 0 ) + goto err_init; /* Sanity check */ - if ( ciphertext_len != context->max_len ) { - DBGC ( context, "RSA %p ciphertext incorrect length (%zd " + if ( ciphertext->len != context.max_len ) { + DBGC ( &context, "RSA %p ciphertext incorrect length (%zd " "bytes, should be %zd)\n", - context, ciphertext_len, context->max_len ); - return -ERANGE; + &context, ciphertext->len, context.max_len ); + rc = -ERANGE; + goto err_sanity; } - DBGC ( context, "RSA %p decrypting:\n", context ); - DBGC_HDA ( context, 0, ciphertext, ciphertext_len ); /* Decipher the message (using the big integer input buffer as * temporary storage) */ - temp = context->input0; + temp = context.input0; encoded = temp; - rsa_cipher ( context, ciphertext, encoded ); + rsa_cipher ( &context, ciphertext->data, encoded ); /* Parse the message */ - end = ( encoded + context->max_len ); - if ( ( encoded[0] != 0x00 ) || ( encoded[1] != 0x02 ) ) - goto invalid; + end = ( encoded + context.max_len ); + if ( ( encoded[0] != 0x00 ) || ( encoded[1] != 0x02 ) ) { + rc = -EINVAL; + goto err_invalid; + } zero = memchr ( &encoded[2], 0, ( end - &encoded[2] ) ); - if ( ! zero ) - goto invalid; + if ( ! zero ) { + DBGC ( &context, "RSA %p invalid decrypted message:\n", + &context ); + DBGC_HDA ( &context, 0, encoded, context.max_len ); + rc = -EINVAL; + goto err_invalid; + } start = ( zero + 1 ); - plaintext_len = ( end - start ); + len = ( end - start ); + + /* Create space for plaintext */ + if ( ( rc = asn1_grow ( plaintext, len ) ) != 0 ) + goto err_grow; /* Copy out message */ - memcpy ( plaintext, start, plaintext_len ); - DBGC ( context, "RSA %p decrypted:\n", context ); - DBGC_HDA ( context, 0, plaintext, plaintext_len ); + memcpy ( plaintext->data, start, len ); + DBGC ( &context, "RSA %p decrypted:\n", &context ); + DBGC_HDA ( &context, 0, plaintext->data, len ); - return plaintext_len; + /* Free context */ + rsa_free ( &context ); - invalid: - DBGC ( context, "RSA %p invalid decrypted message:\n", context ); - DBGC_HDA ( context, 0, encoded, context->max_len ); - return -EINVAL; + return 0; + + err_grow: + err_invalid: + err_sanity: + rsa_free ( &context ); + err_init: + return rc; } /** @@ -458,9 +475,9 @@ static int rsa_encode_digest ( struct rsa_context *context, /* Sanity check */ max_len = ( context->max_len - 11 ); if ( digestinfo_len > max_len ) { - DBGC ( context, "RSA %p %s digestInfo too long (%zd bytes, max" - "%zd)\n", - context, digest->name, digestinfo_len, max_len ); + DBGC ( context, "RSA %p %s digestInfo too long (%zd bytes, " + "max %zd)\n", context, digest->name, digestinfo_len, + max_len ); return -ERANGE; } DBGC ( context, "RSA %p encoding %s digest:\n", @@ -488,137 +505,149 @@ static int rsa_encode_digest ( struct rsa_context *context, /** * Sign digest value using RSA * - * @v ctx RSA context + * @v key Key * @v digest Digest algorithm * @v value Digest value * @v signature Signature - * @ret signature_len Signature length, or negative error + * @ret rc Return status code */ -static int rsa_sign ( void *ctx, struct digest_algorithm *digest, - const void *value, void *signature ) { - struct rsa_context *context = ctx; - void *temp; +static int rsa_sign ( const struct asn1_cursor *key, + struct digest_algorithm *digest, const void *value, + struct asn1_builder *signature ) { + struct rsa_context context; int rc; - DBGC ( context, "RSA %p signing %s digest:\n", context, digest->name ); - DBGC_HDA ( context, 0, value, digest->digestsize ); + DBGC ( &context, "RSA %p signing %s digest:\n", + &context, digest->name ); + DBGC_HDA ( &context, 0, value, digest->digestsize ); - /* Encode digest (using the big integer output buffer as - * temporary storage) - */ - temp = context->output0; - if ( ( rc = rsa_encode_digest ( context, digest, value, temp ) ) != 0 ) - return rc; + /* Initialise context */ + if ( ( rc = rsa_init ( &context, key ) ) != 0 ) + goto err_init; + + /* Create space for encoded digest and signature */ + if ( ( rc = asn1_grow ( signature, context.max_len ) ) != 0 ) + goto err_grow; + + /* Encode digest */ + if ( ( rc = rsa_encode_digest ( &context, digest, value, + signature->data ) ) != 0 ) + goto err_encode; /* Encipher the encoded digest */ - rsa_cipher ( context, temp, signature ); - DBGC ( context, "RSA %p signed %s digest:\n", context, digest->name ); - DBGC_HDA ( context, 0, signature, context->max_len ); + rsa_cipher ( &context, signature->data, signature->data ); + DBGC ( &context, "RSA %p signed %s digest:\n", &context, digest->name ); + DBGC_HDA ( &context, 0, signature->data, signature->len ); + + /* Free context */ + rsa_free ( &context ); + + return 0; - return context->max_len; + err_encode: + err_grow: + rsa_free ( &context ); + err_init: + return rc; } /** * Verify signed digest value using RSA * - * @v ctx RSA context + * @v key Key * @v digest Digest algorithm * @v value Digest value * @v signature Signature - * @v signature_len Signature length * @ret rc Return status code */ -static int rsa_verify ( void *ctx, struct digest_algorithm *digest, - const void *value, const void *signature, - size_t signature_len ) { - struct rsa_context *context = ctx; +static int rsa_verify ( const struct asn1_cursor *key, + struct digest_algorithm *digest, const void *value, + const struct asn1_cursor *signature ) { + struct rsa_context context; void *temp; void *expected; void *actual; int rc; + DBGC ( &context, "RSA %p verifying %s digest:\n", + &context, digest->name ); + DBGC_HDA ( &context, 0, value, digest->digestsize ); + DBGC_HDA ( &context, 0, signature->data, signature->len ); + + /* Initialise context */ + if ( ( rc = rsa_init ( &context, key ) ) != 0 ) + goto err_init; + /* Sanity check */ - if ( signature_len != context->max_len ) { - DBGC ( context, "RSA %p signature incorrect length (%zd " + if ( signature->len != context.max_len ) { + DBGC ( &context, "RSA %p signature incorrect length (%zd " "bytes, should be %zd)\n", - context, signature_len, context->max_len ); - return -ERANGE; + &context, signature->len, context.max_len ); + rc = -ERANGE; + goto err_sanity; } - DBGC ( context, "RSA %p verifying %s digest:\n", - context, digest->name ); - DBGC_HDA ( context, 0, value, digest->digestsize ); - DBGC_HDA ( context, 0, signature, signature_len ); /* Decipher the signature (using the big integer input buffer * as temporary storage) */ - temp = context->input0; + temp = context.input0; expected = temp; - rsa_cipher ( context, signature, expected ); - DBGC ( context, "RSA %p deciphered signature:\n", context ); - DBGC_HDA ( context, 0, expected, context->max_len ); + rsa_cipher ( &context, signature->data, expected ); + DBGC ( &context, "RSA %p deciphered signature:\n", &context ); + DBGC_HDA ( &context, 0, expected, context.max_len ); /* Encode digest (using the big integer output buffer as * temporary storage) */ - temp = context->output0; + temp = context.output0; actual = temp; - if ( ( rc = rsa_encode_digest ( context, digest, value, actual ) ) !=0 ) - return rc; + if ( ( rc = rsa_encode_digest ( &context, digest, value, + actual ) ) != 0 ) + goto err_encode; /* Verify the signature */ - if ( memcmp ( actual, expected, context->max_len ) != 0 ) { - DBGC ( context, "RSA %p signature verification failed\n", - context ); - return -EACCES_VERIFY; + if ( memcmp ( actual, expected, context.max_len ) != 0 ) { + DBGC ( &context, "RSA %p signature verification failed\n", + &context ); + rc = -EACCES_VERIFY; + goto err_verify; } - DBGC ( context, "RSA %p signature verified successfully\n", context ); - return 0; -} + /* Free context */ + rsa_free ( &context ); -/** - * Finalise RSA cipher - * - * @v ctx RSA context - */ -static void rsa_final ( void *ctx ) { - struct rsa_context *context = ctx; + DBGC ( &context, "RSA %p signature verified successfully\n", &context ); + return 0; - rsa_free ( context ); + err_verify: + err_encode: + err_sanity: + rsa_free ( &context ); + err_init: + return rc; } /** * Check for matching RSA public/private key pair * * @v private_key Private key - * @v private_key_len Private key length * @v public_key Public key - * @v public_key_len Public key length * @ret rc Return status code */ -static int rsa_match ( const void *private_key, size_t private_key_len, - const void *public_key, size_t public_key_len ) { +static int rsa_match ( const struct asn1_cursor *private_key, + const struct asn1_cursor *public_key ) { struct asn1_cursor private_modulus; struct asn1_cursor private_exponent; - struct asn1_cursor private_cursor; struct asn1_cursor public_modulus; struct asn1_cursor public_exponent; - struct asn1_cursor public_cursor; int rc; - /* Initialise cursors */ - private_cursor.data = private_key; - private_cursor.len = private_key_len; - public_cursor.data = public_key; - public_cursor.len = public_key_len; - /* Parse moduli and exponents */ if ( ( rc = rsa_parse_mod_exp ( &private_modulus, &private_exponent, - &private_cursor ) ) != 0 ) + private_key ) ) != 0 ) return rc; if ( ( rc = rsa_parse_mod_exp ( &public_modulus, &public_exponent, - &public_cursor ) ) != 0 ) + public_key ) ) != 0 ) return rc; /* Compare moduli */ @@ -631,14 +660,10 @@ static int rsa_match ( const void *private_key, size_t private_key_len, /** RSA public-key algorithm */ struct pubkey_algorithm rsa_algorithm = { .name = "rsa", - .ctxsize = RSA_CTX_SIZE, - .init = rsa_init, - .max_len = rsa_max_len, .encrypt = rsa_encrypt, .decrypt = rsa_decrypt, .sign = rsa_sign, .verify = rsa_verify, - .final = rsa_final, .match = rsa_match, }; |
