Merge pull request #8858 from waleed-elmelegy-arm/add_alpn_to_session

Add ALPN information in session tickets
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index 39bea79..57d7bc6 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -1304,6 +1304,11 @@
     char *MBEDTLS_PRIVATE(hostname);             /*!< host name binded with tickets */
 #endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION && MBEDTLS_SSL_CLI_C */
 
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C)
+    char *ticket_alpn;                      /*!< ALPN negotiated in the session
+                                                 during which the ticket was generated. */
+#endif
+
 #if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_CLI_C)
     /*! Time in milliseconds when the last ticket was received. */
     mbedtls_ms_time_t MBEDTLS_PRIVATE(ticket_reception_time);
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 2ec898b..a8807f6 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2852,6 +2852,13 @@
                                      const char *hostname);
 #endif
 
+#if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_EARLY_DATA) && \
+    defined(MBEDTLS_SSL_ALPN)
+MBEDTLS_CHECK_RETURN_CRITICAL
+int mbedtls_ssl_session_set_ticket_alpn(mbedtls_ssl_session *session,
+                                        const char *alpn);
+#endif
+
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3) && defined(MBEDTLS_SSL_SESSION_TICKETS)
 
 #define MBEDTLS_SSL_TLS1_3_MAX_ALLOWED_TICKET_LIFETIME (604800)
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 681ccab..ac8e0ea 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -3750,9 +3750,16 @@
     size_t hostname_len = (session->hostname == NULL) ?
                           0 : strlen(session->hostname) + 1;
 #endif
+
+#if defined(MBEDTLS_SSL_SRV_C) && \
+    defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
+    const size_t alpn_len = (session->ticket_alpn == NULL) ?
+                            0 : strlen(session->ticket_alpn) + 1;
+#endif
     size_t needed =   4  /* ticket_age_add */
                     + 1  /* ticket_flags */
                     + 1; /* resumption_key length */
