Merge pull request #5636 from mprse/tls_ecdh_2b

TLS ECDH 2b: client-side static ECDH (1.2)
diff --git a/include/mbedtls/ssl_ticket.h b/include/mbedtls/ssl_ticket.h
index 8559309..98fd287 100644
--- a/include/mbedtls/ssl_ticket.h
+++ b/include/mbedtls/ssl_ticket.h
@@ -34,6 +34,10 @@
 #include "mbedtls/ssl.h"
 #include "mbedtls/cipher.h"
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#include "psa/crypto.h"
+#endif
+
 #if defined(MBEDTLS_THREADING_C)
 #include "mbedtls/threading.h"
 #endif
@@ -53,7 +57,14 @@
     unsigned char MBEDTLS_PRIVATE(name)[MBEDTLS_SSL_TICKET_KEY_NAME_BYTES];
                                                      /*!< random key identifier              */
     uint32_t MBEDTLS_PRIVATE(generation_time);       /*!< key generation timestamp (seconds) */
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
     mbedtls_cipher_context_t MBEDTLS_PRIVATE(ctx);   /*!< context for auth enc/decryption    */
+#else
+    mbedtls_svc_key_id_t MBEDTLS_PRIVATE(key);       /*!< key used for auth enc/decryption   */
+    psa_algorithm_t MBEDTLS_PRIVATE(alg);            /*!< algorithm of auth enc/decryption   */
+    psa_key_type_t MBEDTLS_PRIVATE(key_type);        /*!< key type                           */
+    size_t MBEDTLS_PRIVATE(key_bits);                /*!< key length in bits                 */
+#endif
 }
 mbedtls_ssl_ticket_key;
 
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index 03516b5..ad1f84c 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -282,6 +282,74 @@
 }
 #endif
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+static int rsa_decrypt_wrap( void *ctx,
+                    const unsigned char *input, size_t ilen,
+                    unsigned char *output, size_t *olen, size_t osize,
+                    int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
+{
+    mbedtls_rsa_context * rsa = (mbedtls_rsa_context *) ctx;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    psa_status_t status;
+    mbedtls_pk_context key;
+    int key_len;
+    unsigned char buf[MBEDTLS_PK_RSA_PRV_DER_MAX_BYTES];
+
+    ((void) f_rng);
+    ((void) p_rng);
+
+#if !defined(MBEDTLS_RSA_ALT)
+    if( rsa->padding != MBEDTLS_RSA_PKCS_V15 )
+        return( MBEDTLS_ERR_RSA_INVALID_PADDING );
+#endif /* !MBEDTLS_RSA_ALT */
+
+    if( ilen != mbedtls_rsa_get_len( rsa ) )
+        return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
+
+    /* mbedtls_pk_write_key_der() expects a full PK context;
+     * re-construct one to make it happy */
+    key.pk_info = &mbedtls_rsa_info;
+    key.pk_ctx = ctx;
+    key_len = mbedtls_pk_write_key_der( &key, buf, sizeof( buf ) );
+    if( key_len <= 0 )
+        return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
+
+    psa_set_key_type( &attributes, PSA_KEY_TYPE_RSA_KEY_PAIR );
+    psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_DECRYPT );
+    psa_set_key_algorithm( &attributes, PSA_ALG_RSA_PKCS1V15_CRYPT );
+
+    status = psa_import_key( &attributes,
+                             buf + sizeof( buf ) - key_len, key_len,
+                             &key_id );
+    if( status != PSA_SUCCESS )
+    {
+        ret = mbedtls_pk_error_from_psa( status );
+        goto cleanup;
+    }
+
+    status = psa_asymmetric_decrypt( key_id, PSA_ALG_RSA_PKCS1V15_CRYPT,
+                                     input, ilen,
+                                     NULL, 0,
+                                     output, osize, olen );
+    if( status != PSA_SUCCESS )
+    {
+        ret = mbedtls_pk_error_from_psa_rsa( status );
+        goto cleanup;
+    }
+
+    ret = 0;
+
+cleanup:
+    mbedtls_platform_zeroize( buf, sizeof( buf ) );
+    status = psa_destroy_key( key_id );
+    if( ret == 0 && status != PSA_SUCCESS )
+        ret = mbedtls_pk_error_from_psa( status );
+
+    return( ret );
+}
+#else
 static int rsa_decrypt_wrap( void *ctx,
                     const unsigned char *input, size_t ilen,
                     unsigned char *output, size_t *olen, size_t osize,
@@ -295,6 +363,7 @@
     return( mbedtls_rsa_pkcs1_decrypt( rsa, f_rng, p_rng,
                 olen, input, output, osize ) );
 }
+#endif
 
 static int rsa_encrypt_wrap( void *ctx,
                     const unsigned char *input, size_t ilen,
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index c757ac8..094fca8 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -3116,21 +3116,113 @@
 
         MBEDTLS_SSL_DEBUG_MSG( 2, ( "ECDHE curve: %s", (*curve)->name ) );
 
-        if( ( ret = mbedtls_ecdh_setup( &ssl->handshake->ecdh_ctx,
-                                        (*curve)->grp_id ) ) != 0 )
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_RSA ||
+            ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecp_group_load", ret );
-            return( ret );
-        }
+            psa_status_t status = PSA_ERROR_GENERIC_ERROR;
+            psa_key_attributes_t key_attributes;
+            mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+            size_t ecdh_bits = 0;
+            uint8_t *p = ssl->out_msg + ssl->out_msglen;
+            const size_t header_size = 4; // curve_type(1), namedcurve(2),
+                                          // data length(1)
+            const size_t data_length_size = 1;
 
