Add psk code to tls13 client side

Change-Id: I222b2c9d393889448e5e6ad06638536b54edb703
Signed-off-by: XiaokangQian <xiaokang.qian@arm.com>
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index 183b6ee..62f00fa 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -595,6 +595,320 @@
     return( 0 );
 }
 
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+/*
+ * ssl_tls13_write_psk_key_exchange_modes_ext() structure:
+ *
+ * enum { psk_ke( 0 ), psk_dhe_ke( 1 ), ( 255 ) } PskKeyExchangeMode;
+ *
+ * struct {
+ *     PskKeyExchangeMode ke_modes<1..255>;
+ * } PskKeyExchangeModes;
+ */
+static int ssl_tls13_write_psk_key_exchange_modes_ext( mbedtls_ssl_context *ssl,
+                                                       unsigned char *buf,
+                                                       unsigned char *end,
+                                                       size_t *out_len )
+{
+    const unsigned char *psk;
+    size_t psk_len;
+    const unsigned char *psk_identity;
+    size_t psk_identity_len;
+    unsigned char *p = buf;
+    int num_modes = 0;
+
+    ((void) num_modes );
+    *out_len = 0;
+    /* Skip writing extension if no PSK key exchange mode
+     * is enabled in the config.
+     */
+    if( !mbedtls_ssl_conf_tls13_some_psk_enabled( ssl ) ||
+         mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
+                                      &psk_identity, &psk_identity_len ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 3, ( "skip psk_key_exchange_modes extension" ) );
+        return( 0 );
+    }
+
+    /* Require 7 bytes of data, otherwise fail,
+     * even if extension might be shorter.
+     */
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 7 );
+    MBEDTLS_SSL_DEBUG_MSG(
+            3, ( "client hello, adding psk_key_exchange_modes extension" ) );
+
+    /* Extension Type */
+    MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_PSK_KEY_EXCHANGE_MODES, p, 0 );
+
+    /* Skip extension length (2 byte) and
+     * PSK mode list length (1 byte) for now.
+     */
+    p += 5;
+
+    if( mbedtls_ssl_conf_tls13_psk_enabled( ssl ) )
+    {
+        *p++ = MBEDTLS_SSL_TLS1_3_PSK_MODE_PURE;
+        num_modes++;
+
+        MBEDTLS_SSL_DEBUG_MSG( 4, ( "Adding pure PSK key exchange mode" ) );
+    }
+
+    if( mbedtls_ssl_conf_tls13_psk_ephemeral_enabled( ssl ) )
+    {
+        *p++ = MBEDTLS_SSL_TLS1_3_PSK_MODE_ECDHE;
+        num_modes++;
+
+        MBEDTLS_SSL_DEBUG_MSG( 4, ( "Adding PSK-ECDHE key exchange mode" ) );
+    }
+
+    /* Add extension length: PSK mode list length byte + actual
+     * PSK mode list length
+     */
+    MBEDTLS_PUT_UINT16_BE( num_modes + 1, buf, 2 );
+    /* Add PSK mode list length */
+    buf[4] = num_modes;
+
+    *out_len = p - buf;
+    ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_PSK_KEY_EXCHANGE_MODES;
+    return ( 0 );
+}
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+
+/*
+ * mbedtls_ssl_tls13_write_pre_shared_key_ext() structure:
+ *
+ * struct {
+ *   opaque identity<1..2^16-1>;
+ *   uint32 obfuscated_ticket_age;
+ * } PskIdentity;
+ *
+ * opaque PskBinderEntry<32..255>;
+ *
+ * struct {
+ *   select ( Handshake.msg_type ) {
+ *
+ *     case client_hello:
+ *       PskIdentity identities<7..2^16-1>;
+ *       PskBinderEntry binders<33..2^16-1>;
+ *
+ *     case server_hello:
+ *       uint16 selected_identity;
+ *   };
+ *
+ * } PreSharedKeyExtension;
+ *
+ *
+ * part = 0 ==> everything up to the PSK binder list,
+ *              returning the binder list length in `binder_list_length`.
+ * part = 1 ==> the PSK binder list
+ */
+
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+
+int mbedtls_ssl_tls13_write_pre_shared_key_ext_without_binders(
+    mbedtls_ssl_context *ssl,
+    unsigned char *buf, unsigned char *end,
+    size_t *out_len, size_t *binders_len )
+{
+    unsigned char *p = buf;
+    const unsigned char *psk;
+    size_t psk_len;
+    const unsigned char *psk_identity;
+    size_t psk_identity_len;
+    const mbedtls_ssl_ciphersuite_t *suite_info = NULL;
+    const int *ciphersuites;
+    int hash_len = 0;
+    size_t identities_len, l_binders_len;
+    uint32_t obfuscated_ticket_age = 0;
+    psa_algorithm_t psa_hash_alg;
+
+    *out_len = 0;
+    *binders_len = 0;
+
+    /* Check if we have any PSKs to offer. If so, return the first.
+     *
+     * NOTE: Ultimately, we want to be able to offer multiple PSKs,
+     *       in which case we want to iterate over them here.
+     *
+     * As it stands, however, we only ever offer one, chosen
+     * by the following heuristic:
+     * - If a ticket has been configured, offer the corresponding PSK.
+     * - If no ticket has been configured by an external PSK has been
+     *   configured, offer that.
+     * - Otherwise, skip the PSK extension.
+     */
+
+    if( mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
+                                      &psk_identity, &psk_identity_len ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 3, ( "skip pre_shared_key extensions" ) );
+        return( 0 );
+    }
+
+    /*
+     * Ciphersuite list
+     */
+    ciphersuites = ssl->conf->ciphersuite_list;
+    for ( int i = 0; ciphersuites[i] != 0; i++ )
+    {
+        suite_info = mbedtls_ssl_ciphersuite_from_id( ciphersuites[i] );
+
+        if( suite_info == NULL )
+            continue;
+
+        /* In this implementation we only add one pre-shared-key extension. */
+        ssl->session_negotiate->ciphersuite = ciphersuites[i];
+        ssl->handshake->ciphersuite_info = suite_info;
+        break;
+    }
+
+    if( suite_info != NULL )
+    {
+        psa_hash_alg = mbedtls_psa_translate_md( suite_info->mac );
+        hash_len = PSA_HASH_LENGTH( psa_hash_alg );
+    }
+    if( hash_len == -1 )
+        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+
+    /* Check if we have space to write the extension, binder included.
+     * - extension_type         (2 bytes)
+     * - extension_data_len     (2 bytes)
+     * - identities_len         (2 bytes)
+     * - identity_len           (2 bytes)
+     * - identity               (psk_identity_len bytes)
+     * - obfuscated_ticket_age  (4 bytes)
+     * - binders_len            (2 bytes)
+     * - binder_len             (1 byte)
+     * - binder                 (hash_len bytes)
+     */
+
+    identities_len = 6 + psk_identity_len;
+    l_binders_len = 1 + hash_len;
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 4 + 2 + identities_len + 2 + l_binders_len );
+
+    MBEDTLS_SSL_DEBUG_MSG( 3,
+                 ( "client hello, adding pre_shared_key extension, "
+                   "omitting PSK binder list" ) );
+
+    /* Extension header */
+    MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_PRE_SHARED_KEY, p, 0 );
+    MBEDTLS_PUT_UINT16_BE( 2 + identities_len + 2 + l_binders_len , p, 2 );
+
+    MBEDTLS_PUT_UINT16_BE( identities_len, p, 4 );
+    MBEDTLS_PUT_UINT16_BE( psk_identity_len, p, 6 );
+    p += 8;
+    memcpy( p, psk_identity, psk_identity_len );
+    p += psk_identity_len;
+
+    /* add obfuscated ticket age */
+    MBEDTLS_PUT_UINT32_BE( obfuscated_ticket_age, p, 0 );
+    p += 4;
+
+    *out_len = ( p - buf ) + l_binders_len + 2;
+    *binders_len = l_binders_len + 2;
+
+    ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_PRE_SHARED_KEY;
+
+    return( 0 );
+}
+
+int mbedtls_ssl_tls13_write_pre_shared_key_ext_binders(
+    mbedtls_ssl_context *ssl,
+    unsigned char *buf, unsigned char *end )
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    unsigned char *p = buf;
+    const mbedtls_ssl_ciphersuite_t *suite_info = NULL;
+    const int *ciphersuites;
+    int hash_len = 0;
+    const unsigned char *psk;
+    size_t psk_len;
+    const unsigned char *psk_identity;
+    size_t psk_identity_len;
+    int psk_type;
+    unsigned char transcript[MBEDTLS_MD_MAX_SIZE];
+    size_t transcript_len;
+    psa_algorithm_t psa_hash_alg;
+
+    /* Check if we have any PSKs to offer. If so, return the first.
+     *
+     * NOTE: Ultimately, we want to be able to offer multiple PSKs,
+     *       in which case we want to iterate over them here.
+     *
+     * As it stands, however, we only ever offer one, chosen
+     * by the following heuristic:
+     * - If a ticket has been configured, offer the corresponding PSK.
+     * - If no ticket has been configured by an external PSK has been
+     *   configured, offer that.
+     * - Otherwise, skip the PSK extension.
+     */
+
+    if( mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
+                                      &psk_identity, &psk_identity_len ) != 0 )
+    {
+        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+    }
+
+    /*
+     * Ciphersuite list
+     */
+    ciphersuites = ssl->conf->ciphersuite_list;
+    for ( int i = 0; ciphersuites[i] != 0; i++ )
+    {
+        suite_info = mbedtls_ssl_ciphersuite_from_id( ciphersuites[i] );
+
+        if( suite_info == NULL )
+            continue;
+
+        /* In this implementation we only add one pre-shared-key extension. */
+        ssl->session_negotiate->ciphersuite = ciphersuites[i];
+        ssl->handshake->ciphersuite_info = suite_info;
+        break;
+    }
+
+    if( suite_info != NULL )
+    {
+        psa_hash_alg = mbedtls_psa_translate_md( suite_info->mac );
+        hash_len = PSA_HASH_LENGTH( psa_hash_alg );
+    }
+    if( ( hash_len == -1 ) || ( ( end - buf ) != 3 + hash_len ) )
+        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+
+    MBEDTLS_SSL_DEBUG_MSG( 3, ( "client hello, adding PSK binder list" ) );
+
+    /* 2 bytes length field for array of psk binders */
+    MBEDTLS_PUT_UINT16_BE( hash_len + 1, p, 0 );
+    p += 2;
+
+    /* 1 bytes length field for next psk binder */
+    *p++ = MBEDTLS_BYTE_0( hash_len );
+
+    if( ssl->handshake->resume == 1 )
+        psk_type = MBEDTLS_SSL_TLS1_3_PSK_RESUMPTION;
+    else
+        psk_type = MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL;
+
+    /* Get current state of handshake transcript. */
+    ret = mbedtls_ssl_get_handshake_transcript( ssl, suite_info->mac,
+                                                transcript, sizeof( transcript ),
+                                                &transcript_len );
+    if( ret != 0 )
+        return( ret );
+
+    ret = mbedtls_ssl_tls13_create_psk_binder( ssl,
+              mbedtls_psa_translate_md( suite_info->mac ),
+              psk, psk_len, psk_type,
+              transcript, p );
+    if( ret != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_tls13_create_psk_binder", ret );
+        return( ret );
+    }
+
+    return( 0 );
+}
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+
 int mbedtls_ssl_tls13_write_client_hello_exts( mbedtls_ssl_context *ssl,
                                                unsigned char *buf,
                                                unsigned char *end,
@@ -631,6 +945,22 @@
         p += ext_len;
     }
 
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+    /* For PSK-based key exchange we need the pre_shared_key extension
+     * and the psk_key_exchange_modes extension.
+     *
+     * The pre_shared_key extension MUST be the last extension in the
+     * ClientHello. Servers MUST check that it is the last extension and
+     * otherwise fail the handshake with an "illegal_parameter" alert.
+     *
+     * Add the psk_key_exchange_modes extension.
+     */
+    ret = ssl_tls13_write_psk_key_exchange_modes_ext( ssl, p, end, &ext_len );
+    if( ret != 0 )
+        return( ret );
+    p += ext_len;
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+
     *out_len = p - buf;
 
     return( 0 );
@@ -929,6 +1259,98 @@
     return( 0 );
 }
 
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+/*
+ * struct {
+ *   opaque identity<1..2^16-1>;
+ *   uint32 obfuscated_ticket_age;
+ * } PskIdentity;
+ *
+ * opaque PskBinderEntry<32..255>;
+ *
+ * struct {
+ *   select ( Handshake.msg_type ) {
+ *     case client_hello:
+ *          PskIdentity identities<7..2^16-1>;
+ *          PskBinderEntry binders<33..2^16-1>;
+ *     case server_hello:
+ *          uint16 selected_identity;
+ *   };
+ *
+ * } PreSharedKeyExtension;
+ *
+ */
+
+static int ssl_tls13_parse_server_psk_identity_ext( mbedtls_ssl_context *ssl,
+                                                    const unsigned char *buf,
+                                                    size_t len )
+{
+    int ret = 0;
+    size_t selected_identity;
+
+    const unsigned char *psk;
+    size_t psk_len;
+    const unsigned char *psk_identity;
+    size_t psk_identity_len;
+
+
+    /* Check which PSK we've offered.
+     *
+     * NOTE: Ultimately, we want to offer multiple PSKs, and in this
+     *       case, we need to iterate over them here.
+     */
+    if( mbedtls_ssl_get_psk_to_offer( ssl, &psk, &psk_len,
+                                      &psk_identity, &psk_identity_len ) != 0 )
+    {
+        /* If we haven't offered a PSK, the server must not send
+         * a PSK identity extension. */
+        return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+    }
+
+    if( len != (size_t) 2 )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad psk_identity extension in server hello message" ) );
+        return( MBEDTLS_ERR_SSL_DECODE_ERROR );
+    }
+
+    selected_identity = MBEDTLS_GET_UINT16_BE( buf, 0 );
+
+    /* We have offered only one PSK, so the only valid choice
+     * for the server is PSK index 0.
+     *
+     * This will change once we support multiple PSKs. */
+    if( selected_identity > 0 )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "Server's chosen PSK identity out of range" ) );
+
+        if( ( ret = mbedtls_ssl_send_alert_message( ssl,
+                        MBEDTLS_SSL_ALERT_LEVEL_FATAL,
+                        MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER ) ) != 0 )
+        {
+            return( ret );
+        }
+
+        return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+    }
+
+    /* Set the chosen PSK
+     *
+     * TODO: We don't have to do this in case we offered 0-RTT and the
+     *       server accepted it, because in this case we've already
+     *       set the handshake PSK. */
+    ret = mbedtls_ssl_set_hs_psk( ssl, psk, psk_len );
+    if( ret != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_set_hs_psk", ret );
+        return( ret );
+    }
+
+    ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_PRE_SHARED_KEY;
+    return( 0 );
+}
+
+#endif
+
 /* Parse ServerHello message and configure context
  *
  * struct {
@@ -1139,12 +1561,24 @@
                     goto cleanup;
                 break;
 
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
             case MBEDTLS_TLS_EXT_PRE_SHARED_KEY:
-                MBEDTLS_SSL_DEBUG_MSG( 3, ( "found pre_shared_key extension." ) );
-                MBEDTLS_SSL_DEBUG_MSG( 3, ( "pre_shared_key:Not supported yet" ) );
+                MBEDTLS_SSL_DEBUG_MSG( 3, ( "found pre_shared_key extension" ) );
+                if( is_hrr )
+                {
+                    fatal_alert = MBEDTLS_SSL_ALERT_MSG_UNSUPPORTED_EXT;
+                    goto cleanup;
+                }
 
-                fatal_alert = MBEDTLS_SSL_ALERT_MSG_UNSUPPORTED_EXT;
-                goto cleanup;
+                if( ( ret = ssl_tls13_parse_server_psk_identity_ext(
+                                ssl, p, extension_data_len ) ) != 0 )
+                {
+                    MBEDTLS_SSL_DEBUG_RET(
+                        1, ( "ssl_tls13_parse_server_psk_identity_ext" ), ret );
+                    return( ret );
+                }
+                break;
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
 
             case MBEDTLS_TLS_EXT_KEY_SHARE:
                 MBEDTLS_SSL_DEBUG_MSG( 3, ( "found key_shares extension" ) );