Add USE_PSA version of PK test functions

While at it, also fix buffer size for functions that already depend on
USE_PSA: it should be PSA_HASH_MAX_SIZE for functions that always use
PSA, and the new macro MBEDTLS_USE_PSA_MD_MAX_SIZE for functions that
use it or not depending on USE_PSA.

The only case where MBEDTLS_MD_MAX_SIZE is OK is when the function
always uses MD - currently this is the case with
pk_sign_verify_restart() as it is incompatible with USE_PSA anyway.

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/library/use_psa_helpers.h b/library/use_psa_helpers.h
index 6b63ce8..e9a1335 100644
--- a/library/use_psa_helpers.h
+++ b/library/use_psa_helpers.h
@@ -1,8 +1,10 @@
 /**
  *  Internal macros for parts of the code governed by MBEDTLS_USE_PSA_CRYPTO.
- *  These macros allow checking if an algorithm is available, either via the
- *  legacy API or the PSA Crypto API, depending on MBEDTLS_USE_PSA_CRYPTO.
- *  When possible, they're named after the corresponding PSA_WANT_ macro.
+ *  Some macros allow checking if an algorithm is available, either via the
+ *  legacy API or the PSA Crypto API, depending on MBEDTLS_USE_PSA_CRYPTO;
+ *  when possible, they're named after the corresponding PSA_WANT_ macro.
+ *  Other macros provide max sizes or similar information in a USE_PSA-aware
+ *  way; they're name after a similar constant from the legacy API or PSA.
  *
  *  Copyright The Mbed TLS Contributors
  *  SPDX-License-Identifier: Apache-2.0
@@ -55,4 +57,11 @@
 #define MBEDTLS_USE_PSA_WANT_ALG_SHA_512
 #endif
 
+/* Hash information */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#define MBEDTLS_USE_PSA_MD_MAX_SIZE PSA_HASH_MAX_SIZE
+#else
+#define MBEDTLS_USE_PSA_MD_MAX_SIZE MBEDTLS_MD_MAX_SIZE
+#endif
+
 #endif /* MBEDTLS_USE_PSA_HELPERS_H */
diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function
index 78338d7..7fe4594 100644
--- a/tests/suites/test_suite_pk.function
+++ b/tests/suites/test_suite_pk.function
@@ -8,6 +8,7 @@
 #include "mbedtls/rsa.h"
 
 #include "md_internal.h"
+#include "use_psa_helpers.h"
 
 #include <limits.h>
 #include <stdint.h>
@@ -481,7 +482,7 @@
                              char * input_E, data_t * result_str,
                              int result )
 {
-    unsigned char hash_result[MBEDTLS_MD_MAX_SIZE];
+    unsigned char hash_result[MBEDTLS_USE_PSA_MD_MAX_SIZE];
     mbedtls_rsa_context *rsa;
     mbedtls_pk_context pk;
     mbedtls_pk_restart_ctx *rs_ctx = NULL;
@@ -498,7 +499,7 @@
 
     mbedtls_pk_init( &pk );
 
-    memset( hash_result, 0x00, MBEDTLS_MD_MAX_SIZE );
+    memset( hash_result, 0x00, sizeof( hash_result ) );
 
     TEST_ASSERT( mbedtls_pk_setup( &pk, mbedtls_pk_info_from_type( MBEDTLS_PK_RSA ) ) == 0 );
     rsa = mbedtls_pk_rsa( pk );
@@ -508,8 +509,20 @@
     TEST_ASSERT( mbedtls_test_read_mpi( &rsa->E, radix_E, input_E ) == 0 );
 
 
-    if( mbedtls_md_info_from_type( digest ) != NULL )
-        TEST_ASSERT( mbedtls_md( mbedtls_md_info_from_type( digest ), message_str->x, message_str->len, hash_result ) == 0 );
+    if( digest != MBEDTLS_MD_NONE )
+    {
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        size_t hash_len;
+        psa_algorithm_t hash_alg = mbedtls_psa_translate_md( digest );
+        TEST_EQUAL( PSA_SUCCESS, psa_hash_compute( hash_alg,
+                        message_str->x, message_str->len,
+                        hash_result, sizeof( hash_result ), &hash_len ) );
+
+#else
+        TEST_EQUAL( 0, mbedtls_md( mbedtls_md_info_from_type( digest ),
+                            message_str->x, message_str->len, hash_result ) );
+#endif
+    }
 
     TEST_ASSERT( mbedtls_pk_verify( &pk, digest, hash_result, 0,
                             result_str->x, mbedtls_pk_get_len( &pk ) ) == result );
@@ -534,7 +547,7 @@
                                  int mgf1_hash_id, int salt_len, int sig_len,
                                  int result )
 {
-    unsigned char hash_result[MBEDTLS_MD_MAX_SIZE];
+    unsigned char hash_result[MBEDTLS_USE_PSA_MD_MAX_SIZE];
     mbedtls_rsa_context *rsa;
     mbedtls_pk_context pk;
     mbedtls_pk_rsassa_pss_options pss_opts;
@@ -557,10 +570,17 @@
 
     if( digest != MBEDTLS_MD_NONE )
     {
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        psa_algorithm_t hash_alg = mbedtls_psa_translate_md( digest );
+        TEST_EQUAL( PSA_SUCCESS, psa_hash_compute( hash_alg,
+                    message_str->x, message_str->len,
+                    hash_result, sizeof( hash_result ), &hash_len ) );
+#else
         const mbedtls_md_info_t *md_info = mbedtls_md_info_from_type( digest );
         TEST_ASSERT( mbedtls_md( md_info, message_str->x, message_str->len,
                                  hash_result ) == 0 );
         hash_len = mbedtls_md_get_size( md_info );
+#endif
     }
     else
     {
@@ -1317,7 +1337,7 @@
     mbedtls_pk_context pk;
     size_t sig_len;
     unsigned char sig[MBEDTLS_PK_SIGNATURE_MAX_SIZE];
-    unsigned char hash[MBEDTLS_MD_MAX_SIZE];
+    unsigned char hash[PSA_HASH_MAX_SIZE];
     size_t hash_len = mbedtls_md_internal_get_size( md_alg );
     void const *options = NULL;
     mbedtls_pk_rsassa_pss_options rsassa_pss_options;
@@ -1361,7 +1381,7 @@
     unsigned char sig[MBEDTLS_PK_SIGNATURE_MAX_SIZE];
     unsigned char pkey[PSA_EXPORT_PUBLIC_KEY_MAX_SIZE];
     unsigned char *pkey_start;
-    unsigned char hash[MBEDTLS_MD_MAX_SIZE];
+    unsigned char hash[PSA_HASH_MAX_SIZE];
     psa_algorithm_t psa_md_alg = mbedtls_psa_translate_md( md_alg );
     psa_algorithm_t psa_alg;
     size_t hash_len = PSA_HASH_LENGTH( psa_md_alg );