-        if( ( ret = mbedtls_ecdh_make_params(
-                  &ssl->handshake->ecdh_ctx, &len,
-                  ssl->out_msg + ssl->out_msglen,
-                  MBEDTLS_SSL_OUT_CONTENT_LEN - ssl->out_msglen,
-                  ssl->conf->f_rng, ssl->conf->p_rng ) ) != 0 )
+            MBEDTLS_SSL_DEBUG_MSG( 1, ( "Perform PSA-based ECDH computation." ) );
+
+            /* Convert EC group to PSA key type. */
+            handshake->ecdh_psa_type = mbedtls_psa_parse_tls_ecc_group(
+                        (*curve)->tls_id, &ecdh_bits );
+
+            if( handshake->ecdh_psa_type == 0 || ecdh_bits > 0xffff )
+            {
+                MBEDTLS_SSL_DEBUG_MSG( 1, ( "Invalid ecc group parse." ) );
+                return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+            }
+            handshake->ecdh_bits = (uint16_t) ecdh_bits;
+
+            key_attributes = psa_key_attributes_init();
+            psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE );
+            psa_set_key_algorithm( &key_attributes, PSA_ALG_ECDH );
+            psa_set_key_type( &key_attributes, handshake->ecdh_psa_type );
+            psa_set_key_bits( &key_attributes, handshake->ecdh_bits );
+
+            /*
+             * ECParameters curve_params
+             *
+             * First byte is curve_type, always named_curve
+             */
+            *p++ = MBEDTLS_ECP_TLS_NAMED_CURVE;
+
+            /*
+             * Next two bytes are the namedcurve value
+             */
+            MBEDTLS_PUT_UINT16_BE( (*curve)->tls_id, p, 0 );
+            p += 2;
+
+            /* Generate ECDH private key. */
+            status = psa_generate_key( &key_attributes,
+                                       &handshake->ecdh_psa_privkey );
+            if( status != PSA_SUCCESS )
+            {
+                ret = psa_ssl_status_to_mbedtls( status );
+                MBEDTLS_SSL_DEBUG_RET( 1, "psa_generate_key", ret );
+                return( ret );
+            }
+
+            /*
+             * ECPoint  public
+             *
+             * First byte is data length.
+             * It will be filled later. p holds now the data length location.
+             */
+
+            /* Export the public part of the ECDH private key from PSA.
+             * Make one byte space for the length.
+             */
+            unsigned char *own_pubkey = p + data_length_size;
+
+            size_t own_pubkey_max_len = (size_t)( MBEDTLS_SSL_OUT_CONTENT_LEN
+                                        - ( own_pubkey - ssl->out_msg ) );
+
+            status = psa_export_public_key( handshake->ecdh_psa_privkey,
+                                            own_pubkey, own_pubkey_max_len,
+                                            &len );
+            if( status != PSA_SUCCESS )
+            {
+                ret = psa_ssl_status_to_mbedtls( status );
+                MBEDTLS_SSL_DEBUG_RET( 1, "psa_export_public_key", ret );
+                (void) psa_destroy_key( handshake->ecdh_psa_privkey );
+                handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+                return( ret );
+            }
+
+            /* Store the length of the exported public key. */
+            *p = (uint8_t) len;
+
+            /* Determine full message length. */
+            len += header_size;
+        }
+        else
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecdh_make_params", ret );
-            return( ret );
+            if( ( ret = mbedtls_ecdh_setup( &ssl->handshake->ecdh_ctx,
+                                            (*curve)->grp_id ) ) != 0 )
+            {
+                MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecp_group_load", ret );
+                return( ret );
+            }
+
+            if( ( ret = mbedtls_ecdh_make_params(
+                    &ssl->handshake->ecdh_ctx, &len,
+                    ssl->out_msg + ssl->out_msglen,
+                    MBEDTLS_SSL_OUT_CONTENT_LEN - ssl->out_msglen,
+                    ssl->conf->f_rng, ssl->conf->p_rng ) ) != 0 )
+            {
+                MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecdh_make_params", ret );
+                return( ret );
+            }
         }
 
 #if defined(MBEDTLS_KEY_EXCHANGE_WITH_SERVER_SIGNATURE_ENABLED)
@@ -3801,6 +3893,67 @@
     }
     else
 #endif /* MBEDTLS_KEY_EXCHANGE_DHE_RSA_ENABLED */
