Extract common code into hash_mprime()

This will also make it easier to provide a PSA-based version for when MD
is not available.

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/library/rsa.c b/library/rsa.c
index ead5b74..af1a9d5 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -1155,6 +1155,48 @@
 
     return( ret );
 }
+
+/**
+ * Generate Hash(M') as in RFC 8017 page 43 points 5 and 6.
+ *
+ * \param hash      the input hash
+ * \param hlen      length of the input hash
+ * \param salt      the input salt
+ * \param slen      length of the input salt
+ * \param out       the output buffer - must be large enough for \c md_alg
+ * \param md_alg    message digest to use
+ */
+static int hash_mprime( const unsigned char *hash, size_t hlen,
+                        const unsigned char *salt, size_t slen,
+                        unsigned char *out, mbedtls_md_type_t md_alg )
+{
+    const unsigned char zeros[8] = { 0, 0, 0, 0, 0, 0, 0, 0 };
+    mbedtls_md_context_t md_ctx;
+    int ret;
+
+    const mbedtls_md_info_t *md_info = mbedtls_md_info_from_type( md_alg );
+    if( md_info == NULL )
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    mbedtls_md_init( &md_ctx );
+    if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 )
+        goto exit;
+    if( ( ret = mbedtls_md_starts( &md_ctx ) ) != 0 )
+        goto exit;
+    if( ( ret = mbedtls_md_update( &md_ctx, zeros, sizeof( zeros ) ) ) != 0 )
+        goto exit;
+    if( ( ret = mbedtls_md_update( &md_ctx, hash, hlen ) ) != 0 )
+        goto exit;
+    if( ( ret = mbedtls_md_update( &md_ctx, salt, slen ) ) != 0 )
+        goto exit;
+    if( ( ret = mbedtls_md_finish( &md_ctx, out ) ) != 0 )
+        goto exit;
+
+exit:
+    mbedtls_md_free( &md_ctx );
+
+    return( ret );
+}
 #endif /* MBEDTLS_PKCS1_V21 */
 
 #if defined(MBEDTLS_PKCS1_V21)
@@ -1544,7 +1586,7 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     size_t msb;
     const mbedtls_md_info_t *md_info;
-    mbedtls_md_context_t md_ctx;
+
     RSA_VALIDATE_RET( ctx != NULL );
     RSA_VALIDATE_RET( ( md_alg  == MBEDTLS_MD_NONE &&
                         hashlen == 0 ) ||
@@ -1616,30 +1658,20 @@
 
     p += slen;
 
-    mbedtls_md_init( &md_ctx );
-    if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 )
-        goto exit;
-
     /* Generate H = Hash( M' ) */
-    if( ( ret = mbedtls_md_starts( &md_ctx ) ) != 0 )
-        goto exit;
-    if( ( ret = mbedtls_md_update( &md_ctx, p, 8 ) ) != 0 )
-        goto exit;
-    if( ( ret = mbedtls_md_update( &md_ctx, hash, hashlen ) ) != 0 )
-        goto exit;
-    if( ( ret = mbedtls_md_update( &md_ctx, salt, slen ) ) != 0 )
-        goto exit;
-    if( ( ret = mbedtls_md_finish( &md_ctx, p ) ) != 0 )
-        goto exit;
+    ret = hash_mprime( hash, hashlen, salt, slen, p, ctx->hash_id );
+    if( ret != 0 )
+        return( ret );
 
     /* Compensate for boundary condition when applying mask */
     if( msb % 8 == 0 )
         offset = 1;
 
     /* maskedDB: Apply dbMask to DB */
-    if( ( ret = mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen,
-                          ctx->hash_id ) ) != 0 )
-        goto exit;
+    ret = mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen,
+                    ctx->hash_id );
+    if( ret != 0 )
+        return( ret );
 
     msb = mbedtls_mpi_bitlen( &ctx->N ) - 1;
     sig[0] &= 0xFF >> ( olen * 8 - msb );
@@ -1647,12 +1679,6 @@
     p += hlen;
     *p++ = 0xBC;
 
-exit:
-    mbedtls_md_free( &md_ctx );
-
-    if( ret != 0 )
-        return( ret );
-
     return mbedtls_rsa_private( ctx, f_rng, p_rng, sig, sig );
 }
 
@@ -1949,11 +1975,9 @@
     unsigned char *p;
     unsigned char *hash_start;
     unsigned char result[MBEDTLS_MD_MAX_SIZE];
-    unsigned char zeros[8];
     unsigned int hlen;
     size_t observed_salt_len, msb;
     const mbedtls_md_info_t *md_info;
-    mbedtls_md_context_t md_ctx;
     unsigned char buf[MBEDTLS_MPI_MAX_SIZE] = {0};
 
     RSA_VALIDATE_RET( ctx != NULL );
@@ -1994,8 +2018,6 @@
 
     hlen = mbedtls_md_get_size( md_info );
 
-    memset( zeros, 0, 8 );
-
     /*
      * Note: EMSA-PSS verification is over the length of N - 1 bits
      */
@@ -2038,36 +2060,15 @@
     /*
      * Generate H = Hash( M' )
      */
-    mbedtls_md_init( &md_ctx );
-    if( ( ret = mbedtls_md_setup( &md_ctx, md_info, 0 ) ) != 0 )
-        goto exit;
-
-    ret = mbedtls_md_starts( &md_ctx );
-    if ( ret != 0 )
-        goto exit;
-    ret = mbedtls_md_update( &md_ctx, zeros, 8 );
-    if ( ret != 0 )
-        goto exit;
-    ret = mbedtls_md_update( &md_ctx, hash, hashlen );
-    if ( ret != 0 )
-        goto exit;
-    ret = mbedtls_md_update( &md_ctx, p, observed_salt_len );
-    if ( ret != 0 )
-        goto exit;
-    ret = mbedtls_md_finish( &md_ctx, result );
-    if ( ret != 0 )
-        goto exit;
+    ret = hash_mprime( hash, hashlen, p, observed_salt_len,
+                       result, mgf1_hash_id );
+    if( ret != 0 )
+        return( ret );
 
     if( memcmp( hash_start, result, hlen ) != 0 )
-    {
-        ret = MBEDTLS_ERR_RSA_VERIFY_FAILED;
-        goto exit;
-    }
+        return( MBEDTLS_ERR_RSA_VERIFY_FAILED );
 
-exit:
-    mbedtls_md_free( &md_ctx );
-
-    return( ret );
+    return( 0 );
 }
 
 /*