Address comments in psk client review

Improve comments
Refine cipher suite related code in psk
Refine get_psk_offered()

Change-Id: Ic3b0b5f86eb1e71f11bb499961aa8494284f1840
Signed-off-by: XiaokangQian <xiaokang.qian@arm.com>
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index 62f00fa..6e82631 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -615,15 +615,16 @@
     const unsigned char *psk_identity;
     size_t psk_identity_len;
     unsigned char *p = buf;
-    int num_modes = 0;
+    int ke_modes_len = 0;
 
-    ((void) num_modes );
+    ((void) ke_modes_len );
     *out_len = 0;
+
     /* Skip writing extension if no PSK key exchange mode
-     * is enabled in the config.
+     * is enabled in the config or there is no PSK to offer.
      */
     if( !mbedtls_ssl_conf_tls13_some_psk_enabled( ssl ) ||
-         mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
+         mbedtls_ssl_get_psk_to_offer( ssl, NULL, &psk, &psk_len,
                                       &psk_identity, &psk_identity_len ) != 0 )
     {
         MBEDTLS_SSL_DEBUG_MSG( 3, ( "skip psk_key_exchange_modes extension" ) );
@@ -637,18 +638,17 @@
     MBEDTLS_SSL_DEBUG_MSG(
             3, ( "client hello, adding psk_key_exchange_modes extension" ) );
 
-    /* Extension Type */
     MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_PSK_KEY_EXCHANGE_MODES, p, 0 );
 
-    /* Skip extension length (2 byte) and
-     * PSK mode list length (1 byte) for now.
+    /* Skip extension length (2 bytes) and
+     * ke_modes length (1 byte) for now.
      */
     p += 5;
 
     if( mbedtls_ssl_conf_tls13_psk_enabled( ssl ) )
     {
         *p++ = MBEDTLS_SSL_TLS1_3_PSK_MODE_PURE;
-        num_modes++;
+        ke_modes_len++;
 
         MBEDTLS_SSL_DEBUG_MSG( 4, ( "Adding pure PSK key exchange mode" ) );
     }
@@ -656,17 +656,14 @@
     if( mbedtls_ssl_conf_tls13_psk_ephemeral_enabled( ssl ) )
     {
         *p++ = MBEDTLS_SSL_TLS1_3_PSK_MODE_ECDHE;
-        num_modes++;
+        ke_modes_len++;
 
         MBEDTLS_SSL_DEBUG_MSG( 4, ( "Adding PSK-ECDHE key exchange mode" ) );
     }
 
-    /* Add extension length: PSK mode list length byte + actual
-     * PSK mode list length
-     */
-    MBEDTLS_PUT_UINT16_BE( num_modes + 1, buf, 2 );
-    /* Add PSK mode list length */
-    buf[4] = num_modes;
+    /* Now write the extension and ke_modes length */
+    MBEDTLS_PUT_UINT16_BE( ke_modes_len + 1, buf, 2 );
+    buf[4] = ke_modes_len;
 
     *out_len = p - buf;
     ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_PSK_KEY_EXCHANGE_MODES;
@@ -685,22 +682,12 @@
  * opaque PskBinderEntry<32..255>;
  *
  * struct {
- *   select ( Handshake.msg_type ) {
  *
- *     case client_hello:
- *       PskIdentity identities<7..2^16-1>;
- *       PskBinderEntry binders<33..2^16-1>;
- *
- *     case server_hello:
- *       uint16 selected_identity;
- *   };
+ *     PskIdentity identities<7..2^16-1>;
+ *     PskBinderEntry binders<33..2^16-1>;
  *
  * } PreSharedKeyExtension;
  *
- *
- * part = 0 ==> everything up to the PSK binder list,
- *              returning the binder list length in `binder_list_length`.
- * part = 1 ==> the PSK binder list
  */
 
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
@@ -715,12 +702,12 @@
     size_t psk_len;
     const unsigned char *psk_identity;
     size_t psk_identity_len;
-    const mbedtls_ssl_ciphersuite_t *suite_info = NULL;
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info = NULL;
     const int *ciphersuites;
