Add server side end-of-early-data handler
Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 40d51d8..e69b091 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -2779,6 +2779,34 @@
return ret;
}
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+ if (ssl->early_data_status == MBEDTLS_SSL_EARLY_DATA_STATUS_ACCEPTED) {
+ /* TODO: compute early transform here?
+ *
+ * RFC 8446, section A.2
+ * | Send Finished
+ * | K_send = application
+ * +--------+--------+
+ * No 0-RTT | | 0-RTT
+ * | |
+ * | | K_recv = early data
+ * | +------> WAIT_EOED -+
+ *
+ * early transform is set after server finished in this section. But
+ * it breaks our key computation, so we put early transform computation
+ * at the end of client hello. For time being, I am not sure the benifit
+ * for moving computation here.
+ */
+ MBEDTLS_SSL_DEBUG_MSG(
+ 1, ("Switch to early keys for inbound traffic. "
+ "( K_recv = early data )"));
+ mbedtls_ssl_set_inbound_transform(
+ ssl, ssl->handshake->transform_earlydata);
+ mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_END_OF_EARLY_DATA);
+ return 0;
+ }
+#endif /* MBEDTLS_SSL_EARLY_DATA */
+
MBEDTLS_SSL_DEBUG_MSG(1, ("Switch to handshake keys for inbound traffic"));
mbedtls_ssl_set_inbound_transform(ssl, ssl->handshake->transform_handshake);
@@ -2818,6 +2846,98 @@
return 0;
}
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+/*
+ * Handler for MBEDTLS_SSL_END_OF_EARLY_DATA( WAIT_EOED )
+ *
+ * RFC 8446 section A.2
+ *
+ * |
+ * +------> WAIT_EOED -+
+ * | Recv | | Recv EndOfEarlyData
+ * | early data | | K_recv = handshake
+ * +------------+ |
+ * |
+ * WAIT_FLIGHT2 <--------+
+ * |
+ */
+MBEDTLS_CHECK_RETURN_CRITICAL
+static int ssl_tls13_process_wait_eoed(mbedtls_ssl_context *ssl)
+{
+ int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+ mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+
+ MBEDTLS_SSL_DEBUG_MSG(2, ("=> ssl_tls13_process_wait_eoed"));
+
+ if ((ret = mbedtls_ssl_read_record(ssl, 0)) != 0) {
+ MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_read_record", ret);
+ return ret;
+ }
+
+ /* RFC 8446 section 4.5
+ *
+ * struct {} EndOfEarlyData;
+ */
+ if (ssl->in_msgtype == MBEDTLS_SSL_MSG_HANDSHAKE &&
+ ssl->in_msg[0] == MBEDTLS_SSL_HS_END_OF_EARLY_DATA) {
+ MBEDTLS_SSL_DEBUG_MSG(
+ 1, ("Switch to handshake keys for inbound traffic"
+ "( K_recv = handshake )"));
+ mbedtls_ssl_set_inbound_transform(ssl, handshake->transform_handshake);
+ mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_WAIT_FLIGHT2);
+
+ ret = mbedtls_ssl_add_hs_hdr_to_checksum(
+ ssl, MBEDTLS_SSL_HS_END_OF_EARLY_DATA, 0);
+ if (0 != ret) {
+ MBEDTLS_SSL_DEBUG_RET(
+ 1, ("mbedtls_ssl_add_hs_hdr_to_checksum"), ret);
+ }
+
+ goto cleanup;
+
+ }
+
+ /* RFC 8446 section 2.3 figure 4
+ *
+ * 0-RTT data is sent via application data message.
+ */
+ ret = MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE;
+ if (ssl->in_msgtype != MBEDTLS_SSL_MSG_APPLICATION_DATA) {
+ MBEDTLS_SSL_DEBUG_MSG(
+ 2, ("Unexpected message type %d", ssl->in_msgtype));
+ goto cleanup;
+ }
+
+ /*
+ * Output early data
+ *
+ * For time being, we print received data via debug message.
+ *
+ * TODO: Remove it when `mbedtls_ssl_read_early_data` is ready.
+ */
+ ssl->in_msg[ssl->in_msglen] = 0;
+ MBEDTLS_SSL_DEBUG_MSG(3, ("\n%s", ssl->in_msg));
+
+ /* RFC 8446 section 4.6.1
+ *
+ * A server receiving more than max_early_data_size bytes of 0-RTT data
+ * SHOULD terminate the connection with an "unexpected_message" alert.
+ *
+ * TODO: Add received data size check here.
+ */
+
+ ret = 0;
+
+cleanup:
+ if (ret == MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE) {
+ MBEDTLS_SSL_PEND_FATAL_ALERT(MBEDTLS_SSL_ALERT_MSG_UNEXPECTED_MESSAGE,
+ MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE);
+ }
+ MBEDTLS_SSL_DEBUG_MSG(2, ("<= ssl_tls13_process_wait_eoed"));
+ return ret;
+}
+#endif /* MBEDTLS_SSL_EARLY_DATA */
+
/*
* Handler for MBEDTLS_SSL_CLIENT_FINISHED
*/
@@ -3262,6 +3382,12 @@
ret = ssl_tls13_process_wait_flight2(ssl);
break;
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+ case MBEDTLS_SSL_END_OF_EARLY_DATA:
+ ret = ssl_tls13_process_wait_eoed(ssl);
+ break;
+#endif /* MBEDTLS_SSL_EARLY_DATA */
+
case MBEDTLS_SSL_CLIENT_FINISHED:
ret = ssl_tls13_process_client_finished(ssl);
break;