+#if defined(MBEDTLS_USE_PSA_CRYPTO) &&                           \
+        ( defined(MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED) ||     \
+          defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED) )
+    if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_RSA ||
+        ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA )
+    {
+        size_t data_len = (size_t)( *p++ );
+        size_t buf_len = (size_t)( end - p );
+        psa_status_t status = PSA_ERROR_GENERIC_ERROR;
+        mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "Read the peer's public key." ) );
+
+        /*
+         * We must have at least two bytes (1 for length, at least 1 for data)
+         */
+        if( buf_len < 2 )
+        {
+            MBEDTLS_SSL_DEBUG_MSG( 1, ( "Invalid buffer length" ) );
+            return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
+        }
+
+        if( data_len < 1 || data_len > buf_len )
+        {
+            MBEDTLS_SSL_DEBUG_MSG( 1, ( "Invalid data length" ) );
+            return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
+        }
+
+        /* Store peer's ECDH public key. */
+        memcpy( handshake->ecdh_psa_peerkey, p, data_len );
+        handshake->ecdh_psa_peerkey_len = data_len;
+
+        /* Compute ECDH shared secret. */
+        status = psa_raw_key_agreement(
+                    PSA_ALG_ECDH, handshake->ecdh_psa_privkey,
+                    handshake->ecdh_psa_peerkey, handshake->ecdh_psa_peerkey_len,
+                    handshake->premaster, sizeof( handshake->premaster ),
+                    &handshake->pmslen );
+        if( status != PSA_SUCCESS )
+        {
+            ret = psa_ssl_status_to_mbedtls( status );
+            MBEDTLS_SSL_DEBUG_RET( 1, "psa_raw_key_agreement", ret );
+            (void) psa_destroy_key( handshake->ecdh_psa_privkey );
+            handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+            return( ret );
+        }
+
+        status = psa_destroy_key( handshake->ecdh_psa_privkey );
+
+        if( status != PSA_SUCCESS )
+        {
+            ret = psa_ssl_status_to_mbedtls( status );
+            MBEDTLS_SSL_DEBUG_RET( 1, "psa_destroy_key", ret );
+            return( ret );
+        }
+        handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+    }
+    else
+#endif /* MBEDTLS_USE_PSA_CRYPTO &&
+            ( MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED ||
+              MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED ) */
 #if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED) ||                     \
     defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED) ||                   \
     defined(MBEDTLS_KEY_EXCHANGE_ECDH_RSA_ENABLED) ||                      \
diff --git a/library/ssl_ticket.c b/library/ssl_ticket.c
index b04e184..7f65849 100644
--- a/library/ssl_ticket.c
+++ b/library/ssl_ticket.c
@@ -73,6 +73,10 @@
     unsigned char buf[MAX_KEY_BYTES];
     mbedtls_ssl_ticket_key *key = ctx->keys + index;
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+#endif
+
 #if defined(MBEDTLS_HAVE_TIME)
     key->generation_time = (uint32_t) mbedtls_time( NULL );
 #endif
@@ -83,10 +87,23 @@
     if( ( ret = ctx->f_rng( ctx->p_rng, buf, sizeof( buf ) ) ) != 0 )
         return( ret );
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_set_key_usage_flags( &attributes,
+                             PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT );
+    psa_set_key_algorithm( &attributes, key->alg );
+    psa_set_key_type( &attributes, key->key_type );
+    psa_set_key_bits( &attributes, key->key_bits );
+
+    ret = psa_ssl_status_to_mbedtls(
+            psa_import_key( &attributes, buf,
+                            PSA_BITS_TO_BYTES( key->key_bits ),
+                            &key->key ) );
+#else
     /* With GCM and CCM, same context can encrypt & decrypt */
     ret = mbedtls_cipher_setkey( &key->ctx, buf,
                                  mbedtls_cipher_get_key_bitlen( &key->ctx ),
                                  MBEDTLS_ENCRYPT );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     mbedtls_platform_zeroize( buf, sizeof( buf ) );
 
@@ -106,6 +123,10 @@
         uint32_t current_time = (uint32_t) mbedtls_time( NULL );
         uint32_t key_time = ctx->keys[ctx->active].generation_time;
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+#endif
+
         if( current_time >= key_time &&
             current_time - key_time < ctx->ticket_lifetime )
         {
@@ -114,6 +135,13 @@
 
         ctx->active = 1 - ctx->active;
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        if( ( status = psa_destroy_key( ctx->keys[ctx->active].key ) ) != PSA_SUCCESS )
+        {
+            return psa_ssl_status_to_mbedtls( status );
+        }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
         return( ssl_ticket_gen_key( ctx, ctx->active ) );
     }
     else
