Implement cipher_auth_{en,de}crypt_ext()

Work in progress: next step is to test it!

Extract the part that is common with non-ext version to a new internal
function. (We can't just use the non-ext version for that, as it's going to be
deprecated.)

Currently the NIST_KW part is somewhat duplicated between the ext
and non-ext versions, but that's OK because it will soon be removed from the
non-ext version.

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/library/cipher.c b/library/cipher.c
index 853eeec..20e1e7b 100644
--- a/library/cipher.c
+++ b/library/cipher.c
@@ -1288,23 +1288,16 @@
 
 #if defined(MBEDTLS_CIPHER_MODE_AEAD)
 /*
- * Packet-oriented encryption for AEAD modes
+ * Packet-oriented encryption for AEAD modes: internal function shared by
+ * mbedtls_cipher_auth_encrypt() and mbedtls_cipher_auth_encrypt_ext().
  */
-int mbedtls_cipher_auth_encrypt( mbedtls_cipher_context_t *ctx,
+static int mbedtls_cipher_aead_encrypt( mbedtls_cipher_context_t *ctx,
                          const unsigned char *iv, size_t iv_len,
                          const unsigned char *ad, size_t ad_len,
                          const unsigned char *input, size_t ilen,
                          unsigned char *output, size_t *olen,
                          unsigned char *tag, size_t tag_len )
 {
-    CIPHER_VALIDATE_RET( ctx != NULL );
-    CIPHER_VALIDATE_RET( iv != NULL );
-    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
-    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
-    CIPHER_VALIDATE_RET( output != NULL );
-    CIPHER_VALIDATE_RET( olen != NULL );
-    CIPHER_VALIDATE_RET( tag_len == 0 || tag != NULL );
-
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ctx->psa_enabled == 1 )
     {
@@ -1370,44 +1363,21 @@
                                 ilen, iv, ad, ad_len, input, output, tag ) );
     }
 #endif /* MBEDTLS_CHACHAPOLY_C */
-#if defined(MBEDTLS_NIST_KW_C)
-   if( MBEDTLS_MODE_KW == ctx->cipher_info->mode ||
-       MBEDTLS_MODE_KWP == ctx->cipher_info->mode )
-    {
-        mbedtls_nist_kw_mode_t mode = ( MBEDTLS_MODE_KW == ctx->cipher_info->mode ) ?
-                                        MBEDTLS_KW_MODE_KW : MBEDTLS_KW_MODE_KWP;
-
-        /* There is no iv, tag or ad associated with KW and KWP, these length should be 0 */
-        if( iv_len != 0 || tag_len != 0 || ad_len != 0 )
-        {
-            return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
-        }
-
-        return( mbedtls_nist_kw_wrap( ctx->cipher_ctx, mode, input, ilen, output, olen, SIZE_MAX ) );
-    }
-#endif /* MBEDTLS_NIST_KW_C */
 
     return( MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
 }
 
 /*
- * Packet-oriented decryption for AEAD modes
+ * Packet-oriented encryption for AEAD modes: internal function shared by
+ * mbedtls_cipher_auth_encrypt() and mbedtls_cipher_auth_encrypt_ext().
  */