+
     *olen = 0;
 
     if (session->resumption_key_len > MBEDTLS_SSL_TLS1_3_TICKET_RESUMPTION_KEY_LEN) {
@@ -3771,6 +3778,15 @@
     needed += 8; /* ticket_creation_time or ticket_reception_time */
 #endif
 
+#if defined(MBEDTLS_SSL_SRV_C)
+    if (session->endpoint == MBEDTLS_SSL_IS_SERVER) {
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
+        needed +=   2                         /* alpn_len */
+                  + alpn_len;                 /* alpn */
+#endif
+    }
+#endif /* MBEDTLS_SSL_SRV_C */
+
 #if defined(MBEDTLS_SSL_CLI_C)
     if (session->endpoint == MBEDTLS_SSL_IS_CLIENT) {
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
@@ -3813,13 +3829,26 @@
     p += 2;
 #endif /* MBEDTLS_SSL_RECORD_SIZE_LIMIT */
 
-#if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C)
+#if defined(MBEDTLS_SSL_SRV_C)
     if (session->endpoint == MBEDTLS_SSL_IS_SERVER) {
+#if defined(MBEDTLS_HAVE_TIME)
         MBEDTLS_PUT_UINT64_BE((uint64_t) session->ticket_creation_time, p, 0);
         p += 8;
-    }
 #endif /* MBEDTLS_HAVE_TIME */
 
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
+        MBEDTLS_PUT_UINT16_BE(alpn_len, p, 0);
+        p += 2;
+
+        if (alpn_len > 0) {
+            /* save chosen alpn */
+            memcpy(p, session->ticket_alpn, alpn_len);
+            p += alpn_len;
+        }
+#endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */
+    }
+#endif /* MBEDTLS_SSL_SRV_C */
+
 #if defined(MBEDTLS_SSL_CLI_C)
     if (session->endpoint == MBEDTLS_SSL_IS_CLIENT) {
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
@@ -3894,16 +3923,41 @@
     p += 2;
 #endif /* MBEDTLS_SSL_RECORD_SIZE_LIMIT */
 
-#if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C)
+#if  defined(MBEDTLS_SSL_SRV_C)
     if (session->endpoint == MBEDTLS_SSL_IS_SERVER) {
+#if defined(MBEDTLS_HAVE_TIME)
         if (end - p < 8) {
             return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
         }
         session->ticket_creation_time = MBEDTLS_GET_UINT64_BE(p, 0);
         p += 8;
-    }
 #endif /* MBEDTLS_HAVE_TIME */
 
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
+        size_t alpn_len;
+
+        if (end - p < 2) {
+            return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
+        }
+
+        alpn_len = MBEDTLS_GET_UINT16_BE(p, 0);
+        p += 2;
+
+        if (end - p < (long int) alpn_len) {
+            return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
+        }
+
+        if (alpn_len > 0) {
+            int ret = mbedtls_ssl_session_set_ticket_alpn(session, (char *) p);
+            if (ret != 0) {
+                return ret;
+            }
+            p += alpn_len;
+        }
+#endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */
+    }
+#endif /* MBEDTLS_SSL_SRV_C */
+
 #if defined(MBEDTLS_SSL_CLI_C)
     if (session->endpoint == MBEDTLS_SSL_IS_CLIENT) {
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
@@ -4060,6 +4114,13 @@
 #define SSL_SERIALIZED_SESSION_CONFIG_RECORD_SIZE 0
 #endif /* MBEDTLS_SSL_RECORD_SIZE_LIMIT */
 
+#if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C) && \
+    defined(MBEDTLS_SSL_EARLY_DATA)
+#define SSL_SERIALIZED_SESSION_CONFIG_ALPN 1
+#else
+#define SSL_SERIALIZED_SESSION_CONFIG_ALPN 0
+#endif /* MBEDTLS_SSL_ALPN */
+
 #define SSL_SERIALIZED_SESSION_CONFIG_TIME_BIT          0
 #define SSL_SERIALIZED_SESSION_CONFIG_CRT_BIT           1
 #define SSL_SERIALIZED_SESSION_CONFIG_CLIENT_TICKET_BIT 2
@@ -4070,6 +4131,7 @@
 #define SSL_SERIALIZED_SESSION_CONFIG_SNI_BIT           7
 #define SSL_SERIALIZED_SESSION_CONFIG_EARLY_DATA_BIT    8
 #define SSL_SERIALIZED_SESSION_CONFIG_RECORD_SIZE_BIT   9
+#define SSL_SERIALIZED_SESSION_CONFIG_ALPN_BIT          10
 
 #define SSL_SERIALIZED_SESSION_CONFIG_BITFLAG                           \
     ((uint16_t) (                                                      \
@@ -4086,7 +4148,9 @@
          (SSL_SERIALIZED_SESSION_CONFIG_EARLY_DATA << \
              SSL_SERIALIZED_SESSION_CONFIG_EARLY_DATA_BIT) | \
          (SSL_SERIALIZED_SESSION_CONFIG_RECORD_SIZE << \
-             SSL_SERIALIZED_SESSION_CONFIG_RECORD_SIZE_BIT)))
+             SSL_SERIALIZED_SESSION_CONFIG_RECORD_SIZE_BIT) | \
+         (SSL_SERIALIZED_SESSION_CONFIG_ALPN << \
+             SSL_SERIALIZED_SESSION_CONFIG_ALPN_BIT)))
 
 static const unsigned char ssl_serialized_session_header[] = {
     MBEDTLS_VERSION_MAJOR,
@@ -4161,8 +4225,12 @@
  * #endif
  *    select ( endpoint ) {
  *         case client: ClientOnlyData;
+ *         case server:
  * #if defined(MBEDTLS_HAVE_TIME)
- *         case server: uint64 ticket_creation_time;
+ *                      uint64 ticket_creation_time;
+ * #endif
+ * #if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
+ *                      opaque ticket_alpn<0..256>;
  * #endif
  *     };
  * } serialized_session_tls13;
@@ -4852,6 +4920,11 @@
     mbedtls_free(session->ticket);
 #endif
 
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) && \
+    defined(MBEDTLS_SSL_SRV_C)
+    mbedtls_free(session->ticket_alpn);
+#endif
+
     mbedtls_platform_zeroize(session, sizeof(mbedtls_ssl_session));
 }
 
@@ -9798,4 +9871,36 @@
           MBEDTLS_SSL_SERVER_NAME_INDICATION &&
           MBEDTLS_SSL_CLI_C */
 
