share write_body between HRR and ServerHello

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 13fcb65..528409b 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -917,6 +917,72 @@
     return( 0 );
 }
 
+static int ssl_tls13_write_hrr_key_share_ext( mbedtls_ssl_context *ssl,
+                                              unsigned char *buf,
+                                              unsigned char *end,
+                                              size_t *out_len )
+{
+    uint16_t selected_group = ssl->handshake->hrr_selected_group;
+    /* key_share Extension
+     *
+     *  struct {
+     *    select (Handshake.msg_type) {
+     *      ...
+     *      case hello_retry_request:
+     *          NamedGroup selected_group;
+     *      ...
+     *    };
+     * } KeyShare;
+     */
+
+    *out_len = 0;
+
+    /* For a pure PSK-based ciphersuite there is no key share to declare. */
+    if( ! mbedtls_ssl_conf_tls13_some_ephemeral_enabled( ssl ) )
+        return( 0 );
+
+    /* We should only send the key_share extension if the client's initial
+     * key share was not acceptable. */
+    if( ssl->handshake->offered_group_id != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 4, ( "Skip key_share extension in HRR" ) );
+        return( 0 );
+    }
+
+    if( selected_group == 0 )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "no matching named group found" ) );
+        return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+    }
+
+    if( ! mbedtls_ssl_named_group_is_offered( ssl, selected_group ) ||
+        ! mbedtls_ssl_named_group_is_supported( selected_group ) )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 4, ( "should never happen" ) );
+        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+    }
+
+    /* extension header, extension length, NamedGroup value */
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( buf, end, 6 );
+
+    /* Write extension header */
+    MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_KEY_SHARE, buf, 0 );
+
+    /* Write extension length */
+    MBEDTLS_PUT_UINT16_BE( 2, buf, 2 );
+
+    /* Write selected group */
+    MBEDTLS_PUT_UINT16_BE( selected_group, buf, 4 );
+
+    MBEDTLS_SSL_DEBUG_MSG( 3,
+        ( "HRR selected_group: %s (%x)",
+            mbedtls_ssl_named_group_to_str( selected_group ),
+            selected_group ) );
+
+    *out_len = 6;
+    return( 0 );
+
+}
 
 /*
  * Structure of ServerHello message:
@@ -933,7 +999,8 @@
 static int ssl_tls13_write_server_hello_body( mbedtls_ssl_context *ssl,
                                               unsigned char *buf,
                                               unsigned char *end,
-                                              size_t *out_len )
+                                              size_t *out_len,
+                                              int is_hrr )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char *p = buf;
@@ -959,8 +1026,16 @@
      * opaque Random[MBEDTLS_SERVER_HELLO_RANDOM_LEN];
      */
     MBEDTLS_SSL_CHK_BUF_PTR( p, end, MBEDTLS_SERVER_HELLO_RANDOM_LEN );
-    memcpy( p, &ssl->handshake->randbytes[MBEDTLS_CLIENT_HELLO_RANDOM_LEN],
-               MBEDTLS_SERVER_HELLO_RANDOM_LEN );
+    if( is_hrr )
+    {
+        memcpy( p, mbedtls_ssl_tls13_hello_retry_request_magic,
+                   MBEDTLS_SERVER_HELLO_RANDOM_LEN );
+    }
+    else
+    {
+        memcpy( p, &ssl->handshake->randbytes[MBEDTLS_CLIENT_HELLO_RANDOM_LEN],
+                   MBEDTLS_SERVER_HELLO_RANDOM_LEN );
+    }
     MBEDTLS_SSL_DEBUG_BUF( 3, "server hello, random bytes",
                            p, MBEDTLS_SERVER_HELLO_RANDOM_LEN );
     p += MBEDTLS_SERVER_HELLO_RANDOM_LEN;
@@ -1026,7 +1101,10 @@
 
     if( mbedtls_ssl_conf_tls13_some_ephemeral_enabled( ssl ) )
     {
-        ret = ssl_tls13_write_key_share_ext( ssl, p, end, &output_len );
+        if( is_hrr )
+            ret = ssl_tls13_write_hrr_key_share_ext( ssl, p, end, &output_len );
+        else
+            ret = ssl_tls13_write_key_share_ext( ssl, p, end, &output_len );
         if( ret != 0 )
             return( ret );
         p += output_len;
@@ -1079,7 +1157,8 @@
 
     MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_server_hello_body( ssl, buf,
                                                              buf + buf_len,
-                                                             &msg_len ) );
+                                                             &msg_len,
+                                                             0 ) );
 
     mbedtls_ssl_add_hs_msg_to_checksum(
         ssl, MBEDTLS_SSL_HS_SERVER_HELLO, buf, msg_len );
