introduce sent/recv extensions field

And remove `extensions_present`

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index 5eac1f1..7b66be1 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -448,7 +448,6 @@
     {
         size_t cert_data_len, extensions_len;
         const unsigned char *extensions_end;
-        uint32_t extensions_present;
 
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, certificate_list_end, 3 );
         cert_data_len = MBEDTLS_GET_UINT24_BE( p, 0 );
@@ -508,7 +507,7 @@
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, certificate_list_end, extensions_len );
 
         extensions_end = p + extensions_len;
-        extensions_present = MBEDTLS_SSL_EXT_NONE;
+        ssl->handshake->received_extensions = MBEDTLS_SSL_EXT_NONE;
 
         while( p < extensions_end )
         {
@@ -528,26 +527,12 @@
 
             MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, extension_data_len );
 
-            /* RFC 8446 page 35
-             *
-             * If an implementation receives an extension which it recognizes and
-             * which is not specified for the message in which it appears, it MUST
-             * abort the handshake with an "illegal_parameter" alert.
-             */
-            extensions_present |= mbedtls_tls13_get_extension_mask( extension_type );
-            MBEDTLS_SSL_DEBUG_MSG( 3,
-                        ( "encrypted extensions : received %s(%u) extension",
-                        mbedtls_tls13_get_extension_name( extension_type ),
-                        extension_type ) );
-            if( ( extensions_present & MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_CT ) == 0 )
-            {
-                MBEDTLS_SSL_DEBUG_MSG(
-                    3, ( "forbidden extension received." ) );
-                MBEDTLS_SSL_PEND_FATAL_ALERT(
-                    MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
-                    MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-                return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-            }
+            ret = mbedtls_tls13_check_received_extensions(
+                  ssl, MBEDTLS_SSL_HS_CERTIFICATE, extension_type,
+                  MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_CT );
+            if( ret != 0 )
+                return( ret );
+
             switch( extension_type )
             {
                 default:
@@ -561,7 +546,8 @@
             p += extension_data_len;
         }
 
-        MBEDTLS_SSL_TLS1_3_PRINT_EXTS( 3, "Certificate", extensions_present );
+        MBEDTLS_SSL_TLS1_3_PRINT_EXTS(
+            3, MBEDTLS_SSL_HS_CERTIFICATE, ssl->handshake->received_extensions );
     }
 
 exit:
@@ -1691,9 +1677,31 @@
     return( "unknown" );
 }
 
+static const char *ssl_tls13_get_hs_msg_name( int hs_msg_type )
+{
+    switch( hs_msg_type )
+    {
+        case MBEDTLS_SSL_HS_CLIENT_HELLO:
+            return( "ClientHello" );
+        case MBEDTLS_SSL_HS_SERVER_HELLO:
+            return( "ServerHello" );
+        case -MBEDTLS_SSL_HS_SERVER_HELLO: // HRR does not have IANA value.
+            return( "HelloRetryRequest" );
+        case MBEDTLS_SSL_HS_NEW_SESSION_TICKET:
+            return( "NewSessionTicket" );
+        case MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS:
+            return( "EncryptedExtensions" );
+        case MBEDTLS_SSL_HS_CERTIFICATE:
+            return( "Certificate" );
+        case MBEDTLS_SSL_HS_CERTIFICATE_REQUEST:
+            return( "CertificateRequest" );
+    }
+    return( NULL );
+}
+
 void mbedtls_ssl_tls13_print_extensions( const mbedtls_ssl_context *ssl,
                                          int level, const char *file, int line,
-                                         const char *hs_msg_name,
+                                         int hs_msg_type,
                                          uint32_t extensions_present )
 {
     static const struct{
@@ -1724,7 +1732,8 @@
             { MBEDTLS_SSL_EXT_KEY_SHARE, "key_share" } };
 
     mbedtls_debug_print_msg( ssl, level, file, line,
-                             "extension list of %s:", hs_msg_name );
+                             "extension list of %s:",
+                             ssl_tls13_get_hs_msg_name( hs_msg_type ) );
 
     for( unsigned i = 0;
          i < sizeof( mask_to_str_table ) / sizeof( mask_to_str_table[0] );
@@ -1742,4 +1751,63 @@
 
 #endif /* MBEDTLS_DEBUG_C */
 
+/* RFC 8446 section 4.2
+ *
+ * If an implementation receives an extension which it recognizes and which is
+ * not specified for the message in which it appears, it MUST abort the handshake
+ * with an "illegal_parameter" alert.
+ *
+ */
+
+int mbedtls_tls13_check_received_extensions( mbedtls_ssl_context *ssl,
+                                             int hs_msg_type,
+                                             uint32_t extension_type,
+                                             uint32_t allowed_mask )
+{
+    uint32_t extension_mask;
+
+#if defined(MBEDTLS_DEBUG_C)
+    const char *hs_msg_name = ssl_tls13_get_hs_msg_name( hs_msg_type );
+#endif
+
+    extension_mask = mbedtls_tls13_get_extension_mask( extension_type );
+
+    MBEDTLS_SSL_DEBUG_MSG( 3,
+                ( "%s : received %s(%x) extension",
+                  hs_msg_name,
+                  mbedtls_tls13_get_extension_name( extension_type ),
+                  (unsigned int)extension_type ) );
+
+    if( ( extension_mask & allowed_mask ) == 0 )
+    {
+        MBEDTLS_SSL_DEBUG_MSG(
+            3, ( "%s : forbidden extension received.", hs_msg_name ) );
+        MBEDTLS_SSL_PEND_FATAL_ALERT(
+            MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
+            MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+        return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+    }
+
+    ssl->handshake->received_extensions |= extension_mask;
+    switch( hs_msg_type )
+    {
+        case MBEDTLS_SSL_HS_SERVER_HELLO:
+        case -MBEDTLS_SSL_HS_SERVER_HELLO: // HRR does not have IANA value.
+        case MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS:
+        case MBEDTLS_SSL_HS_CERTIFICATE:
+            if( ( ~ssl->handshake->sent_extensions & extension_mask ) == 0 )
+                return( 0 );
+            break;
+        default:
+            return( 0 );
+    }
+
+    MBEDTLS_SSL_DEBUG_MSG(
+            3, ( "%s : forbidden extension received.", hs_msg_name ) );
+    MBEDTLS_SSL_PEND_FATAL_ALERT(
+        MBEDTLS_SSL_ALERT_MSG_UNSUPPORTED_EXT,
+        MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+    return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+}
+
 #endif /* MBEDTLS_SSL_TLS_C && MBEDTLS_SSL_PROTO_TLS1_3 */