+#if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_EARLY_DATA) && \
+    defined(MBEDTLS_SSL_ALPN)
+int mbedtls_ssl_session_set_ticket_alpn(mbedtls_ssl_session *session,
+                                        const char *alpn)
+{
+    size_t alpn_len = 0;
+
+    if (alpn != NULL) {
+        alpn_len = strlen(alpn);
+
+        if (alpn_len > MBEDTLS_SSL_MAX_ALPN_NAME_LEN) {
+            return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
+        }
+    }
+
+    if (session->ticket_alpn != NULL) {
+        mbedtls_zeroize_and_free(session->ticket_alpn,
+                                 strlen(session->ticket_alpn));
+        session->ticket_alpn = NULL;
+    }
+
+    if (alpn != NULL) {
+        session->ticket_alpn = mbedtls_calloc(alpn_len + 1, 1);
+        if (session->ticket_alpn == NULL) {
+            return MBEDTLS_ERR_SSL_ALLOC_FAILED;
+        }
+        memcpy(session->ticket_alpn, alpn, alpn_len);
+    }
+
+    return 0;
+}
+#endif /* MBEDTLS_SSL_SRV_C && MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */
 #endif /* MBEDTLS_SSL_TLS_C */
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 203aa17..2c30da8 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -467,7 +467,14 @@
 
 #if defined(MBEDTLS_SSL_EARLY_DATA)
     dst->max_early_data_size = src->max_early_data_size;
-#endif
+
+#if defined(MBEDTLS_SSL_ALPN)
+    int ret = mbedtls_ssl_session_set_ticket_alpn(dst, src->ticket_alpn);
+    if (ret != 0) {
+        return ret;
+    }
+#endif /* MBEDTLS_SSL_ALPN */
+#endif /* MBEDTLS_SSL_EARLY_DATA*/
 
     return 0;
 }
@@ -3137,6 +3144,15 @@
 
     MBEDTLS_SSL_PRINT_TICKET_FLAGS(4, session->ticket_flags);
 
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
+    if (session->ticket_alpn == NULL) {
+        ret = mbedtls_ssl_session_set_ticket_alpn(session, ssl->alpn_chosen);
+        if (ret != 0) {
+            return ret;
+        }
+    }
+#endif
+
     /* Generate ticket_age_add */
     if ((ret = ssl->conf->f_rng(ssl->conf->p_rng,
                                 (unsigned char *) &session->ticket_age_add,
diff --git a/tests/src/test_helpers/ssl_helpers.c b/tests/src/test_helpers/ssl_helpers.c
index 56e03f1..963938f 100644
--- a/tests/src/test_helpers/ssl_helpers.c
+++ b/tests/src/test_helpers/ssl_helpers.c
@@ -1793,7 +1793,13 @@
 
 #if defined(MBEDTLS_SSL_EARLY_DATA)
     session->max_early_data_size = 0x87654321;
-#endif
+#if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C)
+    int ret = mbedtls_ssl_session_set_ticket_alpn(session, "ALPNExample");
+    if (ret != 0) {
+        return -1;
+    }
+#endif /* MBEDTLS_SSL_ALPN && MBEDTLS_SSL_SRV_C */
+#endif /* MBEDTLS_SSL_EARLY_DATA */
 
 #if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C)
     if (session->endpoint == MBEDTLS_SSL_IS_SERVER) {
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index 7ef5805..2fe4997 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -2104,6 +2104,14 @@
 #if defined(MBEDTLS_SSL_EARLY_DATA)
         TEST_ASSERT(
             original.max_early_data_size == restored.max_early_data_size);
+#if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C)
+        if (endpoint_type == MBEDTLS_SSL_IS_SERVER) {
+            TEST_ASSERT(original.ticket_alpn != NULL);
+            TEST_ASSERT(restored.ticket_alpn != NULL);
+            TEST_MEMORY_COMPARE(original.ticket_alpn, strlen(original.ticket_alpn),
+                                restored.ticket_alpn, strlen(restored.ticket_alpn));
+        }
+#endif
 #endif
 
 #if defined(MBEDTLS_SSL_SESSION_TICKETS) && defined(MBEDTLS_SSL_CLI_C)