Add psk_key_exchange_modes parser
Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 7d99433..fc5ceeb 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -45,6 +45,53 @@
#include "ssl_tls13_keys.h"
#include "ssl_debug_helpers.h"
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+/* From RFC 8446:
+ *
+ * enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
+ * struct {
+ * PskKeyExchangeMode ke_modes<1..255>;
+ * } PskKeyExchangeModes;
+ */
+static int ssl_tls13_parse_key_exchange_modes_ext( mbedtls_ssl_context *ssl,
+ const unsigned char *buf,
+ const unsigned char *end)
+{
+ size_t ke_modes_len;
+ int ke_modes = 0;
+
+ /* Read PSK mode list length (1 Byte) */
+ MBEDTLS_SSL_CHK_BUF_READ_PTR( buf, end, 1 );
+ ke_modes_len = *buf++;
+ /* Currently, there are only two PSK modes, so even without looking
+ * at the content, something's wrong if the list has more than 2 items. */
+ if( ke_modes_len > 2 )
+ return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+
+ MBEDTLS_SSL_CHK_BUF_READ_PTR( buf, end, ke_modes_len );
+
+ while( ke_modes_len-- != 0 )
+ {
+ switch( *buf++ )
+ {
+ case MBEDTLS_SSL_TLS1_3_PSK_MODE_PURE:
+ ke_modes |= MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK;
+ MBEDTLS_SSL_DEBUG_MSG( 3, ( "Found PSK KEX MODE" ) );
+ break;
+ case MBEDTLS_SSL_TLS1_3_PSK_MODE_ECDHE:
+ ke_modes |= MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_EPHEMERAL;
+ MBEDTLS_SSL_DEBUG_MSG( 3, ( "Found PSK_EPHEMERAL KEX MODE" ) );
+ break;
+ default:
+ return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+ }
+ }
+
+ ssl->handshake->tls13_kex_modes = ke_modes;
+ return( 0 );
+}
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+
/* From RFC 8446:
* struct {
* ProtocolVersion versions<2..254>;
@@ -754,6 +801,23 @@
ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_SUPPORTED_VERSIONS;
break;
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+ case MBEDTLS_TLS_EXT_PSK_KEY_EXCHANGE_MODES:
+ MBEDTLS_SSL_DEBUG_MSG( 3, ( "found psk key exchange modes extension" ) );
+
+ ret = ssl_tls13_parse_key_exchange_modes_ext(
+ ssl, p, extension_data_end );
+ if( ret != 0 )
+ {
+ MBEDTLS_SSL_DEBUG_RET(
+ 1, "ssl_tls13_parse_key_exchange_modes_ext", ret );
+ return( ret );
+ }
+
+ ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_PSK_KEY_EXCHANGE_MODES;
+ break;
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+
#if defined(MBEDTLS_SSL_ALPN)
case MBEDTLS_TLS_EXT_ALPN:
MBEDTLS_SSL_DEBUG_MSG( 3, ( "found alpn extension" ) );