Merge pull request #5524 from mprse/tls_ecdh_2c

TLS ECDH 2c: ECHDE in TLS 1.3 (client-side)
diff --git a/include/mbedtls/psa_util.h b/include/mbedtls/psa_util.h
index a70ee96..b4c7ba8 100644
--- a/include/mbedtls/psa_util.h
+++ b/include/mbedtls/psa_util.h
@@ -29,7 +29,7 @@
 
 #include "mbedtls/build_info.h"
 
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)
 
 #include "psa/crypto.h"
 
@@ -363,6 +363,6 @@
 
 #endif /* !defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG) */
 
-#endif /* MBEDTLS_PSA_CRYPTO_C */
+#endif /* defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3) */
 
 #endif /* MBEDTLS_PSA_UTIL_H */
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 2e3c1ef..9ed5dfa 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -27,8 +27,9 @@
 #include "mbedtls/ssl.h"
 #include "mbedtls/cipher.h"
 
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)
 #include "psa/crypto.h"
+#include "mbedtls/psa_util.h"
 #endif
 
 #if defined(MBEDTLS_MD5_C)
@@ -629,13 +630,13 @@
 #if defined(MBEDTLS_ECDH_C) || defined(MBEDTLS_ECDSA_C)
     mbedtls_ecdh_context ecdh_ctx;              /*!<  ECDH key exchange       */
 
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)
     psa_key_type_t ecdh_psa_type;
     uint16_t ecdh_bits;
     mbedtls_svc_key_id_t ecdh_psa_privkey;
     unsigned char ecdh_psa_peerkey[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
     size_t ecdh_psa_peerkey_len;
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+#endif /* MBEDTLS_USE_PSA_CRYPTO || MBEDTLS_SSL_PROTO_TLS1_3 */
 #endif /* MBEDTLS_ECDH_C || MBEDTLS_ECDSA_C */
 
 #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED)
@@ -2109,7 +2110,9 @@
                                     psa_algorithm_t *alg,
                                     psa_key_type_t *key_type,
                                     size_t *key_size );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)
 /**
  * \brief       Convert given PSA status to mbedtls error code.
  *
@@ -2133,6 +2136,6 @@
             return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
     }
 }
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+#endif /* MBEDTLS_USE_PSA_CRYPTO || MBEDTLS_SSL_PROTO_TLS1_3 */
 
 #endif /* ssl_misc.h */
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index 30b1ed4..eac3d4f 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -146,33 +146,61 @@
                 unsigned char *end,
                 size_t *out_len )
 {
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    const mbedtls_ecp_curve_info *curve_info =
-        mbedtls_ecp_curve_info_from_tls_id( named_group );
+    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
+    int ret = MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE;
+    psa_key_attributes_t key_attributes;
+    size_t own_pubkey_len;
+    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+    size_t ecdh_bits = 0;
 
-    if( curve_info == NULL )
-        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+    MBEDTLS_SSL_DEBUG_MSG( 1, ( "Perform PSA-based ECDH computation." ) );
 
-    MBEDTLS_SSL_DEBUG_MSG( 3, ( "offer curve %s", curve_info->name ) );
+    /* Convert EC group to PSA key type. */
+    if( ( handshake->ecdh_psa_type =
+        mbedtls_psa_parse_tls_ecc_group( named_group, &ecdh_bits ) ) == 0 )
+            return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
 
-    if( ( ret = mbedtls_ecdh_setup_no_everest( &ssl->handshake->ecdh_ctx,
-                                               curve_info->grp_id ) ) != 0 )
+    if( ecdh_bits > 0xffff )
+        return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+    ssl->handshake->ecdh_bits = (uint16_t) ecdh_bits;
+
+    key_attributes = psa_key_attributes_init();
+    psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE );
+    psa_set_key_algorithm( &key_attributes, PSA_ALG_ECDH );
+    psa_set_key_type( &key_attributes, handshake->ecdh_psa_type );
+    psa_set_key_bits( &key_attributes, handshake->ecdh_bits );
+
+    /* Generate ECDH private key. */
+    status = psa_generate_key( &key_attributes,
+                                &handshake->ecdh_psa_privkey );
+    if( status != PSA_SUCCESS )
     {
-        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecdh_setup_no_everest", ret );
+        ret = psa_ssl_status_to_mbedtls( status );
+        MBEDTLS_SSL_DEBUG_RET( 1, "psa_generate_key", ret );
         return( ret );
+
     }
 
