Use the PSA-based HKDF functions
Use the `mbedtls_psa_hkdf_extract` and `mbedtls_psa_hkdf_expand`
functions in the HKDF handling.
Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/library/ssl_tls13_keys.c b/library/ssl_tls13_keys.c
index a5af590..d6a027a 100644
--- a/library/ssl_tls13_keys.c
+++ b/library/ssl_tls13_keys.c
@@ -136,7 +136,7 @@
*dst_len = total_hkdf_lbl_len;
}
-#if defined( MBEDTLS_TEST_HOOKS )
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
MBEDTLS_STATIC_TESTABLE
psa_status_t mbedtls_psa_hkdf_extract( psa_algorithm_t alg,
@@ -312,7 +312,7 @@
return( ( status == PSA_SUCCESS ) ? destroy_status : status );
}
-#endif /* MBEDTLS_TEST_HOOKS */
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
int mbedtls_ssl_tls13_hkdf_expand_label(
mbedtls_md_type_t hash_alg,
@@ -321,10 +321,15 @@
const unsigned char *ctx, size_t ctx_len,
unsigned char *buf, size_t buf_len )
{
- const mbedtls_md_info_t *md_info;
unsigned char hkdf_label[ SSL_TLS1_3_KEY_SCHEDULE_MAX_HKDF_LABEL_LEN ];
size_t hkdf_label_len;
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+ psa_algorithm_t alg;
+#else
+ const mbedtls_md_info_t *md_info;
+#endif
+
if( label_len > MBEDTLS_SSL_TLS1_3_KEY_SCHEDULE_MAX_LABEL_LEN )
{
/* Should never happen since this is an internal
@@ -345,9 +350,17 @@
return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
}
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+ alg = mbedtls_psa_translate_md( hash_alg );
+ if( ! PSA_ALG_IS_HASH( alg ) )
+ return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+ alg = PSA_ALG_HMAC( alg );
+#else
md_info = mbedtls_md_info_from_type( hash_alg );
if( md_info == NULL )
return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
ssl_tls13_hkdf_encode_label( buf_len,
label, label_len,
@@ -355,10 +368,18 @@
hkdf_label,
&hkdf_label_len );
- return( mbedtls_hkdf_expand( md_info,
- secret, secret_len,
- hkdf_label, hkdf_label_len,
- buf, buf_len ) );
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+ return( psa_ssl_status_to_mbedtls(
+ mbedtls_psa_hkdf_expand( alg,
+ secret, secret_len,
+ hkdf_label, hkdf_label_len,
+ buf, buf_len ) ) );
+#else
+ return mbedtls_hkdf_expand( md_info,
+ secret, secret_len,
+ hkdf_label, hkdf_label_len,
+ buf, buf_len );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
}
/*
@@ -479,12 +500,22 @@
unsigned char tmp_secret[ MBEDTLS_MD_MAX_SIZE ] = { 0 };
unsigned char tmp_input [ MBEDTLS_ECP_MAX_BYTES ] = { 0 };
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+ size_t secret_len;
+ psa_algorithm_t alg = mbedtls_psa_translate_md( hash_alg );
+ if( ! PSA_ALG_IS_HASH( alg ) )
+ return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+ alg = PSA_ALG_HMAC( alg );
+ hlen = PSA_HASH_LENGTH( alg );
+#else
const mbedtls_md_info_t *md_info;
md_info = mbedtls_md_info_from_type( hash_alg );
if( md_info == NULL )
return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
hlen = mbedtls_md_get_size( md_info );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
/* For non-initial runs, call Derive-Secret( ., "derived", "")
* on the old secret. */
@@ -514,14 +545,18 @@
/* HKDF-Extract takes a salt and input key material.
* The salt is the old secret, and the input key material
* is the input secret (PSK / ECDHE). */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+ ret = psa_ssl_status_to_mbedtls(
+ mbedtls_psa_hkdf_extract( alg,
+ tmp_secret, hlen,
+ tmp_input, ilen,
+ secret_new, hlen, &secret_len ) );
+#else
ret = mbedtls_hkdf_extract( md_info,
tmp_secret, hlen,
tmp_input, ilen,
secret_new );
- if( ret != 0 )
- goto cleanup;
-
- ret = 0;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
cleanup: