Merge pull request #781 from mpg/cipher-auth-crypt-restricted

Fix buffer overflow with NIST-KW in cipher layer
diff --git a/ChangeLog.d/cipher-auth-crypt-nist-kw.txt b/ChangeLog.d/cipher-auth-crypt-nist-kw.txt
new file mode 100644
index 0000000..63519a1
--- /dev/null
+++ b/ChangeLog.d/cipher-auth-crypt-nist-kw.txt
@@ -0,0 +1,22 @@
+API changes
+   * The functions mbedtls_cipher_auth_encrypt() and
+     mbedtls_cipher_auth_decrypt() no longer accept NIST_KW contexts,
+     as they have no way to check if the output buffer is large enough.
+     Please use mbedtls_cipher_auth_encrypt_ext() and
+     mbedtls_cipher_auth_decrypt_ext() instead.
+
+Security
+   * The functions mbedtls_cipher_auth_encrypt() and
+     mbedtls_cipher_auth_decrypt() would write past the minimum documented
+     size of the output buffer when used with NIST_KW. As a result, code using
+     those functions as documented with NIST_KW could have a buffer overwrite
+     of up to 15 bytes, with consequences ranging up to arbitrary code
+     execution depending on the location of the output buffer.
+
+New deprecations
+   * The functions mbedtls_cipher_auth_encrypt() and
+     mbedtls_cipher_auth_decrypt() are deprecated in favour of the new
+     functions mbedtls_cipher_auth_encrypt_ext() and
+     mbedtls_cipher_auth_decrypt_ext(). Please note that with AEAD ciphers,
+     these new functions always append the tag to the ciphertext, and include
+     the tag in the ciphertext length.
diff --git a/include/mbedtls/cipher.h b/include/mbedtls/cipher.h
index 8827e0b..1cafa6e 100644
--- a/include/mbedtls/cipher.h
+++ b/include/mbedtls/cipher.h
@@ -857,30 +857,52 @@
                   unsigned char *output, size_t *olen );
 
 #if defined(MBEDTLS_CIPHER_MODE_AEAD)
+#if ! defined(MBEDTLS_DEPRECATED_REMOVED)
+#if defined(MBEDTLS_DEPRECATED_WARNING)
+#define MBEDTLS_DEPRECATED    __attribute__((deprecated))
+#else
+#define MBEDTLS_DEPRECATED
+#endif /* MBEDTLS_DEPRECATED_WARNING */
 /**
- * \brief               The generic autenticated encryption (AEAD) function.
+ * \brief               The generic authenticated encryption (AEAD) function.
+ *
+ * \deprecated          Superseded by mbedtls_cipher_auth_encrypt_ext().
+ *
+ * \note                This function only supports AEAD algorithms, not key
+ *                      wrapping algorithms such as NIST_KW; for this, see
+ *                      mbedtls_cipher_auth_encrypt_ext().
  *
  * \param ctx           The generic cipher context. This must be initialized and
- *                      bound to a key.
- * \param iv            The IV to use, or NONCE_COUNTER for CTR-mode ciphers.
- *                      This must be a readable buffer of at least \p iv_len
- *                      Bytes.
- * \param iv_len        The IV length for ciphers with variable-size IV.
- *                      This parameter is discarded by ciphers with fixed-size IV.
+ *                      bound to a key associated with an AEAD algorithm.
+ * \param iv            The nonce to use. This must be a readable buffer of
+ *                      at least \p iv_len Bytes and must not be \c NULL.
+ * \param iv_len        The length of the nonce. This must satisfy the
+ *                      constraints imposed by the AEAD cipher used.
  * \param ad            The additional data to authenticate. This must be a
- *                      readable buffer of at least \p ad_len Bytes.
+ *                      readable buffer of at least \p ad_len Bytes, and may
+ *                      be \c NULL is \p ad_len is \c 0.
  * \param ad_len        The length of \p ad.
  * \param input         The buffer holding the input data. This must be a
- *                      readable buffer of at least \p ilen Bytes.
+ *                      readable buffer of at least \p ilen Bytes, and may be
+ *                      \c NULL if \p ilen is \c 0.
  * \param ilen          The length of the input data.
- * \param output        The buffer for the output data. This must be able to
- *                      hold at least \p ilen Bytes.
- * \param olen          The length of the output data, to be updated with the
- *                      actual number of Bytes written. This must not be
- *                      \c NULL.
+ * \param output        The buffer for the output data. This must be a
+ *                      writable buffer of at least \p ilen Bytes, and must
+ *                      not be \c NULL.
+ * \param olen          This will be filled with the actual number of Bytes
+ *                      written to the \p output buffer. This must point to a
+ *                      writable object of type \c size_t.
  * \param tag           The buffer for the authentication tag. This must be a
- *                      writable buffer of at least \p tag_len Bytes.
- * \param tag_len       The desired length of the authentication tag.
+ *                      writable buffer of at least \p tag_len Bytes. See note
+ *                      below regarding restrictions with PSA-based contexts.
+ * \param tag_len       The desired length of the authentication tag. This
+ *                      must match the constraints imposed by the AEAD cipher
+ *                      used, and in particular must not be \c 0.
+ *
+ * \note                If the context is based on PSA (that is, it was set up
+ *                      with mbedtls_cipher_setup_psa()), then it is required
+ *                      that \c tag == output + ilen. That is, the tag must be
+ *                      appended to the ciphertext as recommended by RFC 5116.
  *
  * \return              \c 0 on success.
  * \return              #MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA on
@@ -892,36 +914,53 @@
                          const unsigned char *ad, size_t ad_len,
                          const unsigned char *input, size_t ilen,
                          unsigned char *output, size_t *olen,
-                         unsigned char *tag, size_t tag_len );
+                         unsigned char *tag, size_t tag_len )
+                         MBEDTLS_DEPRECATED;
 
 /**
- * \brief               The generic autenticated decryption (AEAD) function.
+ * \brief               The generic authenticated decryption (AEAD) function.
+ *
+ * \deprecated          Superseded by mbedtls_cipher_auth_decrypt_ext().
+ *
+ * \note                This function only supports AEAD algorithms, not key
+ *                      wrapping algorithms such as NIST_KW; for this, see
+ *                      mbedtls_cipher_auth_decrypt_ext().
  *
  * \note                If the data is not authentic, then the output buffer
  *                      is zeroed out to prevent the unauthentic plaintext being
  *                      used, making this interface safer.
  *
  * \param ctx           The generic cipher context. This must be initialized and
- *                      and bound to a key.
- * \param iv            The IV to use, or NONCE_COUNTER for CTR-mode ciphers.
- *                      This must be a readable buffer of at least \p iv_len
- *                      Bytes.
- * \param iv_len        The IV length for ciphers with variable-size IV.
- *                      This parameter is discarded by ciphers with fixed-size IV.
- * \param ad            The additional data to be authenticated. This must be a
- *                      readable buffer of at least \p ad_len Bytes.
+ *                      bound to a key associated with an AEAD algorithm.
+ * \param iv            The nonce to use. This must be a readable buffer of
+ *                      at least \p iv_len Bytes and must not be \c NULL.
+ * \param iv_len        The length of the nonce. This must satisfy the
+ *                      constraints imposed by the AEAD cipher used.
+ * \param ad            The additional data to authenticate. This must be a
+ *                      readable buffer of at least \p ad_len Bytes, and may
+ *                      be \c NULL is \p ad_len is \c 0.
  * \param ad_len        The length of \p ad.
  * \param input         The buffer holding the input data. This must be a
- *                      readable buffer of at least \p ilen Bytes.
+ *                      readable buffer of at least \p ilen Bytes, and may be
+ *                      \c NULL if \p ilen is \c 0.
  * \param ilen          The length of the input data.
- * \param output        The buffer for the output data.
- *                      This must be able to hold at least \p ilen Bytes.
- * \param olen          The length of the output data, to be updated with the
- *                      actual number of Bytes written. This must not be
- *                      \c NULL.
- * \param tag           The buffer holding the authentication tag. This must be
- *                      a readable buffer of at least \p tag_len Bytes.
- * \param tag_len       The length of the authentication tag.
+ * \param output        The buffer for the output data. This must be a
+ *                      writable buffer of at least \p ilen Bytes, and must
+ *                      not be \c NULL.
+ * \param olen          This will be filled with the actual number of Bytes
+ *                      written to the \p output buffer. This must point to a
+ *                      writable object of type \c size_t.
+ * \param tag           The buffer for the authentication tag. This must be a
+ *                      readable buffer of at least \p tag_len Bytes. See note
+ *                      below regarding restrictions with PSA-based contexts.
+ * \param tag_len       The length of the authentication tag. This must match
+ *                      the constraints imposed by the AEAD cipher used, and in
+ *                      particular must not be \c 0.
+ *
+ * \note                If the context is based on PSA (that is, it was set up
+ *                      with mbedtls_cipher_setup_psa()), then it is required
+ *                      that \c tag == input + len. That is, the tag must be
+ *                      appended to the ciphertext as recommended by RFC 5116.
  *
  * \return              \c 0 on success.
  * \return              #MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA on
@@ -934,9 +973,120 @@
                          const unsigned char *ad, size_t ad_len,
                          const unsigned char *input, size_t ilen,
                          unsigned char *output, size_t *olen,
-                         const unsigned char *tag, size_t tag_len );
+                         const unsigned char *tag, size_t tag_len )
+                         MBEDTLS_DEPRECATED;
+#undef MBEDTLS_DEPRECATED
+#endif /* MBEDTLS_DEPRECATED_REMOVED */
 #endif /* MBEDTLS_CIPHER_MODE_AEAD */
 
+#if defined(MBEDTLS_CIPHER_MODE_AEAD) || defined(MBEDTLS_NIST_KW_C)
+/**
+ * \brief               The authenticated encryption (AEAD/NIST_KW) function.
+ *
+ * \note                For AEAD modes, the tag will be appended to the
+ *                      ciphertext, as recommended by RFC 5116.
+ *                      (NIST_KW doesn't have a separate tag.)
+ *
+ * \param ctx           The generic cipher context. This must be initialized and
+ *                      bound to a key, with an AEAD algorithm or NIST_KW.
+ * \param iv            The nonce to use. This must be a readable buffer of
+ *                      at least \p iv_len Bytes and may be \c NULL if \p
+ *                      iv_len is \c 0.
+ * \param iv_len        The length of the nonce. For AEAD ciphers, this must
+ *                      satisfy the constraints imposed by the cipher used.
+ *                      For NIST_KW, this must be \c 0.
+ * \param ad            The additional data to authenticate. This must be a
+ *                      readable buffer of at least \p ad_len Bytes, and may
+ *                      be \c NULL is \p ad_len is \c 0.
+ * \param ad_len        The length of \p ad. For NIST_KW, this must be \c 0.
+ * \param input         The buffer holding the input data. This must be a
+ *                      readable buffer of at least \p ilen Bytes, and may be
+ *                      \c NULL if \p ilen is \c 0.
+ * \param ilen          The length of the input data.
+ * \param output        The buffer for the output data. This must be a
+ *                      writable buffer of at least \p output_len Bytes, and
+ *                      must not be \c NULL.
+ * \param output_len    The length of the \p output buffer in Bytes. For AEAD
+ *                      ciphers, this must be at least \p ilen + \p tag_len.
+ *                      For NIST_KW, this must be at least \p ilen + 8
+ *                      (rounded up to a multiple of 8 if KWP is used);
+ *                      \p ilen + 15 is always a safe value.
+ * \param olen          This will be filled with the actual number of Bytes
+ *                      written to the \p output buffer. This must point to a
+ *                      writable object of type \c size_t.
+ * \param tag_len       The desired length of the authentication tag. For AEAD
+ *                      ciphers, this must match the constraints imposed by
+ *                      the cipher used, and in particular must not be \c 0.
+ *                      For NIST_KW, this must be \c 0.
+ *
+ * \return              \c 0 on success.
+ * \return              #MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA on
+ *                      parameter-verification failure.
+ * \return              A cipher-specific error code on failure.
+ */
+int mbedtls_cipher_auth_encrypt_ext( mbedtls_cipher_context_t *ctx,
+                         const unsigned char *iv, size_t iv_len,
+                         const unsigned char *ad, size_t ad_len,
+                         const unsigned char *input, size_t ilen,
+                         unsigned char *output, size_t output_len,
+                         size_t *olen, size_t tag_len );
+
+/**
+ * \brief               The authenticated encryption (AEAD/NIST_KW) function.
+ *
+ * \note                If the data is not authentic, then the output buffer
+ *                      is zeroed out to prevent the unauthentic plaintext being
+ *                      used, making this interface safer.
+ *
+ * \note                For AEAD modes, the tag must be appended to the
+ *                      ciphertext, as recommended by RFC 5116.
+ *                      (NIST_KW doesn't have a separate tag.)
+ *
+ * \param ctx           The generic cipher context. This must be initialized and
+ *                      bound to a key, with an AEAD algorithm or NIST_KW.
+ * \param iv            The nonce to use. This must be a readable buffer of
+ *                      at least \p iv_len Bytes and may be \c NULL if \p
+ *                      iv_len is \c 0.
+ * \param iv_len        The length of the nonce. For AEAD ciphers, this must
+ *                      satisfy the constraints imposed by the cipher used.
+ *                      For NIST_KW, this must be \c 0.
+ * \param ad            The additional data to authenticate. This must be a
+ *                      readable buffer of at least \p ad_len Bytes, and may
+ *                      be \c NULL is \p ad_len is \c 0.
+ * \param ad_len        The length of \p ad. For NIST_KW, this must be \c 0.
+ * \param input         The buffer holding the input data. This must be a
+ *                      readable buffer of at least \p ilen Bytes, and may be
+ *                      \c NULL if \p ilen is \c 0.
+ * \param ilen          The length of the input data. For AEAD ciphers this
+ *                      must be at least \p tag_len. For NIST_KW this must be
+ *                      at least \c 8.
+ * \param output        The buffer for the output data. This must be a
+ *                      writable buffer of at least \p output_len Bytes, and
+ *                      may be \c NULL if \p output_len is \c 0.
+ * \param output_len    The length of the \p output buffer in Bytes. For AEAD
+ *                      ciphers, this must be at least \p ilen - \p tag_len.
+ *                      For NIST_KW, this must be at least \p ilen - 8.
+ * \param olen          This will be filled with the actual number of Bytes
+ *                      written to the \p output buffer. This must point to a
+ *                      writable object of type \c size_t.
+ * \param tag_len       The actual length of the authentication tag. For AEAD
+ *                      ciphers, this must match the constraints imposed by
+ *                      the cipher used, and in particular must not be \c 0.
+ *                      For NIST_KW, this must be \c 0.
+ *
+ * \return              \c 0 on success.
+ * \return              #MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA on
+ *                      parameter-verification failure.
+ * \return              #MBEDTLS_ERR_CIPHER_AUTH_FAILED if data is not authentic.
+ * \return              A cipher-specific error code on failure.
+ */
+int mbedtls_cipher_auth_decrypt_ext( mbedtls_cipher_context_t *ctx,
+                         const unsigned char *iv, size_t iv_len,
+                         const unsigned char *ad, size_t ad_len,
+                         const unsigned char *input, size_t ilen,
+                         unsigned char *output, size_t output_len,
+                         size_t *olen, size_t tag_len );
+#endif /* MBEDTLS_CIPHER_MODE_AEAD || MBEDTLS_NIST_KW_C */
 #ifdef __cplusplus
 }
 #endif
