Add ALPN information in session tickets
Signed-off-by: Waleed Elmelegy <waleed.elmelegy@arm.com>
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 681ccab..d7d26ab 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -3735,7 +3735,25 @@
#if defined(MBEDTLS_SSL_PROTO_TLS1_3)
/* Serialization of TLS 1.3 sessions:
*
- * For more detail, see the description of ssl_session_save().
+ * struct {
+ * opaque hostname<0..2^16-1>;
+ * uint64 ticket_reception_time;
+ * uint32 ticket_lifetime;
+ * opaque ticket<1..2^16-1>;
+ * } ClientOnlyData;
+ *
+ * struct {
+ * uint32 ticket_age_add;
+ * uint8 ticket_flags;
+ * opaque resumption_key<0..255>;
+ * uint32 max_early_data_size;
+ * uint16 record_size_limit;
+ * select ( endpoint ) {
+ * case client: ClientOnlyData;
+ * case server: uint64 ticket_creation_time;
+ * };
+ * } serialized_session_tls13;
+ *
*/
#if defined(MBEDTLS_SSL_SESSION_TICKETS)
MBEDTLS_CHECK_RETURN_CRITICAL
@@ -3750,9 +3768,16 @@
size_t hostname_len = (session->hostname == NULL) ?
0 : strlen(session->hostname) + 1;
#endif
+
+#if defined(MBEDTLS_SSL_SRV_C) && \
+ defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
+ const uint8_t alpn_len = (session->alpn == NULL) ?
+ 0 : (uint8_t) strlen(session->alpn) + 1;
+#endif
size_t needed = 4 /* ticket_age_add */
+ 1 /* ticket_flags */
+ 1; /* resumption_key length */
+
*olen = 0;
if (session->resumption_key_len > MBEDTLS_SSL_TLS1_3_TICKET_RESUMPTION_KEY_LEN) {
@@ -3771,6 +3796,15 @@
needed += 8; /* ticket_creation_time or ticket_reception_time */
#endif
+#if defined(MBEDTLS_SSL_SRV_C)
+ if (session->endpoint == MBEDTLS_SSL_IS_SERVER) {
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
+ needed += 1 /* alpn_len */
+ + alpn_len; /* alpn */
+#endif
+ }
+#endif /* MBEDTLS_SSL_SRV_C */
+
#if defined(MBEDTLS_SSL_CLI_C)
if (session->endpoint == MBEDTLS_SSL_IS_CLIENT) {
#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
@@ -3813,13 +3847,24 @@
p += 2;
#endif /* MBEDTLS_SSL_RECORD_SIZE_LIMIT */
-#if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C)
+#if defined(MBEDTLS_SSL_SRV_C)
if (session->endpoint == MBEDTLS_SSL_IS_SERVER) {
+#if defined(MBEDTLS_HAVE_TIME)
MBEDTLS_PUT_UINT64_BE((uint64_t) session->ticket_creation_time, p, 0);
p += 8;
- }
#endif /* MBEDTLS_HAVE_TIME */
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
+ *p++ = alpn_len;
+ if (alpn_len > 0) {
+ /* save chosen alpn */
+ memcpy(p, session->alpn, alpn_len);
+ p += alpn_len;
+ }
+#endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */
+ }
+#endif /* MBEDTLS_SSL_SRV_C */
+
#if defined(MBEDTLS_SSL_CLI_C)
if (session->endpoint == MBEDTLS_SSL_IS_CLIENT) {
#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
@@ -3894,16 +3939,39 @@
p += 2;
#endif /* MBEDTLS_SSL_RECORD_SIZE_LIMIT */
-#if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C)
+#if defined(MBEDTLS_SSL_SRV_C)
if (session->endpoint == MBEDTLS_SSL_IS_SERVER) {
+#if defined(MBEDTLS_HAVE_TIME)
if (end - p < 8) {
return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
}
session->ticket_creation_time = MBEDTLS_GET_UINT64_BE(p, 0);
p += 8;
- }
#endif /* MBEDTLS_HAVE_TIME */
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
+ uint8_t alpn_len;
+
+ if (end - p < 1) {
+ return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
+ }
+ alpn_len = *p++;
+
+ if (end - p < alpn_len) {
+ return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
+ }
+ if (alpn_len > 0) {
+ session->alpn = mbedtls_calloc(alpn_len, sizeof(char));
+ if (session->alpn == NULL) {
+ return MBEDTLS_ERR_SSL_ALLOC_FAILED;
+ }
+ memcpy(session->alpn, p, alpn_len);
+ p += alpn_len;
+ }
+#endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */
+ }
+#endif /* MBEDTLS_SSL_SRV_C */
+
#if defined(MBEDTLS_SSL_CLI_C)
if (session->endpoint == MBEDTLS_SSL_IS_CLIENT) {
#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
@@ -4849,6 +4917,10 @@
defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
mbedtls_free(session->hostname);
#endif
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) && \
+ defined(MBEDTLS_SSL_SRV_C)
+ mbedtls_free(session->alpn);
+#endif
mbedtls_free(session->ticket);
#endif
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 887c5c6..291d645 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -467,7 +467,17 @@
#if defined(MBEDTLS_SSL_EARLY_DATA)
dst->max_early_data_size = src->max_early_data_size;
-#endif
+
+#if defined(MBEDTLS_SSL_ALPN)
+ if (src->alpn != NULL) {
+ dst->alpn = mbedtls_calloc(strlen(src->alpn) + 1, sizeof(char));
+ if (dst->alpn == NULL) {
+ return MBEDTLS_ERR_SSL_ALLOC_FAILED;
+ }
+ memcpy(dst->alpn, src->alpn, strlen(src->alpn) + 1);
+ }
+#endif /* MBEDTLS_SSL_ALPN */
+#endif /* MBEDTLS_SSL_EARLY_DATA*/
return 0;
}
@@ -3137,6 +3147,16 @@
MBEDTLS_SSL_PRINT_TICKET_FLAGS(4, session->ticket_flags);
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
+ if (ssl->alpn_chosen != NULL) {
+ session->alpn = mbedtls_calloc(strlen(ssl->alpn_chosen) + 1, sizeof(char));
+ if (session->alpn == NULL) {
+ return MBEDTLS_ERR_SSL_ALLOC_FAILED;
+ }
+ memcpy(session->alpn, ssl->alpn_chosen, strlen(ssl->alpn_chosen) + 1);
+ }
+#endif
+
/* Generate ticket_age_add */
if ((ret = ssl->conf->f_rng(ssl->conf->p_rng,
(unsigned char *) &session->ticket_age_add,