psa_import_key: validate symmetric key size
When importing a symmetric key, validate that the key size is valid
for the given key type.
Non-supported key types may no longer be imported.
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index e41e512..4d2f8d0 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -346,6 +346,57 @@
}
}
+static psa_status_t prepare_raw_data_slot( psa_key_type_t type,
+ size_t bits,
+ struct raw_data *raw )
+{
+ /* Check that the bit size is acceptable for the key type */
+ switch( type )
+ {
+ case PSA_KEY_TYPE_RAW_DATA:
+#if defined(MBEDTLS_MD_C)
+ case PSA_KEY_TYPE_HMAC:
+#endif
+ break;
+#if defined(MBEDTLS_AES_C)
+ case PSA_KEY_TYPE_AES:
+ if( bits != 128 && bits != 192 && bits != 256 )
+ return( PSA_ERROR_INVALID_ARGUMENT );
+ break;
+#endif
+#if defined(MBEDTLS_CAMELLIA_C)
+ case PSA_KEY_TYPE_CAMELLIA:
+ if( bits != 128 && bits != 192 && bits != 256 )
+ return( PSA_ERROR_INVALID_ARGUMENT );
+ break;
+#endif
+#if defined(MBEDTLS_DES_C)
+ case PSA_KEY_TYPE_DES:
+ if( bits != 64 && bits != 128 && bits != 192 )
+ return( PSA_ERROR_INVALID_ARGUMENT );
+ break;
+#endif
+#if defined(MBEDTLS_ARC4_C)
+ case PSA_KEY_TYPE_ARC4:
+ if( bits < 8 || bits > 2048 )
+ return( PSA_ERROR_INVALID_ARGUMENT );
+ break;
+#endif
+ default:
+ return( PSA_ERROR_NOT_SUPPORTED );
+ }
+
+ /* Allocate memory for the key */
+ raw->bytes = PSA_BITS_TO_BYTES( bits );
+ raw->data = mbedtls_calloc( 1, raw->bytes );
+ if( raw->data == NULL )
+ {
+ raw->bytes = 0;
+ return( PSA_ERROR_INSUFFICIENT_MEMORY );
+ }
+ return( PSA_SUCCESS );
+}
+
psa_status_t psa_import_key( psa_key_slot_t key,
psa_key_type_t type,
const uint8_t *data,
@@ -361,14 +412,16 @@
if( PSA_KEY_TYPE_IS_RAW_BYTES( type ) )
{
+ psa_status_t status;
/* Ensure that a bytes-to-bit conversion won't overflow. */
if( data_length > SIZE_MAX / 8 )
return( PSA_ERROR_NOT_SUPPORTED );
- slot->data.raw.data = mbedtls_calloc( 1, data_length );
- if( slot->data.raw.data == NULL )
- return( PSA_ERROR_INSUFFICIENT_MEMORY );
+ status = prepare_raw_data_slot( type,
+ PSA_BYTES_TO_BITS( data_length ),
+ &slot->data.raw );
+ if( status != PSA_SUCCESS )
+ return( status );
memcpy( slot->data.raw.data, data, data_length );
- slot->data.raw.bytes = data_length;
}
else
#if defined(MBEDTLS_PK_PARSE_C)