diff --git a/library/cipher.c b/library/cipher.c
index 853eeec..457f8f6 100644
--- a/library/cipher.c
+++ b/library/cipher.c
@@ -1288,23 +1288,16 @@
 
 #if defined(MBEDTLS_CIPHER_MODE_AEAD)
 /*
- * Packet-oriented encryption for AEAD modes
+ * Packet-oriented encryption for AEAD modes: internal function shared by
+ * mbedtls_cipher_auth_encrypt() and mbedtls_cipher_auth_encrypt_ext().
  */
-int mbedtls_cipher_auth_encrypt( mbedtls_cipher_context_t *ctx,
+static int mbedtls_cipher_aead_encrypt( mbedtls_cipher_context_t *ctx,
                          const unsigned char *iv, size_t iv_len,
                          const unsigned char *ad, size_t ad_len,
                          const unsigned char *input, size_t ilen,
                          unsigned char *output, size_t *olen,
                          unsigned char *tag, size_t tag_len )
 {
-    CIPHER_VALIDATE_RET( ctx != NULL );
-    CIPHER_VALIDATE_RET( iv != NULL );
-    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
-    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
-    CIPHER_VALIDATE_RET( output != NULL );
-    CIPHER_VALIDATE_RET( olen != NULL );
-    CIPHER_VALIDATE_RET( tag_len == 0 || tag != NULL );
-
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ctx->psa_enabled == 1 )
     {
@@ -1320,7 +1313,7 @@
 
         /* PSA Crypto API always writes the authentication tag
          * at the end of the encrypted message. */
-        if( tag != output + ilen )
+        if( output == NULL || tag != output + ilen )
             return( MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
 
         status = psa_aead_encrypt( cipher_psa->slot,
@@ -1370,44 +1363,21 @@
                                 ilen, iv, ad, ad_len, input, output, tag ) );
     }
 #endif /* MBEDTLS_CHACHAPOLY_C */
-#if defined(MBEDTLS_NIST_KW_C)
-   if( MBEDTLS_MODE_KW == ctx->cipher_info->mode ||
-       MBEDTLS_MODE_KWP == ctx->cipher_info->mode )
-    {
-        mbedtls_nist_kw_mode_t mode = ( MBEDTLS_MODE_KW == ctx->cipher_info->mode ) ?
-                                        MBEDTLS_KW_MODE_KW : MBEDTLS_KW_MODE_KWP;
-
-        /* There is no iv, tag or ad associated with KW and KWP, these length should be 0 */
-        if( iv_len != 0 || tag_len != 0 || ad_len != 0 )
-        {
-            return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
-        }
-
-        return( mbedtls_nist_kw_wrap( ctx->cipher_ctx, mode, input, ilen, output, olen, SIZE_MAX ) );
-    }
-#endif /* MBEDTLS_NIST_KW_C */
 
     return( MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
 }
 
 /*
- * Packet-oriented decryption for AEAD modes
+ * Packet-oriented encryption for AEAD modes: internal function shared by
+ * mbedtls_cipher_auth_encrypt() and mbedtls_cipher_auth_encrypt_ext().
  */
-int mbedtls_cipher_auth_decrypt( mbedtls_cipher_context_t *ctx,
+static int mbedtls_cipher_aead_decrypt( mbedtls_cipher_context_t *ctx,
                          const unsigned char *iv, size_t iv_len,
                          const unsigned char *ad, size_t ad_len,
                          const unsigned char *input, size_t ilen,
                          unsigned char *output, size_t *olen,
                          const unsigned char *tag, size_t tag_len )
 {
-    CIPHER_VALIDATE_RET( ctx != NULL );
-    CIPHER_VALIDATE_RET( iv != NULL );
-    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
-    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
-    CIPHER_VALIDATE_RET( output != NULL );
-    CIPHER_VALIDATE_RET( olen != NULL );
-    CIPHER_VALIDATE_RET( tag_len == 0 || tag != NULL );
-
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ctx->psa_enabled == 1 )
     {
@@ -1423,7 +1393,7 @@
 
         /* PSA Crypto API always writes the authentication tag
          * at the end of the encrypted message. */
-        if( tag != input + ilen )
+        if( input == NULL || tag != input + ilen )
             return( MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
 
         status = psa_aead_decrypt( cipher_psa->slot,
@@ -1495,25 +1465,169 @@
         return( ret );
     }
 #endif /* MBEDTLS_CHACHAPOLY_C */
+
+    return( MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
+}
+
+#if !defined(MBEDTLS_DEPRECATED_REMOVED)
+/*
+ * Packet-oriented encryption for AEAD modes: public legacy function.
+ */
+int mbedtls_cipher_auth_encrypt( mbedtls_cipher_context_t *ctx,
+                         const unsigned char *iv, size_t iv_len,
+                         const unsigned char *ad, size_t ad_len,
+                         const unsigned char *input, size_t ilen,
+                         unsigned char *output, size_t *olen,
+                         unsigned char *tag, size_t tag_len )
+{
+    CIPHER_VALIDATE_RET( ctx != NULL );
+    CIPHER_VALIDATE_RET( iv_len == 0 || iv != NULL );
+    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
+    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
+    CIPHER_VALIDATE_RET( ilen == 0 || output != NULL );
+    CIPHER_VALIDATE_RET( olen != NULL );
+    CIPHER_VALIDATE_RET( tag_len == 0 || tag != NULL );
+
+    return( mbedtls_cipher_aead_encrypt( ctx, iv, iv_len, ad, ad_len,
+                                         input, ilen, output, olen,
+                                         tag, tag_len ) );
+}
+
+/*
+ * Packet-oriented decryption for AEAD modes: public legacy function.
+ */
+int mbedtls_cipher_auth_decrypt( mbedtls_cipher_context_t *ctx,
+                         const unsigned char *iv, size_t iv_len,
+                         const unsigned char *ad, size_t ad_len,
+                         const unsigned char *input, size_t ilen,
+                         unsigned char *output, size_t *olen,
+                         const unsigned char *tag, size_t tag_len )
+{
+    CIPHER_VALIDATE_RET( ctx != NULL );
+    CIPHER_VALIDATE_RET( iv_len == 0 || iv != NULL );
+    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
+    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
+    CIPHER_VALIDATE_RET( ilen == 0 || output != NULL );
+    CIPHER_VALIDATE_RET( olen != NULL );
+    CIPHER_VALIDATE_RET( tag_len == 0 || tag != NULL );
+
+    return( mbedtls_cipher_aead_decrypt( ctx, iv, iv_len, ad, ad_len,
+                                         input, ilen, output, olen,
+                                         tag, tag_len ) );
+}
+#endif /* !MBEDTLS_DEPRECATED_REMOVED */
+#endif /* MBEDTLS_CIPHER_MODE_AEAD */
+
+#if defined(MBEDTLS_CIPHER_MODE_AEAD) || defined(MBEDTLS_NIST_KW_C)
+/*
+ * Packet-oriented encryption for AEAD/NIST_KW: public function.
+ */
+int mbedtls_cipher_auth_encrypt_ext( mbedtls_cipher_context_t *ctx,
+                         const unsigned char *iv, size_t iv_len,
+                         const unsigned char *ad, size_t ad_len,
+                         const unsigned char *input, size_t ilen,
+                         unsigned char *output, size_t output_len,
+                         size_t *olen, size_t tag_len )
+{
+    CIPHER_VALIDATE_RET( ctx != NULL );
+    CIPHER_VALIDATE_RET( iv_len == 0 || iv != NULL );
+    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
+    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
+    CIPHER_VALIDATE_RET( output != NULL );
+    CIPHER_VALIDATE_RET( olen != NULL );
+
 #if defined(MBEDTLS_NIST_KW_C)
-    if( MBEDTLS_MODE_KW == ctx->cipher_info->mode ||
-        MBEDTLS_MODE_KWP == ctx->cipher_info->mode )
+    if(
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        ctx->psa_enabled == 0 &&
+#endif
+        ( MBEDTLS_MODE_KW == ctx->cipher_info->mode ||
+          MBEDTLS_MODE_KWP == ctx->cipher_info->mode ) )
     {
         mbedtls_nist_kw_mode_t mode = ( MBEDTLS_MODE_KW == ctx->cipher_info->mode ) ?
                                         MBEDTLS_KW_MODE_KW : MBEDTLS_KW_MODE_KWP;
 
-        /* There is no iv, tag or ad associated with KW and KWP, these length should be 0 */
+        /* There is no iv, tag or ad associated with KW and KWP,
+         * so these length should be 0 as documented. */
         if( iv_len != 0 || tag_len != 0 || ad_len != 0 )
-        {
             return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
-        }
 
-        return( mbedtls_nist_kw_unwrap( ctx->cipher_ctx, mode, input, ilen, output, olen, SIZE_MAX ) );
+        (void) iv;
+        (void) ad;
+
+        return( mbedtls_nist_kw_wrap( ctx->cipher_ctx, mode, input, ilen,
+                                      output, olen, output_len ) );
     }
 #endif /* MBEDTLS_NIST_KW_C */
 
+#if defined(MBEDTLS_CIPHER_MODE_AEAD)
+    /* AEAD case: check length before passing on to shared function */
+    if( output_len < ilen + tag_len )
+        return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
+
+    int ret = mbedtls_cipher_aead_encrypt( ctx, iv, iv_len, ad, ad_len,
+                                       input, ilen, output, olen,
+                                       output + ilen, tag_len );
+    *olen += tag_len;
+    return( ret );
+#else
     return( MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
-}
 #endif /* MBEDTLS_CIPHER_MODE_AEAD */
+}
+
+/*
+ * Packet-oriented decryption for AEAD/NIST_KW: public function.
+ */
+int mbedtls_cipher_auth_decrypt_ext( mbedtls_cipher_context_t *ctx,
+                         const unsigned char *iv, size_t iv_len,
+                         const unsigned char *ad, size_t ad_len,
+                         const unsigned char *input, size_t ilen,
+                         unsigned char *output, size_t output_len,
+                         size_t *olen, size_t tag_len )
+{
+    CIPHER_VALIDATE_RET( ctx != NULL );
+    CIPHER_VALIDATE_RET( iv_len == 0 || iv != NULL );
+    CIPHER_VALIDATE_RET( ad_len == 0 || ad != NULL );
+    CIPHER_VALIDATE_RET( ilen == 0 || input != NULL );
+    CIPHER_VALIDATE_RET( output_len == 0 || output != NULL );
+    CIPHER_VALIDATE_RET( olen != NULL );
+
+#if defined(MBEDTLS_NIST_KW_C)
+    if(
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        ctx->psa_enabled == 0 &&
+#endif
+        ( MBEDTLS_MODE_KW == ctx->cipher_info->mode ||
+          MBEDTLS_MODE_KWP == ctx->cipher_info->mode ) )
+    {
+        mbedtls_nist_kw_mode_t mode = ( MBEDTLS_MODE_KW == ctx->cipher_info->mode ) ?
+                                        MBEDTLS_KW_MODE_KW : MBEDTLS_KW_MODE_KWP;
+
+        /* There is no iv, tag or ad associated with KW and KWP,
+         * so these length should be 0 as documented. */
+        if( iv_len != 0 || tag_len != 0 || ad_len != 0 )
+            return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
+
+        (void) iv;
+        (void) ad;
+
+        return( mbedtls_nist_kw_unwrap( ctx->cipher_ctx, mode, input, ilen,
+                                        output, olen, output_len ) );
+    }
+#endif /* MBEDTLS_NIST_KW_C */
+
+#if defined(MBEDTLS_CIPHER_MODE_AEAD)
+    /* AEAD case: check length before passing on to shared function */
+    if( ilen < tag_len || output_len < ilen - tag_len )
+        return( MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
+
+    return( mbedtls_cipher_aead_decrypt( ctx, iv, iv_len, ad, ad_len,
+                                         input, ilen - tag_len, output, olen,
+                                         input + ilen - tag_len, tag_len ) );
+#else
+    return( MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
+#endif /* MBEDTLS_CIPHER_MODE_AEAD */
+}
+#endif /* MBEDTLS_CIPHER_MODE_AEAD || MBEDTLS_NIST_KW_C */
 
 #endif /* MBEDTLS_CIPHER_C */
diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index 0718d5a..72f09bb 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -850,20 +850,21 @@
          * Encrypt and authenticate
          */
 
-        if( ( ret = mbedtls_cipher_auth_encrypt( &transform->cipher_ctx_enc,
+        if( ( ret = mbedtls_cipher_auth_encrypt_ext( &transform->cipher_ctx_enc,
                    iv, transform->ivlen,
-                   add_data, add_data_len,       /* add data     */
-                   data, rec->data_len,          /* source       */
-                   data, &rec->data_len,         /* destination  */
-                   data + rec->data_len, transform->taglen ) ) != 0 )
+                   add_data, add_data_len,
+                   data, rec->data_len,                     /* src */
+                   data, rec->buf_len - (data - rec->buf),  /* dst */
+                   &rec->data_len,
+                   transform->taglen ) ) != 0 )
         {
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_auth_encrypt", ret );
             return( ret );
         }
         MBEDTLS_SSL_DEBUG_BUF( 4, "after encrypt: tag",
-                               data + rec->data_len, transform->taglen );
+                               data + rec->data_len - transform->taglen,
+                               transform->taglen );
         /* Account for authentication tag. */
-        rec->data_len += transform->taglen;
         post_avail -= transform->taglen;
 
         /*
@@ -1422,12 +1423,11 @@
         /*
          * Decrypt and authenticate
          */
-        if( ( ret = mbedtls_cipher_auth_decrypt( &transform->cipher_ctx_dec,
+        if( ( ret = mbedtls_cipher_auth_decrypt_ext( &transform->cipher_ctx_dec,
                   iv, transform->ivlen,
                   add_data, add_data_len,
-                  data, rec->data_len,
-                  data, &olen,
-                  data + rec->data_len,
+                  data, rec->data_len + transform->taglen,          /* src */
+                  data, rec->buf_len - (data - rec->buf), &olen,    /* dst */
                   transform->taglen ) ) != 0 )
         {
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_auth_decrypt", ret );
diff --git a/library/ssl_ticket.c b/library/ssl_ticket.c
index e3e8023..626d137 100644
--- a/library/ssl_ticket.c
+++ b/library/ssl_ticket.c
@@ -209,7 +209,6 @@
     unsigned char *iv = start + TICKET_KEY_NAME_BYTES;
     unsigned char *state_len_bytes = iv + TICKET_IV_BYTES;
     unsigned char *state = state_len_bytes + TICKET_CRYPT_LEN_BYTES;
-    unsigned char *tag;
     size_t clear_len, ciph_len;
 
     *tlen = 0;
@@ -250,23 +249,23 @@
     state_len_bytes[1] = ( clear_len      ) & 0xff;
 
     /* Encrypt and authenticate */
-    tag = state + clear_len;
-    if( ( ret = mbedtls_cipher_auth_encrypt( &key->ctx,
+    if( ( ret = mbedtls_cipher_auth_encrypt_ext( &key->ctx,
                     iv, TICKET_IV_BYTES,
                     /* Additional data: key name, IV and length */
                     key_name, TICKET_ADD_DATA_LEN,
-                    state, clear_len, state, &ciph_len,
-                    tag, TICKET_AUTH_TAG_BYTES ) ) != 0 )
+                    state, clear_len,
+                    state, end - state, &ciph_len,
+                    TICKET_AUTH_TAG_BYTES ) ) != 0 )
     {
         goto cleanup;
     }
-    if( ciph_len != clear_len )
+    if( ciph_len != clear_len + TICKET_AUTH_TAG_BYTES )
     {
         ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
         goto cleanup;
     }
 
-    *tlen = TICKET_MIN_LEN + ciph_len;
+    *tlen = TICKET_MIN_LEN + ciph_len - TICKET_AUTH_TAG_BYTES;
 
 cleanup:
 #if defined(MBEDTLS_THREADING_C)
@@ -308,7 +307,6 @@
     unsigned char *iv = buf + TICKET_KEY_NAME_BYTES;
     unsigned char *enc_len_p = iv + TICKET_IV_BYTES;
     unsigned char *ticket = enc_len_p + TICKET_CRYPT_LEN_BYTES;
-    unsigned char *tag;
     size_t enc_len, clear_len;
 
     if( ctx == NULL || ctx->f_rng == NULL )
@@ -326,7 +324,6 @@
         goto cleanup;
 
     enc_len = ( enc_len_p[0] << 8 ) | enc_len_p[1];
-    tag = ticket + enc_len;
 
     if( len != TICKET_MIN_LEN + enc_len )
     {
@@ -344,13 +341,13 @@
     }
 
     /* Decrypt and authenticate */
-    if( ( ret = mbedtls_cipher_auth_decrypt( &key->ctx,
+    if( ( ret = mbedtls_cipher_auth_decrypt_ext( &key->ctx,
                     iv, TICKET_IV_BYTES,
                     /* Additional data: key name, IV and length */
                     key_name, TICKET_ADD_DATA_LEN,
-                    ticket, enc_len,
-                    ticket, &clear_len,
-                    tag, TICKET_AUTH_TAG_BYTES ) ) != 0 )
+                    ticket, enc_len + TICKET_AUTH_TAG_BYTES,
+                    ticket, enc_len, &clear_len,
+                    TICKET_AUTH_TAG_BYTES ) ) != 0 )
     {
         if( ret == MBEDTLS_ERR_CIPHER_AUTH_FAILED )
             ret = MBEDTLS_ERR_SSL_INVALID_MAC;
diff --git a/tests/suites/test_suite_cipher.function b/tests/suites/test_suite_cipher.function
index ea1e9ad..1d98f3d 100644
--- a/tests/suites/test_suite_cipher.function
+++ b/tests/suites/test_suite_cipher.function
@@ -13,6 +13,65 @@
 #include "test/psa_crypto_helpers.h"
 #endif
 
+#if defined(MBEDTLS_CIPHER_MODE_AEAD) || defined(MBEDTLS_NIST_KW_C)
+#define MBEDTLS_CIPHER_AUTH_CRYPT
+#endif
+
+#if defined(MBEDTLS_CIPHER_AUTH_CRYPT)
+/* Helper for resetting key/direction
+ *
+ * The documentation doesn't explicitly say whether calling
+ * mbedtls_cipher_setkey() twice is allowed or not. This currently works with
+ * the default software implementation, but only by accident. It isn't
+ * guaranteed to work with new ciphers or with alternative implementations of
+ * individual ciphers, and it doesn't work with the PSA wrappers. So don't do
+ * it, and instead start with a fresh context.
+ */
+static int cipher_reset_key( mbedtls_cipher_context_t *ctx, int cipher_id,
+        int use_psa, size_t tag_len, const data_t *key, int direction )
+{
+    mbedtls_cipher_free( ctx );
+    mbedtls_cipher_init( ctx );
+
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
+    (void) use_psa;
+    (void) tag_len;
+#else
+    if( use_psa == 1 )
+    {
+        TEST_ASSERT( 0 == mbedtls_cipher_setup_psa( ctx,
+                              mbedtls_cipher_info_from_type( cipher_id ),
+                              tag_len ) );
+    }
+    else
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+    {
+        TEST_ASSERT( 0 == mbedtls_cipher_setup( ctx,
+                              mbedtls_cipher_info_from_type( cipher_id ) ) );
+    }
+
+    TEST_ASSERT( 0 == mbedtls_cipher_setkey( ctx, key->x, 8 * key->len,
+                                             direction ) );
+    return( 1 );
+
+exit:
+    return( 0 );
+}
+
+/*
+ * Check if a buffer is all-0 bytes:
+ * return   1 if it is,
+ *          0 if it isn't.
+ */
+int buffer_is_all_zero( const uint8_t *buf, size_t size )
+{
+    for( size_t i = 0; i < size; i++ )
+        if( buf[i] != 0 )
+            return 0;
+    return 1;
+}
+#endif /* MBEDTLS_CIPHER_AUTH_CRYPT */
+
 /* END_HEADER */
 
 /* BEGIN_DEPENDENCIES
@@ -485,6 +544,108 @@
                                      NULL, valid_size ) );
 #endif /* defined(MBEDTLS_CIPHER_MODE_AEAD) */
 
+#if defined(MBEDTLS_CIPHER_MODE_AEAD) || defined(MBEDTLS_NIST_KW_C)
+    /* mbedtls_cipher_auth_encrypt_ext */
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_encrypt_ext( NULL,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size, &size_t_var,
+                                         valid_size ) );
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_encrypt_ext( &valid_ctx,
+                                         NULL, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size, &size_t_var,
+                                         valid_size ) );
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_encrypt_ext( &valid_ctx,
+                                         valid_buffer, valid_size,
+                                         NULL, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size, &size_t_var,
+                                         valid_size ) );
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_encrypt_ext( &valid_ctx,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         NULL, valid_size,
+                                         valid_buffer, valid_size, &size_t_var,
+                                         valid_size ) );
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_encrypt_ext( &valid_ctx,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         NULL, valid_size, &size_t_var,
+                                         valid_size ) );
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_encrypt_ext( &valid_ctx,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size, NULL,
+                                         valid_size ) );
+
+    /* mbedtls_cipher_auth_decrypt_ext */
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_decrypt_ext( NULL,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size, &size_t_var,
+                                         valid_size ) );
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_decrypt_ext( &valid_ctx,
+                                         NULL, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size, &size_t_var,
+                                         valid_size ) );
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_decrypt_ext( &valid_ctx,
+                                         valid_buffer, valid_size,
+                                         NULL, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size, &size_t_var,
+                                         valid_size ) );
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_decrypt_ext( &valid_ctx,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         NULL, valid_size,
+                                         valid_buffer, valid_size, &size_t_var,
+                                         valid_size ) );
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_decrypt_ext( &valid_ctx,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         NULL, valid_size, &size_t_var,
+                                         valid_size ) );
+    TEST_INVALID_PARAM_RET(
+        MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA,
+        mbedtls_cipher_auth_decrypt_ext( &valid_ctx,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size,
+                                         valid_buffer, valid_size, NULL,
+                                         valid_size ) );
+#endif /* MBEDTLS_CIPHER_MODE_AEAD || MBEDTLS_NIST_KW_C */
+
     /* mbedtls_cipher_free() */
     TEST_VALID_PARAM( mbedtls_cipher_free( NULL ) );
 exit:
