summaryrefslogtreecommitdiffstats
path: root/src/crypto/rsa.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/crypto/rsa.c')
-rw-r--r--src/crypto/rsa.c413
1 files changed, 219 insertions, 194 deletions
diff --git a/src/crypto/rsa.c b/src/crypto/rsa.c
index 16c67d822..be055d881 100644
--- a/src/crypto/rsa.c
+++ b/src/crypto/rsa.c
@@ -22,6 +22,7 @@
*/
FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
+FILE_SECBOOT ( PERMITTED );
#include <stdint.h>
#include <stdlib.h>
@@ -47,6 +48,28 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
#define EINFO_EACCES_VERIFY \
__einfo_uniqify ( EINFO_EACCES, 0x01, "RSA signature incorrect" )
+/** An RSA context */
+struct rsa_context {
+ /** Allocated memory */
+ void *dynamic;
+ /** Modulus */
+ bigint_element_t *modulus0;
+ /** Modulus size */
+ unsigned int size;
+ /** Modulus length */
+ size_t max_len;
+ /** Exponent */
+ bigint_element_t *exponent0;
+ /** Exponent size */
+ unsigned int exponent_size;
+ /** Input buffer */
+ bigint_element_t *input0;
+ /** Output buffer */
+ bigint_element_t *output0;
+ /** Temporary working space for modular exponentiation */
+ void *tmp;
+};
+
/**
* Identify RSA prefix
*
@@ -69,10 +92,9 @@ rsa_find_prefix ( struct digest_algorithm *digest ) {
*
* @v context RSA context
*/
-static void rsa_free ( struct rsa_context *context ) {
+static inline void rsa_free ( struct rsa_context *context ) {
free ( context->dynamic );
- context->dynamic = NULL;
}
/**
@@ -88,8 +110,7 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len,
unsigned int size = bigint_required_size ( modulus_len );
unsigned int exponent_size = bigint_required_size ( exponent_len );
bigint_t ( size ) *modulus;
- bigint_t ( exponent_size ) *exponent;
- size_t tmp_len = bigint_mod_exp_tmp_len ( modulus, exponent );
+ size_t tmp_len = bigint_mod_exp_tmp_len ( modulus );
struct {
bigint_t ( size ) modulus;
bigint_t ( exponent_size ) exponent;
@@ -98,9 +119,6 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len,
uint8_t tmp[tmp_len];
} __attribute__ (( packed )) *dynamic;
- /* Free any existing dynamic storage */
- rsa_free ( context );
-
/* Allocate dynamic storage */
dynamic = malloc ( sizeof ( *dynamic ) );
if ( ! dynamic )
@@ -121,34 +139,6 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len,
}
/**
- * Parse RSA integer
- *
- * @v integer Integer to fill in
- * @v raw ASN.1 cursor
- * @ret rc Return status code
- */
-static int rsa_parse_integer ( struct asn1_cursor *integer,
- const struct asn1_cursor *raw ) {
-
- /* Enter integer */
- memcpy ( integer, raw, sizeof ( *integer ) );
- asn1_enter ( integer, ASN1_INTEGER );
-
- /* Skip initial sign byte if applicable */
- if ( ( integer->len > 1 ) &&
- ( *( ( uint8_t * ) integer->data ) == 0x00 ) ) {
- integer->data++;
- integer->len--;
- }
-
- /* Fail if cursor or integer are invalid */
- if ( ! integer->len )
- return -EINVAL;
-
- return 0;
-}
-
-/**
* Parse RSA modulus and exponent
*
* @v modulus Modulus to fill in
@@ -159,7 +149,6 @@ static int rsa_parse_integer ( struct asn1_cursor *integer,
static int rsa_parse_mod_exp ( struct asn1_cursor *modulus,
struct asn1_cursor *exponent,
const struct asn1_cursor *raw ) {
- struct asn1_bit_string bits;
struct asn1_cursor cursor;
int is_private;
int rc;
@@ -178,8 +167,8 @@ static int rsa_parse_mod_exp ( struct asn1_cursor *modulus,
asn1_skip_any ( &cursor );
/* Enter privateKey, if present */
- if ( asn1_check_algorithm ( &cursor,
- &rsa_encryption_algorithm ) == 0 ) {
+ if ( asn1_check_algorithm ( &cursor, &rsa_encryption_algorithm,
+ NULL ) == 0 ) {
/* Skip privateKeyAlgorithm */
asn1_skip_any ( &cursor );
@@ -203,17 +192,15 @@ static int rsa_parse_mod_exp ( struct asn1_cursor *modulus,
asn1_skip ( &cursor, ASN1_SEQUENCE );
/* Enter subjectPublicKey */
- if ( ( rc = asn1_integral_bit_string ( &cursor, &bits ) ) != 0 )
- return rc;
- cursor.data = bits.data;
- cursor.len = bits.len;
+ asn1_enter_bits ( &cursor, NULL );
/* Enter RSAPublicKey */
asn1_enter ( &cursor, ASN1_SEQUENCE );
}
/* Extract modulus */
- if ( ( rc = rsa_parse_integer ( modulus, &cursor ) ) != 0 )
+ memcpy ( modulus, &cursor, sizeof ( *modulus ) );
+ if ( ( rc = asn1_enter_unsigned ( modulus ) ) != 0 )
return rc;
asn1_skip_any ( &cursor );
@@ -222,7 +209,8 @@ static int rsa_parse_mod_exp ( struct asn1_cursor *modulus,
asn1_skip ( &cursor, ASN1_INTEGER );
/* Extract publicExponent/privateExponent */
- if ( ( rc = rsa_parse_integer ( exponent, &cursor ) ) != 0 )
+ memcpy ( exponent, &cursor, sizeof ( *exponent ) );
+ if ( ( rc = asn1_enter_unsigned ( exponent ) ) != 0 )
return rc;
return 0;
@@ -231,29 +219,23 @@ static int rsa_parse_mod_exp ( struct asn1_cursor *modulus,
/**
* Initialise RSA cipher
*
- * @v ctx RSA context
+ * @v context RSA context
* @v key Key
- * @v key_len Length of key
* @ret rc Return status code
*/
-static int rsa_init ( void *ctx, const void *key, size_t key_len ) {
- struct rsa_context *context = ctx;
+static int rsa_init ( struct rsa_context *context,
+ const struct asn1_cursor *key ) {
struct asn1_cursor modulus;
struct asn1_cursor exponent;
- struct asn1_cursor cursor;
int rc;
/* Initialise context */
memset ( context, 0, sizeof ( *context ) );
- /* Initialise cursor */
- cursor.data = key;
- cursor.len = key_len;
-
/* Parse modulus and exponent */
- if ( ( rc = rsa_parse_mod_exp ( &modulus, &exponent, &cursor ) ) != 0 ){
+ if ( ( rc = rsa_parse_mod_exp ( &modulus, &exponent, key ) ) != 0 ){
DBGC ( context, "RSA %p invalid modulus/exponent:\n", context );
- DBGC_HDA ( context, 0, cursor.data, cursor.len );
+ DBGC_HDA ( context, 0, key->data, key->len );
goto err_parse;
}
@@ -281,18 +263,6 @@ static int rsa_init ( void *ctx, const void *key, size_t key_len ) {
}
/**
- * Calculate RSA maximum output length
- *
- * @v ctx RSA context
- * @ret max_len Maximum output length
- */
-static size_t rsa_max_len ( void *ctx ) {
- struct rsa_context *context = ctx;
-
- return context->max_len;
-}
-
-/**
* Perform RSA cipher operation
*
* @v context RSA context
@@ -320,111 +290,158 @@ static void rsa_cipher ( struct rsa_context *context,
/**
* Encrypt using RSA
*
- * @v ctx RSA context
+ * @v key Key
* @v plaintext Plaintext
- * @v plaintext_len Length of plaintext
* @v ciphertext Ciphertext
* @ret ciphertext_len Length of ciphertext, or negative error
*/
-static int rsa_encrypt ( void *ctx, const void *plaintext,
- size_t plaintext_len, void *ciphertext ) {
- struct rsa_context *context = ctx;
+static int rsa_encrypt ( const struct asn1_cursor *key,
+ const struct asn1_cursor *plaintext,
+ struct asn1_builder *ciphertext ) {
+ struct rsa_context context;
void *temp;
uint8_t *encoded;
- size_t max_len = ( context->max_len - 11 );
- size_t random_nz_len = ( max_len - plaintext_len + 8 );
+ size_t max_len;
+ size_t random_nz_len;
int rc;
+ DBGC ( &context, "RSA %p encrypting:\n", &context );
+ DBGC_HDA ( &context, 0, plaintext->data, plaintext->len );
+
+ /* Initialise context */
+ if ( ( rc = rsa_init ( &context, key ) ) != 0 )
+ goto err_init;
+
+ /* Calculate lengths */
+ max_len = ( context.max_len - 11 );
+ random_nz_len = ( max_len - plaintext->len + 8 );
+
/* Sanity check */
- if ( plaintext_len > max_len ) {
- DBGC ( context, "RSA %p plaintext too long (%zd bytes, max "
- "%zd)\n", context, plaintext_len, max_len );
- return -ERANGE;
+ if ( plaintext->len > max_len ) {
+ DBGC ( &context, "RSA %p plaintext too long (%zd bytes, max "
+ "%zd)\n", &context, plaintext->len, max_len );
+ rc = -ERANGE;
+ goto err_sanity;
}
- DBGC ( context, "RSA %p encrypting:\n", context );
- DBGC_HDA ( context, 0, plaintext, plaintext_len );
/* Construct encoded message (using the big integer output
* buffer as temporary storage)
*/
- temp = context->output0;
+ temp = context.output0;
encoded = temp;
encoded[0] = 0x00;
encoded[1] = 0x02;
if ( ( rc = get_random_nz ( &encoded[2], random_nz_len ) ) != 0 ) {
- DBGC ( context, "RSA %p could not generate random data: %s\n",
- context, strerror ( rc ) );
- return rc;
+ DBGC ( &context, "RSA %p could not generate random data: %s\n",
+ &context, strerror ( rc ) );
+ goto err_random;
}
encoded[ 2 + random_nz_len ] = 0x00;
- memcpy ( &encoded[ context->max_len - plaintext_len ],
- plaintext, plaintext_len );
+ memcpy ( &encoded[ context.max_len - plaintext->len ],
+ plaintext->data, plaintext->len );
+
+ /* Create space for ciphertext */
+ if ( ( rc = asn1_grow ( ciphertext, context.max_len ) ) != 0 )
+ goto err_grow;
/* Encipher the encoded message */
- rsa_cipher ( context, encoded, ciphertext );
- DBGC ( context, "RSA %p encrypted:\n", context );
- DBGC_HDA ( context, 0, ciphertext, context->max_len );
+ rsa_cipher ( &context, encoded, ciphertext->data );
+ DBGC ( &context, "RSA %p encrypted:\n", &context );
+ DBGC_HDA ( &context, 0, ciphertext->data, context.max_len );
+
+ /* Free context */
+ rsa_free ( &context );
- return context->max_len;
+ return 0;
+
+ err_grow:
+ err_random:
+ err_sanity:
+ rsa_free ( &context );
+ err_init:
+ return rc;
}
/**
* Decrypt using RSA
*
- * @v ctx RSA context
+ * @v key Key
* @v ciphertext Ciphertext
- * @v ciphertext_len Ciphertext length
* @v plaintext Plaintext
- * @ret plaintext_len Plaintext length, or negative error
+ * @ret rc Return status code
*/
-static int rsa_decrypt ( void *ctx, const void *ciphertext,
- size_t ciphertext_len, void *plaintext ) {
- struct rsa_context *context = ctx;
+static int rsa_decrypt ( const struct asn1_cursor *key,
+ const struct asn1_cursor *ciphertext,
+ struct asn1_builder *plaintext ) {
+ struct rsa_context context;
void *temp;
uint8_t *encoded;
uint8_t *end;
uint8_t *zero;
uint8_t *start;
- size_t plaintext_len;
+ size_t len;
+ int rc;
+
+ DBGC ( &context, "RSA %p decrypting:\n", &context );
+ DBGC_HDA ( &context, 0, ciphertext->data, ciphertext->len );
+
+ /* Initialise context */
+ if ( ( rc = rsa_init ( &context, key ) ) != 0 )
+ goto err_init;
/* Sanity check */
- if ( ciphertext_len != context->max_len ) {
- DBGC ( context, "RSA %p ciphertext incorrect length (%zd "
+ if ( ciphertext->len != context.max_len ) {
+ DBGC ( &context, "RSA %p ciphertext incorrect length (%zd "
"bytes, should be %zd)\n",
- context, ciphertext_len, context->max_len );
- return -ERANGE;
+ &context, ciphertext->len, context.max_len );
+ rc = -ERANGE;
+ goto err_sanity;
}
- DBGC ( context, "RSA %p decrypting:\n", context );
- DBGC_HDA ( context, 0, ciphertext, ciphertext_len );
/* Decipher the message (using the big integer input buffer as
* temporary storage)
*/
- temp = context->input0;
+ temp = context.input0;
encoded = temp;
- rsa_cipher ( context, ciphertext, encoded );
+ rsa_cipher ( &context, ciphertext->data, encoded );
/* Parse the message */
- end = ( encoded + context->max_len );
- if ( ( encoded[0] != 0x00 ) || ( encoded[1] != 0x02 ) )
- goto invalid;
+ end = ( encoded + context.max_len );
+ if ( ( encoded[0] != 0x00 ) || ( encoded[1] != 0x02 ) ) {
+ rc = -EINVAL;
+ goto err_invalid;
+ }
zero = memchr ( &encoded[2], 0, ( end - &encoded[2] ) );
- if ( ! zero )
- goto invalid;
+ if ( ! zero ) {
+ DBGC ( &context, "RSA %p invalid decrypted message:\n",
+ &context );
+ DBGC_HDA ( &context, 0, encoded, context.max_len );
+ rc = -EINVAL;
+ goto err_invalid;
+ }
start = ( zero + 1 );
- plaintext_len = ( end - start );
+ len = ( end - start );
+
+ /* Create space for plaintext */
+ if ( ( rc = asn1_grow ( plaintext, len ) ) != 0 )
+ goto err_grow;
/* Copy out message */
- memcpy ( plaintext, start, plaintext_len );
- DBGC ( context, "RSA %p decrypted:\n", context );
- DBGC_HDA ( context, 0, plaintext, plaintext_len );
+ memcpy ( plaintext->data, start, len );
+ DBGC ( &context, "RSA %p decrypted:\n", &context );
+ DBGC_HDA ( &context, 0, plaintext->data, len );
- return plaintext_len;
+ /* Free context */
+ rsa_free ( &context );
- invalid:
- DBGC ( context, "RSA %p invalid decrypted message:\n", context );
- DBGC_HDA ( context, 0, encoded, context->max_len );
- return -EINVAL;
+ return 0;
+
+ err_grow:
+ err_invalid:
+ err_sanity:
+ rsa_free ( &context );
+ err_init:
+ return rc;
}
/**
@@ -458,9 +475,9 @@ static int rsa_encode_digest ( struct rsa_context *context,
/* Sanity check */
max_len = ( context->max_len - 11 );
if ( digestinfo_len > max_len ) {
- DBGC ( context, "RSA %p %s digestInfo too long (%zd bytes, max"
- "%zd)\n",
- context, digest->name, digestinfo_len, max_len );
+ DBGC ( context, "RSA %p %s digestInfo too long (%zd bytes, "
+ "max %zd)\n", context, digest->name, digestinfo_len,
+ max_len );
return -ERANGE;
}
DBGC ( context, "RSA %p encoding %s digest:\n",
@@ -488,137 +505,149 @@ static int rsa_encode_digest ( struct rsa_context *context,
/**
* Sign digest value using RSA
*
- * @v ctx RSA context
+ * @v key Key
* @v digest Digest algorithm
* @v value Digest value
* @v signature Signature
- * @ret signature_len Signature length, or negative error
+ * @ret rc Return status code
*/
-static int rsa_sign ( void *ctx, struct digest_algorithm *digest,
- const void *value, void *signature ) {
- struct rsa_context *context = ctx;
- void *temp;
+static int rsa_sign ( const struct asn1_cursor *key,
+ struct digest_algorithm *digest, const void *value,
+ struct asn1_builder *signature ) {
+ struct rsa_context context;
int rc;
- DBGC ( context, "RSA %p signing %s digest:\n", context, digest->name );
- DBGC_HDA ( context, 0, value, digest->digestsize );
+ DBGC ( &context, "RSA %p signing %s digest:\n",
+ &context, digest->name );
+ DBGC_HDA ( &context, 0, value, digest->digestsize );
- /* Encode digest (using the big integer output buffer as
- * temporary storage)
- */
- temp = context->output0;
- if ( ( rc = rsa_encode_digest ( context, digest, value, temp ) ) != 0 )
- return rc;
+ /* Initialise context */
+ if ( ( rc = rsa_init ( &context, key ) ) != 0 )
+ goto err_init;
+
+ /* Create space for encoded digest and signature */
+ if ( ( rc = asn1_grow ( signature, context.max_len ) ) != 0 )
+ goto err_grow;
+
+ /* Encode digest */
+ if ( ( rc = rsa_encode_digest ( &context, digest, value,
+ signature->data ) ) != 0 )
+ goto err_encode;
/* Encipher the encoded digest */
- rsa_cipher ( context, temp, signature );
- DBGC ( context, "RSA %p signed %s digest:\n", context, digest->name );
- DBGC_HDA ( context, 0, signature, context->max_len );
+ rsa_cipher ( &context, signature->data, signature->data );
+ DBGC ( &context, "RSA %p signed %s digest:\n", &context, digest->name );
+ DBGC_HDA ( &context, 0, signature->data, signature->len );
+
+ /* Free context */
+ rsa_free ( &context );
+
+ return 0;
- return context->max_len;
+ err_encode:
+ err_grow:
+ rsa_free ( &context );
+ err_init:
+ return rc;
}
/**
* Verify signed digest value using RSA
*
- * @v ctx RSA context
+ * @v key Key
* @v digest Digest algorithm
* @v value Digest value
* @v signature Signature
- * @v signature_len Signature length
* @ret rc Return status code
*/
-static int rsa_verify ( void *ctx, struct digest_algorithm *digest,
- const void *value, const void *signature,
- size_t signature_len ) {
- struct rsa_context *context = ctx;
+static int rsa_verify ( const struct asn1_cursor *key,
+ struct digest_algorithm *digest, const void *value,
+ const struct asn1_cursor *signature ) {
+ struct rsa_context context;
void *temp;
void *expected;
void *actual;
int rc;
+ DBGC ( &context, "RSA %p verifying %s digest:\n",
+ &context, digest->name );
+ DBGC_HDA ( &context, 0, value, digest->digestsize );
+ DBGC_HDA ( &context, 0, signature->data, signature->len );
+
+ /* Initialise context */
+ if ( ( rc = rsa_init ( &context, key ) ) != 0 )
+ goto err_init;
+
/* Sanity check */
- if ( signature_len != context->max_len ) {
- DBGC ( context, "RSA %p signature incorrect length (%zd "
+ if ( signature->len != context.max_len ) {
+ DBGC ( &context, "RSA %p signature incorrect length (%zd "
"bytes, should be %zd)\n",
- context, signature_len, context->max_len );
- return -ERANGE;
+ &context, signature->len, context.max_len );
+ rc = -ERANGE;
+ goto err_sanity;
}
- DBGC ( context, "RSA %p verifying %s digest:\n",
- context, digest->name );
- DBGC_HDA ( context, 0, value, digest->digestsize );
- DBGC_HDA ( context, 0, signature, signature_len );
/* Decipher the signature (using the big integer input buffer
* as temporary storage)
*/
- temp = context->input0;
+ temp = context.input0;
expected = temp;
- rsa_cipher ( context, signature, expected );
- DBGC ( context, "RSA %p deciphered signature:\n", context );
- DBGC_HDA ( context, 0, expected, context->max_len );
+ rsa_cipher ( &context, signature->data, expected );
+ DBGC ( &context, "RSA %p deciphered signature:\n", &context );
+ DBGC_HDA ( &context, 0, expected, context.max_len );
/* Encode digest (using the big integer output buffer as
* temporary storage)
*/
- temp = context->output0;
+ temp = context.output0;
actual = temp;
- if ( ( rc = rsa_encode_digest ( context, digest, value, actual ) ) !=0 )
- return rc;
+ if ( ( rc = rsa_encode_digest ( &context, digest, value,
+ actual ) ) != 0 )
+ goto err_encode;
/* Verify the signature */
- if ( memcmp ( actual, expected, context->max_len ) != 0 ) {
- DBGC ( context, "RSA %p signature verification failed\n",
- context );
- return -EACCES_VERIFY;
+ if ( memcmp ( actual, expected, context.max_len ) != 0 ) {
+ DBGC ( &context, "RSA %p signature verification failed\n",
+ &context );
+ rc = -EACCES_VERIFY;
+ goto err_verify;
}
- DBGC ( context, "RSA %p signature verified successfully\n", context );
- return 0;
-}
+ /* Free context */
+ rsa_free ( &context );
-/**
- * Finalise RSA cipher
- *
- * @v ctx RSA context
- */
-static void rsa_final ( void *ctx ) {
- struct rsa_context *context = ctx;
+ DBGC ( &context, "RSA %p signature verified successfully\n", &context );
+ return 0;
- rsa_free ( context );
+ err_verify:
+ err_encode:
+ err_sanity:
+ rsa_free ( &context );
+ err_init:
+ return rc;
}
/**
* Check for matching RSA public/private key pair
*
* @v private_key Private key
- * @v private_key_len Private key length
* @v public_key Public key
- * @v public_key_len Public key length
* @ret rc Return status code
*/
-static int rsa_match ( const void *private_key, size_t private_key_len,
- const void *public_key, size_t public_key_len ) {
+static int rsa_match ( const struct asn1_cursor *private_key,
+ const struct asn1_cursor *public_key ) {
struct asn1_cursor private_modulus;
struct asn1_cursor private_exponent;
- struct asn1_cursor private_cursor;
struct asn1_cursor public_modulus;
struct asn1_cursor public_exponent;
- struct asn1_cursor public_cursor;
int rc;
- /* Initialise cursors */
- private_cursor.data = private_key;
- private_cursor.len = private_key_len;
- public_cursor.data = public_key;
- public_cursor.len = public_key_len;
-
/* Parse moduli and exponents */
if ( ( rc = rsa_parse_mod_exp ( &private_modulus, &private_exponent,
- &private_cursor ) ) != 0 )
+ private_key ) ) != 0 )
return rc;
if ( ( rc = rsa_parse_mod_exp ( &public_modulus, &public_exponent,
- &public_cursor ) ) != 0 )
+ public_key ) ) != 0 )
return rc;
/* Compare moduli */
@@ -631,14 +660,10 @@ static int rsa_match ( const void *private_key, size_t private_key_len,
/** RSA public-key algorithm */
struct pubkey_algorithm rsa_algorithm = {
.name = "rsa",
- .ctxsize = RSA_CTX_SIZE,
- .init = rsa_init,
- .max_len = rsa_max_len,
.encrypt = rsa_encrypt,
.decrypt = rsa_decrypt,
.sign = rsa_sign,
.verify = rsa_verify,
- .final = rsa_final,
.match = rsa_match,
};