psa: Move RSA sign/verify hash to the PSA RSA specific file

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/library/psa_crypto_rsa.c b/library/psa_crypto_rsa.c
index fa64001..2886146 100644
--- a/library/psa_crypto_rsa.c
+++ b/library/psa_crypto_rsa.c
@@ -319,6 +319,212 @@
 }
 #endif /* defined(BUILTIN_KEY_TYPE_RSA_KEY_PAIR) */
 
+/****************************************************************/
+/* Sign/verify hashes */
+/****************************************************************/
+
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
+    defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
+
+/* Decode the hash algorithm from alg and store the mbedtls encoding in
+ * md_alg. Verify that the hash length is acceptable. */
+static psa_status_t psa_rsa_decode_md_type( psa_algorithm_t alg,
+                                            size_t hash_length,
+                                            mbedtls_md_type_t *md_alg )
+{
+    psa_algorithm_t hash_alg = PSA_ALG_SIGN_GET_HASH( alg );
+    const mbedtls_md_info_t *md_info = mbedtls_md_info_from_psa( hash_alg );
+    *md_alg = mbedtls_md_get_type( md_info );
+
+    /* The Mbed TLS RSA module uses an unsigned int for hash length
+     * parameters. Validate that it fits so that we don't risk an
+     * overflow later. */
+#if SIZE_MAX > UINT_MAX
+    if( hash_length > UINT_MAX )
+        return( PSA_ERROR_INVALID_ARGUMENT );
+#endif
+
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
+    /* For PKCS#1 v1.5 signature, if using a hash, the hash length
+     * must be correct. */
+    if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) &&
+        alg != PSA_ALG_RSA_PKCS1V15_SIGN_RAW )
+    {
+        if( md_info == NULL )
+            return( PSA_ERROR_NOT_SUPPORTED );
+        if( mbedtls_md_get_size( md_info ) != hash_length )
+            return( PSA_ERROR_INVALID_ARGUMENT );
+    }
+#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
+
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
+    /* PSS requires a hash internally. */
+    if( PSA_ALG_IS_RSA_PSS( alg ) )
+    {
+        if( md_info == NULL )
+            return( PSA_ERROR_NOT_SUPPORTED );
+    }
+#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
+
+    return( PSA_SUCCESS );
+}
+
+psa_status_t mbedtls_psa_rsa_sign_hash(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *key_buffer, size_t key_buffer_size,
+    psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
+    uint8_t *signature, size_t signature_size, size_t *signature_length )
+{
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    mbedtls_rsa_context *rsa = NULL;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    mbedtls_md_type_t md_alg;
+
+    status = mbedtls_psa_rsa_load_representation( attributes->core.type,
+                                                  key_buffer,
+                                                  key_buffer_size,
+                                                  &rsa );
+    if( status != PSA_SUCCESS )
+        return( status );
+
+    status = psa_rsa_decode_md_type( alg, hash_length, &md_alg );
+    if( status != PSA_SUCCESS )
+        goto exit;
+
+    if( signature_size < mbedtls_rsa_get_len( rsa ) )
+    {
+        status = PSA_ERROR_BUFFER_TOO_SMALL;
+        goto exit;
+    }
+
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
+    if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) )
+    {
+        mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V15,
+                                 MBEDTLS_MD_NONE );
+        ret = mbedtls_rsa_pkcs1_sign( rsa,
+                                      mbedtls_psa_get_random,
+                                      MBEDTLS_PSA_RANDOM_STATE,
+                                      MBEDTLS_RSA_PRIVATE,
+                                      md_alg,
+                                      (unsigned int) hash_length,
+                                      hash,
+                                      signature );
+    }
+    else
+#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
+    if( PSA_ALG_IS_RSA_PSS( alg ) )
+    {
+        mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
+        ret = mbedtls_rsa_rsassa_pss_sign( rsa,
+                                           mbedtls_psa_get_random,
+                                           MBEDTLS_PSA_RANDOM_STATE,
+                                           MBEDTLS_RSA_PRIVATE,
+                                           MBEDTLS_MD_NONE,
+                                           (unsigned int) hash_length,
+                                           hash,
+                                           signature );
+    }
+    else
+#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
+    {
+        status = PSA_ERROR_INVALID_ARGUMENT;
+        goto exit;
+    }
+
+    if( ret == 0 )
+        *signature_length = mbedtls_rsa_get_len( rsa );
+    status = mbedtls_to_psa_error( ret );
+
+exit:
+    mbedtls_rsa_free( rsa );
+    mbedtls_free( rsa );
+
+    return( status );
+}
+
+psa_status_t mbedtls_psa_rsa_verify_hash(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *key_buffer, size_t key_buffer_size,
+    psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
+    const uint8_t *signature, size_t signature_length )
+{
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    mbedtls_rsa_context *rsa = NULL;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    mbedtls_md_type_t md_alg;
+
+    status = mbedtls_psa_rsa_load_representation( attributes->core.type,
+                                                  key_buffer,
+                                                  key_buffer_size,
+                                                  &rsa );
+    if( status != PSA_SUCCESS )
+        goto exit;
+
+    status = psa_rsa_decode_md_type( alg, hash_length, &md_alg );
+    if( status != PSA_SUCCESS )
+        goto exit;
+
+    if( signature_length != mbedtls_rsa_get_len( rsa ) )
+    {
+        status = PSA_ERROR_INVALID_SIGNATURE;
+        goto exit;
+    }
+
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
+    if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) )
+    {
+        mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V15,
+                                 MBEDTLS_MD_NONE );
+        ret = mbedtls_rsa_pkcs1_verify( rsa,
+                                        mbedtls_psa_get_random,
+                                        MBEDTLS_PSA_RANDOM_STATE,
+                                        MBEDTLS_RSA_PUBLIC,
+                                        md_alg,
+                                        (unsigned int) hash_length,
+                                        hash,
+                                        signature );
+    }
+    else
+#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
+    if( PSA_ALG_IS_RSA_PSS( alg ) )
+    {
+        mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
+        ret = mbedtls_rsa_rsassa_pss_verify( rsa,
+                                             mbedtls_psa_get_random,
+                                             MBEDTLS_PSA_RANDOM_STATE,
+                                             MBEDTLS_RSA_PUBLIC,
+                                             MBEDTLS_MD_NONE,
+                                             (unsigned int) hash_length,
+                                             hash,
+                                             signature );
+    }
+    else
+#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
+    {
+        status = PSA_ERROR_INVALID_ARGUMENT;
+        goto exit;
+    }
+
+    /* Mbed TLS distinguishes "invalid padding" from "valid padding but
+     * the rest of the signature is invalid". This has little use in
+     * practice and PSA doesn't report this distinction. */
+    status = ( ret == MBEDTLS_ERR_RSA_INVALID_PADDING ) ?
+             PSA_ERROR_INVALID_SIGNATURE :
+             mbedtls_to_psa_error( ret );
+
+exit:
+    mbedtls_rsa_free( rsa );
+    mbedtls_free( rsa );
+
+    return( status );
+}
+
+#endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
+        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) */
+
 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) || \
     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)