Add server side end-of-early-data handler

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 40d51d8..e69b091 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -2779,6 +2779,34 @@
         return ret;
     }
 
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+    if (ssl->early_data_status == MBEDTLS_SSL_EARLY_DATA_STATUS_ACCEPTED) {
+        /* TODO: compute early transform here?
+         *
+         * RFC 8446, section A.2
+         *                      | Send Finished
+         *                      | K_send = application
+         *             +--------+--------+
+         *    No 0-RTT |                 | 0-RTT
+         *             |                 |
+         *             |                 | K_recv = early data
+         *             |    +------> WAIT_EOED -+
+         *
+         * early transform is set after server finished in this section. But
+         * it breaks our key computation, so we put early transform computation
+         * at the end of client hello. For time being, I am not sure the benifit
+         * for moving computation here.
+         */
+        MBEDTLS_SSL_DEBUG_MSG(
+            1, ("Switch to early keys for inbound traffic. "
+                "( K_recv = early data )"));
+        mbedtls_ssl_set_inbound_transform(
+            ssl, ssl->handshake->transform_earlydata);
+        mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_END_OF_EARLY_DATA);
+        return 0;
+    }
+#endif /* MBEDTLS_SSL_EARLY_DATA */
+
     MBEDTLS_SSL_DEBUG_MSG(1, ("Switch to handshake keys for inbound traffic"));
     mbedtls_ssl_set_inbound_transform(ssl, ssl->handshake->transform_handshake);
 
@@ -2818,6 +2846,98 @@
     return 0;
 }
 
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+/*
+ * Handler for MBEDTLS_SSL_END_OF_EARLY_DATA( WAIT_EOED )
+ *
+ * RFC 8446 section A.2
+ *
+ *                 |
+ *    +------> WAIT_EOED -+
+ *    |       Recv |      | Recv EndOfEarlyData
+ *    | early data |      | K_recv = handshake
+ *    +------------+      |
+ *                        |
+ *  WAIT_FLIGHT2 <--------+
+ *                                 |
+ */
+MBEDTLS_CHECK_RETURN_CRITICAL
+static int ssl_tls13_process_wait_eoed(mbedtls_ssl_context *ssl)
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+
+    MBEDTLS_SSL_DEBUG_MSG(2, ("=> ssl_tls13_process_wait_eoed"));
+
+    if ((ret = mbedtls_ssl_read_record(ssl, 0)) != 0) {
+        MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_read_record", ret);
+        return ret;
+    }
+
+    /* RFC 8446 section 4.5
+     *
+     * struct {} EndOfEarlyData;
+     */
+    if (ssl->in_msgtype == MBEDTLS_SSL_MSG_HANDSHAKE        &&
+        ssl->in_msg[0]  == MBEDTLS_SSL_HS_END_OF_EARLY_DATA) {
+        MBEDTLS_SSL_DEBUG_MSG(
+            1, ("Switch to handshake keys for inbound traffic"
+                "( K_recv = handshake )"));
+        mbedtls_ssl_set_inbound_transform(ssl, handshake->transform_handshake);
+        mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_WAIT_FLIGHT2);
+
+        ret = mbedtls_ssl_add_hs_hdr_to_checksum(
+            ssl, MBEDTLS_SSL_HS_END_OF_EARLY_DATA, 0);
+        if (0 != ret) {
+            MBEDTLS_SSL_DEBUG_RET(
+                1, ("mbedtls_ssl_add_hs_hdr_to_checksum"), ret);
+        }
+
+        goto cleanup;
+
+    }
+
+    /* RFC 8446 section 2.3 figure 4
+     *
+     * 0-RTT data is sent via application data message.
+     */
+    ret = MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE;
+    if (ssl->in_msgtype != MBEDTLS_SSL_MSG_APPLICATION_DATA) {
+        MBEDTLS_SSL_DEBUG_MSG(
+            2, ("Unexpected message type %d", ssl->in_msgtype));
+        goto cleanup;
+    }
+
+    /*
+     * Output early data
+     *
+     * For time being, we print received data via debug message.
+     *
+     * TODO: Remove it when `mbedtls_ssl_read_early_data` is ready.
+     */
+    ssl->in_msg[ssl->in_msglen] = 0;
+    MBEDTLS_SSL_DEBUG_MSG(3, ("\n%s", ssl->in_msg));
+
+    /* RFC 8446 section 4.6.1
+     *
+     * A server receiving more than max_early_data_size bytes of 0-RTT data
+     * SHOULD terminate the connection with an "unexpected_message" alert.
+     *
+     * TODO: Add received data size check here.
+     */
+
+    ret = 0;
+
+cleanup:
+    if (ret == MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE) {
+        MBEDTLS_SSL_PEND_FATAL_ALERT(MBEDTLS_SSL_ALERT_MSG_UNEXPECTED_MESSAGE,
+                                     MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE);
+    }
+    MBEDTLS_SSL_DEBUG_MSG(2, ("<= ssl_tls13_process_wait_eoed"));
+    return ret;
+}
+#endif /* MBEDTLS_SSL_EARLY_DATA */
+
 /*
  * Handler for MBEDTLS_SSL_CLIENT_FINISHED
  */
@@ -3262,6 +3382,12 @@
             ret = ssl_tls13_process_wait_flight2(ssl);
             break;
 
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+        case MBEDTLS_SSL_END_OF_EARLY_DATA:
+            ret = ssl_tls13_process_wait_eoed(ssl);
+            break;
+#endif /* MBEDTLS_SSL_EARLY_DATA */
+
         case MBEDTLS_SSL_CLIENT_FINISHED:
             ret = ssl_tls13_process_client_finished(ssl);
             break;