ssl_client.c: Adapt extensions writing to the TLS 1.2 case

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/library/ssl_client.c b/library/ssl_client.c
index 33c02e6..4c0e0ee 100644
--- a/library/ssl_client.c
+++ b/library/ssl_client.c
@@ -282,15 +282,33 @@
                                         size_t *out_len )
 {
     int ret;
-    unsigned char *p = buf;
-    int tls12_uses_ec = 0;
+    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
+    unsigned char propose_tls12 = 0;
+#endif
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3)
+    unsigned char propose_tls13 = 0;
+#endif
 
+    unsigned char *p = buf;
     unsigned char *p_extensions_len; /* Pointer to extensions length */
     size_t output_len;               /* Length of buffer used by function */
     size_t extensions_len;           /* Length of the list of extensions*/
+    int tls12_uses_ec = 0;
 
     *out_len = 0;
 
+#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
+    propose_tls12 = ( handshake->min_minor_ver <= MBEDTLS_SSL_MINOR_VERSION_3 )
+                    &&
+                    ( MBEDTLS_SSL_MINOR_VERSION_3 <= ssl->minor_ver );
+#endif
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3)
+    propose_tls13 = ( handshake->min_minor_ver <= MBEDTLS_SSL_MINOR_VERSION_4 )
+                    &&
+                    ( MBEDTLS_SSL_MINOR_VERSION_4 <= ssl->minor_ver );
+#endif
+
     /*
      * Write client_version (TLS 1.2) or legacy_version (TLS 1.3)
      *
@@ -316,11 +334,11 @@
      * opaque Random[32];
      *
      * The random bytes have been prepared by ssl_prepare_client_hello() into
-     * the ssl->handshake->randbytes buffer and are copied here into the
-     * output buffer.
+     * the handshake->randbytes buffer and are copied here into the output
+     * buffer.
      */
     MBEDTLS_SSL_CHK_BUF_PTR( p, end, MBEDTLS_CLIENT_HELLO_RANDOM_LEN );
-    memcpy( p, ssl->handshake->randbytes, MBEDTLS_CLIENT_HELLO_RANDOM_LEN );
+    memcpy( p, handshake->randbytes, MBEDTLS_CLIENT_HELLO_RANDOM_LEN );
     MBEDTLS_SSL_DEBUG_BUF( 3, "client hello, random bytes",
                            p, MBEDTLS_CLIENT_HELLO_RANDOM_LEN );
     p += MBEDTLS_CLIENT_HELLO_RANDOM_LEN;
@@ -359,19 +377,19 @@
     {
         unsigned char cookie_len = 0;
 
-        if( ssl->handshake->cookie != NULL )
+        if( handshake->cookie != NULL )
         {
             MBEDTLS_SSL_DEBUG_BUF( 3, "client hello, cookie",
-                                   ssl->handshake->cookie,
-                                   ssl->handshake->verify_cookie_len );
-            cookie_len = ssl->handshake->verify_cookie_len;
+                                   handshake->cookie,
+                                   handshake->verify_cookie_len );
+            cookie_len = handshake->verify_cookie_len;
         }
 
         MBEDTLS_SSL_CHK_BUF_PTR( p, end, cookie_len + 1 );
         *p++ = cookie_len;
         if( cookie_len > 0 )
         {
-            memcpy( p, ssl->handshake->cookie, cookie_len );
+            memcpy( p, handshake->cookie, cookie_len );
             p += cookie_len;
         }
     }
@@ -403,7 +421,7 @@
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
     /* Keeping track of the included extensions */
-    ssl->handshake->extensions_present = MBEDTLS_SSL_EXT_NONE;
+    handshake->extensions_present = MBEDTLS_SSL_EXT_NONE;
 #endif
 
     /* First write extensions, then the total length */
@@ -427,27 +445,44 @@
 #endif /* MBEDTLS_SSL_ALPN */
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
-    ret = mbedtls_ssl_tls13_write_client_hello_exts( ssl, p, end, &output_len );
-    if( ret != 0 )
-        return( ret );
-    p += output_len;
+    if( propose_tls13 )
+    {
+        ret = mbedtls_ssl_tls13_write_client_hello_exts( ssl, p, end,
+                                                         &output_len );
+        if( ret != 0 )
+            return( ret );
+        p += output_len;
+    }
 #endif
 
+#if defined(MBEDTLS_ECDH_C) || defined(MBEDTLS_ECDSA_C) || \
+    defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED)
+    if(
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
-#if defined(MBEDTLS_ECDH_C)
-    if( mbedtls_ssl_conf_tls13_some_ephemeral_enabled( ssl ) )
+        ( propose_tls13 &&
+          mbedtls_ssl_conf_tls13_some_ephemeral_enabled( ssl ) ) ||
+#endif
+#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
+        ( propose_tls12 && tls12_uses_ec ) ||
+#endif
+        0 )
     {
         ret = mbedtls_ssl_write_supported_groups_ext( ssl, p, end, &output_len );
         if( ret != 0 )
             return( ret );
         p += output_len;
     }
-#endif /* MBEDTLS_ECDH_C */
-#endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
+#endif /* MBEDTLS_ECDH_C || MBEDTLS_ECDSA_C || MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */
 
-#if defined(MBEDTLS_SSL_PROTO_TLS1_3)
 #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
-    if( mbedtls_ssl_conf_tls13_ephemeral_enabled( ssl ) )
+    if(
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3)
+        ( propose_tls13 && mbedtls_ssl_conf_tls13_ephemeral_enabled( ssl ) ) ||
+#endif
+#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
+        propose_tls12 ||
+#endif
+       0 )
     {
         ret = mbedtls_ssl_write_sig_alg_ext( ssl, p, end, &output_len );
         if( ret != 0 )
@@ -455,16 +490,32 @@
         p += output_len;
     }
 #endif /* MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED */
-#endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
 
-    /* Add more extensions here */
+#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
+    if( propose_tls12 )
+    {
+        ret = mbedtls_ssl_tls12_write_client_hello_exts( ssl, p, end,
+                                                         tls12_uses_ec,
+                                                         &output_len );
+        if( ret != 0 )
+            return( ret );
+        p += output_len;
+    }
+#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
 
     /* Write the length of the list of extensions. */
     extensions_len = p - p_extensions_len - 2;
-    MBEDTLS_PUT_UINT16_BE( extensions_len, p_extensions_len, 0 );
-    MBEDTLS_SSL_DEBUG_MSG( 3, ( "client hello, total extension length: %" MBEDTLS_PRINTF_SIZET ,
-                                extensions_len ) );
-    MBEDTLS_SSL_DEBUG_BUF( 3, "client hello extensions", p_extensions_len, extensions_len );
+
+    if( extensions_len == 0 )
+       p = p_extensions_len;
+    else
+    {
+        MBEDTLS_PUT_UINT16_BE( extensions_len, p_extensions_len, 0 );
+        MBEDTLS_SSL_DEBUG_MSG( 3, ( "client hello, total extension length: %" \
+                                    MBEDTLS_PRINTF_SIZET, extensions_len ) );
+        MBEDTLS_SSL_DEBUG_BUF( 3, "client hello extensions",
+                                  p_extensions_len, extensions_len );
+    }
 
     *out_len = p - buf;
     return( 0 );