@@ -131,15 +159,44 @@
 {
     const unsigned char idx = 1 - ctx->active;
     mbedtls_ssl_ticket_key * const key = ctx->keys + idx;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    const size_t bitlen = key->key_bits;
+#else
     const int bitlen = mbedtls_cipher_get_key_bitlen( &key->ctx );
-    int ret;
+#endif
+
     if( nlength < TICKET_KEY_NAME_BYTES || klength * 8 < (size_t)bitlen )
         return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
 
-    /* With GCM and CCM, same context can encrypt & decrypt */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if( ( status = psa_destroy_key( key->key ) ) != PSA_SUCCESS )
+    {
+        ret = psa_ssl_status_to_mbedtls( status );
+        return( ret );
+    }
+
+    psa_set_key_usage_flags( &attributes,
+                             PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT );
+    psa_set_key_algorithm( &attributes, key->alg );
+    psa_set_key_type( &attributes, key->key_type );
+    psa_set_key_bits( &attributes, key->key_bits );
+
+    if( ( status = psa_import_key( &attributes, k,
+                                   PSA_BITS_TO_BYTES( key->key_bits ),
+                                   &key->key ) ) != PSA_SUCCESS )
+    {
+        ret = psa_ssl_status_to_mbedtls( status );
+        return( ret );
+    }
+#else
     ret = mbedtls_cipher_setkey( &key->ctx, k, bitlen, MBEDTLS_ENCRYPT );
     if( ret != 0 )
         return( ret );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     ctx->active = idx;
     ctx->ticket_lifetime = lifetime;
@@ -161,15 +218,22 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     const mbedtls_cipher_info_t *cipher_info;
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_algorithm_t alg;
+    psa_key_type_t key_type;
+    size_t key_bits;
+#endif
+
     ctx->f_rng = f_rng;
     ctx->p_rng = p_rng;
 
     ctx->ticket_lifetime = lifetime;
 
-    cipher_info = mbedtls_cipher_info_from_type( cipher);
+    cipher_info = mbedtls_cipher_info_from_type( cipher );
 
     if( mbedtls_cipher_info_get_mode( cipher_info ) != MBEDTLS_MODE_GCM &&
-        mbedtls_cipher_info_get_mode( cipher_info ) != MBEDTLS_MODE_CCM )
+        mbedtls_cipher_info_get_mode( cipher_info ) != MBEDTLS_MODE_CCM &&
+        mbedtls_cipher_info_get_mode( cipher_info ) != MBEDTLS_MODE_CHACHAPOLY )
     {
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
     }
@@ -178,26 +242,24 @@
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-    ret = mbedtls_cipher_setup_psa( &ctx->keys[0].ctx,
-                                    cipher_info, TICKET_AUTH_TAG_BYTES );
-    if( ret != 0 && ret != MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE )
-        return( ret );
-    /* We don't yet expect to support all ciphers through PSA,
-     * so allow fallback to ordinary mbedtls_cipher_setup(). */
-    if( ret == MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE )
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+    if( mbedtls_ssl_cipher_to_psa( cipher_info->type, TICKET_AUTH_TAG_BYTES,
+                                   &alg, &key_type, &key_bits ) != PSA_SUCCESS )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+    ctx->keys[0].alg = alg;
+    ctx->keys[0].key_type = key_type;
+    ctx->keys[0].key_bits = key_bits;
+
+    ctx->keys[1].alg = alg;
+    ctx->keys[1].key_type = key_type;
+    ctx->keys[1].key_bits = key_bits;
+#else
     if( ( ret = mbedtls_cipher_setup( &ctx->keys[0].ctx, cipher_info ) ) != 0 )
         return( ret );
 
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    ret = mbedtls_cipher_setup_psa( &ctx->keys[1].ctx,
-                                    cipher_info, TICKET_AUTH_TAG_BYTES );
-    if( ret != 0 && ret != MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE )
-        return( ret );
-    if( ret == MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE )
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
     if( ( ret = mbedtls_cipher_setup( &ctx->keys[1].ctx, cipher_info ) ) != 0 )
         return( ret );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     if( ( ret = ssl_ticket_gen_key( ctx, 0 ) ) != 0 ||
         ( ret = ssl_ticket_gen_key( ctx, 1 ) ) != 0 )
@@ -238,6 +300,10 @@
     unsigned char *state = state_len_bytes + TICKET_CRYPT_LEN_BYTES;
     size_t clear_len, ciph_len;
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+#endif
+
     *tlen = 0;
 
     if( ctx == NULL || ctx->f_rng == NULL )
