Use hash algoritm for parameter instead of HMAC

To be compatible with the other functions `mbedtls_psa_hkdf_extract` and
`mbedtls_psa_hkdf_expand` use hash algorithm for parameter.

Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/library/ssl_tls13_invasive.h b/library/ssl_tls13_invasive.h
index 8a3a501..4e39f90 100644
--- a/library/ssl_tls13_invasive.h
+++ b/library/ssl_tls13_invasive.h
@@ -30,10 +30,7 @@
  *  \brief  Take the input keying material \p ikm and extract from it a
  *          fixed-length pseudorandom key \p prk.
  *
- *  \param       alg       The HMAC algorithm to use
- *                         (\c #PSA_ALG_HMAC( PSA_ALG_XXX ) value such that
- *                         PSA_ALG_XXX is a hash algorithm and
- *                         #PSA_ALG_IS_HMAC(\p alg) is true).
+ *  \param       hash_alg  Hash algorithm to use.
  *  \param       salt      An optional salt value (a non-secret random value);
  *                         if the salt is not provided, a string of all zeros
  *                         of the length of the hash provided by \p alg is used
@@ -51,7 +48,7 @@
  *  \return An PSA_ERROR_* error for errors returned from the underlying
  *          PSA layer.
  */
-psa_status_t mbedtls_psa_hkdf_extract( psa_algorithm_t alg,
+psa_status_t mbedtls_psa_hkdf_extract( psa_algorithm_t hash_alg,
                                        const unsigned char *salt, size_t salt_len,
                                        const unsigned char *ikm, size_t ikm_len,
                                        unsigned char *prk, size_t prk_size,
@@ -61,9 +58,7 @@
  *  \brief  Expand the supplied \p prk into several additional pseudorandom
  *          keys, which is the output of the HKDF.
  *
- *  \param  alg       The HMAC algorithm to use (\c #PSA_ALG_HMAC( PSA_ALG_XXX )
- *                    value such that PSA_ALG_XXX is a hash algorithm and
- *                    #PSA_ALG_IS_HMAC(\p alg) is true).
+ *  \param  hash_alg  Hash algorithm to use.
  *  \param  prk       A pseudorandom key of \p prk_len bytes. \p prk is
  *                    usually the output from the HKDF extract step.
  *  \param  prk_len   The length in bytes of \p prk.
@@ -80,7 +75,7 @@
  *  \return An PSA_ERROR_* error for errors returned from the underlying
  *          PSA layer.
  */
-psa_status_t mbedtls_psa_hkdf_expand( psa_algorithm_t alg,
+psa_status_t mbedtls_psa_hkdf_expand( psa_algorithm_t hash_alg,
                                       const unsigned char *prk, size_t prk_len,
                                       const unsigned char *info, size_t info_len,
                                       unsigned char *okm, size_t okm_len );
diff --git a/library/ssl_tls13_keys.c b/library/ssl_tls13_keys.c
index 2ce654b..5c851c7 100644
--- a/library/ssl_tls13_keys.c
+++ b/library/ssl_tls13_keys.c
@@ -137,7 +137,7 @@
 }
 
 MBEDTLS_STATIC_TESTABLE
-psa_status_t mbedtls_psa_hkdf_extract( psa_algorithm_t alg,
+psa_status_t mbedtls_psa_hkdf_extract( psa_algorithm_t hash_alg,
                                        const unsigned char *salt, size_t salt_len,
                                        const unsigned char *ikm, size_t ikm_len,
                                        unsigned char *prk, size_t prk_size,
@@ -148,6 +148,7 @@
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t destroy_status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_algorithm_t alg = PSA_ALG_HMAC( hash_alg );
 
     if( salt == NULL || salt_len == 0 )
     {
@@ -190,7 +191,7 @@
 }
 
 MBEDTLS_STATIC_TESTABLE
-psa_status_t mbedtls_psa_hkdf_expand( psa_algorithm_t alg,
+psa_status_t mbedtls_psa_hkdf_expand( psa_algorithm_t hash_alg,
                                       const unsigned char *prk, size_t prk_len,
                                       const unsigned char *info, size_t info_len,
                                       unsigned char *okm, size_t okm_len )
@@ -206,6 +207,7 @@
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t destroy_status = PSA_ERROR_CORRUPTION_DETECTED;
     unsigned char t[PSA_MAC_MAX_SIZE];
+    psa_algorithm_t alg = PSA_ALG_HMAC( hash_alg );
 
     if( okm == NULL )
     {
@@ -350,7 +352,7 @@
                                  &hkdf_label_len );
 
     return( psa_ssl_status_to_mbedtls(
-                mbedtls_psa_hkdf_expand( PSA_ALG_HMAC( hash_alg ),
+                mbedtls_psa_hkdf_expand( hash_alg,
                                          secret, secret_len,
                                          hkdf_label, hkdf_label_len,
                                          buf, buf_len ) ) );
@@ -521,7 +523,7 @@
      * The salt is the old secret, and the input key material
      * is the input secret (PSK / ECDHE). */
     ret = psa_ssl_status_to_mbedtls(
-            mbedtls_psa_hkdf_extract( PSA_ALG_HMAC( hash_alg ),
+            mbedtls_psa_hkdf_extract( hash_alg,
                                       tmp_secret, hlen,
                                       tmp_input, ilen,
                                       secret_new, hlen, &secret_len ) );
@@ -914,8 +916,8 @@
                                unsigned char *result )
 {
     int ret = 0;
-    unsigned char binder_key[MBEDTLS_MD_MAX_SIZE];
-    unsigned char early_secret[MBEDTLS_MD_MAX_SIZE];
+    unsigned char binder_key[PSA_MAC_MAX_SIZE];
+    unsigned char early_secret[PSA_MAC_MAX_SIZE];
     size_t const hash_len = PSA_HASH_LENGTH( hash_alg );
     size_t actual_len;