Implement PSA-based PSK-to-MS derivation in mbedtls_ssl_derive_keys
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 4c0d0c1..7e861a5 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -607,6 +607,28 @@
 #endif
 #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
 
+#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED) && \
+    defined(MBEDTLS_USE_PSA_CRYPTO)
+static int ssl_use_opaque_psk( mbedtls_ssl_context const *ssl )
+{
+    if( ssl->conf->f_psk != NULL )
+    {
+        /* If we've used a callback to select the PSK,
+         * the static configuration is irrelevant. */
+        if( ssl->handshake->psk_opaque != 0 )
+            return( 1 );
+
+        return( 0 );
+    }
+
+    if( ssl->conf->psk_opaque != 0 )
+        return( 1 );
+
+    return( 0 );
+}
+#endif /* MBEDTLS_USE_PSA_CRYPTO &&
+          MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED */
+
 int mbedtls_ssl_derive_keys( mbedtls_ssl_context *ssl )
 {
     int ret = 0;
@@ -758,21 +780,70 @@
         }
 #endif /* MBEDTLS_SSL_EXTENDED_MS_ENABLED */
 
-        ret = handshake->tls_prf( handshake->premaster, handshake->pmslen,
-                                  lbl, salt, salt_len,
-                                  session->master, 48 );
-        if( ret != 0 )
+#if defined(MBEDTLS_USE_PSA_CRYPTO) &&          \
+    defined(MBEDTLS_KEY_EXCHANGE_PSK_ENABLED)
+        if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_PSK &&
+            ssl->minor_ver == MBEDTLS_SSL_MINOR_VERSION_3 &&
+            ssl_use_opaque_psk( ssl ) == 1 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "prf", ret );
-            return( ret );
+            /* Perform PSK-to-MS expansion in a single step. */
+            psa_status_t status;
+            psa_algorithm_t alg;
+            psa_crypto_generator_t generator = PSA_CRYPTO_GENERATOR_INIT;
+            psa_key_slot_t psk;
+
+            MBEDTLS_SSL_DEBUG_MSG( 2, ( "perform PSA-based PSK-to-MS expansion" ) );
+
+            psk = ssl->conf->psk_opaque;
+            if( ssl->handshake->psk_opaque != 0 )
+                psk = ssl->handshake->psk_opaque;
+
+            if( md_type == MBEDTLS_MD_SHA384 )
+                alg = PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_384);
+            else
+                alg = PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_256);
+
+            status = psa_key_derivation( &generator, psk, alg,
+                                         salt, salt_len,
+                                         (unsigned char const *) lbl,
+                                         (size_t) strlen( lbl ),
+                                         48 );
+            if( status != PSA_SUCCESS )
+            {
+                psa_generator_abort( &generator );
+                return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+            }
+
+            status = psa_generator_read( &generator, session->master, 48 );
+            if( status != PSA_SUCCESS )
+            {
+                psa_generator_abort( &generator );
+                return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+            }
+
+            status = psa_generator_abort( &generator );
+            if( status != PSA_SUCCESS )
+                return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
         }
+        else
+#endif
+        {
+            ret = handshake->tls_prf( handshake->premaster, handshake->pmslen,
+                                      lbl, salt, salt_len,
+                                      session->master, 48 );
+            if( ret != 0 )
+            {
+                MBEDTLS_SSL_DEBUG_RET( 1, "prf", ret );
+                return( ret );
+            }
 
-        MBEDTLS_SSL_DEBUG_BUF( 3, "premaster secret",
-                               handshake->premaster,
-                               handshake->pmslen );
+            MBEDTLS_SSL_DEBUG_BUF( 3, "premaster secret",
+                                   handshake->premaster,
+                                   handshake->pmslen );
 
-        mbedtls_platform_zeroize( handshake->premaster,
-                                  sizeof(handshake->premaster) );
+            mbedtls_platform_zeroize( handshake->premaster,
+                                      sizeof(handshake->premaster) );
+        }
     }
 
     /*