@@ -275,6 +341,17 @@
     MBEDTLS_PUT_UINT16_BE( clear_len, state_len_bytes, 0 );
 
     /* Encrypt and authenticate */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if( ( status = psa_aead_encrypt( key->key, key->alg, iv, TICKET_IV_BYTES,
+                                     key_name, TICKET_ADD_DATA_LEN,
+                                     state, clear_len,
+                                     state, end - state,
+                                     &ciph_len ) ) != PSA_SUCCESS )
+    {
+        ret = psa_ssl_status_to_mbedtls( status );
+        goto cleanup;
+    }
+#else
     if( ( ret = mbedtls_cipher_auth_encrypt_ext( &key->ctx,
                     iv, TICKET_IV_BYTES,
                     /* Additional data: key name, IV and length */
@@ -285,6 +362,8 @@
     {
         goto cleanup;
     }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
     if( ciph_len != clear_len + TICKET_AUTH_TAG_BYTES )
     {
         ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
@@ -335,6 +414,10 @@
     unsigned char *ticket = enc_len_p + TICKET_CRYPT_LEN_BYTES;
     size_t enc_len, clear_len;
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+#endif
+
     if( ctx == NULL || ctx->f_rng == NULL )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
@@ -367,6 +450,16 @@
     }
 
     /* Decrypt and authenticate */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if( ( status = psa_aead_decrypt( key->key, key->alg, iv, TICKET_IV_BYTES,
+                                     key_name, TICKET_ADD_DATA_LEN,
+                                     ticket, enc_len + TICKET_AUTH_TAG_BYTES,
+                                     ticket, enc_len, &clear_len ) ) != PSA_SUCCESS )
+    {
+        ret = psa_ssl_status_to_mbedtls( status );
+        goto cleanup;
+    }
+#else
     if( ( ret = mbedtls_cipher_auth_decrypt_ext( &key->ctx,
                     iv, TICKET_IV_BYTES,
                     /* Additional data: key name, IV and length */
@@ -380,6 +473,8 @@
 
         goto cleanup;
     }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
     if( clear_len != enc_len )
     {
         ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
@@ -418,8 +513,13 @@
  */
 void mbedtls_ssl_ticket_free( mbedtls_ssl_ticket_context *ctx )
 {
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_destroy_key( ctx->keys[0].key );
+    psa_destroy_key( ctx->keys[1].key );
+#else
     mbedtls_cipher_free( &ctx->keys[0].ctx );
     mbedtls_cipher_free( &ctx->keys[1].ctx );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #if defined(MBEDTLS_THREADING_C)
     mbedtls_mutex_free( &ctx->mutex );
diff --git a/library/ssl_tls13_invasive.h b/library/ssl_tls13_invasive.h
index aa35784..a025dbe 100644
--- a/library/ssl_tls13_invasive.h
+++ b/library/ssl_tls13_invasive.h
@@ -29,6 +29,37 @@
 #if defined(MBEDTLS_PSA_CRYPTO_C)
 
 /**
+ *  \brief  Take the input keying material \p ikm and extract from it a
+ *          fixed-length pseudorandom key \p prk.
+ *
+ *  \param       alg       The HMAC algorithm to use
+ *                         (\c #PSA_ALG_HMAC( PSA_ALG_XXX ) value such that
+ *                         PSA_ALG_XXX is a hash algorithm and
+ *                         #PSA_ALG_IS_HMAC(\p alg) is true).
+ *  \param       salt      An optional salt value (a non-secret random value);
+ *                         if the salt is not provided, a string of all zeros
+ *                         of the length of the hash provided by \p alg is used
+ *                         as the salt.
+ *  \param       salt_len  The length in bytes of the optional \p salt.
+ *  \param       ikm       The input keying material.
+ *  \param       ikm_len   The length in bytes of \p ikm.
+ *  \param[out]  prk       A pseudorandom key of \p prk_len bytes.
+ *  \param       prk_size  Size of the \p prk buffer in bytes.
+ *  \param[out]  prk_len   On success, the length in bytes of the
+ *                         pseudorandom key in \p prk.
+ *
+ *  \return 0 on success.
+ *  \return #PSA_ERROR_INVALID_ARGUMENT when the parameters are invalid.
+ *  \return An PSA_ERROR_* error for errors returned from the underlying
+ *          PSA layer.
+ */
+psa_status_t mbedtls_psa_hkdf_extract( psa_algorithm_t alg,
+                                       const unsigned char *salt, size_t salt_len,
+                                       const unsigned char *ikm, size_t ikm_len,
+                                       unsigned char *prk, size_t prk_size,
+                                       size_t *prk_len );
+
+/**
  *  \brief  Expand the supplied \p prk into several additional pseudorandom
  *          keys, which is the output of the HKDF.
  *
diff --git a/library/ssl_tls13_keys.c b/library/ssl_tls13_keys.c
index 10b3b7e..a5af590 100644
--- a/library/ssl_tls13_keys.c
+++ b/library/ssl_tls13_keys.c
@@ -139,6 +139,59 @@
 #if defined( MBEDTLS_TEST_HOOKS )
 
 MBEDTLS_STATIC_TESTABLE
+psa_status_t mbedtls_psa_hkdf_extract( psa_algorithm_t alg,
+                                       const unsigned char *salt, size_t salt_len,
+                                       const unsigned char *ikm, size_t ikm_len,
+                                       unsigned char *prk, size_t prk_size,
+                                       size_t *prk_len )
+{
+    unsigned char null_salt[PSA_MAC_MAX_SIZE] = { '\0' };
+    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_status_t destroy_status = PSA_ERROR_CORRUPTION_DETECTED;
+
+    if( salt == NULL || salt_len == 0 )
+    {
+        size_t hash_len;
+
+        if( salt_len != 0 )
+        {
+            return( PSA_ERROR_INVALID_ARGUMENT );
+        }
+
+        hash_len = PSA_HASH_LENGTH( alg );
+
+        if( hash_len == 0 )
+        {
+            return( PSA_ERROR_INVALID_ARGUMENT );
+        }
+
+        /* salt_len <= sizeof( salt ) because
+           PSA_HASH_LENGTH( alg ) <= PSA_MAC_MAX_SIZE. */
+        salt = null_salt;
+        salt_len = hash_len;
+    }
+
+    psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_SIGN_MESSAGE );
+    psa_set_key_algorithm( &attributes, alg );
+    psa_set_key_type( &attributes, PSA_KEY_TYPE_HMAC );
+
+    status = psa_import_key( &attributes, salt, salt_len, &key );
+    if( status != PSA_SUCCESS )
+    {
+        goto cleanup;
+    }
+
+    status = psa_mac_compute( key, alg, ikm, ikm_len, prk, prk_size, prk_len );
+
+cleanup:
+    destroy_status = psa_destroy_key( key );
+
+    return( ( status == PSA_SUCCESS ) ? destroy_status : status );
+}
+
+MBEDTLS_STATIC_TESTABLE
 psa_status_t mbedtls_psa_hkdf_expand( psa_algorithm_t alg,
                                       const unsigned char *prk, size_t prk_len,
                                       const unsigned char *info, size_t info_len,
diff --git a/tests/ssl-opt.sh b/tests/ssl-opt.sh
index 9e99c1f..d5334ce 100755
--- a/tests/ssl-opt.sh
+++ b/tests/ssl-opt.sh
@@ -2816,7 +2816,6 @@
             -c "a session has been resumed"
 
 requires_config_enabled MBEDTLS_SSL_PROTO_TLS1_2
-requires_config_disabled MBEDTLS_USE_PSA_CRYPTO
 run_test    "Session resume using tickets: manual rotation" \
             "$P_SRV debug_level=3 tickets=1 ticket_rotate=1" \
             "$P_CLI debug_level=3 tickets=1 reconnect=1" \
@@ -3109,6 +3108,21 @@
             -s "a session has been resumed" \
             -c "a session has been resumed"
 
+requires_config_enabled MBEDTLS_SSL_PROTO_TLS1_2
+run_test    "Session resume using tickets: CHACHA20-POLY1305" \
+            "$P_SRV debug_level=3 tickets=1 ticket_aead=CHACHA20-POLY1305" \
+            "$P_CLI debug_level=3 tickets=1 reconnect=1" \
+            0 \
+            -c "client hello, adding session ticket extension" \
+            -s "found session ticket extension" \
+            -s "server hello, adding session ticket extension" \
+            -c "found session_ticket extension" \
+            -c "parse new session ticket" \
+            -S "session successfully restored from cache" \
+            -s "session successfully restored from ticket" \
+            -s "a session has been resumed" \
+            -c "a session has been resumed"
+
 # Tests for Session Tickets with DTLS
 
 requires_config_enabled MBEDTLS_SSL_PROTO_TLS1_2
diff --git a/tests/suites/test_suite_hkdf.function b/tests/suites/test_suite_hkdf.function
index feb1717..1ad6f3d 100644
--- a/tests/suites/test_suite_hkdf.function
+++ b/tests/suites/test_suite_hkdf.function
@@ -30,71 +30,57 @@
 /* END_CASE */
 
 /* BEGIN_CASE */
-void test_hkdf_extract( int md_alg, char *hex_ikm_string,
-                        char *hex_salt_string, char *hex_prk_string )
+void test_hkdf_extract( int md_alg,
+                        data_t *ikm,
+                        data_t *salt,
+                        data_t *prk )
 {
     int ret;
-    unsigned char *ikm = NULL;
-    unsigned char *salt = NULL;
-    unsigned char *prk = NULL;
     unsigned char *output_prk = NULL;
-    size_t ikm_len, salt_len, prk_len, output_prk_len;
+    size_t output_prk_len;
 
     const mbedtls_md_info_t *md = mbedtls_md_info_from_type( md_alg );
     TEST_ASSERT( md != NULL );
 
     output_prk_len = mbedtls_md_get_size( md );
-    output_prk = mbedtls_calloc( 1, output_prk_len );
+    ASSERT_ALLOC( output_prk, output_prk_len );
 
-    ikm = mbedtls_test_unhexify_alloc( hex_ikm_string, &ikm_len );
-    salt = mbedtls_test_unhexify_alloc( hex_salt_string, &salt_len );
-    prk = mbedtls_test_unhexify_alloc( hex_prk_string, &prk_len );
-
-    ret = mbedtls_hkdf_extract( md, salt, salt_len, ikm, ikm_len, output_prk );
+    ret = mbedtls_hkdf_extract( md, salt->x, salt->len,
+                                ikm->x, ikm->len, output_prk );
     TEST_ASSERT( ret == 0 );
 
-    ASSERT_COMPARE( output_prk, output_prk_len, prk, prk_len );
+    ASSERT_COMPARE( output_prk, output_prk_len, prk->x, prk->len );
 
 exit:
-    mbedtls_free(ikm);
-    mbedtls_free(salt);
-    mbedtls_free(prk);
     mbedtls_free(output_prk);
 }
 /* END_CASE */
 
 /* BEGIN_CASE */
-void test_hkdf_expand( int md_alg, char *hex_info_string,
-                       char *hex_prk_string, char *hex_okm_string )
+void test_hkdf_expand( int md_alg,
+                       data_t *info,
+                       data_t *prk,
+                       data_t *okm )
 {
     enum { OKM_LEN  = 1024 };
     int ret;
-    unsigned char *info = NULL;
-    unsigned char *prk = NULL;
-    unsigned char *okm = NULL;
     unsigned char *output_okm = NULL;
-    size_t info_len, prk_len, okm_len;
 
     const mbedtls_md_info_t *md = mbedtls_md_info_from_type( md_alg );
     TEST_ASSERT( md != NULL );
 
     ASSERT_ALLOC( output_okm, OKM_LEN );
 
-    prk = mbedtls_test_unhexify_alloc( hex_prk_string, &prk_len );
-    info = mbedtls_test_unhexify_alloc( hex_info_string, &info_len );
-    okm = mbedtls_test_unhexify_alloc( hex_okm_string, &okm_len );
-    TEST_ASSERT( prk_len == mbedtls_md_get_size( md ) );
-    TEST_ASSERT( okm_len < OKM_LEN );
+    TEST_ASSERT( prk->len == mbedtls_md_get_size( md ) );
+    TEST_ASSERT( okm->len < OKM_LEN );
 
-    ret = mbedtls_hkdf_expand( md, prk, prk_len, info, info_len,
+    ret = mbedtls_hkdf_expand( md, prk->x, prk->len,
+                               info->x, info->len,
                                output_okm, OKM_LEN );
     TEST_ASSERT( ret == 0 );
-    ASSERT_COMPARE( output_okm, okm_len, okm, okm_len );
+    ASSERT_COMPARE( output_okm, okm->len, okm->x, okm->len );
 
 exit:
-    mbedtls_free(info);
-    mbedtls_free(prk);
-    mbedtls_free(okm);
     mbedtls_free(output_okm);
 }
 /* END_CASE */
@@ -113,7 +99,7 @@
     fake_md_info.type = MBEDTLS_MD_NONE;
     fake_md_info.size = hash_len;
 
-    prk = mbedtls_calloc( MBEDTLS_MD_MAX_SIZE, 1 );
+    ASSERT_ALLOC( prk, MBEDTLS_MD_MAX_SIZE);
     salt_len = 0;
     ikm_len = 0;
 
diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function
index 29f8622..8eff010 100644
--- a/tests/suites/test_suite_pk.function
+++ b/tests/suites/test_suite_pk.function
@@ -756,6 +756,8 @@
     mbedtls_pk_context pk;
     size_t olen;
 
+    USE_PSA_INIT( );
+
     mbedtls_pk_init( &pk );
     mbedtls_mpi_init( &N ); mbedtls_mpi_init( &P );
     mbedtls_mpi_init( &Q ); mbedtls_mpi_init( &E );
@@ -794,6 +796,7 @@
     mbedtls_mpi_free( &N ); mbedtls_mpi_free( &P );
     mbedtls_mpi_free( &Q ); mbedtls_mpi_free( &E );
     mbedtls_pk_free( &pk );
+    USE_PSA_DONE( );
 }
 /* END_CASE */
 
diff --git a/tests/suites/test_suite_ssl.data b/tests/suites/test_suite_ssl.data
index eb1b8f4..0c6e313 100644
--- a/tests/suites/test_suite_ssl.data
+++ b/tests/suites/test_suite_ssl.data
@@ -4392,6 +4392,37 @@
 # Handshake secret to Master Secret
 ssl_tls13_key_evolution:MBEDTLS_MD_SHA256:"fb9fc80689b3a5d02c33243bf69a1b1b20705588a794304a6e7120155edf149a":"":"7f2882bb9b9a46265941653e9c2f19067118151e21d12e57a7b6aca1f8150c8d"
 
+SSL TLS 1.3 Key schedule: HKDF RFC5869 Test Vector #1 Extract
+depends_on:PSA_WANT_ALG_SHA_256
+psa_hkdf_extract:PSA_ALG_HMAC(PSA_ALG_SHA_256):"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":"000102030405060708090a0b0c":"077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5"
+
+SSL TLS 1.3 Key schedule: HKDF RFC5869 Test Vector #2 Extract
+depends_on:PSA_WANT_ALG_SHA_256
+psa_hkdf_extract:PSA_ALG_HMAC(PSA_ALG_SHA_256):"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f":"606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeaf":"06a6b88c5853361a06104c9ceb35b45cef760014904671014a193f40c15fc244"
+
+SSL TLS 1.3 Key schedule: HKDF RFC5869 Test Vector #3 Extract
+depends_on:PSA_WANT_ALG_SHA_256
+psa_hkdf_extract:PSA_ALG_HMAC(PSA_ALG_SHA_256):"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":"":"19ef24a32c717b167f33a91d6f648bdf96596776afdb6377ac434c1c293ccb04"
+
+SSL TLS 1.3 Key schedule: HKDF RFC5869 Test Vector #4 Extract
+depends_on:PSA_WANT_ALG_SHA_1
+psa_hkdf_extract:PSA_ALG_HMAC(PSA_ALG_SHA_1):"0b0b0b0b0b0b0b0b0b0b0b":"000102030405060708090a0b0c":"9b6c18c432a7bf8f0e71c8eb88f4b30baa2ba243"
+
+SSL TLS 1.3 Key schedule: HKDF RFC5869 Test Vector #5 Extract
+depends_on:PSA_WANT_ALG_SHA_1
+psa_hkdf_extract:PSA_ALG_HMAC(PSA_ALG_SHA_1):"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f":"606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeaf":"8adae09a2a307059478d309b26c4115a224cfaf6"
+
+SSL TLS 1.3 Key schedule: HKDF RFC5869 Test Vector #6 Extract
+depends_on:PSA_WANT_ALG_SHA_1
+psa_hkdf_extract:PSA_ALG_HMAC(PSA_ALG_SHA_1):"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b":"":"da8c8a73c7fa77288ec6f5e7c297786aa0d32d01"
+
+SSL TLS 1.3 Key schedule: HKDF RFC5869 Test Vector #7 Extract
+depends_on:PSA_WANT_ALG_SHA_1
+psa_hkdf_extract:PSA_ALG_HMAC(PSA_ALG_SHA_1):"0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c":"":"2adccada18779e7c2077ad2eb19d3f3e731385dd"
+
+SSL TLS 1.3 Key schedule: HKDF extract fails with wrong hash alg
+psa_hkdf_extract_ret:0:PSA_ERROR_INVALID_ARGUMENT
+
 SSL TLS 1.3 Key schedule: HKDF RFC5869 Test Vector #1 Expand
 depends_on:PSA_WANT_ALG_SHA_256
 psa_hkdf_expand:PSA_ALG_HMAC(PSA_ALG_SHA_256):"f0f1f2f3f4f5f6f7f8f9":"077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5":"3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c5bf34007208d5b887185865"
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index 67f4d6e..855cfc7 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -3886,35 +3886,84 @@
 /* END_CASE */
 
 /* BEGIN_CASE depends_on:MBEDTLS_TEST_HOOKS:MBEDTLS_SSL_PROTO_TLS1_3 */
-void psa_hkdf_expand( int alg, char *hex_info_string,
-                      char *hex_prk_string, char *hex_okm_string )
+void psa_hkdf_extract( int alg,
+                       data_t *ikm,
+                       data_t *salt,
+                       data_t *prk )
+{
+    unsigned char *output_prk = NULL;
+    size_t output_prk_size, output_prk_len;
+
+    PSA_INIT( );
+
+    output_prk_size = PSA_HASH_LENGTH( alg );
+    ASSERT_ALLOC( output_prk, output_prk_size );
+
+    PSA_ASSERT( mbedtls_psa_hkdf_extract( alg, salt->x, salt->len,
+                                          ikm->x, ikm->len,
+                                          output_prk, output_prk_size,
+                                          &output_prk_len ) );
+
+    ASSERT_COMPARE( output_prk, output_prk_len, prk->x, prk->len );
+
+exit:
+    mbedtls_free( output_prk );
+
+    PSA_DONE( );
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_TEST_HOOKS:MBEDTLS_SSL_PROTO_TLS1_3 */
+void psa_hkdf_extract_ret( int alg, int ret )
+{
+    int output_ret;
+    unsigned char *salt = NULL;
+    unsigned char *ikm = NULL;
+    unsigned char *prk = NULL;
+    size_t salt_len, ikm_len, prk_len;
+
+    PSA_INIT( );
+
+    ASSERT_ALLOC( prk, PSA_MAC_MAX_SIZE);
+    salt_len = 0;
+    ikm_len = 0;
+    prk_len = 0;
+
+    output_ret = mbedtls_psa_hkdf_extract( alg, salt, salt_len,
+                                           ikm, ikm_len,
+                                           prk, PSA_MAC_MAX_SIZE, &prk_len );
+    TEST_ASSERT( output_ret == ret );
+    TEST_ASSERT( prk_len == 0 );
+
+exit:
+    mbedtls_free( prk );
+
+    PSA_DONE( );
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_TEST_HOOKS:MBEDTLS_SSL_PROTO_TLS1_3 */
+void psa_hkdf_expand( int alg,
+                      data_t *info,
+                      data_t *prk,
+                      data_t *okm )
 {
     enum { OKM_LEN  = 1024 };
-    unsigned char *info = NULL;
-    unsigned char *prk = NULL;
-    unsigned char *okm = NULL;
     unsigned char *output_okm = NULL;
-    size_t info_len, prk_len, okm_len;
 
     PSA_INIT( );
 
     ASSERT_ALLOC( output_okm, OKM_LEN );
+    TEST_ASSERT( prk->len == PSA_HASH_LENGTH( alg ) );
+    TEST_ASSERT( okm->len < OKM_LEN );
 
-    prk = mbedtls_test_unhexify_alloc( hex_prk_string, &prk_len );
-    info = mbedtls_test_unhexify_alloc( hex_info_string, &info_len );
-    okm = mbedtls_test_unhexify_alloc( hex_okm_string, &okm_len );
-    TEST_ASSERT( prk_len == PSA_HASH_LENGTH( alg ) );
-    TEST_ASSERT( okm_len < OKM_LEN );
-
-    PSA_ASSERT( mbedtls_psa_hkdf_expand( alg, prk, prk_len, info, info_len,
+    PSA_ASSERT( mbedtls_psa_hkdf_expand( alg, prk->x, prk->len,
+                                         info->x, info->len,
                                          output_okm, OKM_LEN ) );
 
-    ASSERT_COMPARE( output_okm, okm_len, okm, okm_len );
+    ASSERT_COMPARE( output_okm, okm->len, okm->x, okm->len );
 
 exit:
-    mbedtls_free( info );
-    mbedtls_free( prk );
-    mbedtls_free( okm );
     mbedtls_free( output_okm );
 
     PSA_DONE( );