tls13: srv: Move PSK ciphersuite selection up

Move PSK ciphersuite selection up to the main
ClientHello parsing function. That way the
ciphersuite selection only happens in this
function.

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 391b8d4..ad1be2f 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -438,8 +438,9 @@
 struct psk_attributes {
     int type;
     int key_exchange_mode;
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
 };
-#define PSK_ATTRIBUTES_INIT { 0, 0 }
+#define PSK_ATTRIBUTES_INIT { 0, 0, 0 }
 
 /* Parser for pre_shared_key extension in client hello
  *    struct {
@@ -522,7 +523,7 @@
         int psk_ciphersuite_id;
         psa_algorithm_t psk_hash_alg;
         int allowed_key_exchange_modes;
-        const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
+
 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
         mbedtls_ssl_session session;
         mbedtls_ssl_session_init(&session);
@@ -595,9 +596,9 @@
 
         ssl_tls13_select_ciphersuite(ssl, ciphersuites, ciphersuites_end,
                                      psk_ciphersuite_id, psk_hash_alg,
-                                     &ciphersuite_info);
+                                     &psk->ciphersuite_info);
 
-        if (ciphersuite_info == NULL) {
+        if (psk->ciphersuite_info == NULL) {
 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
             mbedtls_ssl_session_free(&session);
 #endif
@@ -614,7 +615,7 @@
 
         ret = ssl_tls13_offered_psks_check_binder_match(
             ssl, binder, binder_len, psk->type,
-            mbedtls_md_psa_alg_from_type((mbedtls_md_type_t) ciphersuite_info->mac));
+            mbedtls_md_psa_alg_from_type((mbedtls_md_type_t) psk->ciphersuite_info->mac));
         if (ret != SSL_TLS1_3_BINDER_MATCH) {
             /* For security reasons, the handshake should be aborted when we
              * fail to validate a binder value. See RFC 8446 section 4.2.11.2
@@ -633,12 +634,6 @@
 
         matched_identity = identity_id;
 
-        /* Update handshake parameters */
-        ssl->handshake->ciphersuite_info = ciphersuite_info;
-        ssl->session_negotiate->ciphersuite = ciphersuite_info->id;
-        MBEDTLS_SSL_DEBUG_MSG(2, ("overwrite ciphersuite: %04x - %s",
-                                  ((unsigned) ciphersuite_info->id),
-                                  ciphersuite_info->name));
 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
         if (psk->type == MBEDTLS_SSL_TLS1_3_PSK_RESUMPTION) {
             ret = ssl_tls13_session_copy_ticket(ssl->session_negotiate,
@@ -1720,10 +1715,18 @@
     }
 
 #if defined(MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_SOME_PSK_ENABLED)
-    if ((handshake->key_exchange_mode !=
-         MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_EPHEMERAL) &&
-        (psk.type == MBEDTLS_SSL_TLS1_3_PSK_RESUMPTION)) {
-        handshake->resume = 1;
+    if (handshake->key_exchange_mode &
+        MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ALL) {
+        handshake->ciphersuite_info = psk.ciphersuite_info;
+        ssl->session_negotiate->ciphersuite = psk.ciphersuite_info->id;
+
+        MBEDTLS_SSL_DEBUG_MSG(2, ("Select PSK ciphersuite: %04x - %s",
+                                  ((unsigned) psk.ciphersuite_info->id),
+                                  psk.ciphersuite_info->name));
+
+        if (psk.type == MBEDTLS_SSL_TLS1_3_PSK_RESUMPTION) {
+            handshake->resume = 1;
+        }
     }
 #endif