+    psa_algorithm_t psa_hash_alg;
     int hash_len = 0;
     size_t identities_len, l_binders_len;
     uint32_t obfuscated_ticket_age = 0;
-    psa_algorithm_t psa_hash_alg;
 
     *out_len = 0;
     *binders_len = 0;
@@ -738,7 +725,7 @@
      * - Otherwise, skip the PSK extension.
      */
 
-    if( mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
+    if( mbedtls_ssl_get_psk_to_offer( ssl, NULL, &psk, &psk_len,
                                       &psk_identity, &psk_identity_len ) != 0 )
     {
         MBEDTLS_SSL_DEBUG_MSG( 3, ( "skip pre_shared_key extensions" ) );
@@ -751,22 +738,27 @@
     ciphersuites = ssl->conf->ciphersuite_list;
     for ( int i = 0; ciphersuites[i] != 0; i++ )
     {
-        suite_info = mbedtls_ssl_ciphersuite_from_id( ciphersuites[i] );
+        ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( ciphersuites[i] );
 
-        if( suite_info == NULL )
+        if( mbedtls_ssl_validate_ciphersuite(
+                                ssl, ciphersuite_info,
+                                MBEDTLS_SSL_VERSION_TLS1_3,
+                                MBEDTLS_SSL_VERSION_TLS1_3 ) != 0 )
             continue;
 
         /* In this implementation we only add one pre-shared-key extension. */
         ssl->session_negotiate->ciphersuite = ciphersuites[i];
-        ssl->handshake->ciphersuite_info = suite_info;
         break;
     }
 
-    if( suite_info != NULL )
-    {
-        psa_hash_alg = mbedtls_psa_translate_md( suite_info->mac );
-        hash_len = PSA_HASH_LENGTH( psa_hash_alg );
-    }
+    ciphersuite_info = mbedtls_ssl_ciphersuite_from_id(
+            ssl->session_negotiate->ciphersuite );
+    /* No suitable ciphersuite for the PSK */
+    if( ciphersuite_info  == NULL )
+        return( 0 );
+
+    psa_hash_alg = mbedtls_psa_translate_md( ciphersuite_info->mac );
+    hash_len = PSA_HASH_LENGTH( psa_hash_alg );
     if( hash_len == -1 )
         return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
 
@@ -818,64 +810,28 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char *p = buf;
-    const mbedtls_ssl_ciphersuite_t *suite_info = NULL;
-    const int *ciphersuites;
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info = NULL;
+    psa_algorithm_t psa_hash_alg;
     int hash_len = 0;
-    const unsigned char *psk;
-    size_t psk_len;
-    const unsigned char *psk_identity;
-    size_t psk_identity_len;
+    const unsigned char *psk = NULL;
+    size_t psk_len = 0;
     int psk_type;
     unsigned char transcript[MBEDTLS_MD_MAX_SIZE];
     size_t transcript_len;
-    psa_algorithm_t psa_hash_alg;
 
-    /* Check if we have any PSKs to offer. If so, return the first.
-     *
-     * NOTE: Ultimately, we want to be able to offer multiple PSKs,
-     *       in which case we want to iterate over them here.
-     *
-     * As it stands, however, we only ever offer one, chosen
-     * by the following heuristic:
-     * - If a ticket has been configured, offer the corresponding PSK.
-     * - If no ticket has been configured by an external PSK has been
-     *   configured, offer that.
-     * - Otherwise, skip the PSK extension.
-     */
+    ciphersuite_info = mbedtls_ssl_ciphersuite_from_id(
+            ssl->session_negotiate->ciphersuite );
+    if( ciphersuite_info  == NULL )
+        return( 0 );
 
-    if( mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
-                                      &psk_identity, &psk_identity_len ) != 0 )
-    {
-        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
-    }
-
-    /*
-     * Ciphersuite list
-     */
-    ciphersuites = ssl->conf->ciphersuite_list;
-    for ( int i = 0; ciphersuites[i] != 0; i++ )
-    {
-        suite_info = mbedtls_ssl_ciphersuite_from_id( ciphersuites[i] );
-
-        if( suite_info == NULL )
-            continue;
-
-        /* In this implementation we only add one pre-shared-key extension. */
-        ssl->session_negotiate->ciphersuite = ciphersuites[i];
-        ssl->handshake->ciphersuite_info = suite_info;
-        break;
-    }
-
-    if( suite_info != NULL )
-    {
-        psa_hash_alg = mbedtls_psa_translate_md( suite_info->mac );
-        hash_len = PSA_HASH_LENGTH( psa_hash_alg );
-    }
+    psa_hash_alg = mbedtls_psa_translate_md( ciphersuite_info->mac );
+    hash_len = PSA_HASH_LENGTH( psa_hash_alg );
     if( ( hash_len == -1 ) || ( ( end - buf ) != 3 + hash_len ) )
         return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
 
     MBEDTLS_SSL_DEBUG_MSG( 3, ( "client hello, adding PSK binder list" ) );
 
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 3 + hash_len );
     /* 2 bytes length field for array of psk binders */
     MBEDTLS_PUT_UINT16_BE( hash_len + 1, p, 0 );
     p += 2;
@@ -889,14 +845,14 @@
         psk_type = MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL;
 
     /* Get current state of handshake transcript. */
-    ret = mbedtls_ssl_get_handshake_transcript( ssl, suite_info->mac,
+    ret = mbedtls_ssl_get_handshake_transcript( ssl, ciphersuite_info->mac,
                                                 transcript, sizeof( transcript ),
                                                 &transcript_len );
     if( ret != 0 )
         return( ret );
 
     ret = mbedtls_ssl_tls13_create_psk_binder( ssl,
-              mbedtls_psa_translate_md( suite_info->mac ),
+              mbedtls_psa_translate_md( ciphersuite_info->mac ),
               psk, psk_len, psk_type,
               transcript, p );
     if( ret != 0 )
@@ -1269,13 +1225,8 @@
  * opaque PskBinderEntry<32..255>;
  *
  * struct {
- *   select ( Handshake.msg_type ) {
- *     case client_hello:
- *          PskIdentity identities<7..2^16-1>;
- *          PskBinderEntry binders<33..2^16-1>;
- *     case server_hello:
- *          uint16 selected_identity;
- *   };
+ *
+ *   uint16 selected_identity;
  *
  * } PreSharedKeyExtension;
  *
@@ -1283,7 +1234,7 @@
 
 static int ssl_tls13_parse_server_psk_identity_ext( mbedtls_ssl_context *ssl,
                                                     const unsigned char *buf,
-                                                    size_t len )
+                                                    const unsigned char *end )
 {
     int ret = 0;
     size_t selected_identity;
@@ -1299,7 +1250,7 @@
      * NOTE: Ultimately, we want to offer multiple PSKs, and in this
      *       case, we need to iterate over them here.
      */
-    if( mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
+    if( mbedtls_ssl_get_psk_to_offer( ssl, NULL, &psk, &psk_len,
                                       &psk_identity, &psk_identity_len ) != 0 )
     {
         /* If we haven't offered a PSK, the server must not send
@@ -1307,12 +1258,7 @@
         return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
     }
 
-    if( len != (size_t) 2 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad psk_identity extension in server hello message" ) );
-        return( MBEDTLS_ERR_SSL_DECODE_ERROR );
-    }
-
+    MBEDTLS_SSL_CHK_BUF_PTR( buf, end, 2 );
     selected_identity = MBEDTLS_GET_UINT16_BE( buf, 0 );
 
     /* We have offered only one PSK, so the only valid choice
@@ -1571,7 +1517,7 @@
                 }
 
                 if( ( ret = ssl_tls13_parse_server_psk_identity_ext(
-                                ssl, p, extension_data_len ) ) != 0 )
+                                ssl, p, extension_data_end ) ) != 0 )
                 {
                     MBEDTLS_SSL_DEBUG_RET(
                         1, ( "ssl_tls13_parse_server_psk_identity_ext" ), ret );