Port ALPN support for tls13 client from tls13-prototype
Summary:
Port ALPN implementation of tls13 client from
[tls13-prototype](https://github.com/hannestschofenig/mbedtls/blob/tls13-prototype/library/ssl_tls13_client.c#L1124).
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Signed-off-by: lhuang04 <lhuang04@fb.com>
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index f556c0f..99c12cd 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -113,6 +113,120 @@
return( 0 );
}
+#if defined(MBEDTLS_SSL_ALPN)
+/*
+ * ssl_tls13_write_alpn_ext( ) structure:
+ *
+ * opaque ProtocolName<1..2^8-1>;
+ *
+ * struct {
+ * ProtocolName protocol_name_list<2..2^16-1>
+ * } ProtocolNameList;
+ *
+ */
+static int ssl_tls13_write_alpn_ext( mbedtls_ssl_context *ssl,
+ unsigned char *buf,
+ const unsigned char *end,
+ size_t *olen )
+{
+ unsigned char *p = buf;
+ size_t alpnlen = 0;
+ const char **cur;
+
+ *olen = 0;
+
+ if( ssl->conf->alpn_list == NULL )
+ return( 0 );
+
+ MBEDTLS_SSL_DEBUG_MSG( 3, ( "client hello, adding alpn extension" ) );
+
+ for( cur = ssl->conf->alpn_list; *cur != NULL; cur++ )
+ alpnlen += strlen( *cur ) + 1;
+
+ MBEDTLS_SSL_CHK_BUF_PTR( p, end, 6 + alpnlen );
+
+ MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_ALPN, p, 0 );
+ p += 2;
+
+ /*
+ * opaque ProtocolName<1..2^8-1>;
+ *
+ * struct {
+ * ProtocolName protocol_name_list<2..2^16-1>
+ * } ProtocolNameList;
+ */
+
+ /* Skip writing extension and list length for now */
+ p += 4;
+
+ for( cur = ssl->conf->alpn_list; *cur != NULL; cur++ )
+ {
+ /*
+ * mbedtls_ssl_conf_set_alpn_protocols() checked that the length of
+ * protocol names is less than 255.
+ */
+ *p = (unsigned char)strlen( *cur );
+ memcpy( p + 1, *cur, *p );
+ p += 1 + *p;
+ }
+
+ *olen = p - buf;
+
+ /* List length = olen - 2 (ext_type) - 2 (ext_len) - 2 (list_len) */
+ MBEDTLS_PUT_UINT16_BE( *olen - 6, buf, 4 );
+
+ /* Extension length = olen - 2 (ext_type) - 2 (ext_len) */
+ MBEDTLS_PUT_UINT16_BE( *olen - 4, buf, 2 );
+
+ return( 0 );
+}
+
+static int ssl_tls13_parse_alpn_ext( mbedtls_ssl_context *ssl,
+ const unsigned char *buf, size_t len )
+{
+ size_t list_len, name_len;
+ const unsigned char *p = buf;
+ const unsigned char *end = buf + len;
+
+ /* If we didn't send it, the server shouldn't send it */
+ if( ssl->conf->alpn_list == NULL )
+ return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+ /*
+ * opaque ProtocolName<1..2^8-1>;
+ *
+ * struct {
+ * ProtocolName protocol_name_list<2..2^16-1>
+ * } ProtocolNameList;
+ *
+ * the "ProtocolNameList" MUST contain exactly one "ProtocolName"
+ */
+
+ /* Min length is 2 ( list_len ) + 1 ( name_len ) + 1 ( name ) */
+ MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 4 );
+
+ list_len = MBEDTLS_GET_UINT16_BE( p, 0 );
+ p += 2;
+ MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, list_len );
+
+ name_len = *p++;
+ MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, list_len - 1 );
+
+ /* Check that the server chosen protocol was in our list and save it */
+ for ( const char **alpn = ssl->conf->alpn_list; *alpn != NULL; alpn++ )
+ {
+ if( name_len == strlen( *alpn ) &&
+ memcmp( buf + 3, *alpn, name_len ) == 0 )
+ {
+ ssl->alpn_chosen = *alpn;
+ return( 0 );
+ }
+ }
+
+ return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+}
+#endif /* MBEDTLS_SSL_ALPN */
+
#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
static int ssl_tls13_reset_key_share( mbedtls_ssl_context *ssl )
@@ -753,6 +867,13 @@
return( ret );
p += output_len;
+#if defined(MBEDTLS_SSL_ALPN)
+ ssl_tls13_write_alpn_ext( ssl, p, end, &output_len );
+ if( ret != 0 )
+ return( ret );
+ p += output_len;
+#endif /* MBEDTLS_SSL_ALPN */
+
#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
/*
@@ -1622,6 +1743,17 @@
MBEDTLS_SSL_DEBUG_MSG( 3, ( "found extensions supported groups" ) );
break;
+#if defined(MBEDTLS_SSL_ALPN)
+ case MBEDTLS_TLS_EXT_ALPN:
+ MBEDTLS_SSL_DEBUG_MSG( 3, ( "found alpn extension" ) );
+
+ if( ( ret = ssl_tls13_parse_alpn_ext( ssl, p, (size_t)extension_data_len ) ) != 0 )
+ {
+ return( ret );
+ }
+
+ break;
+#endif /* MBEDTLS_SSL_ALPN */
default:
MBEDTLS_SSL_DEBUG_MSG(
3, ( "unsupported extension found: %u ", extension_type) );