@@ -1319,213 +1398,23 @@
     return( 0 );
 }
 
-#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
-static int ssl_tls13_write_hrr_key_share_ext( mbedtls_ssl_context *ssl,
-                                              unsigned char *buf,
-                                              unsigned char *end,
-                                              size_t *out_len )
-{
-    uint16_t selected_group = ssl->handshake->hrr_selected_group;
-    /* key_share Extension
-     *
-     *  struct {
-     *    select (Handshake.msg_type) {
-     *      ...
-     *      case hello_retry_request:
-     *          NamedGroup selected_group;
-     *      ...
-     *    };
-     * } KeyShare;
-     */
-
-    *out_len = 0;
-
-    /* For a pure PSK-based ciphersuite there is no key share to declare. */
-    if( ! mbedtls_ssl_conf_tls13_some_ephemeral_enabled( ssl ) )
-        return( 0 );
-
-    /* We should only send the key_share extension if the client's initial
-     * key share was not acceptable. */
-    if( ssl->handshake->offered_group_id != 0 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 4, ( "Skip key_share extension in HRR" ) );
-        return( 0 );
-    }
-
-    if( selected_group == 0 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "no matching named group found" ) );
-        return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-    }
-
-    if( ! mbedtls_ssl_named_group_is_offered( ssl, selected_group ) ||
-        ! mbedtls_ssl_named_group_is_supported( selected_group ) )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 4, ( "should never happen" ) );
-        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
-    }
-
-    /* extension header, extension length, NamedGroup value */
-    MBEDTLS_SSL_CHK_BUF_READ_PTR( buf, end, 6 );
-
-    /* Write extension header */
-    MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_KEY_SHARE, buf, 0 );
-
-    /* Write extension length */
-    MBEDTLS_PUT_UINT16_BE( 2, buf, 2 );
-
-    /* Write selected group */
-    MBEDTLS_PUT_UINT16_BE( selected_group, buf, 4 );
-
-    MBEDTLS_SSL_DEBUG_MSG( 3,
-        ( "HRR selected_group: %s (%x)",
-            mbedtls_ssl_named_group_to_str( selected_group ),
-            selected_group ) );
-
-    *out_len = 6;
-    return( 0 );
-
-}
-#endif /* MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED */
-
-static int ssl_tls13_write_hello_retry_request_body( mbedtls_ssl_context *ssl,
-                                                     unsigned char *buf,
-                                                     unsigned char *end,
-                                                     size_t *out_len )
-{
-    int ret;
-    unsigned char *p = buf;
-    unsigned char *start = buf;
-    size_t output_len;
-    unsigned char *extension_start;
-
-    MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write hello retry request" ) );
-
-    *out_len = 0;
-
-    /*
-     * struct {
-     *    ProtocolVersion legacy_version = 0x0303;
-     *    Random random ( with magic value );
-     *    opaque legacy_session_id_echo<0..32>;
-     *    CipherSuite cipher_suite;
-     *    uint8 legacy_compression_method = 0;
-     *    Extension extensions<0..2^16-1>;
-     * } ServerHello; --- aka HelloRetryRequest
-     */
-
-
-    /*
-     * Write legacy_version
-     *    ProtocolVersion legacy_version = 0x0303;    // TLS v1.2
-     *
-     *  For TLS 1.3 we use the legacy version number {0x03, 0x03}
-     *  instead of the true version number.
-     */
-    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 2 );
-    MBEDTLS_PUT_UINT16_BE( 0x0303, p, 0 );
-    p += 2;
-
-    /* write magic string (as a replacement for the random value) */
-    MBEDTLS_SSL_CHK_BUF_PTR( p, end, MBEDTLS_SERVER_HELLO_RANDOM_LEN );
-    memcpy( p, mbedtls_ssl_tls13_hello_retry_request_magic,
-            MBEDTLS_SERVER_HELLO_RANDOM_LEN );
-    MBEDTLS_SSL_DEBUG_BUF( 3, "client hello, random bytes",
-                           p, MBEDTLS_SERVER_HELLO_RANDOM_LEN );
-    p += MBEDTLS_SERVER_HELLO_RANDOM_LEN;
-
-    /*
-     * Write legacy_session_id_echo
-     */
-    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 1 + ssl->session_negotiate->id_len );
-    *p++ = (unsigned char)ssl->session_negotiate->id_len;
-    if( ssl->session_negotiate->id_len > 0 )
-    {
-        memcpy( p, &ssl->session_negotiate->id[0],
-                ssl->session_negotiate->id_len );
-        p += ssl->session_negotiate->id_len;
-        MBEDTLS_SSL_DEBUG_MSG( 3, ( "session id length ( %"
-                                        MBEDTLS_PRINTF_SIZET " )",
-                                    ssl->session_negotiate->id_len ) );
-        MBEDTLS_SSL_DEBUG_BUF( 3, "session id", ssl->session_negotiate->id,
-                               ssl->session_negotiate->id_len );
-    }
-
-    /*
-     * Write ciphersuite
-     */
-    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 2 );
-    MBEDTLS_PUT_UINT16_BE( ssl->session_negotiate->ciphersuite, p, 0 );
-    p += 2;
-    MBEDTLS_SSL_DEBUG_MSG( 3,
-        ( "server hello, chosen ciphersuite: %s ( id=%d )",
-          mbedtls_ssl_get_ciphersuite_name(
-            ssl->session_negotiate->ciphersuite ),
-          ssl->session_negotiate->ciphersuite ) );
-
-    /* write legacy_compression_method = ( 0 ) */
-    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 1 );
-    *p++ = 0x0;
-
-    /* Extensions */
-    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 2 );
-    extension_start = p;
-    p += 2;
-
-    /* Add supported_version extension */
-    if( ( ret = ssl_tls13_write_server_hello_supported_versions_ext(
-                                            ssl, p, end, &output_len ) ) != 0 )
-    {
-        MBEDTLS_SSL_DEBUG_RET( 1, "ssl_tls13_write_selected_version_ext",
-                               ret );
-        return( ret );
-    }
-    p += output_len;
-
-#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
-    /* Add key_share extension, if necessary */
-    if( mbedtls_ssl_conf_tls13_some_ephemeral_enabled( ssl ) )
-    {
-        ret = ssl_tls13_write_hrr_key_share_ext( ssl, p, end, &output_len );
-        if( ret != 0 )
-        {
-            MBEDTLS_SSL_DEBUG_RET( 1, "ssl_tls13_write_hrr_key_share_ext", ret );
-            return( ret );
-        }
-        p += output_len;
-    }
-#endif /* MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED */
-
-    /* Write length information */
-    MBEDTLS_PUT_UINT16_BE( p - extension_start - 2, extension_start, 0 );
-
-    MBEDTLS_SSL_DEBUG_BUF( 4, "hello retry request extensions",
-                           extension_start, p - extension_start );
-
-    *out_len = p - start;
-
-    MBEDTLS_SSL_DEBUG_BUF( 3, "hello retry request", start, *out_len );
-
-    *out_len = p - buf;
-
-    MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= write hello retry request" ) );
-    return( 0 );
-}
-
 static int ssl_tls13_write_hello_retry_request( mbedtls_ssl_context *ssl )
 {
     int ret;
     unsigned char *buf;
     size_t buf_len, msg_len;
 
+    MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write hello retry request" ) );
+
     MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_hello_retry_request_coordinate( ssl ) );
 
     MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_start_handshake_msg( ssl,
                        MBEDTLS_SSL_HS_SERVER_HELLO, &buf, &buf_len ) );
 
-    MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_hello_retry_request_body(
-                              ssl, buf, buf + buf_len, &msg_len ) );
-
+    MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_server_hello_body( ssl, buf,
+                                                             buf + buf_len,
+                                                             &msg_len,
+                                                             1 ) );
     mbedtls_ssl_add_hs_msg_to_checksum(
         ssl, MBEDTLS_SSL_HS_SERVER_HELLO, buf, msg_len );
 
@@ -1538,7 +1427,7 @@
     mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_HELLO );
 
 cleanup:
-
+    MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= write hello retry request" ) );
     return( ret );
 }