Use the PSA-based HKDF functions

Use the `mbedtls_psa_hkdf_extract` and `mbedtls_psa_hkdf_expand`
functions in the HKDF handling.

Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/library/ssl_tls13_keys.c b/library/ssl_tls13_keys.c
index a5af590..d6a027a 100644
--- a/library/ssl_tls13_keys.c
+++ b/library/ssl_tls13_keys.c
@@ -136,7 +136,7 @@
     *dst_len = total_hkdf_lbl_len;
 }
 
-#if defined( MBEDTLS_TEST_HOOKS )
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
 
 MBEDTLS_STATIC_TESTABLE
 psa_status_t mbedtls_psa_hkdf_extract( psa_algorithm_t alg,
@@ -312,7 +312,7 @@
     return( ( status == PSA_SUCCESS ) ? destroy_status : status );
 }
 
-#endif /* MBEDTLS_TEST_HOOKS */
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 int mbedtls_ssl_tls13_hkdf_expand_label(
                      mbedtls_md_type_t hash_alg,
@@ -321,10 +321,15 @@
                      const unsigned char *ctx, size_t ctx_len,
                      unsigned char *buf, size_t buf_len )
 {
-    const mbedtls_md_info_t *md_info;
     unsigned char hkdf_label[ SSL_TLS1_3_KEY_SCHEDULE_MAX_HKDF_LABEL_LEN ];
     size_t hkdf_label_len;
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_algorithm_t alg;
+#else
+    const mbedtls_md_info_t *md_info;
+#endif
+
     if( label_len > MBEDTLS_SSL_TLS1_3_KEY_SCHEDULE_MAX_LABEL_LEN )
     {
         /* Should never happen since this is an internal
@@ -345,9 +350,17 @@
         return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
     }
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    alg = mbedtls_psa_translate_md( hash_alg );
+    if( ! PSA_ALG_IS_HASH( alg ) )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+    alg = PSA_ALG_HMAC( alg );
+#else
     md_info = mbedtls_md_info_from_type( hash_alg );
     if( md_info == NULL )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     ssl_tls13_hkdf_encode_label( buf_len,
                                  label, label_len,
@@ -355,10 +368,18 @@
                                  hkdf_label,
                                  &hkdf_label_len );
 
-    return( mbedtls_hkdf_expand( md_info,
-                                 secret, secret_len,
-                                 hkdf_label, hkdf_label_len,
-                                 buf, buf_len ) );
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    return( psa_ssl_status_to_mbedtls(
+                mbedtls_psa_hkdf_expand( alg,
+                                         secret, secret_len,
+                                         hkdf_label, hkdf_label_len,
+                                         buf, buf_len ) ) );
+#else
+    return mbedtls_hkdf_expand( md_info,
+                                secret, secret_len,
+                                hkdf_label, hkdf_label_len,
+                                buf, buf_len );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 }
 
 /*
@@ -479,12 +500,22 @@
     unsigned char tmp_secret[ MBEDTLS_MD_MAX_SIZE ] = { 0 };
     unsigned char tmp_input [ MBEDTLS_ECP_MAX_BYTES ] = { 0 };
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    size_t secret_len;
+    psa_algorithm_t alg = mbedtls_psa_translate_md( hash_alg );
+    if( ! PSA_ALG_IS_HASH( alg ) )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+    alg = PSA_ALG_HMAC( alg );
+    hlen = PSA_HASH_LENGTH( alg );
+#else
     const mbedtls_md_info_t *md_info;
     md_info = mbedtls_md_info_from_type( hash_alg );
     if( md_info == NULL )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
     hlen = mbedtls_md_get_size( md_info );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     /* For non-initial runs, call Derive-Secret( ., "derived", "")
      * on the old secret. */
@@ -514,14 +545,18 @@
     /* HKDF-Extract takes a salt and input key material.
      * The salt is the old secret, and the input key material
      * is the input secret (PSK / ECDHE). */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    ret = psa_ssl_status_to_mbedtls(
+            mbedtls_psa_hkdf_extract( alg,
+                                      tmp_secret, hlen,
+                                      tmp_input, ilen,
+                                      secret_new, hlen, &secret_len ) );
+#else
     ret = mbedtls_hkdf_extract( md_info,
                     tmp_secret, hlen,
                     tmp_input, ilen,
                     secret_new );
-    if( ret != 0 )
-        goto cleanup;
-
-    ret = 0;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
  cleanup: