Merge pull request #5771 from superna9999/5761-rsa-decrypt-rework-pk-wrap-as-opaque

RSA decrypt 0: Rework `mbedtls_pk_wrap_as_opaque()`
diff --git a/include/mbedtls/pk.h b/include/mbedtls/pk.h
index dc808e8..5225c57 100644
--- a/include/mbedtls/pk.h
+++ b/include/mbedtls/pk.h
@@ -922,28 +922,24 @@
  * \warning         This is a temporary utility function for tests. It might
  *                  change or be removed at any time without notice.
  *
- * \note            ECDSA & RSA keys are supported.
- *                  For both key types, signing with the specified hash
- *                  is the only allowed use of that key with PK API.
- *                  The RSA key supports RSA-PSS signing with the specified
- *                  hash with the PK EXT API.
- *                  In addition, the ECDSA key is also allowed for ECDH key
- *                  agreement derivation operation using the PSA API.
- *
  * \param pk        Input: the EC or RSA key to import to a PSA key.
  *                  Output: a PK context wrapping that PSA key.
  * \param key       Output: a PSA key identifier.
  *                  It's the caller's responsibility to call
  *                  psa_destroy_key() on that key identifier after calling
  *                  mbedtls_pk_free() on the PK context.
- * \param hash_alg  The hash algorithm to allow for use with that key.
+ * \param alg       The algorithm to allow for use with that key.
+ * \param usage     The usage to allow for use with that key.
+ * \param alg2      The secondary algorithm to allow for use with that key.
  *
  * \return          \c 0 if successful.
  * \return          An Mbed TLS error code otherwise.
  */
 int mbedtls_pk_wrap_as_opaque( mbedtls_pk_context *pk,
                                mbedtls_svc_key_id_t *key,
-                               psa_algorithm_t hash_alg );
+                               psa_algorithm_t alg,
+                               psa_key_usage_t usage,
+                               psa_algorithm_t alg2 );
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #ifdef __cplusplus
diff --git a/library/pk.c b/library/pk.c
index bba2ef7..42ff432 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -720,12 +720,16 @@
  */
 int mbedtls_pk_wrap_as_opaque( mbedtls_pk_context *pk,
                                mbedtls_svc_key_id_t *key,
-                               psa_algorithm_t hash_alg )
+                               psa_algorithm_t alg,
+                               psa_key_usage_t usage,
+                               psa_algorithm_t alg2 )
 {
 #if !defined(MBEDTLS_ECP_C) && !defined(MBEDTLS_RSA_C)
     ((void) pk);
     ((void) key);
-    ((void) hash_alg);
+    ((void) alg);
+    ((void) usage);
+    ((void) alg2);
 #else
 #if defined(MBEDTLS_ECP_C)
     if( mbedtls_pk_get_type( pk ) == MBEDTLS_PK_ECKEY )
@@ -752,10 +756,10 @@
         /* prepare the key attributes */
         psa_set_key_type( &attributes, key_type );
         psa_set_key_bits( &attributes, bits );
-        psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_SIGN_HASH |
-                                              PSA_KEY_USAGE_DERIVE);
-        psa_set_key_algorithm( &attributes, PSA_ALG_ECDSA( hash_alg ) );
-        psa_set_key_enrollment_algorithm( &attributes, PSA_ALG_ECDH );
+        psa_set_key_usage_flags( &attributes, usage );
+        psa_set_key_algorithm( &attributes, alg );
+        if( alg2 != PSA_ALG_NONE )
+            psa_set_key_enrollment_algorithm( &attributes, alg2 );
 
         /* import private key into PSA */
         status = psa_import_key( &attributes, d, d_len, key );
@@ -786,11 +790,10 @@
         /* prepare the key attributes */
         psa_set_key_type( &attributes, PSA_KEY_TYPE_RSA_KEY_PAIR );
         psa_set_key_bits( &attributes, mbedtls_pk_get_bitlen( pk ) );
-        psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_SIGN_HASH );
-        psa_set_key_algorithm( &attributes,
-                               PSA_ALG_RSA_PKCS1V15_SIGN( hash_alg ) );
-        psa_set_key_enrollment_algorithm( &attributes,
-                                          PSA_ALG_RSA_PSS( hash_alg ) );
+        psa_set_key_usage_flags( &attributes, usage );
+        psa_set_key_algorithm( &attributes, alg );
+        if( alg2 != PSA_ALG_NONE )
+            psa_set_key_enrollment_algorithm( &attributes, alg2 );
 
         /* import private key into PSA */
         status = psa_import_key( &attributes,
diff --git a/programs/ssl/ssl_client2.c b/programs/ssl/ssl_client2.c
index 2f33d8f..f741d99 100644
--- a/programs/ssl/ssl_client2.c
+++ b/programs/ssl/ssl_client2.c
@@ -1698,8 +1698,22 @@
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( opt.key_opaque != 0 )
     {
-        if( ( ret = mbedtls_pk_wrap_as_opaque( &pkey, &key_slot,
-                                               PSA_ALG_ANY_HASH ) ) != 0 )
+        psa_algorithm_t psa_alg, psa_alg2;
+
+        if( mbedtls_pk_get_type( &pkey ) == MBEDTLS_PK_ECKEY )
+        {
+            psa_alg = PSA_ALG_ECDSA( PSA_ALG_ANY_HASH );
+            psa_alg2 = PSA_ALG_NONE;
+        }
+        else
+        {
+            psa_alg = PSA_ALG_RSA_PKCS1V15_SIGN( PSA_ALG_ANY_HASH );
+            psa_alg2 = PSA_ALG_RSA_PSS( PSA_ALG_ANY_HASH );
+        }
+
+        if( ( ret = mbedtls_pk_wrap_as_opaque( &pkey, &key_slot, psa_alg,
+                                               PSA_KEY_USAGE_SIGN_HASH,
+                                               psa_alg2 ) ) != 0 )
         {
             mbedtls_printf( " failed\n  !  "
                             "mbedtls_pk_wrap_as_opaque returned -0x%x\n\n", (unsigned int)  -ret );
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 46c6b76..d728b95 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -2564,11 +2564,29 @@
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( opt.key_opaque != 0 )
     {
+        psa_algorithm_t psa_alg, psa_alg2;
+        psa_key_usage_t psa_usage;
+
         if ( mbedtls_pk_get_type( &pkey ) == MBEDTLS_PK_ECKEY ||
              mbedtls_pk_get_type( &pkey ) == MBEDTLS_PK_RSA )
         {
+            if( mbedtls_pk_get_type( &pkey ) == MBEDTLS_PK_ECKEY )
+            {
+                psa_alg = PSA_ALG_ECDSA( PSA_ALG_ANY_HASH );
+                psa_alg2 = PSA_ALG_ECDH;
+                psa_usage = PSA_KEY_USAGE_SIGN_HASH | PSA_KEY_USAGE_DERIVE;
+            }
+            else
+            {
+                psa_alg = PSA_ALG_RSA_PKCS1V15_SIGN( PSA_ALG_ANY_HASH );
+                psa_alg2 = PSA_ALG_NONE;
+                psa_usage = PSA_KEY_USAGE_SIGN_HASH;
+            }
+
             if( ( ret = mbedtls_pk_wrap_as_opaque( &pkey, &key_slot,
-                                                PSA_ALG_ANY_HASH ) ) != 0 )
+                                                   psa_alg,
+                                                   psa_usage,
+                                                   psa_alg2 ) ) != 0 )
             {
                 mbedtls_printf( " failed\n  !  "
                                 "mbedtls_pk_wrap_as_opaque returned -0x%x\n\n", (unsigned int)  -ret );
@@ -2579,8 +2597,23 @@
         if ( mbedtls_pk_get_type( &pkey2 ) == MBEDTLS_PK_ECKEY ||
              mbedtls_pk_get_type( &pkey2 ) == MBEDTLS_PK_RSA )
         {
+            if( mbedtls_pk_get_type( &pkey2 ) == MBEDTLS_PK_ECKEY )
+            {
+                psa_alg = PSA_ALG_ECDSA( PSA_ALG_ANY_HASH );
+                psa_alg2 = PSA_ALG_ECDH;
+                psa_usage = PSA_KEY_USAGE_SIGN_HASH | PSA_KEY_USAGE_DERIVE;
+            }
+            else
+            {
+                psa_alg = PSA_ALG_RSA_PKCS1V15_SIGN( PSA_ALG_ANY_HASH );
+                psa_alg2 = PSA_ALG_NONE;
+                psa_usage = PSA_KEY_USAGE_SIGN_HASH;
+            }
+
             if( ( ret = mbedtls_pk_wrap_as_opaque( &pkey2, &key_slot2,
-                                                PSA_ALG_ANY_HASH ) ) != 0 )
+                                                   psa_alg,
+                                                   psa_usage,
+                                                   psa_alg2 ) ) != 0 )
             {
                 mbedtls_printf( " failed\n  !  "
                                 "mbedtls_pk_wrap_as_opaque returned -0x%x\n\n", (unsigned int)  -ret );
diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function
index 4b3af4c..32c2644 100644
--- a/tests/suites/test_suite_pk.function
+++ b/tests/suites/test_suite_pk.function
@@ -1080,6 +1080,7 @@
     unsigned char pkey_legacy[200];
     unsigned char pkey_psa[200];
     unsigned char *pkey_legacy_start, *pkey_psa_start;
+    psa_algorithm_t alg_psa;
     size_t sig_len, klen_legacy, klen_psa;
     int ret;
     mbedtls_svc_key_id_t key_id;
@@ -1107,6 +1108,7 @@
         TEST_ASSERT( mbedtls_rsa_gen_key( mbedtls_pk_rsa( pk ),
                         mbedtls_test_rnd_std_rand, NULL,
                         parameter_arg, 3 ) == 0 );
+        alg_psa = PSA_ALG_RSA_PKCS1V15_SIGN( PSA_ALG_SHA_256 );
     }
     else
 #endif /* MBEDTLS_RSA_C && MBEDTLS_GENPRIME */
@@ -1122,6 +1124,7 @@
         TEST_ASSERT( mbedtls_ecp_gen_key( grpid,
                         (mbedtls_ecp_keypair*) pk.pk_ctx,
                         mbedtls_test_rnd_std_rand, NULL ) == 0 );
+        alg_psa = PSA_ALG_ECDSA( PSA_ALG_SHA_256 );
     }
     else
 #endif /* MBEDTLS_ECDSA_C */
@@ -1139,8 +1142,9 @@
     pkey_legacy_start = pkey_legacy + sizeof( pkey_legacy ) - klen_legacy;
 
     /* Turn PK context into an opaque one. */
-    TEST_ASSERT( mbedtls_pk_wrap_as_opaque( &pk, &key_id,
-                                            PSA_ALG_SHA_256 ) == 0 );
+    TEST_ASSERT( mbedtls_pk_wrap_as_opaque( &pk, &key_id, alg_psa,
+                                            PSA_KEY_USAGE_SIGN_HASH,
+                                            PSA_ALG_NONE ) == 0 );
 
     PSA_ASSERT( psa_get_key_attributes( key_id, &attributes ) );
     TEST_EQUAL( psa_get_key_type( &attributes ), expected_type );
@@ -1241,6 +1245,7 @@
     unsigned char *pkey_start;
     unsigned char hash[MBEDTLS_MD_MAX_SIZE];
     psa_algorithm_t psa_md_alg = mbedtls_psa_translate_md( md_alg );
+    psa_algorithm_t psa_alg;
     size_t hash_len = PSA_HASH_LENGTH( psa_md_alg );
     const mbedtls_md_info_t *md_info = mbedtls_md_info_from_type( md_alg );
     void const *options = NULL;
@@ -1266,8 +1271,17 @@
     /* mbedtls_pk_write_pubkey_der() writes backwards in the data buffer. */
     pkey_start = pkey + sizeof( pkey ) - pkey_len;
 