-    ret = mbedtls_ecdh_tls13_make_params( &ssl->handshake->ecdh_ctx, out_len,
-                                          buf, end - buf,
-                                          ssl->conf->f_rng, ssl->conf->p_rng );
-    if( ret != 0 )
+    /* Export the public part of the ECDH private key from PSA. */
+    status = psa_export_public_key( handshake->ecdh_psa_privkey,
+                                    buf, (size_t)( end - buf ),
+                                    &own_pubkey_len );
+    if( status != PSA_SUCCESS )
     {
-        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecdh_tls13_make_params", ret );
+        ret = psa_ssl_status_to_mbedtls( status );
+        MBEDTLS_SSL_DEBUG_RET( 1, "psa_export_public_key", ret );
         return( ret );
+
     }
 
-    MBEDTLS_SSL_DEBUG_ECDH( 3, &ssl->handshake->ecdh_ctx,
-                            MBEDTLS_DEBUG_ECDH_Q );
+    if( own_pubkey_len > (size_t)( end - buf ) )
+    {
+            MBEDTLS_SSL_DEBUG_MSG( 1, ( "No space in the buffer for ECDH public key." ) );
+        return( MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL );
+    }
+
+    *out_len = own_pubkey_len;
+
     return( 0 );
 }
 #endif /* MBEDTLS_ECDH_C */
@@ -273,7 +301,7 @@
         /* Pointer to group */
         unsigned char *group = p;
         /* Length of key_exchange */
-        size_t key_exchange_len;
+        size_t key_exchange_len = 0;
 
         /* Check there is space for header of KeyShareEntry
          * - group                  (2 bytes)
@@ -281,8 +309,7 @@
          */
         MBEDTLS_SSL_CHK_BUF_PTR( p, end, 4 );
         p += 4;
-        ret = ssl_tls13_generate_and_write_ecdh_key_exchange( ssl, group_id,
-                                                              p, end,
+        ret = ssl_tls13_generate_and_write_ecdh_key_exchange( ssl, group_id, p, end,
                                                               &key_exchange_len );
         p += key_exchange_len;
         if( ret != 0 )
@@ -333,59 +360,24 @@
 
 #if defined(MBEDTLS_ECDH_C)
 
-static int ssl_tls13_check_ecdh_params( const mbedtls_ssl_context *ssl )
-{
-    const mbedtls_ecp_curve_info *curve_info;
-    mbedtls_ecp_group_id grp_id;
-#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
-    grp_id = ssl->handshake->ecdh_ctx.grp.id;
-#else
-    grp_id = ssl->handshake->ecdh_ctx.grp_id;
-#endif
-
-    curve_info = mbedtls_ecp_curve_info_from_grp_id( grp_id );
-    if( curve_info == NULL )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) );
-        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
-    }
-
-    MBEDTLS_SSL_DEBUG_MSG( 2, ( "ECDH curve: %s", curve_info->name ) );
-
-    if( mbedtls_ssl_check_curve( ssl, grp_id ) != 0 )
-        return( -1 );
-
-    MBEDTLS_SSL_DEBUG_ECDH( 3, &ssl->handshake->ecdh_ctx,
-                            MBEDTLS_DEBUG_ECDH_QP );
-
-    return( 0 );
-}
-
 static int ssl_tls13_read_public_ecdhe_share( mbedtls_ssl_context *ssl,
                                               const unsigned char *buf,
                                               size_t buf_len )
 {
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    uint8_t *p = (uint8_t*)buf;
+    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
 
-    ret = mbedtls_ecdh_tls13_read_public( &ssl->handshake->ecdh_ctx,
-                                          buf, buf_len );
-    if( ret != 0 )
-    {
-        MBEDTLS_SSL_DEBUG_RET( 1, ( "mbedtls_ecdh_tls13_read_public" ), ret );
+    /* Get size of the TLS opaque key_exchange field of the KeyShareEntry struct. */
+    uint16_t peerkey_len = MBEDTLS_GET_UINT16_BE( p, 0 );
+    p += 2;
 
-        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
-                                      MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-        return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-    }
+    /* Check if key size is consistent with given buffer length. */
+    if ( peerkey_len > ( buf_len - 2 ) )
+        return( MBEDTLS_ERR_SSL_DECODE_ERROR );
 
-    if( ssl_tls13_check_ecdh_params( ssl ) != 0 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "ssl_tls13_check_ecdh_params() failed!" ) );
-
-        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
-                                      MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-        return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-    }
+    /* Store peer's ECDH public key. */
+    memcpy( handshake->ecdh_psa_peerkey, p, peerkey_len );
+    handshake->ecdh_psa_peerkey_len = peerkey_len;
 
     return( 0 );
 }
