Re-construct the code to merge hello and hrr based on comments

Signed-off-by: XiaokangQian <xiaokang.qian@arm.com>
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index fbdb671..f3126d2 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -421,6 +421,69 @@
 }
 #endif /* MBEDTLS_ECDH_C */
 
+static int ssl_tls13_hrr_check_key_share_ext( mbedtls_ssl_context *ssl,
+                                              const unsigned char *buf,
+                                              const unsigned char *end )
+{
+    /* Variables for parsing the key_share */
+    const uint16_t* grp_id;
+    const mbedtls_ecp_curve_info *curve_info = NULL;
+    const unsigned char *p = buf;
+    int tls_id;
+    int found = 0;
+
+    const uint16_t *group_list = mbedtls_ssl_get_groups( ssl );
+    if( group_list == NULL )
+        return( MBEDTLS_ERR_SSL_BAD_CONFIG );
+
+    MBEDTLS_SSL_DEBUG_BUF( 3, "key_share extension", p, end - buf );
+
+    /* Read selected_group */
+    tls_id = MBEDTLS_GET_UINT16_BE( p, 0 );
+    MBEDTLS_SSL_DEBUG_MSG( 3, ( "selected_group ( %d )", tls_id ) );
+
+    /* Upon receipt of this extension in a HelloRetryRequest, the client
+     * MUST first verify that the selected_group field corresponds to a
+     * group which was provided in the "supported_groups" extension in the
+     * original ClientHello.
+     * The supported_group was based on the info in ssl->conf->group_list.
+     *
+     * If the server provided a key share that was not sent in the ClientHello
+     * then the client MUST abort the handshake with an "illegal_parameter" alert.
+     */
+    for ( ; *group_list != 0; group_list++ )
+    {
+        curve_info = mbedtls_ecp_curve_info_from_tls_id( *group_list );
+        if( curve_info == NULL || curve_info->tls_id != tls_id )
+            continue;
+
+        /* We found a match */
+        found = 1;
+        break;
+    }
+
+    /* Client MUST verify that the selected_group field does not
+     * correspond to a group which was provided in the "key_share"
+     * extension in the original ClientHello. If the server sent an
+     * HRR message with a key share already provided in the
+     * ClientHello then the client MUST abort the handshake with
+     * an "illegal_parameter" alert.
+     */
+    if( found == 0 || tls_id == ssl->handshake->offered_group_id )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "Invalid key share in HRR" ) );
+        MBEDTLS_SSL_PEND_FATAL_ALERT(
+                MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
+                MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+        return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+    }
+
+    /* Remember server's preference for next ClientHello */
+    ssl->handshake->offered_group_id= tls_id;
+
+    return( 0 );
+}
+
 /*
  * ssl_tls13_parse_key_share_ext()
  *      Parse key_share extension in Server Hello
@@ -943,7 +1006,8 @@
  */
 static int ssl_tls13_parse_server_hello( mbedtls_ssl_context *ssl,
                                          const unsigned char *buf,
-                                         const unsigned char *end )
+                                         const unsigned char *end,
+                                         int hrr )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     const unsigned char *p = buf;
@@ -951,6 +1015,10 @@
     const unsigned char *extensions_end;
     uint16_t cipher_suite;
     const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
+#if defined(MBEDTLS_SSL_COOKIE_C)
+    size_t cookie_len;
+    unsigned char *cookie;
+#endif /* MBEDTLS_SSL_COOKIE_C */
 
     /*
      * Check there is space for minimal fields
@@ -1093,6 +1161,32 @@
 
         switch( extension_type )
         {
+#if defined(MBEDTLS_SSL_COOKIE_C)
+            case MBEDTLS_TLS_EXT_COOKIE:
+
+                /* Retrieve length field of cookie */
+                MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, 2 );
+                cookie_len = MBEDTLS_GET_UINT16_BE( p, 0 );
+                cookie = (unsigned char *) ( p + 2 );
+
+                MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, cookie_len + 2 );
+                MBEDTLS_SSL_DEBUG_BUF( 3, "cookie extension", cookie, cookie_len );
+
+                mbedtls_free( ssl->handshake->verify_cookie );
+                ssl->handshake->verify_cookie = mbedtls_calloc( 1, cookie_len );
+                if( ssl->handshake->verify_cookie == NULL )
+                {
+                    MBEDTLS_SSL_DEBUG_MSG( 1,
+                            ( "alloc failed ( %" MBEDTLS_PRINTF_SIZET " bytes )",
+                              cookie_len ) );
+                    return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
+                }
+
+                memcpy( ssl->handshake->verify_cookie, cookie, cookie_len );
+                ssl->handshake->verify_cookie_len = (unsigned char) cookie_len;
+                break;
+#endif /* MBEDTLS_SSL_COOKIE_C */
+
             case MBEDTLS_TLS_EXT_SUPPORTED_VERSIONS:
                 MBEDTLS_SSL_DEBUG_MSG( 3,
                             ( "found supported_versions extension" ) );
@@ -1116,8 +1210,13 @@
 #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
             case MBEDTLS_TLS_EXT_KEY_SHARE:
                 MBEDTLS_SSL_DEBUG_MSG( 3, ( "found key_shares extension" ) );
-                if( ( ret = ssl_tls13_parse_key_share_ext( ssl,
-                                            p, p + extension_data_len ) ) != 0 )
+                if( hrr )
+                    ret = ssl_tls13_hrr_check_key_share_ext( ssl,
+                                            p, p + extension_data_len );
+                else
+                    ret = ssl_tls13_parse_key_share_ext( ssl,
+                                            p, p + extension_data_len );
+                if( ret != 0 )
                 {
                     MBEDTLS_SSL_DEBUG_RET( 1,
                                            "ssl_tls13_parse_key_share_ext",
@@ -1259,268 +1358,6 @@
     return( ret );
 }
 
-static int ssl_hrr_parse( mbedtls_ssl_context *ssl,
-                          const unsigned char *buf,
-                          const unsigned char *end )
-{
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    int cipher_suite;
-
-    /* pointer to the end of the buffer for length checks */
-    const unsigned char *p = buf;
-    const unsigned char *extensions_end;
-    size_t extensions_len; /* stores length of all extensions */
-
-    const mbedtls_ssl_ciphersuite_t* ciphersuite_info; /* pointer to ciphersuite */
-
-#if defined(MBEDTLS_SSL_COOKIE_C)
-    size_t cookie_len;
-    unsigned char *cookie;
-#endif /* MBEDTLS_SSL_COOKIE_C */
-
-    /* Check for minimal length
-     * struct {
-     *    ProtocolVersion legacy_version = 0x0303;
-     *    Random random;
-     *    opaque legacy_session_id_echo<0..32>;
-     *    CipherSuite cipher_suite;
-     *    uint8 legacy_compression_method = 0;
-     *    Extension extensions<6..2 ^ 16 - 1>;
-     * } ServerHello;
-     *
-     * 38 = 32 ( random bytes ) + 2 ( ciphersuite ) + 2 ( version ) +
-     *       1 ( legacy_compression_method ) +
-     *       1 ( minimum for legacy_session_id_echo )
-     */
-    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 38 );
-
-    MBEDTLS_SSL_DEBUG_BUF( 4, "hello retry request", p, end - p );
-
-    MBEDTLS_SSL_DEBUG_BUF( 3, "hello retry request, version", p, 2 );
-
-    /* The version field must contain 0x303 */
-    if( MBEDTLS_GET_UINT16_BE( p, 0 ) != 0x303 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "Unsupported version of TLS." ) );
-        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_PROTOCOL_VERSION,
-                                      MBEDTLS_ERR_SSL_BAD_PROTOCOL_VERSION );
-        return( MBEDTLS_ERR_SSL_BAD_PROTOCOL_VERSION );
-    }
-
-    /* skip version */
-    p += 2;
-
-    /* Internally we use the correct 1.3 version */
-    ssl->major_ver = MBEDTLS_SSL_MAJOR_VERSION_3;
-    ssl->minor_ver = MBEDTLS_SSL_MINOR_VERSION_4;
-
-    /* store server-provided random values */
-    memcpy( ssl->handshake->randbytes + MBEDTLS_SERVER_HELLO_RANDOM_LEN,
-            p, MBEDTLS_SERVER_HELLO_RANDOM_LEN );
-    MBEDTLS_SSL_DEBUG_BUF( 3, "hello retry request, random bytes",
-                           p + 2, MBEDTLS_SERVER_HELLO_RANDOM_LEN );
-
-    /* skip random bytes */
-    p += MBEDTLS_SERVER_HELLO_RANDOM_LEN;
-
-    /* ...
-     * opaque legacy_session_id_echo<0..32>;
-     * ...
-     */
-    if( ssl_tls13_check_server_hello_session_id_echo( ssl, &p, end ) != 0 )
-    {
-        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE,
-                                      MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-        return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-    }
-
-    /* read server-selected ciphersuite, which follows random bytes */
-    cipher_suite = MBEDTLS_GET_UINT16_BE( p, 0 );
-
-    /* skip ciphersuite */
-    p += 2;
-
-    /*
-     * Check whether we have offered this ciphersuite
-     * Via the force_ciphersuite version we may have instructed the client
-     * to use a difference ciphersuite.
-     */
-    ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( cipher_suite );
-    if( ciphersuite_info == NULL ||
-        ssl_tls13_cipher_suite_is_offered( ssl, cipher_suite ) == 0 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "ciphersuite(%04x) not found or not offered",
-                                    (unsigned int)cipher_suite ) );
-
-        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
-                                      MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-        return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-    }
-
-    /* Configure ciphersuites */
-    mbedtls_ssl_optimize_checksum( ssl, ciphersuite_info );
-    ssl->handshake->ciphersuite_info = ciphersuite_info;
-    ssl->session_negotiate->ciphersuite = cipher_suite;
-
-    MBEDTLS_SSL_DEBUG_MSG( 3,
-        ( "hello retry request, chosen ciphersuite: ( %04x ) - %s",
-          (unsigned int)cipher_suite, ciphersuite_info->name ) );
-
-#if defined(MBEDTLS_HAVE_TIME)
-    ssl->session_negotiate->start = time( NULL );
-#endif /* MBEDTLS_HAVE_TIME */
-
-    /* Ensure that compression method is set to zero */
-    if( p[0] != 0 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad hello retry request message" ) );
-        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
-                                      MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-        return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-    }
-
-    /* skip compression */
-    p++;
-
-    /* Are we reading beyond the message buffer? */
-    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 2 );
-
-    extensions_len = MBEDTLS_GET_UINT16_BE( p, 0 );
-    p += 2; /* skip extension length */
-
-    /* Are we reading beyond the message buffer? */
-    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, extensions_len );
-    extensions_end = p + extensions_len;
-
-    MBEDTLS_SSL_DEBUG_MSG( 3,
-            ( "hello retry request, total extension length: %"
-              MBEDTLS_PRINTF_SIZET , extensions_len ) );
-    MBEDTLS_SSL_DEBUG_BUF( 3, "extensions", p, extensions_len );
-
-    while ( p < extensions_end )
-    {
-        unsigned int extension_type;
-        const unsigned char *extensions_data_end;
-        unsigned int extension_data_len; /* size of an individual extension */
-
-        MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, 4 );
-        extension_type = MBEDTLS_GET_UINT16_BE( p, 0 );
-        extension_data_len = MBEDTLS_GET_UINT16_BE( p + 2, 0 );
-
-        p += 4;
-        MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, extension_data_len );
-        extensions_data_end = p + extension_data_len;
-
-        switch( extension_type )
-        {
-#if defined(MBEDTLS_SSL_COOKIE_C)
-            case MBEDTLS_TLS_EXT_COOKIE:
-
-                /* Retrieve length field of cookie */
-                MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_data_end, 2 );
-                cookie_len = MBEDTLS_GET_UINT16_BE( p, 0 );
-                cookie = (unsigned char *) ( p + 2 );
-
-                MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_data_end, cookie_len + 2 );
-                MBEDTLS_SSL_DEBUG_BUF( 3, "cookie extension", cookie, cookie_len );
-
-                mbedtls_free( ssl->handshake->verify_cookie );
-                ssl->handshake->verify_cookie = mbedtls_calloc( 1, cookie_len );
-                if( ssl->handshake->verify_cookie == NULL )
-                {
-                    MBEDTLS_SSL_DEBUG_MSG( 1,
-                            ( "alloc failed ( %" MBEDTLS_PRINTF_SIZET " bytes )",
-                              cookie_len ) );
-                    return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
-                }
-
-                memcpy( ssl->handshake->verify_cookie, cookie, cookie_len );
-                ssl->handshake->verify_cookie_len = (unsigned char) cookie_len;
-                break;
-#endif /* MBEDTLS_SSL_COOKIE_C */
-
-            case MBEDTLS_TLS_EXT_SUPPORTED_VERSIONS:
-                MBEDTLS_SSL_DEBUG_MSG( 3, ( "found supported_versions extension" ) );
-
-                ret = ssl_tls13_parse_supported_versions_ext( ssl,
-                                                              p,
-                                                              p + extension_data_len );
-                if( ret != 0 )
-                    return( ret );
-                break;
-
-#if defined(MBEDTLS_ECDH_C) || defined(MBEDTLS_ECDSA_C)
-            case MBEDTLS_TLS_EXT_KEY_SHARE:
-            {
-                /* Variables for parsing the key_share */
-                const uint16_t* grp_id;
-                const mbedtls_ecp_curve_info *curve_info = NULL;
-                int tls_id;
-                int found = 0;
-
-                MBEDTLS_SSL_DEBUG_BUF( 3, "key_share extension", p, extension_data_len );
-
-                /* Read selected_group */
-                tls_id = MBEDTLS_GET_UINT16_BE( p, 0 );
-                MBEDTLS_SSL_DEBUG_MSG( 3, ( "selected_group ( %d )", tls_id ) );
-
-                /* Upon receipt of this extension in a HelloRetryRequest, the client
-                 * MUST first verify that the selected_group field corresponds to a
-                 * group which was provided in the "supported_groups" extension in the
-                 * original ClientHello.
-                 * The supported_group was based on the info in ssl->conf->group_list.
-                 *
-                 * If the server provided a key share that was not sent in the ClientHello
-                 * then the client MUST abort the handshake with an "illegal_parameter" alert.
-                 */
-                for( grp_id = ssl->conf->group_list; *grp_id != MBEDTLS_ECP_DP_NONE; grp_id++ )
-                {
-                    curve_info = mbedtls_ecp_curve_info_from_tls_id( *grp_id );
-                    if( curve_info == NULL || curve_info->tls_id != tls_id )
-                        continue;
-
-                    /* We found a match */
-                    found = 1;
-                    break;
-                }
-
-                /* Client MUST verify that the selected_group field does not
-                 * correspond to a group which was provided in the "key_share"
-                 * extension in the original ClientHello. If the server sent an
-                 * HRR message with a key share already provided in the
-                 * ClientHello then the client MUST abort the handshake with
-                 * an "illegal_parameter" alert.
-                 */
-                if( found == 0 || tls_id == ssl->handshake->offered_group_id )
-                {
-                    MBEDTLS_SSL_DEBUG_MSG( 1, ( "Invalid key share in HRR" ) );
-                    MBEDTLS_SSL_PEND_FATAL_ALERT(
-                            MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
-                            MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-                    return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-                }
-
-                /* Remember server's preference for next ClientHello */
-                ssl->handshake->offered_group_id= tls_id;
-                break;
-            }
-
-#endif /* MBEDTLS_ECDH_C || MBEDTLS_ECDSA_C */
-            default:
-                MBEDTLS_SSL_DEBUG_MSG( 3,
-                        ( "unknown extension found: %u ( ignoring )",
-                          extension_type ) );
-        }
-
-        /* Jump to next extension */
-        //extensions_len -= 4 + extension_data_len;
-        //ext += 4 + extension_data_len;
-        p += extension_data_len;
-    }
-
-    return( 0 );
-}
-
 static int ssl_hrr_postprocess( mbedtls_ssl_context *ssl )
 {
 #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
@@ -1575,6 +1412,7 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char *buf = NULL;
     size_t buf_len = 0;
+    int hrr = -1;
 
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> %s", __func__ ) );
 
@@ -1586,31 +1424,28 @@
     ssl->major_ver = MBEDTLS_SSL_MAJOR_VERSION_3;
     ssl->handshake->extensions_present = MBEDTLS_SSL_EXT_NONE;
 
-    ret = ssl_tls13_server_hello_coordinate( ssl, &buf, &buf_len );
+    hrr = ssl_tls13_server_hello_coordinate( ssl, &buf, &buf_len );
     /* Parsing step
      * We know what message to expect by now and call
      * the respective parsing function.
      */
-    if( ret == SSL_SERVER_HELLO_COORDINATE_HELLO )
-    {
-        MBEDTLS_SSL_PROC_CHK( ssl_tls13_parse_server_hello( ssl, buf,
-                                                            buf + buf_len ) );
-
-        mbedtls_ssl_tls13_add_hs_msg_to_checksum( ssl,
-                                                  MBEDTLS_SSL_HS_SERVER_HELLO,
-                                                  buf, buf_len );
-
-        MBEDTLS_SSL_PROC_CHK( ssl_tls13_finalize_server_hello( ssl ) );
-    }
-    else if( ret == SSL_SERVER_HELLO_COORDINATE_HRR )
-    {
-        MBEDTLS_SSL_PROC_CHK( ssl_hrr_parse( ssl, buf, buf + buf_len ) );
+    MBEDTLS_SSL_DEBUG_MSG( 2, ( " hrr = %d ", hrr ) );
+    MBEDTLS_SSL_PROC_CHK( ssl_tls13_parse_server_hello( ssl, buf,
+                                                        buf + buf_len,
+                                                        hrr ) );
+    if( hrr == SSL_SERVER_HELLO_COORDINATE_HRR )
         MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_reset_transcript_for_hrr( ssl ) );
 
-        mbedtls_ssl_tls13_add_hs_msg_to_checksum( ssl,
-                                                  MBEDTLS_SSL_HS_SERVER_HELLO,
-                                                  buf, buf_len );
+    mbedtls_ssl_tls13_add_hs_msg_to_checksum( ssl,
+                                              MBEDTLS_SSL_HS_SERVER_HELLO,
+                                              buf, buf_len );
 
+    if( hrr == SSL_SERVER_HELLO_COORDINATE_HELLO )
+    {
+        MBEDTLS_SSL_PROC_CHK( ssl_tls13_finalize_server_hello( ssl ) );
+    }
+    else if( hrr == SSL_SERVER_HELLO_COORDINATE_HRR )
+    {
         MBEDTLS_SSL_PROC_CHK( ssl_hrr_postprocess( ssl ) );
     }