+    if( key_pk_type == MBEDTLS_PK_RSA )
+        psa_alg = PSA_ALG_RSA_PKCS1V15_SIGN( psa_md_alg );
+    else if( key_pk_type == MBEDTLS_PK_RSASSA_PSS )
+        psa_alg = PSA_ALG_RSA_PSS( psa_md_alg );
+    else
+        TEST_ASSUME( ! "PK key type not supported in this configuration" );
+
     /* Turn PK context into an opaque one. */
-    TEST_EQUAL( mbedtls_pk_wrap_as_opaque( &pk, &key_id, psa_md_alg ), 0 );
+    TEST_EQUAL( mbedtls_pk_wrap_as_opaque( &pk, &key_id, psa_alg,
+                                           PSA_KEY_USAGE_SIGN_HASH,
+                                           PSA_ALG_NONE ), 0 );
 
     memset( hash, 0x2a, sizeof( hash ) );
     memset( sig, 0, sizeof( sig ) );
diff --git a/tests/suites/test_suite_x509write.function b/tests/suites/test_suite_x509write.function
index 485bbe2..f5001bd 100644
--- a/tests/suites/test_suite_x509write.function
+++ b/tests/suites/test_suite_x509write.function
@@ -170,7 +170,7 @@
 {
     mbedtls_pk_context key;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
-    psa_algorithm_t md_alg_psa;
+    psa_algorithm_t md_alg_psa, alg_psa;
     mbedtls_x509write_csr req;
     unsigned char buf[4096];
     int ret;
@@ -187,7 +187,17 @@
     mbedtls_pk_init( &key );
     TEST_ASSERT( mbedtls_pk_parse_keyfile( &key, key_file, NULL,
                         mbedtls_test_rnd_std_rand, NULL ) == 0 );
-    TEST_ASSERT( mbedtls_pk_wrap_as_opaque( &key, &key_id, md_alg_psa ) == 0 );
+
+    if( mbedtls_pk_get_type( &key ) == MBEDTLS_PK_ECKEY )
+        alg_psa = PSA_ALG_ECDSA( md_alg_psa );
+    else if( mbedtls_pk_get_type( &key ) == MBEDTLS_PK_RSA )
+        alg_psa = PSA_ALG_RSA_PKCS1V15_SIGN( md_alg_psa );
+    else
+        TEST_ASSUME( ! "PK key type not supported in this configuration" );
+
+    TEST_ASSERT( mbedtls_pk_wrap_as_opaque( &key, &key_id, alg_psa,
+                                            PSA_KEY_USAGE_SIGN_HASH,
+                                            PSA_ALG_NONE ) == 0 );
 
     mbedtls_x509write_csr_init( &req );
     mbedtls_x509write_csr_set_md_alg( &req, md_type );
@@ -280,12 +290,21 @@
     /* For Opaque PK contexts, wrap key as an Opaque RSA context. */
     if( pk_wrap == 2 )
     {
-        psa_algorithm_t md_alg_psa =
-            mbedtls_psa_translate_md( (mbedtls_md_type_t) md_type );
+        psa_algorithm_t alg_psa, md_alg_psa;
 
+        md_alg_psa = mbedtls_psa_translate_md( (mbedtls_md_type_t) md_type );
         TEST_ASSERT( md_alg_psa != MBEDTLS_MD_NONE );
-        TEST_ASSERT( mbedtls_pk_wrap_as_opaque( &issuer_key, &key_id,
-                                                md_alg_psa ) == 0 );
+
+        if( mbedtls_pk_get_type( &issuer_key ) == MBEDTLS_PK_ECKEY )
+            alg_psa = PSA_ALG_ECDSA( md_alg_psa );
+        else if( mbedtls_pk_get_type( &issuer_key ) == MBEDTLS_PK_RSA )
+            alg_psa = PSA_ALG_RSA_PKCS1V15_SIGN( md_alg_psa );
+        else
+            TEST_ASSUME( ! "PK key type not supported in this configuration" );
+
+        TEST_ASSERT( mbedtls_pk_wrap_as_opaque( &issuer_key, &key_id, alg_psa,
+                                                PSA_KEY_USAGE_SIGN_HASH,
+                                                PSA_ALG_NONE ) == 0 );
     }
 #endif /* MBEDTLS_USE_PSA_CRYPTO */