@@ -504,7 +496,16 @@
 #if defined(MBEDTLS_ECDH_C)
     if( mbedtls_ssl_tls13_named_group_is_ecdhe( group ) )
     {
-        /* Complete ECDHE key agreement */
+        const mbedtls_ecp_curve_info *curve_info =
+            mbedtls_ecp_curve_info_from_tls_id( group );
+        if( curve_info == NULL )
+        {
+            MBEDTLS_SSL_DEBUG_MSG( 1, ( "Invalid TLS curve group id" ) );
+            return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+        }
+
+        MBEDTLS_SSL_DEBUG_MSG( 2, ( "ECDH curve: %s", curve_info->name ) );
+
         ret = ssl_tls13_read_public_ecdhe_share( ssl, p, end - p );
         if( ret != 0 )
             return( ret );
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index 8b0d93e..f006438 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -1612,6 +1612,7 @@
     size_t hash_len;
     const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
     uint16_t cipher_suite = ssl->session_negotiate->ciphersuite;
+    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
     ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( cipher_suite );
 
     MBEDTLS_SSL_DEBUG_MSG( 3, ( "Reset SSL session for HRR" ) );
@@ -1665,6 +1666,19 @@
 #if defined(MBEDTLS_SHA256_C) || defined(MBEDTLS_SHA384_C)
     ssl->handshake->update_checksum( ssl, hash_transcript, hash_len );
 #endif /* MBEDTLS_SHA256_C || MBEDTLS_SHA384_C */
+
+    /* Destroy generated private key. */
+    status = psa_destroy_key( ssl->handshake->ecdh_psa_privkey );
+
+    if( status != PSA_SUCCESS )
+    {
+        ret = psa_ssl_status_to_mbedtls( status );
+        MBEDTLS_SSL_DEBUG_RET( 1, "psa_destroy_key", ret );
+        return( ret );
+    }
+
+    ssl->handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+
     return( ret );
 }
 
diff --git a/library/ssl_tls13_keys.c b/library/ssl_tls13_keys.c
index 885dd16..10b3b7e 100644
--- a/library/ssl_tls13_keys.c
+++ b/library/ssl_tls13_keys.c
@@ -1244,10 +1244,11 @@
 int mbedtls_ssl_tls13_key_schedule_stage_handshake( mbedtls_ssl_context *ssl )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_ECDHE_ENABLED) && defined(MBEDTLS_ECDH_C)
+    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_ECDHE_ENABLED && MBEDTLS_ECDH_C */
     mbedtls_ssl_handshake_params *handshake = ssl->handshake;
     mbedtls_md_type_t const md_type = handshake->ciphersuite_info->mac;
-    size_t ephemeral_len = 0;
-    unsigned char ecdhe[MBEDTLS_ECP_MAX_BYTES];
 #if defined(MBEDTLS_DEBUG_C)
     mbedtls_md_info_t const * const md_info = mbedtls_md_info_from_type( md_type );
     size_t const md_size = mbedtls_md_get_size( md_info );