-int mbedtls_cipher_auth_decrypt( mbedtls_cipher_context_t *ctx,
+static int mbedtls_cipher_aead_decrypt( mbedtls_cipher_context_t *ctx,
                          const unsigned char *iv, size_t iv_len,
                          const unsigned char *ad, size_t ad_len,
                          const unsigned char *input, size_t ilen,
                          unsigned char *output, size_t *olen,
                          const unsigned char *tag, size_t tag_len )
 {
-    CIPHER_VALIDATE_RET( ctx != NULL );
-    CIPHER_VALIDATE_RET( iv != NULL );
-    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
-    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
-    CIPHER_VALIDATE_RET( output != NULL );
-    CIPHER_VALIDATE_RET( olen != NULL );
-    CIPHER_VALIDATE_RET( tag_len == 0 || tag != NULL );
-
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ctx->psa_enabled == 1 )
     {
@@ -1495,6 +1465,68 @@
         return( ret );
     }
 #endif /* MBEDTLS_CHACHAPOLY_C */
+
+    return( MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
+}
+
+/*
+ * Packet-oriented encryption for AEAD modes: public function.
+ */
+int mbedtls_cipher_auth_encrypt( mbedtls_cipher_context_t *ctx,
+                         const unsigned char *iv, size_t iv_len,
+                         const unsigned char *ad, size_t ad_len,
+                         const unsigned char *input, size_t ilen,
+                         unsigned char *output, size_t *olen,
+                         unsigned char *tag, size_t tag_len )
+{
+    CIPHER_VALIDATE_RET( ctx != NULL );
+    CIPHER_VALIDATE_RET( iv != NULL );
+    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
+    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
+    CIPHER_VALIDATE_RET( output != NULL );
+    CIPHER_VALIDATE_RET( olen != NULL );
+    CIPHER_VALIDATE_RET( tag_len == 0 || tag != NULL );
+
+#if defined(MBEDTLS_NIST_KW_C)
+    if( MBEDTLS_MODE_KW == ctx->cipher_info->mode ||
+        MBEDTLS_MODE_KWP == ctx->cipher_info->mode )
+    {
+        mbedtls_nist_kw_mode_t mode = ( MBEDTLS_MODE_KW == ctx->cipher_info->mode ) ?
+                                        MBEDTLS_KW_MODE_KW : MBEDTLS_KW_MODE_KWP;
+
+        /* There is no iv, tag or ad associated with KW and KWP, these length should be 0 */
+        if( iv_len != 0 || tag_len != 0 || ad_len != 0 )
+        {
+            return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
+        }
+
+        return( mbedtls_nist_kw_wrap( ctx->cipher_ctx, mode, input, ilen, output, olen, SIZE_MAX ) );
+    }
+#endif /* MBEDTLS_NIST_KW_C */
+
+    return( mbedtls_cipher_aead_encrypt( ctx, iv, iv_len, ad, ad_len,
+                                         input, ilen, output, olen,
+                                         tag, tag_len ) );
+}
+
+/*
+ * Packet-oriented decryption for AEAD modes: public function.
+ */
+int mbedtls_cipher_auth_decrypt( mbedtls_cipher_context_t *ctx,
+                         const unsigned char *iv, size_t iv_len,
+                         const unsigned char *ad, size_t ad_len,
+                         const unsigned char *input, size_t ilen,
+                         unsigned char *output, size_t *olen,
+                         const unsigned char *tag, size_t tag_len )
+{
+    CIPHER_VALIDATE_RET( ctx != NULL );
+    CIPHER_VALIDATE_RET( iv != NULL );
+    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
+    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
+    CIPHER_VALIDATE_RET( output != NULL );
+    CIPHER_VALIDATE_RET( olen != NULL );
+    CIPHER_VALIDATE_RET( tag_len == 0 || tag != NULL );
+
 #if defined(MBEDTLS_NIST_KW_C)
     if( MBEDTLS_MODE_KW == ctx->cipher_info->mode ||
         MBEDTLS_MODE_KWP == ctx->cipher_info->mode )
@@ -1512,8 +1544,108 @@
     }
 #endif /* MBEDTLS_NIST_KW_C */
 
-    return( MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
+    return( mbedtls_cipher_aead_decrypt( ctx, iv, iv_len, ad, ad_len,
+                                         input, ilen, output, olen,
+                                         tag, tag_len ) );
 }
 #endif /* MBEDTLS_CIPHER_MODE_AEAD */
 
+#if defined(MBEDTLS_CIPHER_MODE_AEAD) || defined(MBEDTLS_NIST_KW_C)
+/*
+ * Packet-oriented encryption for AEAD/NIST_KW: public function.
+ */
+int mbedtls_cipher_auth_encrypt_ext( mbedtls_cipher_context_t *ctx,
+                         const unsigned char *iv, size_t iv_len,
+                         const unsigned char *ad, size_t ad_len,
+                         const unsigned char *input, size_t ilen,
+                         unsigned char *output, size_t output_len,
+                         size_t *olen, size_t tag_len )
+{
+    CIPHER_VALIDATE_RET( ctx != NULL );
+    CIPHER_VALIDATE_RET( iv != NULL );
+    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
+    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
+    CIPHER_VALIDATE_RET( output != NULL );
+    CIPHER_VALIDATE_RET( olen != NULL );
+
+#if defined(MBEDTLS_NIST_KW_C)
+    if( MBEDTLS_MODE_KW == ctx->cipher_info->mode ||
+        MBEDTLS_MODE_KWP == ctx->cipher_info->mode )
+    {
+        mbedtls_nist_kw_mode_t mode = ( MBEDTLS_MODE_KW == ctx->cipher_info->mode ) ?
+                                        MBEDTLS_KW_MODE_KW : MBEDTLS_KW_MODE_KWP;
+
+        /* There is no iv, tag or ad associated with KW and KWP,
+         * so these length should be 0 as documented. */
+        if( iv_len != 0 || tag_len != 0 || ad_len != 0 )
+            return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
+
+        return( mbedtls_nist_kw_wrap( ctx->cipher_ctx, mode, input, ilen,
+                                      output, olen, output_len ) );
+    }
+#endif /* MBEDTLS_NIST_KW_C */
+
+#if defined(MBEDTLS_CIPHER_MODE_AEAD)
+    /* AEAD case: check length before passing on to shared function */
+    if( output_len < ilen + tag_len )
+        return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
+
+    int ret = mbedtls_cipher_aead_encrypt( ctx, iv, iv_len, ad, ad_len,
+                                       input, ilen, output, olen,
+                                       output + ilen, tag_len );
+    *olen += tag_len;
+    return( ret );
+#else
+    return( MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
+#endif /* MBEDTLS_CIPHER_MODE_AEAD */
+}
+
+/*
+ * Packet-oriented decryption for AEAD/NIST_KW: public function.
+ */
+int mbedtls_cipher_auth_decrypt_ext( mbedtls_cipher_context_t *ctx,
+                         const unsigned char *iv, size_t iv_len,
+                         const unsigned char *ad, size_t ad_len,
+                         const unsigned char *input, size_t ilen,
+                         unsigned char *output, size_t output_len,
+                         size_t *olen, size_t tag_len )
+{
+    CIPHER_VALIDATE_RET( ctx != NULL );
+    CIPHER_VALIDATE_RET( iv != NULL );
+    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
+    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
+    CIPHER_VALIDATE_RET( output_len == 0 || output != NULL );
+    CIPHER_VALIDATE_RET( olen != NULL );
+
+#if defined(MBEDTLS_NIST_KW_C)
+    if( MBEDTLS_MODE_KW == ctx->cipher_info->mode ||
+        MBEDTLS_MODE_KWP == ctx->cipher_info->mode )
+    {
+        mbedtls_nist_kw_mode_t mode = ( MBEDTLS_MODE_KW == ctx->cipher_info->mode ) ?
+                                        MBEDTLS_KW_MODE_KW : MBEDTLS_KW_MODE_KWP;
+
+        /* There is no iv, tag or ad associated with KW and KWP,
+         * so these length should be 0 as documented. */
+        if( iv_len != 0 || tag_len != 0 || ad_len != 0 )
+            return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
+
+        return( mbedtls_nist_kw_unwrap( ctx->cipher_ctx, mode, input, ilen,
+                                        output, olen, output_len ) );
+    }
+#endif /* MBEDTLS_NIST_KW_C */
+
+#if defined(MBEDTLS_CIPHER_MODE_AEAD)
+    /* AEAD case: check length before passing on to shared function */
+    if( ilen < tag_len || output_len < ilen - tag_len )
+        return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
+
+    return( mbedtls_cipher_aead_decrypt( ctx, iv, iv_len, ad, ad_len,
+                                         input, ilen - tag_len, output, olen,
+                                         input + ilen - tag_len, tag_len ) );
+#else
+    return( MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
+#endif /* MBEDTLS_CIPHER_MODE_AEAD */
+}
+#endif /* MBEDTLS_CIPHER_MODE_AEAD || MBEDTLS_NIST_KW_C */
+
 #endif /* MBEDTLS_CIPHER_C */