Add psk/psk_ephemeral key exchange check

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index d67f453..bf49af2 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -692,24 +692,121 @@
 static int ssl_tls13_client_hello_has_exts_for_ephemeral_key_exchange(
         mbedtls_ssl_context *ssl )
 {
-    return( ssl_tls13_client_hello_has_exts( ssl,
-                                             MBEDTLS_SSL_EXT_SUPPORTED_GROUPS |
-                                             MBEDTLS_SSL_EXT_KEY_SHARE        |
-                                             MBEDTLS_SSL_EXT_SIG_ALG ) );
+    return( ssl_tls13_client_hello_has_exts(
+                ssl,
+                MBEDTLS_SSL_EXT_SUPPORTED_GROUPS |
+                MBEDTLS_SSL_EXT_KEY_SHARE        |
+                MBEDTLS_SSL_EXT_SIG_ALG ) );
 }
 
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+MBEDTLS_CHECK_RETURN_CRITICAL
+static int ssl_tls13_client_hello_has_exts_for_psk_key_exchange(
+               mbedtls_ssl_context *ssl )
+{
+    return( ssl_tls13_client_hello_has_exts(
+                ssl,
+                MBEDTLS_SSL_EXT_PRE_SHARED_KEY          |
+                MBEDTLS_SSL_EXT_PSK_KEY_EXCHANGE_MODES ) );
+}
+
+MBEDTLS_CHECK_RETURN_CRITICAL
+static int ssl_tls13_client_hello_has_exts_for_psk_ephemeral_key_exchange(
+               mbedtls_ssl_context *ssl )
+{
+    return( ssl_tls13_client_hello_has_exts(
+                ssl,
+                MBEDTLS_SSL_EXT_SUPPORTED_GROUPS        |
+                MBEDTLS_SSL_EXT_KEY_SHARE               |
+                MBEDTLS_SSL_EXT_PRE_SHARED_KEY          |
+                MBEDTLS_SSL_EXT_PSK_KEY_EXCHANGE_MODES ) );
+}
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+
 MBEDTLS_CHECK_RETURN_CRITICAL
 static int ssl_tls13_check_ephemeral_key_exchange( mbedtls_ssl_context *ssl )
 {
-    if( !mbedtls_ssl_conf_tls13_ephemeral_enabled( ssl ) )
-        return( 0 );
+    return( mbedtls_ssl_conf_tls13_ephemeral_enabled( ssl ) &&
+            ssl_tls13_client_hello_has_exts_for_ephemeral_key_exchange( ssl ) );
+}
 
-    if( !ssl_tls13_client_hello_has_exts_for_ephemeral_key_exchange( ssl ) )
-        return( 0 );
+MBEDTLS_CHECK_RETURN_CRITICAL
+static int ssl_tls13_check_psk_key_exchange( mbedtls_ssl_context *ssl )
+{
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+    return( mbedtls_ssl_conf_tls13_psk_enabled( ssl ) &&
+            mbedtls_ssl_tls13_psk_enabled( ssl ) &&
+            ssl_tls13_client_hello_has_exts_for_psk_key_exchange( ssl ) );
+#else
+    ((void) ssl);
+    return( 0 );
+#endif
+}
 
-    ssl->handshake->key_exchange_mode =
-        MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_EPHEMERAL;
-    return( 1 );
+MBEDTLS_CHECK_RETURN_CRITICAL
+static int ssl_tls13_check_psk_ephemeral_key_exchange( mbedtls_ssl_context *ssl )
+{
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+    return( mbedtls_ssl_conf_tls13_psk_ephemeral_enabled( ssl ) &&
+            mbedtls_ssl_tls13_psk_ephemeral_enabled( ssl ) &&
+            ssl_tls13_client_hello_has_exts_for_psk_ephemeral_key_exchange( ssl ) );
+#else
+    ((void) ssl);
+    return( 0 );
+#endif
+}
+
+static int ssl_tls13_determine_key_exchange_mode( mbedtls_ssl_context *ssl )
+{
+    /*
+     * Determine the key exchange algorithm to use.
+     * There are three types of key exchanges supported in TLS 1.3:
+     * - (EC)DH with ECDSA,
+     * - (EC)DH with PSK,
+     * - plain PSK.
+     *
+     * The PSK-based key exchanges may additionally be used with 0-RTT.
+     *
+     * Our built-in order of preference is
+     *  1 ) Plain PSK Mode ( psk )
+     *  2 ) (EC)DHE-PSK Mode ( psk_ephemeral )
+     *  3 ) Certificate Mode ( ephemeral )
+     */
+
+    ssl->handshake->key_exchange_mode = MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_NONE;
+
+    if( ssl_tls13_check_psk_key_exchange( ssl ) )
+    {
+        ssl->handshake->key_exchange_mode =
+            MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK;
+        MBEDTLS_SSL_DEBUG_MSG( 2, ( "key exchange mode: psk" ) );
+    }
+    else
+    if( ssl_tls13_check_psk_ephemeral_key_exchange( ssl ) )
+    {
+        ssl->handshake->key_exchange_mode =
+            MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_EPHEMERAL;
+        MBEDTLS_SSL_DEBUG_MSG( 2, ( "key exchange mode: psk_ephemeral" ) );
+    }
+    else
+    if( ssl_tls13_check_ephemeral_key_exchange( ssl ) )
+    {
+        ssl->handshake->key_exchange_mode =
+            MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_EPHEMERAL;
+        MBEDTLS_SSL_DEBUG_MSG( 2, ( "key exchange mode: ephemeral" ) );
+    }
+    else
+    {
+        MBEDTLS_SSL_DEBUG_MSG(
+                1,
+                ( "ClientHello message misses mandatory extensions." ) );
+        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_MISSING_EXTENSION ,
+                                      MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+        return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+    }
+
+    return( 0 );
+
 }
 
 #if defined(MBEDTLS_X509_CRT_PARSE_C) && \
@@ -1216,6 +1313,9 @@
     ssl_tls13_debug_print_client_hello_exts( ssl );
 #endif /* MBEDTLS_DEBUG_C */
 
+    ret = ssl_tls13_determine_key_exchange_mode( ssl );
+    if( ret < 0 )
+        return( ret );
     mbedtls_ssl_add_hs_hdr_to_checksum( ssl,
                                         MBEDTLS_SSL_HS_CLIENT_HELLO,
                                         p - buf );
@@ -1226,8 +1326,7 @@
      * - The content up to but excluding the PSK extension, if present.
      */
     /* If we've settled on a PSK-based exchange, parse PSK identity ext */
-    if( mbedtls_ssl_conf_tls13_some_psk_enabled( ssl ) &&
-        ( ssl->handshake->extensions_present & MBEDTLS_SSL_EXT_PRE_SHARED_KEY ) )
+    if( mbedtls_ssl_tls13_key_exchange_mode_with_psk( ssl ) )
     {
         ssl->handshake->update_checksum( ssl, buf,
                                          pre_shared_key_ext_start - buf );
@@ -1258,19 +1357,6 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
     /*
-     * Here we only support the ephemeral or (EC)DHE key echange mode
-     */
-    if( !ssl_tls13_check_ephemeral_key_exchange( ssl ) )
-    {
-        MBEDTLS_SSL_DEBUG_MSG(
-                1,
-                ( "ClientHello message misses mandatory extensions." ) );
-        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_MISSING_EXTENSION ,
-                                      MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-        return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
-    }
-
-    /*
      * Server certificate selection
      */
     if( ssl->conf->f_cert_cb && ( ret = ssl->conf->f_cert_cb( ssl ) ) != 0 )