CID update to RFC 9146

The DTLS 1.2 CID specification has been published as RFC 9146. This PR updates the implementation to match the RFC content.

Signed-off-by: Hannes Tschofenig <hannes.tschofenig@arm.com>
diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index dbef29b..ecf7a2b 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -388,30 +388,80 @@
 }
 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID || MBEDTLS_SSL_PROTO_TLS1_3 */
 
-/* `add_data` must have size 13 Bytes if the CID extension is disabled,
- * and 13 + 1 + CID-length Bytes if the CID extension is enabled. */
+/* The size of the `add_data` structure depends on various
+ * factors, namely
+ *
+ * 1) CID functionality disabled
+ *
+ * additional_data =
+ *    8:                    seq_num +
+ *    1:                       type +
+ *    2:                    version +
+ *    2:  length of inner plaintext +
+ *
+ * size = 13 bytes
+ *
+ * 2) CID functionality based on RFC 9146 enabled
+ *
+ * size = 8 + 1 + 1 + 1 + 2 + 2 + 6 + 2 + CID-length
+ *      = 23 + CID-length
+ *
+ * 3) CID functionality based on legacy CID version
+    according to draft-ietf-tls-dtls-connection-id-05
+ *  https://tools.ietf.org/html/draft-ietf-tls-dtls-connection-id-05
+ *
+ * size = 13 + 1 + CID-length
+ *
+ * More information about the CID usage:
+ *
+ * Per Section 5.3 of draft-ietf-tls-dtls-connection-id-05 the
+ * size of the additional data structure is calculated as:
+ *
+ * additional_data =
+ *    8:                    seq_num +
+ *    1:                  tls12_cid +
+ *    2:     DTLSCipherText.version +
+ *    n:                        cid +
+ *    1:                 cid_length +
+ *    2: length_of_DTLSInnerPlaintext
+ *
+ * Per RFC 9146 the size of the add_data structure is calculated as:
+ *
+ * additional_data =
+ *    8:        seq_num_placeholder +
+ *    1:                  tls12_cid +
+ *    1:                 cid_length +
+ *    1:                  tls12_cid +
+ *    2:     DTLSCiphertext.version +
+ *    2:                      epoch +
+ *    6:            sequence_number +
+ *    n:                        cid +
+ *    2: length_of_DTLSInnerPlaintext
+ *
+ */
 static void ssl_extract_add_data_from_record( unsigned char* add_data,
                                               size_t *add_data_len,
                                               mbedtls_record *rec,
                                               mbedtls_ssl_protocol_version
-                                                tls_version,
+                                              tls_version,
                                               size_t taglen )
 {
-    /* Quoting RFC 5246 (TLS 1.2):
+    /* Several types of ciphers have been defined for use with TLS and DTLS,
+     * and the MAC calculations for those ciphers differ slightly. Further
+     * variants were added when the CID functionality was added with RFC 9146.
+     * This implementations also considers the use of a legacy version of the
+     * CID specification published in draft-ietf-tls-dtls-connection-id-05,
+     * which is used in deployments.
+     *
+     * We will distinguish between the non-CID and the CID cases below.
+     *
+     * --- Non-CID cases ---
+     *
+     * Quoting RFC 5246 (TLS 1.2):
      *
      *    additional_data = seq_num + TLSCompressed.type +
      *                      TLSCompressed.version + TLSCompressed.length;
      *
-     * For the CID extension, this is extended as follows
-     * (quoting draft-ietf-tls-dtls-connection-id-05,
-     *  https://tools.ietf.org/html/draft-ietf-tls-dtls-connection-id-05):
-     *
-     *       additional_data = seq_num + DTLSPlaintext.type +
-     *                         DTLSPlaintext.version +
-     *                         cid +
-     *                         cid_length +
-     *                         length_of_DTLSInnerPlaintext;
-     *
      * For TLS 1.3, the record sequence number is dropped from the AAD
      * and encoded within the nonce of the AEAD operation instead.
      * Moreover, the additional data involves the length of the TLS
@@ -427,11 +477,72 @@
      *
      *     TLSCiphertext.length = TLSInnerPlaintext.length + taglen.
      *
-     */
+     * --- CID cases ---
+     *
+     * RFC 9146 uses a common pattern when constructing the data
+     * passed into a MAC / AEAD cipher.
+     *
+     * Data concatenation for MACs used with block ciphers with
+     * Encrypt-then-MAC Processing (with CID):
+     *
+     *  data = seq_num_placeholder +
+     *         tls12_cid +
+     *         cid_length +
+     *         tls12_cid +
+     *         DTLSCiphertext.version +
+     *         epoch +
+     *         sequence_number +
+     *         cid +
+     *         DTLSCiphertext.length +
+     *         IV +
+     *         ENC(content + padding + padding_length)
+     *
+     * Data concatenation for MACs used with block ciphers (with CID):
+     *
+     *  data =  seq_num_placeholder +
+     *          tls12_cid +
+     *          cid_length +
+     *          tls12_cid +
+     *          DTLSCiphertext.version +
+     *          epoch +
+     *          sequence_number +
+     *          cid +
+     *          length_of_DTLSInnerPlaintext +
+     *          DTLSInnerPlaintext.content +
+     *          DTLSInnerPlaintext.real_type +
+     *          DTLSInnerPlaintext.zeros
+     *
+     * AEAD ciphers use the following additional data calculation (with CIDs):
+     *
+     *     additional_data = seq_num_placeholder +
+     *                tls12_cid +
+     *                cid_length +
+     *                tls12_cid +
+     *                DTLSCiphertext.version +
+     *                epoch +
+     *                sequence_number +
+     *                cid +
+     *                length_of_DTLSInnerPlaintext
+     *
+     * Section 5.3 of draft-ietf-tls-dtls-connection-id-05 (for legacy CID use)
+     * defines the additional data calculation as follows:
+     *
+     *     additional_data = seq_num +
+     *                tls12_cid +
+     *                DTLSCipherText.version +
+     *                cid +
+     *                cid_length +
+     *                length_of_DTLSInnerPlaintext
+    */
 
     unsigned char *cur = add_data;
     size_t ad_len_field = rec->data_len;
 
+#if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) && \
+    MBEDTLS_SSL_DTLS_CONNECTION_ID_COMPAT == 0
+    const unsigned char seq_num_placeholder[] = { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff };
+#endif
+
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
     if( tls_version == MBEDTLS_SSL_VERSION_TLS1_3 )
     {
@@ -445,25 +556,78 @@
     {
         ((void) tls_version);
         ((void) taglen);
-        memcpy( cur, rec->ctr, sizeof( rec->ctr ) );
-        cur += sizeof( rec->ctr );
+
+#if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
+
+#if MBEDTLS_SSL_DTLS_CONNECTION_ID_COMPAT == 0
+        if( rec->cid_len != 0 )
+        {
+            // seq_num_placeholder
+            memcpy( cur, seq_num_placeholder, sizeof(seq_num_placeholder) );
+            cur += sizeof( seq_num_placeholder );
+
+            // tls12_cid type
+            *cur = rec->type;
+            cur++;
+
+            // cid_length
+            *cur = rec->cid_len;
+            cur++;
+        }
+        else
+        {
+            // epoch + sequence number
+            memcpy( cur, rec->ctr, sizeof( rec->ctr ) );
+            cur += sizeof( rec->ctr );
+        }
+#endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID_COMPAT == 0 */
+#else
+        // epoch + sequence number
+        memcpy(cur, rec->ctr, sizeof(rec->ctr));
+        cur += sizeof(rec->ctr);
+#endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
     }
 
+    // type
     *cur = rec->type;
     cur++;
 
+    // version
     memcpy( cur, rec->ver, sizeof( rec->ver ) );
     cur += sizeof( rec->ver );
 
-#if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
-    if( rec->cid_len != 0 )
+#if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) && \
+    MBEDTLS_SSL_DTLS_CONNECTION_ID_COMPAT == 1
+
+    if (rec->cid_len != 0)
     {
-        memcpy( cur, rec->cid, rec->cid_len );
+        // CID
+        memcpy(cur, rec->cid, rec->cid_len);
         cur += rec->cid_len;
 
+        // cid_length
         *cur = rec->cid_len;
         cur++;
 
+        // length of inner plaintext
+        MBEDTLS_PUT_UINT16_BE(ad_len_field, cur, 0);
+        cur += 2;
+    }
+    else
+#elif defined(MBEDTLS_SSL_DTLS_CONNECTION_ID) && \
+    MBEDTLS_SSL_DTLS_CONNECTION_ID_COMPAT == 0
+
+    if( rec->cid_len != 0 )
+    {
+        // epoch + sequence number
+        memcpy(cur, rec->ctr, sizeof(rec->ctr));
+        cur += sizeof(rec->ctr);
+
+        // CID
+        memcpy( cur, rec->cid, rec->cid_len );
+        cur += rec->cid_len;
+
+        // length of inner plaintext
         MBEDTLS_PUT_UINT16_BE( ad_len_field, cur, 0 );
         cur += 2;
     }
@@ -538,7 +702,14 @@
     mbedtls_ssl_mode_t ssl_mode;
     int auth_done = 0;
     unsigned char * data;
-    unsigned char add_data[13 + 1 + MBEDTLS_SSL_CID_OUT_LEN_MAX ];
+    /* For an explanation of the additional data length see
+    * the descrpition of ssl_extract_add_data_from_record().
+    */
+#if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
+    unsigned char add_data[23 + MBEDTLS_SSL_CID_OUT_LEN_MAX];
+#else
+    unsigned char add_data[13];
+#endif
     size_t add_data_len;
     size_t post_avail;
 
@@ -1021,13 +1192,7 @@
             size_t sign_mac_length = 0;
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 
-            /*
-             * MAC(MAC_write_key, seq_num +
-             *     TLSCipherText.type +
-             *     TLSCipherText.version +
-             *     length_of( (IV +) ENC(...) ) +
-             *     IV +
-             *     ENC(content + padding + padding_length));
+            /* MAC(MAC_write_key, add_data, IV, ENC(content + padding + padding_length))
              */
 
             if( post_avail < transform->maclen)
@@ -1133,7 +1298,14 @@
     size_t padlen = 0, correct = 1;
 #endif
     unsigned char* data;
-    unsigned char add_data[13 + 1 + MBEDTLS_SSL_CID_IN_LEN_MAX ];
+    /* For an explanation of the additional data length see
+    * the descrpition of ssl_extract_add_data_from_record().
+    */
+#if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
+    unsigned char add_data[23 + MBEDTLS_SSL_CID_IN_LEN_MAX];
+#else
+    unsigned char add_data[13];
+#endif
     size_t add_data_len;
 
 #if !defined(MBEDTLS_DEBUG_C)
@@ -3487,7 +3659,7 @@
     {
         /* Shift pointers to account for record header including CID
          * struct {
-         *   ContentType special_type = tls12_cid;
+         *   ContentType outer_type = tls12_cid;
          *   ProtocolVersion version;
          *   uint16 epoch;
          *   uint48 sequence_number;