@@ -1264,15 +1265,28 @@
         if( mbedtls_ssl_tls13_named_group_is_ecdhe( handshake->offered_group_id ) )
         {
 #if defined(MBEDTLS_ECDH_C)
-            ret = mbedtls_ecdh_calc_secret( &handshake->ecdh_ctx,
-                                            &ephemeral_len, ecdhe, sizeof( ecdhe ),
-                                            ssl->conf->f_rng,
-                                            ssl->conf->p_rng );
-            if( ret != 0 )
-            {
-                MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecdh_calc_secret", ret );
-                return( ret );
-            }
+        /* Compute ECDH shared secret. */
+        status = psa_raw_key_agreement(
+                    PSA_ALG_ECDH, handshake->ecdh_psa_privkey,
+                    handshake->ecdh_psa_peerkey, handshake->ecdh_psa_peerkey_len,
+                    handshake->premaster, sizeof( handshake->premaster ),
+                    &handshake->pmslen );
+        if( status != PSA_SUCCESS )
+        {
+            ret = psa_ssl_status_to_mbedtls( status );
+            MBEDTLS_SSL_DEBUG_RET( 1, "psa_raw_key_agreement", ret );
+            return( ret );
+        }
+
+        status = psa_destroy_key( handshake->ecdh_psa_privkey );
+        if( status != PSA_SUCCESS )
+        {
+            ret = psa_ssl_status_to_mbedtls( status );
+            MBEDTLS_SSL_DEBUG_RET( 1, "psa_destroy_key", ret );
+            return( ret );
+        }
+
+        handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
 #endif /* MBEDTLS_ECDH_C */
         }
         else if( mbedtls_ssl_tls13_named_group_is_dhe( handshake->offered_group_id ) )
@@ -1290,7 +1304,7 @@
      */
     ret = mbedtls_ssl_tls13_evolve_secret( md_type,
                                            handshake->tls13_master_secrets.early,
-                                           ecdhe, ephemeral_len,
+                                           handshake->premaster, handshake->pmslen,
                                            handshake->tls13_master_secrets.handshake );
     if( ret != 0 )
     {
@@ -1302,7 +1316,7 @@
                            handshake->tls13_master_secrets.handshake, md_size );
 
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_ECDHE_ENABLED)
-    mbedtls_platform_zeroize( ecdhe, sizeof( ecdhe ) );
+    mbedtls_platform_zeroize( handshake->premaster, sizeof( handshake->premaster ) );
 #endif /* MBEDTLS_KEY_EXCHANGE_SOME_ECDHE_ENABLED */
     return( 0 );
 }
diff --git a/programs/ssl/ssl_client2.c b/programs/ssl/ssl_client2.c
index 39e89ec..f83af07 100644
--- a/programs/ssl/ssl_client2.c
+++ b/programs/ssl/ssl_client2.c
@@ -21,9 +21,9 @@
 
 #include "ssl_test_lib.h"
 
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)
 #include "test/psa_crypto_helpers.h"
-#endif
+#endif /* MBEDTLS_USE_PSA_CRYPTO || MBEDTLS_SSL_PROTO_TLS1_3 */
 
 #if defined(MBEDTLS_SSL_TEST_IMPOSSIBLE)
 int main( void )
@@ -698,6 +698,8 @@
     psa_key_attributes_t key_attributes;
 #endif
     psa_status_t status;
+#elif defined(MBEDTLS_SSL_PROTO_TLS1_3)
+    psa_status_t status;
 #endif
 
 #if defined(MBEDTLS_X509_CRT_PARSE_C)
@@ -770,7 +772,7 @@
     memset( (void * ) alpn_list, 0, sizeof( alpn_list ) );
 #endif
 
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)
     status = psa_crypto_init();
     if( status != PSA_SUCCESS )
     {
@@ -779,7 +781,7 @@
         ret = MBEDTLS_ERR_SSL_HW_ACCEL_FAILED;
         goto exit;
     }
-#endif  /* MBEDTLS_USE_PSA_CRYPTO */
+#endif  /* MBEDTLS_USE_PSA_CRYPTO || MBEDTLS_SSL_PROTO_TLS1_3 */
 #if defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG)
     mbedtls_test_enable_insecure_external_rng( );
 #endif  /* MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG */
@@ -3085,7 +3087,7 @@
 #endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED &&
           MBEDTLS_USE_PSA_CRYPTO */
 
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)
     const char* message = mbedtls_test_helper_is_psa_leaking();
     if( message )
     {
@@ -3093,11 +3095,11 @@
             ret = 1;
         mbedtls_printf( "PSA memory leak detected: %s\n",  message);
     }
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+#endif /* MBEDTLS_USE_PSA_CRYPTO || MBEDTLS_SSL_PROTO_TLS1_3 */
 
     /* For builds with MBEDTLS_TEST_USE_PSA_CRYPTO_RNG psa crypto
      * resources are freed by rng_free(). */
-#if defined(MBEDTLS_USE_PSA_CRYPTO) && \
+#if (defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)) && \
     !defined(MBEDTLS_TEST_USE_PSA_CRYPTO_RNG)
     mbedtls_psa_crypto_free( );
 #endif