Merge pull request #4927 from yuhaoth/pr/add-tls13-serverhello-utils

TLS 1.3: ServerHello: add  utils functions used by ServerHello
Regarding the merge job, there was only one of the failure we currently encounter on almost all PR (Session resume using tickets, DTLS: openssl client test case see #5012) thus we can consider that this PR passed CI.
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index 2349245..c1b1836 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -593,6 +593,9 @@
 
 #define MBEDTLS_PREMASTER_SIZE     sizeof( union mbedtls_ssl_premaster_secret )
 
+/* Length in number of bytes of the TLS sequence number */
+#define MBEDTLS_SSL_SEQUENCE_NUMBER_LEN 8
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -1566,7 +1569,7 @@
     size_t MBEDTLS_PRIVATE(out_buf_len);         /*!< length of output buffer          */
 #endif
 
-    unsigned char MBEDTLS_PRIVATE(cur_out_ctr)[8]; /*!<  Outgoing record sequence  number. */
+    unsigned char MBEDTLS_PRIVATE(cur_out_ctr)[MBEDTLS_SSL_SEQUENCE_NUMBER_LEN]; /*!<  Outgoing record sequence  number. */
 
 #if defined(MBEDTLS_SSL_PROTO_DTLS)
     uint16_t MBEDTLS_PRIVATE(mtu);               /*!< path mtu, used to fragment outgoing messages */
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 3f3f505..7c44382 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -307,6 +307,10 @@
       + ( MBEDTLS_SSL_CID_OUT_LEN_MAX ) )
 #endif
 
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL)
+#define MBEDTLS_TLS1_3_MD_MAX_SIZE         MBEDTLS_MD_MAX_SIZE
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
+
 #if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH)
 /**
  * \brief          Return the maximum fragment length (payload, in bytes) for
@@ -573,8 +577,8 @@
                                               flight being received          */
     mbedtls_ssl_transform *alt_transform_out;   /*!<  Alternative transform for
                                               resending messages             */
-    unsigned char alt_out_ctr[8];       /*!<  Alternative record epoch/counter
-                                              for resending messages         */
+    unsigned char alt_out_ctr[MBEDTLS_SSL_SEQUENCE_NUMBER_LEN]; /*!<  Alternative record epoch/counter
+                                                                      for resending messages         */
 
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
     /* The state of CID configuration in this handshake. */
@@ -675,6 +679,13 @@
     int extensions_present;             /*!< extension presence; Each bitfield
                                              represents an extension and defined
                                              as \c MBEDTLS_SSL_EXT_XXX */
+
+    union
+    {
+        unsigned char early    [MBEDTLS_TLS1_3_MD_MAX_SIZE];
+        unsigned char handshake[MBEDTLS_TLS1_3_MD_MAX_SIZE];
+        unsigned char app      [MBEDTLS_TLS1_3_MD_MAX_SIZE];
+    } tls1_3_master_secrets;
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
 
 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
@@ -866,14 +877,14 @@
 
 typedef struct
 {
-    uint8_t ctr[8];         /* In TLS:  The implicit record sequence number.
-                             * In DTLS: The 2-byte epoch followed by
-                             *          the 6-byte sequence number.
-                             * This is stored as a raw big endian byte array
-                             * as opposed to a uint64_t because we rarely
-                             * need to perform arithmetic on this, but do
-                             * need it as a Byte array for the purpose of
-                             * MAC computations.                             */
+    uint8_t ctr[MBEDTLS_SSL_SEQUENCE_NUMBER_LEN];  /* In TLS:  The implicit record sequence number.
+                                                    * In DTLS: The 2-byte epoch followed by
+                                                    *          the 6-byte sequence number.
+                                                    * This is stored as a raw big endian byte array
+                                                    * as opposed to a uint64_t because we rarely
+                                                    * need to perform arithmetic on this, but do
+                                                    * need it as a Byte array for the purpose of
+                                                    * MAC computations.                             */
     uint8_t type;           /* The record content type.                      */
     uint8_t ver[2];         /* SSL/TLS version as present on the wire.
                              * Convert to internal presentation of versions
@@ -956,6 +967,14 @@
  */
 void mbedtls_ssl_handshake_free( mbedtls_ssl_context *ssl );
 
+/* set inbound transform of ssl context */
+void mbedtls_ssl_set_inbound_transform( mbedtls_ssl_context *ssl,
+                                        mbedtls_ssl_transform *transform );
+
+/* set outbound transform of ssl context */
+void mbedtls_ssl_set_outbound_transform( mbedtls_ssl_context *ssl,
+                                         mbedtls_ssl_transform *transform );
+
 int mbedtls_ssl_handshake_client_step( mbedtls_ssl_context *ssl );
 int mbedtls_ssl_handshake_server_step( mbedtls_ssl_context *ssl );
 void mbedtls_ssl_handshake_wrapup( mbedtls_ssl_context *ssl );
@@ -1510,13 +1529,19 @@
 int mbedtls_ssl_tls13_finish_handshake_msg( mbedtls_ssl_context *ssl,
                                             size_t buf_len,
                                             size_t msg_len );
-/*
- * Update checksum with handshake header
- */
+
 void mbedtls_ssl_tls13_add_hs_hdr_to_checksum( mbedtls_ssl_context *ssl,
                                                unsigned hs_type,
                                                size_t total_hs_len );
 
+/*
+ * Update checksum of handshake messages.
+ */
+void mbedtls_ssl_tls1_3_add_hs_msg_to_checksum( mbedtls_ssl_context *ssl,
+                                                unsigned hs_type,
+                                                unsigned char const *msg,
+                                                size_t msg_len );
+
 #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
 /*
  * Write TLS 1.3 Signature Algorithm extension
@@ -1530,4 +1555,11 @@
 
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
 
+/* Get handshake transcript */
+int mbedtls_ssl_get_handshake_transcript( mbedtls_ssl_context *ssl,
+                                          const mbedtls_md_type_t md,
+                                          unsigned char *dst,
+                                          size_t dst_len,
+                                          size_t *olen );
+
 #endif /* ssl_misc.h */
diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index 3bf4a60..3795c65 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -2101,7 +2101,7 @@
 static int ssl_swap_epochs( mbedtls_ssl_context *ssl )
 {
     mbedtls_ssl_transform *tmp_transform;
-    unsigned char tmp_out_ctr[8];
+    unsigned char tmp_out_ctr[MBEDTLS_SSL_SEQUENCE_NUMBER_LEN];
 
     if( ssl->transform_out == ssl->handshake->alt_transform_out )
     {
@@ -2117,9 +2117,11 @@
     ssl->handshake->alt_transform_out = tmp_transform;
 
     /* Swap epoch + sequence_number */
-    memcpy( tmp_out_ctr,                 ssl->cur_out_ctr,            8 );
-    memcpy( ssl->cur_out_ctr,            ssl->handshake->alt_out_ctr, 8 );
-    memcpy( ssl->handshake->alt_out_ctr, tmp_out_ctr,                 8 );
+    memcpy( tmp_out_ctr, ssl->cur_out_ctr, sizeof( tmp_out_ctr ) );
+    memcpy( ssl->cur_out_ctr, ssl->handshake->alt_out_ctr,
+            sizeof( ssl->cur_out_ctr ) );
+    memcpy( ssl->handshake->alt_out_ctr, tmp_out_ctr,
+            sizeof( ssl->handshake->alt_out_ctr ) );
 
     /* Adjust to the newly activated transform */
     mbedtls_ssl_update_out_pointers( ssl, ssl->transform_out );
@@ -2562,7 +2564,7 @@
         mbedtls_ssl_write_version( ssl->major_ver, ssl->minor_ver,
                            ssl->conf->transport, ssl->out_hdr + 1 );
 
-        memcpy( ssl->out_ctr, ssl->cur_out_ctr, 8 );
+        memcpy( ssl->out_ctr, ssl->cur_out_ctr, MBEDTLS_SSL_SEQUENCE_NUMBER_LEN );
         MBEDTLS_PUT_UINT16_BE( len, ssl->out_len, 0);
 
         if( ssl->transform_out != NULL )
@@ -2574,7 +2576,7 @@
             rec.data_len    = ssl->out_msglen;
             rec.data_offset = ssl->out_msg - rec.buf;
 
-            memcpy( &rec.ctr[0], ssl->out_ctr, 8 );
+            memcpy( &rec.ctr[0], ssl->out_ctr, sizeof( rec.ctr ) );
             mbedtls_ssl_write_version( ssl->major_ver, ssl->minor_ver,
                                        ssl->conf->transport, rec.ver );
             rec.type = ssl->out_msgtype;
@@ -3649,9 +3651,12 @@
 #endif
         {
             unsigned i;
-            for( i = 8; i > mbedtls_ssl_ep_len( ssl ); i-- )
+            for( i = MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
+                 i > mbedtls_ssl_ep_len( ssl ); i-- )
+            {
                 if( ++ssl->in_ctr[i - 1] != 0 )
                     break;
+            }
 
             /* The loop goes to its end iff the counter is wrapping */
             if( i == mbedtls_ssl_ep_len( ssl ) )
@@ -4791,7 +4796,7 @@
     }
     else
 #endif /* MBEDTLS_SSL_PROTO_DTLS */
-    memset( ssl->in_ctr, 0, 8 );
+    memset( ssl->in_ctr, 0, MBEDTLS_SSL_SEQUENCE_NUMBER_LEN );
 
     mbedtls_ssl_update_in_pointers( ssl );
 
@@ -4827,12 +4832,12 @@
     {
         ssl->out_ctr = ssl->out_hdr +  3;
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
-        ssl->out_cid = ssl->out_ctr +  8;
+        ssl->out_cid = ssl->out_ctr + MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
         ssl->out_len = ssl->out_cid;
         if( transform != NULL )
             ssl->out_len += transform->out_cid_len;
 #else /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
-        ssl->out_len = ssl->out_ctr + 8;
+        ssl->out_len = ssl->out_ctr + MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
         ssl->out_iv  = ssl->out_len + 2;
     }
@@ -4881,17 +4886,17 @@
          * ssl_parse_record_header(). */
         ssl->in_ctr = ssl->in_hdr +  3;
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
-        ssl->in_cid = ssl->in_ctr +  8;
+        ssl->in_cid = ssl->in_ctr + MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
         ssl->in_len = ssl->in_cid; /* Default: no CID */
 #else /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
-        ssl->in_len = ssl->in_ctr + 8;
+        ssl->in_len = ssl->in_ctr + MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
         ssl->in_iv  = ssl->in_len + 2;
     }
     else
 #endif
     {
-        ssl->in_ctr = ssl->in_hdr - 8;
+        ssl->in_ctr = ssl->in_hdr - MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
         ssl->in_len = ssl->in_hdr + 3;
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
         ssl->in_cid = ssl->in_len;
@@ -5065,9 +5070,11 @@
     }
 
     in_ctr_cmp = memcmp( ssl->in_ctr + ep_len,
-                        ssl->conf->renego_period + ep_len, 8 - ep_len );
-    out_ctr_cmp = memcmp( ssl->cur_out_ctr + ep_len,
-                          ssl->conf->renego_period + ep_len, 8 - ep_len );
+                         &ssl->conf->renego_period[ep_len],
+                         MBEDTLS_SSL_SEQUENCE_NUMBER_LEN - ep_len );
+    out_ctr_cmp = memcmp( &ssl->cur_out_ctr[ep_len],
+                          &ssl->conf->renego_period[ep_len],
+                          sizeof( ssl->cur_out_ctr ) - ep_len );
 
     if( in_ctr_cmp <= 0 && out_ctr_cmp <= 0 )
     {
@@ -5551,6 +5558,20 @@
     mbedtls_platform_zeroize( transform, sizeof( mbedtls_ssl_transform ) );
 }
 
+void mbedtls_ssl_set_inbound_transform( mbedtls_ssl_context *ssl,
+                                        mbedtls_ssl_transform *transform )
+{
+    ssl->transform_in = transform;
+    memset( ssl->in_ctr, 0, MBEDTLS_SSL_SEQUENCE_NUMBER_LEN );
+}
+
+void mbedtls_ssl_set_outbound_transform( mbedtls_ssl_context *ssl,
+                                         mbedtls_ssl_transform *transform )
+{
+    ssl->transform_out = transform;
+    memset( ssl->cur_out_ctr, 0, sizeof( ssl->cur_out_ctr ) );
+}
+
 #if defined(MBEDTLS_SSL_PROTO_DTLS)
 
 void mbedtls_ssl_buffering_free( mbedtls_ssl_context *ssl )
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index b8c4314..e27fdff 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -1220,7 +1220,8 @@
             return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
         }
 
-        memcpy( ssl->cur_out_ctr + 2, ssl->in_ctr + 2, 6 );
+        memcpy( &ssl->cur_out_ctr[2], ssl->in_ctr + 2,
+                sizeof( ssl->cur_out_ctr ) - 2 );
 
 #if defined(MBEDTLS_SSL_DTLS_ANTI_REPLAY)
         if( mbedtls_ssl_dtls_replay_check( ssl ) != 0 )
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index f33f106..4d305ee 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -2820,10 +2820,12 @@
 
         /* Remember current epoch settings for resending */
         ssl->handshake->alt_transform_out = ssl->transform_out;
-        memcpy( ssl->handshake->alt_out_ctr, ssl->cur_out_ctr, 8 );
+        memcpy( ssl->handshake->alt_out_ctr, ssl->cur_out_ctr,
+                sizeof( ssl->handshake->alt_out_ctr ) );
 
         /* Set sequence_number to zero */
-        memset( ssl->cur_out_ctr + 2, 0, 6 );
+        memset( &ssl->cur_out_ctr[2], 0, sizeof( ssl->cur_out_ctr ) - 2 );
+
 
         /* Increment epoch */
         for( i = 2; i > 0; i-- )
@@ -2839,7 +2841,7 @@
     }
     else
 #endif /* MBEDTLS_SSL_PROTO_DTLS */
-    memset( ssl->cur_out_ctr, 0, 8 );
+    memset( ssl->cur_out_ctr, 0, sizeof( ssl->cur_out_ctr ) );
 
     ssl->transform_out = ssl->transform_negotiate;
     ssl->session_out = ssl->session_negotiate;
@@ -5792,11 +5794,11 @@
     }
 #endif /* MBEDTLS_SSL_PROTO_DTLS */
 
-    used += 8;
+    used += MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
     if( used <= buf_len )
     {
-        memcpy( p, ssl->cur_out_ctr, 8 );
-        p += 8;
+        memcpy( p, ssl->cur_out_ctr, MBEDTLS_SSL_SEQUENCE_NUMBER_LEN );
+        p += MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
     }
 
 #if defined(MBEDTLS_SSL_PROTO_DTLS)
@@ -6052,11 +6054,10 @@
     ssl->disable_datagram_packing = *p++;
 #endif /* MBEDTLS_SSL_PROTO_DTLS */
 
-    if( (size_t)( end - p ) < 8 )
+    if( (size_t)( end - p ) < sizeof( ssl->cur_out_ctr ) )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
-
-    memcpy( ssl->cur_out_ctr, p, 8 );
-    p += 8;
+    memcpy( ssl->cur_out_ctr, p, sizeof( ssl->cur_out_ctr ) );
+    p += sizeof( ssl->cur_out_ctr );
 
 #if defined(MBEDTLS_SSL_PROTO_DTLS)
     if( (size_t)( end - p ) < 2 )
@@ -6995,4 +6996,106 @@
 
 #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+int mbedtls_ssl_get_handshake_transcript( mbedtls_ssl_context *ssl,
+                                          const mbedtls_md_type_t md,
+                                          unsigned char *dst,
+                                          size_t dst_len,
+                                          size_t *olen )
+{
+    ((void) ssl);
+    ((void) md);
+    ((void) dst);
+    ((void) dst_len);
+    *olen = 0;
+    return( MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE);
+}
+#else /* MBEDTLS_USE_PSA_CRYPTO */
+
+#if defined(MBEDTLS_SHA384_C)
+static int ssl_get_handshake_transcript_sha384( mbedtls_ssl_context *ssl,
+                                                unsigned char *dst,
+                                                size_t dst_len,
+                                                size_t *olen )
+{
+    int ret;
+    mbedtls_sha512_context sha512;
+
+    if( dst_len < 48 )
+        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+
+    mbedtls_sha512_init( &sha512 );
+    mbedtls_sha512_clone( &sha512, &ssl->handshake->fin_sha512 );
+
+    if( ( ret = mbedtls_sha512_finish( &sha512, dst ) ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_sha512_finish", ret );
+        goto exit;
+    }
+
+    *olen = 48;
+
+exit:
+
+    mbedtls_sha512_free( &sha512 );
+    return( ret );
+}
+#endif /* MBEDTLS_SHA384_C */
+
+#if defined(MBEDTLS_SHA256_C)
+static int ssl_get_handshake_transcript_sha256( mbedtls_ssl_context *ssl,
+                                                unsigned char *dst,
+                                                size_t dst_len,
+                                                size_t *olen )
+{
+    int ret;
+    mbedtls_sha256_context sha256;
+
+    if( dst_len < 32 )
+        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+
+    mbedtls_sha256_init( &sha256 );
+    mbedtls_sha256_clone( &sha256, &ssl->handshake->fin_sha256 );
+
+    if( ( ret = mbedtls_sha256_finish( &sha256, dst ) ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_sha256_finish", ret );
+        goto exit;
+    }
+
+    *olen = 32;
+
+exit:
+
+    mbedtls_sha256_free( &sha256 );
+    return( ret );
+}
+#endif /* MBEDTLS_SHA256_C */
+
+int mbedtls_ssl_get_handshake_transcript( mbedtls_ssl_context *ssl,
+                                          const mbedtls_md_type_t md,
+                                          unsigned char *dst,
+                                          size_t dst_len,
+                                          size_t *olen )
+{
+    switch( md )
+    {
+
+#if defined(MBEDTLS_SHA384_C)
+    case MBEDTLS_MD_SHA384:
+        return( ssl_get_handshake_transcript_sha384( ssl, dst, dst_len, olen ) );
+#endif /* MBEDTLS_SHA384_C */
+
+#if defined(MBEDTLS_SHA256_C)
+    case MBEDTLS_MD_SHA256:
+        return( ssl_get_handshake_transcript_sha256( ssl, dst, dst_len, olen ) );
+#endif /* MBEDTLS_SHA256_C */
+
+    default:
+        break;
+    }
+    return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+}
+#endif /* !MBEDTLS_USE_PSA_CRYPTO */
+
 #endif /* MBEDTLS_SSL_TLS_C */
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index 99ab269..b3a4a09 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -104,6 +104,15 @@
     return( ret );
 }
 
+void mbedtls_ssl_tls1_3_add_hs_msg_to_checksum( mbedtls_ssl_context *ssl,
+                                                unsigned hs_type,
+                                                unsigned char const *msg,
+                                                size_t msg_len )
+{
+    mbedtls_ssl_tls13_add_hs_hdr_to_checksum( ssl, hs_type, msg_len );
+    ssl->handshake->update_checksum( ssl, msg, msg_len );
+}
+
 void mbedtls_ssl_tls13_add_hs_hdr_to_checksum( mbedtls_ssl_context *ssl,
                                                unsigned hs_type,
                                                size_t total_hs_len )
diff --git a/library/ssl_tls13_keys.c b/library/ssl_tls13_keys.c
index 7aec21d..b07c1c3 100644
--- a/library/ssl_tls13_keys.c
+++ b/library/ssl_tls13_keys.c
@@ -21,14 +21,16 @@
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL)
 
-#include "mbedtls/hkdf.h"
-#include "ssl_misc.h"
-#include "ssl_tls13_keys.h"
-#include "mbedtls/debug.h"
-
 #include <stdint.h>
 #include <string.h>
 
+#include "mbedtls/hkdf.h"
+#include "mbedtls/debug.h"
+#include "mbedtls/error.h"
+
+#include "ssl_misc.h"
+#include "ssl_tls13_keys.h"
+
 #define MBEDTLS_SSL_TLS1_3_LABEL( name, string )       \
     .name = string,
 
@@ -820,4 +822,28 @@
     return( 0 );
 }
 
+int mbedtls_ssl_tls1_3_key_schedule_stage_early( mbedtls_ssl_context *ssl )
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    mbedtls_md_type_t md_type;
+
+    if( ssl->handshake->ciphersuite_info == NULL )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "cipher suite info not found" ) );
+        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+    }
+
+    md_type = ssl->handshake->ciphersuite_info->mac;
+
+    ret = mbedtls_ssl_tls1_3_evolve_secret( md_type, NULL, NULL, 0,
+                                ssl->handshake->tls1_3_master_secrets.early );
+    if( ret != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_tls1_3_evolve_secret", ret );
+        return( ret );
+    }
+
+    return( 0 );
+}
+
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
diff --git a/library/ssl_tls13_keys.h b/library/ssl_tls13_keys.h
index ca892b1..866aae9 100644
--- a/library/ssl_tls13_keys.h
+++ b/library/ssl_tls13_keys.h
@@ -531,4 +531,26 @@
                                           mbedtls_ssl_key_set const *traffic_keys,
                                           mbedtls_ssl_context *ssl );
 
+/*
+ * TLS 1.3 key schedule evolutions
+ *
+ *   Early -> Handshake -> Application
+ *
+ * Small wrappers around mbedtls_ssl_tls1_3_evolve_secret().
+ */
+
+/**
+ * \brief Begin TLS 1.3 key schedule by calculating early secret.
+ *
+ *        The TLS 1.3 key schedule can be viewed as a simple state machine
+ *        with states Initial -> Early -> Handshake -> Application, and
+ *        this function represents the Initial -> Early transition.
+ *
+ * \param ssl  The SSL context to operate on.
+ *
+ * \returns    \c 0 on success.
+ * \returns    A negative error code on failure.
+ */
+int mbedtls_ssl_tls1_3_key_schedule_stage_early( mbedtls_ssl_context *ssl );
+
 #endif /* MBEDTLS_SSL_TLS1_3_KEYS_H */