@@ -959,129 +1120,338 @@
 }
 /* END_CASE */
 
-/* BEGIN_CASE depends_on:MBEDTLS_CIPHER_MODE_AEAD */
+/* BEGIN_CASE depends_on:MBEDTLS_CIPHER_AUTH_CRYPT */
 void auth_crypt_tv( int cipher_id, data_t * key, data_t * iv,
                     data_t * ad, data_t * cipher, data_t * tag,
                     char * result, data_t * clear, int use_psa )
 {
-    /* Takes an AEAD ciphertext + tag and performs a pair
-     * of AEAD decryption and AEAD encryption. It checks that
+    /*
+     * Take an AEAD ciphertext + tag and perform a pair
+     * of AEAD decryption and AEAD encryption. Check that
      * this results in the expected plaintext, and that
-     * decryption and encryption are inverse to one another. */
+     * decryption and encryption are inverse to one another.
+     *
+     * Do that twice:
+     * - once with legacy functions auth_decrypt/auth_encrypt
+     * - once with new functions auth_decrypt_ext/auth_encrypt_ext
+     * This allows testing both without duplicating test cases.
+     */
 
     int ret;
-    unsigned char output[300];        /* Temporary buffer for results of
-                                       * encryption and decryption. */
-    unsigned char *output_tag = NULL; /* Temporary buffer for tag in the
-                                       * encryption step. */
+    int using_nist_kw, using_nist_kw_padding;
 
     mbedtls_cipher_context_t ctx;
     size_t outlen;
 
+    unsigned char *cipher_plus_tag = NULL;
+    size_t cipher_plus_tag_len;
+    unsigned char *decrypt_buf = NULL;
+    size_t decrypt_buf_len = 0;
+    unsigned char *encrypt_buf = NULL;
+    size_t encrypt_buf_len = 0;
+
+#if !defined(MBEDTLS_DEPRECATED_WARNING) && \
+    !defined(MBEDTLS_DEPRECATED_REMOVED)
     unsigned char *tmp_tag    = NULL;
     unsigned char *tmp_cipher = NULL;
+    unsigned char *tag_buf = NULL;
+#endif /* !MBEDTLS_DEPRECATED_WARNING && !MBEDTLS_DEPRECATED_REMOVED */
+
+    /* Null pointers are documented as valid for inputs of length 0.
+     * The test framework passes non-null pointers, so set them to NULL.
+     * key, cipher and tag can't be empty. */
+    if( iv->len == 0 )
+        iv->x = NULL;
+    if( ad->len == 0 )
+        ad->x = NULL;
+    if( clear->len == 0 )
+        clear->x = NULL;
 
     mbedtls_cipher_init( &ctx );
-    memset( output, 0xFF, sizeof( output ) );
 
-    /* Prepare context */
-#if !defined(MBEDTLS_USE_PSA_CRYPTO)
-    (void) use_psa;
-#else
+    /* Initialize PSA Crypto */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( use_psa == 1 )
-    {
         PSA_ASSERT( psa_crypto_init( ) );
-
-        /* PSA requires that the tag immediately follows the ciphertext. */
-        tmp_cipher = mbedtls_calloc( 1, cipher->len + tag->len );
-        TEST_ASSERT( tmp_cipher != NULL );
-        tmp_tag = tmp_cipher + cipher->len;
-
-        memcpy( tmp_cipher, cipher->x, cipher->len );
-        memcpy( tmp_tag, tag->x, tag->len );
-
-        TEST_ASSERT( 0 == mbedtls_cipher_setup_psa( &ctx,
-                              mbedtls_cipher_info_from_type( cipher_id ),
-                              tag->len ) );
-    }
-    else
+#else
+    (void) use_psa;
 #endif
+
+    /*
+     * Are we using NIST_KW? with padding?
+     */
+    using_nist_kw_padding = cipher_id == MBEDTLS_CIPHER_AES_128_KWP ||
+                            cipher_id == MBEDTLS_CIPHER_AES_192_KWP ||
+                            cipher_id == MBEDTLS_CIPHER_AES_256_KWP;
+    using_nist_kw = cipher_id == MBEDTLS_CIPHER_AES_128_KW ||
+                    cipher_id == MBEDTLS_CIPHER_AES_192_KW ||
+                    cipher_id == MBEDTLS_CIPHER_AES_256_KW ||
+                    using_nist_kw_padding;
+
+    /****************************************************************
+     *                                                              *
+     *  Part 1: non-deprecated API                                  *
+     *                                                              *
+     ****************************************************************/
+
+    /*
+     * Prepare context for decryption
+     */
+    if( ! cipher_reset_key( &ctx, cipher_id, use_psa, tag->len, key,
+                            MBEDTLS_DECRYPT ) )
+        goto exit;
+
+    /*
+     * prepare buffer for decryption
+     * (we need the tag appended to the ciphertext)
+     */
+    cipher_plus_tag_len = cipher->len + tag->len;
+    ASSERT_ALLOC( cipher_plus_tag, cipher_plus_tag_len );
+    memcpy( cipher_plus_tag, cipher->x, cipher->len );
+    memcpy( cipher_plus_tag + cipher->len, tag->x, tag->len );
+
+    /*
+     * Compute length of output buffer according to the documentation
+     */
+    if( using_nist_kw )
+        decrypt_buf_len = cipher_plus_tag_len - 8;
+    else
+        decrypt_buf_len = cipher_plus_tag_len - tag->len;
+
+
+    /*
+     * Try decrypting to a buffer that's 1B too small
+     */
+    if( decrypt_buf_len != 0 )
     {
-        tmp_tag = tag->x;
-        tmp_cipher = cipher->x;
-        TEST_ASSERT( 0 == mbedtls_cipher_setup( &ctx,
-                              mbedtls_cipher_info_from_type( cipher_id ) ) );
+        ASSERT_ALLOC( decrypt_buf, decrypt_buf_len - 1 );
+
+        outlen = 0;
+        ret = mbedtls_cipher_auth_decrypt_ext( &ctx, iv->x, iv->len,
+                ad->x, ad->len, cipher_plus_tag, cipher_plus_tag_len,
+                decrypt_buf, decrypt_buf_len - 1, &outlen, tag->len );
+        TEST_ASSERT( ret == MBEDTLS_ERR_CIPHER_BAD_INPUT_DATA );
+
+        mbedtls_free( decrypt_buf );
+        decrypt_buf = NULL;
     }
 
-    TEST_ASSERT( 0 == mbedtls_cipher_setkey( &ctx, key->x, 8 * key->len,
-                                             MBEDTLS_DECRYPT ) );
+    /*
+     * Authenticate and decrypt, and check result
+     */
+    ASSERT_ALLOC( decrypt_buf, decrypt_buf_len );
 
-    /* decode buffer and check tag->x */
+    outlen = 0;
+    ret = mbedtls_cipher_auth_decrypt_ext( &ctx, iv->x, iv->len,
+            ad->x, ad->len, cipher_plus_tag, cipher_plus_tag_len,
+            decrypt_buf, decrypt_buf_len, &outlen, tag->len );
 
-    /* Sanity check that we don't use overly long inputs. */
-    TEST_ASSERT( sizeof( output ) >= cipher->len );
-
-    ret = mbedtls_cipher_auth_decrypt( &ctx, iv->x, iv->len, ad->x, ad->len,
-                               tmp_cipher, cipher->len, output, &outlen,
-                               tmp_tag, tag->len );
-
-    /* make sure the message is rejected if it should be */
     if( strcmp( result, "FAIL" ) == 0 )
     {
         TEST_ASSERT( ret == MBEDTLS_ERR_CIPHER_AUTH_FAILED );
-        goto exit;
+        TEST_ASSERT( buffer_is_all_zero( decrypt_buf, decrypt_buf_len ) );
+    }
+    else
+    {
+        TEST_ASSERT( ret == 0 );
+        ASSERT_COMPARE( decrypt_buf, outlen, clear->x, clear->len );
     }
 
-    /* otherwise, make sure it was decrypted properly */
-    TEST_ASSERT( ret == 0 );
+    /* Free this, but keep cipher_plus_tag for deprecated function with PSA */
+    mbedtls_free( decrypt_buf );
+    decrypt_buf = NULL;
 
-    TEST_ASSERT( outlen == clear->len );
-    TEST_ASSERT( memcmp( output, clear->x, clear->len ) == 0 );
+    /*
+     * Encrypt back if test data was authentic
+     */
+    if( strcmp( result, "FAIL" ) != 0 )
+    {
+        /* prepare context for encryption */
+        if( ! cipher_reset_key( &ctx, cipher_id, use_psa, tag->len, key,
+                                MBEDTLS_ENCRYPT ) )
+            goto exit;
 
-    /* then encrypt the clear->x and make sure we get the same ciphertext and tag->x */
-    mbedtls_cipher_free( &ctx );
+        /*
+         * Compute size of output buffer according to documentation
+         */
+        if( using_nist_kw )
+        {
+            encrypt_buf_len = clear->len + 8;
+            if( using_nist_kw_padding && encrypt_buf_len % 8 != 0 )
+                encrypt_buf_len += 8 - encrypt_buf_len % 8;
+        }
+        else
+        {
+            encrypt_buf_len = clear->len + tag->len;
+        }
+
+        /*
+         * Try encrypting with an output buffer that's 1B too small
+         */
+        ASSERT_ALLOC( encrypt_buf, encrypt_buf_len - 1 );
+
+        outlen = 0;
+        ret = mbedtls_cipher_auth_encrypt_ext( &ctx, iv->x, iv->len,
+                ad->x, ad->len, clear->x, clear->len,
+                encrypt_buf, encrypt_buf_len - 1, &outlen, tag->len );
+        TEST_ASSERT( ret != 0 );
+
+        mbedtls_free( encrypt_buf );
+        encrypt_buf = NULL;
+
+        /*
+         * Encrypt and check the result
+         */
+        ASSERT_ALLOC( encrypt_buf, encrypt_buf_len );
+
+        outlen = 0;
+        ret = mbedtls_cipher_auth_encrypt_ext( &ctx, iv->x, iv->len,
+                ad->x, ad->len, clear->x, clear->len,
+                encrypt_buf, encrypt_buf_len, &outlen, tag->len );
+        TEST_ASSERT( ret == 0 );
+
+        TEST_ASSERT( outlen == cipher->len + tag->len );
+        TEST_ASSERT( memcmp( encrypt_buf, cipher->x, cipher->len ) == 0 );
+        TEST_ASSERT( memcmp( encrypt_buf + cipher->len,
+                             tag->x, tag->len ) == 0 );
+
+        mbedtls_free( encrypt_buf );
+        encrypt_buf = NULL;
+    }
+
+    /****************************************************************
+     *                                                              *
+     *  Part 2: deprecated API                                      *
+     *                                                              *
+     ****************************************************************/
+
+#if !defined(MBEDTLS_DEPRECATED_WARNING) && \
+    !defined(MBEDTLS_DEPRECATED_REMOVED)
+
+    /*
+     * Prepare context for decryption
+     */
+    if( ! cipher_reset_key( &ctx, cipher_id, use_psa, tag->len, key,
+                            MBEDTLS_DECRYPT ) )
+        goto exit;
+
+    /*
+     * Prepare pointers for decryption
+     */
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( use_psa == 1 )
     {
-        TEST_ASSERT( 0 == mbedtls_cipher_setup_psa( &ctx,
-                              mbedtls_cipher_info_from_type( cipher_id ),
-                              tag->len ) );
+        /* PSA requires that the tag immediately follows the ciphertext.
+         * Fortunately, we already have that from testing the new API. */
+        tmp_cipher = cipher_plus_tag;
+        tmp_tag = tmp_cipher + cipher->len;
     }
     else
-#endif
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     {
-        TEST_ASSERT( 0 == mbedtls_cipher_setup( &ctx,
-                              mbedtls_cipher_info_from_type( cipher_id ) ) );
+        tmp_cipher = cipher->x;
+        tmp_tag = tag->x;
     }
-    TEST_ASSERT( 0 == mbedtls_cipher_setkey( &ctx, key->x, 8 * key->len,
-                                             MBEDTLS_ENCRYPT ) );
 
-    memset( output, 0xFF, sizeof( output ) );
+    /*
+     * Authenticate and decrypt, and check result
+     */
+
+    ASSERT_ALLOC( decrypt_buf, cipher->len );
     outlen = 0;
+    ret = mbedtls_cipher_auth_decrypt( &ctx, iv->x, iv->len, ad->x, ad->len,
+                               tmp_cipher, cipher->len, decrypt_buf, &outlen,
+                               tmp_tag, tag->len );
 
-    /* Sanity check that we don't use overly long inputs. */
-    TEST_ASSERT( sizeof( output ) >= clear->len + tag->len );
+    if( using_nist_kw )
+    {
+        /* NIST_KW with legacy API */
+        TEST_ASSERT( ret == MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
+    }
+    else if( strcmp( result, "FAIL" ) == 0 )
+    {
+        /* unauthentic message */
+        TEST_ASSERT( ret == MBEDTLS_ERR_CIPHER_AUTH_FAILED );
+        TEST_ASSERT( buffer_is_all_zero( decrypt_buf, cipher->len ) );
+    }
+    else
+    {
+        /* authentic message: is the plaintext correct? */
+        TEST_ASSERT( ret == 0 );
+        ASSERT_COMPARE(  decrypt_buf, outlen, clear->x, clear->len );
+    }
 
-    output_tag = output + clear->len;
-    ret = mbedtls_cipher_auth_encrypt( &ctx, iv->x, iv->len, ad->x, ad->len,
-                               clear->x, clear->len, output, &outlen,
-                               output_tag, tag->len );
-    TEST_ASSERT( ret == 0 );
+    mbedtls_free( decrypt_buf );
+    decrypt_buf = NULL;
+    mbedtls_free( cipher_plus_tag );
+    cipher_plus_tag = NULL;
 
-    TEST_ASSERT( outlen == cipher->len );
-    TEST_ASSERT( memcmp( output, cipher->x, cipher->len ) == 0 );
-    TEST_ASSERT( memcmp( output_tag, tag->x, tag->len ) == 0 );
+    /*
+     * Encrypt back if test data was authentic
+     */
+    if( strcmp( result, "FAIL" ) != 0 )
+    {
+        /* prepare context for encryption */
+        if( ! cipher_reset_key( &ctx, cipher_id, use_psa, tag->len, key,
+                                MBEDTLS_ENCRYPT ) )
+            goto exit;
+
+        /* prepare buffers for encryption */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        if( use_psa )
+        {
+            ASSERT_ALLOC( cipher_plus_tag, cipher->len + tag->len );
+            tmp_cipher = cipher_plus_tag;
+            tmp_tag = cipher_plus_tag + cipher->len;
+        }
+        else
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+        {
+            ASSERT_ALLOC( encrypt_buf, cipher->len );
+            ASSERT_ALLOC( tag_buf, tag->len );
+            tmp_cipher = encrypt_buf;
+            tmp_tag = tag_buf;
+        }
+
+        /*
+         * Encrypt and check the result
+         */
+        outlen = 0;
+        ret = mbedtls_cipher_auth_encrypt( &ctx, iv->x, iv->len, ad->x, ad->len,
+                                   clear->x, clear->len, tmp_cipher, &outlen,
+                                   tmp_tag, tag->len );
+
+        if( using_nist_kw )
+        {
+            TEST_ASSERT( ret == MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE );
+        }
+        else
+        {
+            TEST_ASSERT( ret == 0 );
+
+            TEST_ASSERT( outlen == cipher->len );
+            if( cipher->len != 0 )
+                TEST_ASSERT( memcmp( tmp_cipher, cipher->x, cipher->len ) == 0 );
+            TEST_ASSERT( memcmp( tmp_tag, tag->x, tag->len ) == 0 );
+        }
+    }
+
+#endif /* !MBEDTLS_DEPRECATED_WARNING && !MBEDTLS_DEPRECATED_REMOVED */
 
 exit:
 
     mbedtls_cipher_free( &ctx );
+    mbedtls_free( decrypt_buf );
+    mbedtls_free( encrypt_buf );
+    mbedtls_free( cipher_plus_tag );
+#if !defined(MBEDTLS_DEPRECATED_WARNING) && \
+    !defined(MBEDTLS_DEPRECATED_REMOVED)
+    mbedtls_free( tag_buf );
+#endif /* !MBEDTLS_DEPRECATED_WARNING && !MBEDTLS_DEPRECATED_REMOVED */
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( use_psa == 1 )
-    {
-        mbedtls_free( tmp_cipher );
         PSA_DONE( );
-    }
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 }
 /* END_CASE */