Tune API of internal function mgf_mask in RSA

This is a first step towards making a version of this function that
uses PSA when MD is not available.

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/library/rsa.c b/library/rsa.c
index 17a7d9e..74390af 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -1095,11 +1095,13 @@
  * \param dlen      length of destination buffer
  * \param src       source of the mask generation
  * \param slen      length of the source buffer
- * \param md_ctx    message digest context to use
+ * \param md_alg    message digest to use
  */
 static int mgf_mask( unsigned char *dst, size_t dlen, unsigned char *src,
-                      size_t slen, mbedtls_md_context_t *md_ctx )
+                      size_t slen, mbedtls_md_type_t md_alg )
 {
+    const mbedtls_md_info_t *md_info;
+    mbedtls_md_context_t md_ctx;
     unsigned char mask[MBEDTLS_MD_MAX_SIZE];
     unsigned char counter[4];
     unsigned char *p;
@@ -1107,10 +1109,19 @@
     size_t i, use_len;
     int ret = 0;
 
+    mbedtls_md_init( &md_ctx );
     memset( mask, 0, MBEDTLS_MD_MAX_SIZE );
     memset( counter, 0, 4 );
 
-    hlen = mbedtls_md_get_size( md_ctx->md_info );
+    md_info = mbedtls_md_info_from_type( md_alg );
+    if( md_info == NULL )
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    mbedtls_md_init( &md_ctx );
+    if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 )
+        goto exit;
+
+    hlen = mbedtls_md_get_size( md_info );
 
     /* Generate and apply dbMask */
     p = dst;
@@ -1121,13 +1132,13 @@
         if( dlen < hlen )
             use_len = dlen;
 
-        if( ( ret = mbedtls_md_starts( md_ctx ) ) != 0 )
+        if( ( ret = mbedtls_md_starts( &md_ctx ) ) != 0 )
             goto exit;
-        if( ( ret = mbedtls_md_update( md_ctx, src, slen ) ) != 0 )
+        if( ( ret = mbedtls_md_update( &md_ctx, src, slen ) ) != 0 )
             goto exit;
-        if( ( ret = mbedtls_md_update( md_ctx, counter, 4 ) ) != 0 )
+        if( ( ret = mbedtls_md_update( &md_ctx, counter, 4 ) ) != 0 )
             goto exit;
-        if( ( ret = mbedtls_md_finish( md_ctx, mask ) ) != 0 )
+        if( ( ret = mbedtls_md_finish( &md_ctx, mask ) ) != 0 )
             goto exit;
 
         for( i = 0; i < use_len; ++i )
@@ -1139,6 +1150,7 @@
     }
 
 exit:
+    mbedtls_md_free( &md_ctx );
     mbedtls_platform_zeroize( mask, sizeof( mask ) );
 
     return( ret );
@@ -1208,12 +1220,12 @@
 
     /* maskedDB: Apply dbMask to DB */
     if( ( ret = mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen,
-                          &md_ctx ) ) != 0 )
+                          ctx->hash_id ) ) != 0 )
         goto exit;
 
     /* maskedSeed: Apply seedMask to seed */
     if( ( ret = mgf_mask( output + 1, hlen, output + hlen + 1, olen - hlen - 1,
-                          &md_ctx ) ) != 0 )
+                          ctx->hash_id ) ) != 0 )
         goto exit;
 
 exit:
@@ -1384,10 +1396,10 @@
 
     /* seed: Apply seedMask to maskedSeed */
     if( ( ret = mgf_mask( buf + 1, hlen, buf + hlen + 1, ilen - hlen - 1,
-                          &md_ctx ) ) != 0 ||
+                          ctx->hash_id ) ) != 0 ||
     /* DB: Apply dbMask to maskedDB */
         ( ret = mgf_mask( buf + hlen + 1, ilen - hlen - 1, buf + 1, hlen,
-                          &md_ctx ) ) != 0 )
+                          ctx->hash_id ) ) != 0 )
     {
         mbedtls_md_free( &md_ctx );
         goto cleanup;
@@ -1648,7 +1660,7 @@
 
     /* maskedDB: Apply dbMask to DB */
     if( ( ret = mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen,
-                          &md_ctx ) ) != 0 )
+                          ctx->hash_id ) ) != 0 )
         goto exit;
 
     msb = mbedtls_mpi_bitlen( &ctx->N ) - 1;
@@ -2029,7 +2041,7 @@
     if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 )
         goto exit;
 
-    ret = mgf_mask( p, siglen - hlen - 1, hash_start, hlen, &md_ctx );
+    ret = mgf_mask( p, siglen - hlen - 1, hash_start, hlen, mgf1_hash_id );
     if( ret != 0 )
         goto exit;