Merge pull request #1071 from gilles-peskine-arm/ssl_decrypt_stream_short_buffer

Fix buffer overread in mbedtls_ssl_decrypt_buf with stream cipher
diff --git a/ChangeLog.d/add-new-pkcs5-pbe2-ext-fun.txt b/ChangeLog.d/add-new-pkcs5-pbe2-ext-fun.txt
index a1fded3..f2e7a4a 100644
--- a/ChangeLog.d/add-new-pkcs5-pbe2-ext-fun.txt
+++ b/ChangeLog.d/add-new-pkcs5-pbe2-ext-fun.txt
@@ -1,6 +1,7 @@
 Security
-   * Developers using mbedtls_pkcs5_pbes2() should review the size of the output
-     buffer passed to this function, and note that the output after decryption
-     may include CBC padding. Consider moving to the new function
-     mbedtls_pkcs5_pbes2_ext() which checks for overflow of the output buffer
-     and reports the actual length of the output.
+   * Developers using mbedtls_pkcs5_pbes2() or mbedtls_pkcs12_pbe() should
+     review the size of the output buffer passed to this function, and note
+     that the output after decryption may include CBC padding. Consider moving
+     to the new functions mbedtls_pkcs5_pbes2_ext() or mbedtls_pkcs12_pbe_ext()
+     which checks for overflow of the output buffer and reports the actual
+     length of the output.
diff --git a/include/mbedtls/pkcs12.h b/include/mbedtls/pkcs12.h
index 8d003b5..363f812 100644
--- a/include/mbedtls/pkcs12.h
+++ b/include/mbedtls/pkcs12.h
@@ -79,7 +79,7 @@
  * \param pwd        Latin1-encoded password used. This may only be \c NULL when
  *                   \p pwdlen is 0. No null terminator should be used.
  * \param pwdlen     length of the password (may be 0)
- * \param input      the input data
+ * \param data       the input data
  * \param len        data length
  * \param output     Output buffer.
  *                   On success, it contains the encrypted or decrypted data,
@@ -96,9 +96,60 @@
 int mbedtls_pkcs12_pbe(mbedtls_asn1_buf *pbe_params, int mode,
                        mbedtls_cipher_type_t cipher_type, mbedtls_md_type_t md_type,
                        const unsigned char *pwd,  size_t pwdlen,
-                       const unsigned char *input, size_t len,
+                       const unsigned char *data, size_t len,
                        unsigned char *output);
 
+#if defined(MBEDTLS_CIPHER_PADDING_PKCS7)
+
+/**
+ * \brief            PKCS12 Password Based function (encryption / decryption)
+ *                   for cipher-based and mbedtls_md-based PBE's
+ *
+ *
+ * \warning          When decrypting:
+ *                   - This function validates the CBC padding and returns
+ *                     #MBEDTLS_ERR_PKCS12_PASSWORD_MISMATCH if the padding is
+ *                     invalid. Note that this can help active adversaries
+ *                     attempting to brute-forcing the password. Note also that
+ *                     there is no guarantee that an invalid password will be
+ *                     detected (the chances of a valid padding with a random
+ *                     password are about 1/255).
+ *
+ * \param pbe_params an ASN1 buffer containing the pkcs-12 PbeParams structure
+ * \param mode       either #MBEDTLS_PKCS12_PBE_ENCRYPT or
+ *                   #MBEDTLS_PKCS12_PBE_DECRYPT
+ * \param cipher_type the cipher used
+ * \param md_type    the mbedtls_md used
+ * \param pwd        Latin1-encoded password used. This may only be \c NULL when
+ *                   \p pwdlen is 0. No null terminator should be used.
+ * \param pwdlen     length of the password (may be 0)
+ * \param data       the input data
+ * \param len        data length
+ * \param output     Output buffer.
+ *                   On success, it contains the encrypted or decrypted data,
+ *                   possibly followed by the CBC padding.
+ *                   On failure, the content is indeterminate.
+ *                   For decryption, there must be enough room for \p len
+ *                   bytes.
+ *                   For encryption, there must be enough room for
+ *                   \p len + 1 bytes, rounded up to the block size of
+ *                   the block cipher identified by \p pbe_params.
+ * \param output_size size of output buffer.
+ *                    This must be big enough to accommodate for output plus
+ *                    padding data.
+ * \param output_len On success, length of actual data written to the output buffer.
+ *
+ * \return           0 if successful, or a MBEDTLS_ERR_XXX code
+ */
+int mbedtls_pkcs12_pbe_ext(mbedtls_asn1_buf *pbe_params, int mode,
+                           mbedtls_cipher_type_t cipher_type, mbedtls_md_type_t md_type,
+                           const unsigned char *pwd,  size_t pwdlen,
+                           const unsigned char *data, size_t len,
+                           unsigned char *output, size_t output_size,
+                           size_t *output_len);
+
+#endif /* MBEDTLS_CIPHER_PADDING_PKCS7 */
+
 #endif /* MBEDTLS_ASN1_PARSE_C */
 
 /**
diff --git a/library/ccm.c b/library/ccm.c
index cd689c8..7b29775 100644
--- a/library/ccm.c
+++ b/library/ccm.c
@@ -33,6 +33,7 @@
 #include "mbedtls/ccm.h"
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
+#include "mbedtls/constant_time.h"
 
 #include <string.h>
 
@@ -533,13 +534,8 @@
                                     const unsigned char *tag2,
                                     size_t tag_len)
 {
-    unsigned char i;
-    int diff;
-
     /* Check tag in "constant-time" */
-    for (diff = 0, i = 0; i < tag_len; i++) {
-        diff |= tag1[i] ^ tag2[i];
-    }
+    int diff = mbedtls_ct_memcmp(tag1, tag2, tag_len);
 
     if (diff != 0) {
         return MBEDTLS_ERR_CCM_AUTH_FAILED;
diff --git a/library/chachapoly.c b/library/chachapoly.c
index 0124d75..aebc646 100644
--- a/library/chachapoly.c
+++ b/library/chachapoly.c
@@ -25,6 +25,7 @@
 #include "mbedtls/chachapoly.h"
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
+#include "mbedtls/constant_time.h"
 
 #include <string.h>
 
@@ -310,7 +311,6 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char check_tag[16];
-    size_t i;
     int diff;
 
     if ((ret = chachapoly_crypt_and_tag(ctx,
@@ -320,9 +320,7 @@
     }
 
     /* Check tag in "constant-time" */
-    for (diff = 0, i = 0; i < sizeof(check_tag); i++) {
-        diff |= tag[i] ^ check_tag[i];
-    }
+    diff = mbedtls_ct_memcmp(tag, check_tag, sizeof(check_tag));
 
     if (diff != 0) {
         mbedtls_platform_zeroize(output, length);
diff --git a/library/constant_time.c b/library/constant_time.c
index 55e7f94..8b41aed 100644
--- a/library/constant_time.c
+++ b/library/constant_time.c
@@ -141,6 +141,36 @@
 #endif
 }
 
+#if defined(MBEDTLS_NIST_KW_C)
+
+int mbedtls_ct_memcmp_partial(const void *a,
+                              const void *b,
+                              size_t n,
+                              size_t skip_head,
+                              size_t skip_tail)
+{
+    unsigned int diff = 0;
+
+    volatile const unsigned char *A = (volatile const unsigned char *) a;
+    volatile const unsigned char *B = (volatile const unsigned char *) b;
+
+    size_t valid_end = n - skip_tail;
+
+    for (size_t i = 0; i < n; i++) {
+        unsigned char x = A[i], y = B[i];
+        unsigned int d = x ^ y;
+        mbedtls_ct_condition_t valid = mbedtls_ct_bool_and(mbedtls_ct_uint_ge(i, skip_head),
+                                                           mbedtls_ct_uint_lt(i, valid_end));
+        diff |= mbedtls_ct_uint_if_else_0(valid, d);
+    }
+
+    /* Since we go byte-by-byte, the only bits set will be in the bottom 8 bits, so the
+     * cast from uint to int is safe. */
+    return (int) diff;
+}
+
+#endif
+
 #if defined(MBEDTLS_PKCS1_V15) && defined(MBEDTLS_RSA_C) && !defined(MBEDTLS_RSA_ALT)
 
 void mbedtls_ct_memmove_left(void *start, size_t total, size_t offset)
diff --git a/library/constant_time_internal.h b/library/constant_time_internal.h
index 44b74ae..323edef 100644
--- a/library/constant_time_internal.h
+++ b/library/constant_time_internal.h
@@ -492,6 +492,37 @@
                          size_t n);
  */
 
+#if defined(MBEDTLS_NIST_KW_C)
+
+/** Constant-time buffer comparison without branches.
+ *
+ * Similar to mbedtls_ct_memcmp, except that the result only depends on part of
+ * the input data - differences in the head or tail are ignored. Functionally equivalent to:
+ *
+ * memcmp(a + skip_head, b + skip_head, size - skip_head - skip_tail)
+ *
+ * Time taken depends on \p n, but not on \p skip_head or \p skip_tail .
+ *
+ * Behaviour is undefined if ( \p skip_head + \p skip_tail) > \p n.
+ *
+ * \param a         Secret. Pointer to the first buffer, containing at least \p n bytes. May not be NULL.
+ * \param b         Secret. Pointer to the second buffer, containing at least \p n bytes. May not be NULL.
+ * \param n         The number of bytes to examine (total size of the buffers).
+ * \param skip_head Secret. The number of bytes to treat as non-significant at the start of the buffer.
+ *                  These bytes will still be read.
+ * \param skip_tail Secret. The number of bytes to treat as non-significant at the end of the buffer.
+ *                  These bytes will still be read.
+ *
+ * \return          Zero if the contents of the two buffers are the same, otherwise non-zero.
+ */
+int mbedtls_ct_memcmp_partial(const void *a,
+                              const void *b,
+                              size_t n,
+                              size_t skip_head,
+                              size_t skip_tail);
+
+#endif
+
 /* Include the implementation of static inline functions above. */
 #include "constant_time_impl.h"
 
diff --git a/library/gcm.c b/library/gcm.c
index 786290f..b773529 100644
--- a/library/gcm.c
+++ b/library/gcm.c
@@ -35,6 +35,7 @@
 #include "mbedtls/platform.h"
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
+#include "mbedtls/constant_time.h"
 
 #include <string.h>
 
@@ -601,7 +602,6 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char check_tag[16];
-    size_t i;
     int diff;
 
     if ((ret = mbedtls_gcm_crypt_and_tag(ctx, MBEDTLS_GCM_DECRYPT, length,
@@ -611,9 +611,7 @@
     }
 
     /* Check tag in "constant-time" */
-    for (diff = 0, i = 0; i < tag_len; i++) {
-        diff |= tag[i] ^ check_tag[i];
-    }
+    diff = mbedtls_ct_memcmp(tag, check_tag, tag_len);
 
     if (diff != 0) {
         mbedtls_platform_zeroize(output, length);
diff --git a/library/nist_kw.c b/library/nist_kw.c
index fbd7221..d73e82f 100644
--- a/library/nist_kw.c
+++ b/library/nist_kw.c
@@ -35,6 +35,7 @@
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
 #include "mbedtls/constant_time.h"
+#include "constant_time_internal.h"
 
 #include <stdint.h>
 #include <string.h>
@@ -333,9 +334,9 @@
                            unsigned char *output, size_t *out_len, size_t out_size)
 {
     int ret = 0;
-    size_t i, olen;
+    size_t olen;
     unsigned char A[KW_SEMIBLOCK_LENGTH];
-    unsigned char diff, bad_padding = 0;
+    int diff;
 
     *out_len = 0;
     if (out_size < in_len - KW_SEMIBLOCK_LENGTH) {
@@ -420,19 +421,15 @@
          * larger than 8, because of the type wrap around.
          */
         padlen = in_len - KW_SEMIBLOCK_LENGTH - Plen;
-        if (padlen > 7) {
-            padlen &= 7;
-            ret = MBEDTLS_ERR_CIPHER_AUTH_FAILED;
-        }
+        ret = -((int) mbedtls_ct_uint_if(mbedtls_ct_uint_gt(padlen, 7),
+                                         -MBEDTLS_ERR_CIPHER_AUTH_FAILED, -ret));
+        padlen &= 7;
 
         /* Check padding in "constant-time" */
-        for (diff = 0, i = 0; i < KW_SEMIBLOCK_LENGTH; i++) {
-            if (i >= KW_SEMIBLOCK_LENGTH - padlen) {
-                diff |= output[*out_len - KW_SEMIBLOCK_LENGTH + i];
-            } else {
-                bad_padding |= output[*out_len - KW_SEMIBLOCK_LENGTH + i];
-            }
-        }
+        const uint8_t zero[KW_SEMIBLOCK_LENGTH] = { 0 };
+        diff = mbedtls_ct_memcmp_partial(
+            &output[*out_len - KW_SEMIBLOCK_LENGTH], zero,
+            KW_SEMIBLOCK_LENGTH, KW_SEMIBLOCK_LENGTH - padlen, 0);
 
         if (diff != 0) {
             ret = MBEDTLS_ERR_CIPHER_AUTH_FAILED;
@@ -454,7 +451,6 @@
         *out_len = 0;
     }
 
-    mbedtls_platform_zeroize(&bad_padding, sizeof(bad_padding));
     mbedtls_platform_zeroize(&diff, sizeof(diff));
     mbedtls_platform_zeroize(A, sizeof(A));
 
diff --git a/library/pkcs12.c b/library/pkcs12.c
index 6d52305..ad0f9e6 100644
--- a/library/pkcs12.c
+++ b/library/pkcs12.c
@@ -129,18 +129,46 @@
 
 #undef PKCS12_MAX_PWDLEN
 
+#if !defined(MBEDTLS_CIPHER_PADDING_PKCS7)
+int mbedtls_pkcs12_pbe_ext(mbedtls_asn1_buf *pbe_params, int mode,
+                           mbedtls_cipher_type_t cipher_type, mbedtls_md_type_t md_type,
+                           const unsigned char *pwd,  size_t pwdlen,
+                           const unsigned char *data, size_t len,
+                           unsigned char *output, size_t output_size,
+                           size_t *output_len);
+#endif
+
 int mbedtls_pkcs12_pbe(mbedtls_asn1_buf *pbe_params, int mode,
                        mbedtls_cipher_type_t cipher_type, mbedtls_md_type_t md_type,
                        const unsigned char *pwd,  size_t pwdlen,
                        const unsigned char *data, size_t len,
                        unsigned char *output)
 {
+    size_t output_len = 0;
+
+    /* We assume caller of the function is providing a big enough output buffer
+     * so we pass output_size as SIZE_MAX to pass checks, However, no guarantees
+     * for the output size actually being correct.
+     */
+    return mbedtls_pkcs12_pbe_ext(pbe_params, mode, cipher_type, md_type,
+                                  pwd, pwdlen, data, len, output, SIZE_MAX,
+                                  &output_len);
+}
+
+int mbedtls_pkcs12_pbe_ext(mbedtls_asn1_buf *pbe_params, int mode,
+                           mbedtls_cipher_type_t cipher_type, mbedtls_md_type_t md_type,
+                           const unsigned char *pwd,  size_t pwdlen,
+                           const unsigned char *data, size_t len,
+                           unsigned char *output, size_t output_size,
+                           size_t *output_len)
+{
     int ret, keylen = 0;
     unsigned char key[32];
     unsigned char iv[16];
     const mbedtls_cipher_info_t *cipher_info;
     mbedtls_cipher_context_t cipher_ctx;
-    size_t olen = 0;
+    size_t finish_olen = 0;
+    unsigned int padlen = 0;
 
     if (pwd == NULL && pwdlen != 0) {
         return MBEDTLS_ERR_PKCS12_BAD_INPUT_DATA;
@@ -153,6 +181,19 @@
 
     keylen = (int) mbedtls_cipher_info_get_key_bitlen(cipher_info) / 8;
 
+    if (mode == MBEDTLS_PKCS12_PBE_DECRYPT) {
+        if (output_size < len) {
+            return MBEDTLS_ERR_ASN1_BUF_TOO_SMALL;
+        }
+    }
+
+    if (mode == MBEDTLS_PKCS12_PBE_ENCRYPT) {
+        padlen = cipher_info->block_size - (len % cipher_info->block_size);
+        if (output_size < (len + padlen)) {
+            return MBEDTLS_ERR_ASN1_BUF_TOO_SMALL;
+        }
+    }
+
     if ((ret = pkcs12_pbe_derive_key_iv(pbe_params, md_type, pwd, pwdlen,
                                         key, keylen,
                                         iv, mbedtls_cipher_info_get_iv_size(cipher_info))) != 0) {
@@ -201,14 +242,16 @@
     }
 
     if ((ret = mbedtls_cipher_update(&cipher_ctx, data, len,
-                                     output, &olen)) != 0) {
+                                     output, output_len)) != 0) {
         goto exit;
     }
 
-    if ((ret = mbedtls_cipher_finish(&cipher_ctx, output + olen, &olen)) != 0) {
+    if ((ret = mbedtls_cipher_finish(&cipher_ctx, output + (*output_len), &finish_olen)) != 0) {
         ret = MBEDTLS_ERR_PKCS12_PASSWORD_MISMATCH;
     }
 
+    *output_len += finish_olen;
+
 exit:
     mbedtls_platform_zeroize(key, sizeof(key));
     mbedtls_platform_zeroize(iv,  sizeof(iv));
diff --git a/library/rsa.c b/library/rsa.c
index d0782f5..02626b3 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -1541,7 +1541,8 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     size_t ilen, i, pad_len;
-    unsigned char *p, bad, pad_done;
+    unsigned char *p, pad_done;
+    int bad;
     unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
     unsigned char lhash[MBEDTLS_MD_MAX_SIZE];
     unsigned int hlen;
@@ -1608,9 +1609,8 @@
     p += hlen; /* Skip seed */
 
     /* Check lHash */
-    for (i = 0; i < hlen; i++) {
-        bad |= lhash[i] ^ *p++;
-    }
+    bad |= mbedtls_ct_memcmp(lhash, p, hlen);
+    p += hlen;
 
     /* Get zero-padding len, but always read till end of buffer
      * (minus one, for the 01 byte) */
diff --git a/tests/include/test/macros.h b/tests/include/test/macros.h
index 7edc991..3bfbe33 100644
--- a/tests/include/test/macros.h
+++ b/tests/include/test/macros.h
@@ -143,6 +143,38 @@
         }                                                   \
     } while (0)
 
+/** Allocate memory dynamically and fail the test case if this fails.
+ * The allocated memory will be filled with zeros.
+ *
+ * You must set \p pointer to \c NULL before calling this macro and
+ * put `mbedtls_free(pointer)` in the test's cleanup code.
+ *
+ * If \p item_count is zero, the resulting \p pointer will not be \c NULL.
+ *
+ * This macro expands to an instruction, not an expression.
+ * It may jump to the \c exit label.
+ *
+ * \param pointer    An lvalue where the address of the allocated buffer
+ *                   will be stored.
+ *                   This expression may be evaluated multiple times.
+ * \param item_count Number of elements to allocate.
+ *                   This expression may be evaluated multiple times.
+ *
+ * Note: if passing size 0, mbedtls_calloc may return NULL. In this case,
+ * we reattempt to allocate with the smallest possible buffer to assure a
+ * non-NULL pointer.
+ */
+#define TEST_CALLOC_NONNULL(pointer, item_count)            \
+    do {                                                    \
+        TEST_ASSERT((pointer) == NULL);                     \
+        (pointer) = mbedtls_calloc(sizeof(*(pointer)),      \
+                                   (item_count));           \
+        if (((pointer) == NULL) && ((item_count) == 0)) {   \
+            (pointer) = mbedtls_calloc(1, 1);               \
+        }                                                   \
+        TEST_ASSERT((pointer) != NULL);                     \
+    } while (0)
+
 /* For backwards compatibility */
 #define ASSERT_ALLOC(pointer, item_count) TEST_CALLOC(pointer, item_count)
 
diff --git a/tests/suites/test_suite_constant_time.data b/tests/suites/test_suite_constant_time.data
index 785bdec..82ee869 100644
--- a/tests/suites/test_suite_constant_time.data
+++ b/tests/suites/test_suite_constant_time.data
@@ -702,3 +702,69 @@
 
 mbedtls_ct_memmove_left 16 16
 mbedtls_ct_memmove_left:16:16
+
+mbedtls_ct_memcmp_partial -1 0 0 0
+mbedtls_ct_memcmp_partial:-1:0:0:0
+
+mbedtls_ct_memcmp_partial 0 1 0 0
+mbedtls_ct_memcmp_partial:0:1:0:0
+
+mbedtls_ct_memcmp_partial 0 1 1 0
+mbedtls_ct_memcmp_partial:0:1:1:0
+
+mbedtls_ct_memcmp_partial 0 1 0 1
+mbedtls_ct_memcmp_partial:0:1:0:1
+
+mbedtls_ct_memcmp_partial -1 1 0 0
+mbedtls_ct_memcmp_partial:-1:1:0:0
+
+mbedtls_ct_memcmp_partial 0 2 0 1
+mbedtls_ct_memcmp_partial:0:2:0:1
+
+mbedtls_ct_memcmp_partial 0 2 1 0
+mbedtls_ct_memcmp_partial:0:2:1:0
+
+mbedtls_ct_memcmp_partial 0 16 4 4
+mbedtls_ct_memcmp_partial:0:16:4:4
+
+mbedtls_ct_memcmp_partial 2 16 4 4
+mbedtls_ct_memcmp_partial:2:16:4:4
+
+mbedtls_ct_memcmp_partial 3 16 4 4
+mbedtls_ct_memcmp_partial:3:16:4:4
+
+mbedtls_ct_memcmp_partial 4 16 4 4
+mbedtls_ct_memcmp_partial:4:16:4:4
+
+mbedtls_ct_memcmp_partial 7 16 4 4
+mbedtls_ct_memcmp_partial:7:16:4:4
+
+mbedtls_ct_memcmp_partial 11 16 4 4
+mbedtls_ct_memcmp_partial:11:16:4:4
+
+mbedtls_ct_memcmp_partial 12 16 4 4
+mbedtls_ct_memcmp_partial:12:16:4:4
+
+mbedtls_ct_memcmp_partial 15 16 4 4
+mbedtls_ct_memcmp_partial:15:16:4:4
+
+mbedtls_ct_memcmp_partial 15 16 4 0
+mbedtls_ct_memcmp_partial:15:16:4:0
+
+mbedtls_ct_memcmp_partial 15 16 0 4
+mbedtls_ct_memcmp_partial:15:16:0:4
+
+mbedtls_ct_memcmp_partial 0 16 0 0
+mbedtls_ct_memcmp_partial:0:16:0:0
+
+mbedtls_ct_memcmp_partial 15 16 0 0
+mbedtls_ct_memcmp_partial:15:16:0:0
+
+mbedtls_ct_memcmp_partial -1 16 0 0
+mbedtls_ct_memcmp_partial:-1:16:0:0
+
+mbedtls_ct_memcmp_partial -1 16 12 4
+mbedtls_ct_memcmp_partial:-1:16:12:4
+
+mbedtls_ct_memcmp_partial -1 16 8 8
+mbedtls_ct_memcmp_partial:-1:16:8:8
diff --git a/tests/suites/test_suite_constant_time.function b/tests/suites/test_suite_constant_time.function
index 5dbba06..087ccdf 100644
--- a/tests/suites/test_suite_constant_time.function
+++ b/tests/suites/test_suite_constant_time.function
@@ -259,6 +259,55 @@
 }
 /* END_CASE */
 
+/* BEGIN_CASE depends_on:MBEDTLS_NIST_KW_C */
+
+/**
+ * Generate two arrays of the given size, and test mbedtls_ct_memcmp_partial
+ * over them. The arrays will be identical, except that one byte may be specified
+ * to be different.
+ *
+ * \p diff      Index of byte that differs (if out of range, the arrays will match).
+ * \p size      Size of arrays to compare
+ * \p skip_head Leading bytes to skip, as per mbedtls_ct_memcmp_partial
+ * \p skip_tail Trailing bytes to skip, as per mbedtls_ct_memcmp_partial
+ */
+void mbedtls_ct_memcmp_partial(int diff, int size, int skip_head, int skip_tail)
+{
+    uint8_t *a = NULL, *b = NULL;
+
+    TEST_CALLOC_NONNULL(a, size);
+    TEST_CALLOC_NONNULL(b, size);
+
+    TEST_ASSERT((skip_head + skip_tail) <= size);
+
+    /* Construct data that matches, except for specified byte (if in range). */
+    for (int i = 0; i < size; i++) {
+        a[i] = i & 0xff;
+        b[i] = a[i];
+        if (i == diff) {
+            // modify the specified byte
+            b[i] ^= 1;
+        }
+    }
+
+    int reference = memcmp(a + skip_head, b + skip_head, size - skip_head - skip_tail);
+
+    TEST_CF_SECRET(a, size);
+    TEST_CF_SECRET(b, size);
+
+    int actual = mbedtls_ct_memcmp_partial(a, b, size, skip_head, skip_tail);
+
+    TEST_CF_PUBLIC(a, size);
+    TEST_CF_PUBLIC(b, size);
+    TEST_CF_PUBLIC(&actual, sizeof(actual));
+
+    TEST_EQUAL(!!reference, !!actual);
+exit:
+    mbedtls_free(a);
+    mbedtls_free(b);
+}
+/* END_CASE */
+
 /* BEGIN_CASE */
 void mbedtls_ct_memcpy_if(int eq, int size, int offset)
 {
diff --git a/tests/suites/test_suite_pkcs12.data b/tests/suites/test_suite_pkcs12.data
index 074c6c8..c4e4d77 100644
--- a/tests/suites/test_suite_pkcs12.data
+++ b/tests/suites/test_suite_pkcs12.data
@@ -36,28 +36,36 @@
 
 PBE Encrypt, pad = 7 (OK)
 depends_on:MBEDTLS_MD_CAN_SHA1:MBEDTLS_DES_C:MBEDTLS_CIPHER_MODE_CBC:MBEDTLS_CIPHER_PADDING_PKCS7
-pkcs12_pbe_encrypt:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"300E0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"AAAAAAAAAAAAAAAAAA":0:"5F2C15056A36F3A78856E9E662DD27CB"
+pkcs12_pbe_encrypt:MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"AAAAAAAAAAAAAAAAAA":16:0:"5F2C15056A36F3A78856E9E662DD27CB"
 
 PBE Encrypt, pad = 8 (OK)
 depends_on:MBEDTLS_MD_CAN_SHA1:MBEDTLS_DES_C:MBEDTLS_CIPHER_MODE_CBC:MBEDTLS_CIPHER_PADDING_PKCS7
-pkcs12_pbe_encrypt:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"300E0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"AAAAAAAAAAAAAAAA":0:"5F2C15056A36F3A70F70A3D4EC4004A8"
+pkcs12_pbe_encrypt:MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"AAAAAAAAAAAAAAAA":16:0:"5F2C15056A36F3A70F70A3D4EC4004A8"
+
+PBE Encrypt, pad = 8 (Invalid output size)
+depends_on:MBEDTLS_MD_CAN_SHA1:MBEDTLS_DES_C:MBEDTLS_CIPHER_MODE_CBC:MBEDTLS_CIPHER_PADDING_PKCS7
+pkcs12_pbe_encrypt:MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"AAAAAAAAAAAAAAAA":15:MBEDTLS_ERR_ASN1_BUF_TOO_SMALL:"5F2C15056A36F3A70F70A3D4EC4004A8"
 
 PBE Encrypt, pad = 8 (PKCS7 padding disabled)
 depends_on:MBEDTLS_MD_CAN_SHA1:MBEDTLS_DES_C:MBEDTLS_CIPHER_MODE_CBC:!MBEDTLS_CIPHER_PADDING_PKCS7
-pkcs12_pbe_encrypt:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"300E0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"AAAAAAAAAAAAAAAA":MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE:""
+pkcs12_pbe_encrypt:MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"AAAAAAAAAAAAAAAA":0:MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE:""
 
 PBE Decrypt, pad = 7 (OK)
 depends_on:MBEDTLS_MD_CAN_SHA1:MBEDTLS_DES_C:MBEDTLS_CIPHER_MODE_CBC:MBEDTLS_CIPHER_PADDING_PKCS7
-pkcs12_pbe_decrypt:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"300E0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"5F2C15056A36F3A78856E9E662DD27CB":0:"AAAAAAAAAAAAAAAAAA"
+pkcs12_pbe_decrypt:MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"5F2C15056A36F3A78856E9E662DD27CB":16:0:"AAAAAAAAAAAAAAAAAA"
+
+PBE Decrypt, pad = 8 (Invalid output size)
+depends_on:MBEDTLS_MD_CAN_SHA1:MBEDTLS_DES_C:MBEDTLS_CIPHER_MODE_CBC:MBEDTLS_CIPHER_PADDING_PKCS7
+pkcs12_pbe_decrypt:MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"5F2C15056A36F3A70F70A3D4EC4004A8":15:MBEDTLS_ERR_ASN1_BUF_TOO_SMALL:"AAAAAAAAAAAAAAAA"
 
 PBE Decrypt, pad = 8 (OK)
 depends_on:MBEDTLS_MD_CAN_SHA1:MBEDTLS_DES_C:MBEDTLS_CIPHER_MODE_CBC:MBEDTLS_CIPHER_PADDING_PKCS7
-pkcs12_pbe_decrypt:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"300E0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"5F2C15056A36F3A70F70A3D4EC4004A8":0:"AAAAAAAAAAAAAAAA"
+pkcs12_pbe_decrypt:MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"5F2C15056A36F3A70F70A3D4EC4004A8":16:0:"AAAAAAAAAAAAAAAA"
 
 PBE Decrypt, (Invalid padding & PKCS7 padding disabled)
 depends_on:MBEDTLS_MD_CAN_SHA1:MBEDTLS_DES_C:MBEDTLS_CIPHER_MODE_CBC:!MBEDTLS_CIPHER_PADDING_PKCS7
-pkcs12_pbe_decrypt:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"300E0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"5F2C15056A36F3A79F2B90F1428110E2":0:"AAAAAAAAAAAAAAAAAA07070707070708"
+pkcs12_pbe_decrypt:MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"5F2C15056A36F3A79F2B90F1428110E2":16:0:"AAAAAAAAAAAAAAAAAA07070707070708"
 
 PBE Decrypt, (Invalid padding & PKCS7 padding enabled)
 depends_on:MBEDTLS_MD_CAN_SHA1:MBEDTLS_DES_C:MBEDTLS_CIPHER_MODE_CBC:MBEDTLS_CIPHER_PADDING_PKCS7
-pkcs12_pbe_decrypt:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"300E0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"5F2C15056A36F3A79F2B90F1428110E2":MBEDTLS_ERR_PKCS12_PASSWORD_MISMATCH:"AAAAAAAAAAAAAAAAAA07070707070708"
+pkcs12_pbe_decrypt:MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE:MBEDTLS_CIPHER_DES_EDE3_CBC:MBEDTLS_MD_SHA1:"0409CCCCCCCCCCCCCCCCCC02010A":"BBBBBBBBBBBBBBBBBB":"5F2C15056A36F3A79F2B90F1428110E2":16:MBEDTLS_ERR_PKCS12_PASSWORD_MISMATCH:"AAAAAAAAAAAAAAAAAA07070707070708"
diff --git a/tests/suites/test_suite_pkcs12.function b/tests/suites/test_suite_pkcs12.function
index b557098..3e8ff5b 100644
--- a/tests/suites/test_suite_pkcs12.function
+++ b/tests/suites/test_suite_pkcs12.function
@@ -70,33 +70,52 @@
 /* END_CASE */
 
 /* BEGIN_CASE depends_on:MBEDTLS_ASN1_PARSE_C */
-void pkcs12_pbe_encrypt(int cipher, int md, data_t *params_hex, data_t *pw,
-                        data_t *data, int ref_ret, data_t *ref_out)
+void pkcs12_pbe_encrypt(int params_tag, int cipher, int md, data_t *params_hex, data_t *pw,
+                        data_t *data, int outsize, int ref_ret, data_t *ref_out)
 {
     int my_ret;
     mbedtls_asn1_buf pbe_params;
     unsigned char *my_out = NULL;
     mbedtls_cipher_type_t cipher_alg = (mbedtls_cipher_type_t) cipher;
     mbedtls_md_type_t md_alg = (mbedtls_md_type_t) md;
-    size_t block_size;
+#if defined(MBEDTLS_CIPHER_PADDING_PKCS7)
+    size_t my_out_len = 0;
+#endif
 
     MD_PSA_INIT();
 
-    block_size = mbedtls_cipher_info_get_block_size(mbedtls_cipher_info_from_type(cipher_alg));
-    TEST_CALLOC(my_out, ((data->len/block_size) + 1) * block_size);
+    TEST_CALLOC(my_out, outsize);
 
-    pbe_params.tag = params_hex->x[0];
-    pbe_params.len = params_hex->x[1];
-    pbe_params.p = params_hex->x + 2;
+    pbe_params.tag = params_tag;
+    pbe_params.len = params_hex->len;
+    pbe_params.p = params_hex->x;
 
-    my_ret = mbedtls_pkcs12_pbe(&pbe_params, MBEDTLS_PKCS12_PBE_ENCRYPT, cipher_alg,
-                                md_alg, pw->x, pw->len, data->x, data->len, my_out);
-    TEST_EQUAL(my_ret, ref_ret);
+    if (ref_ret != MBEDTLS_ERR_ASN1_BUF_TOO_SMALL) {
+        my_ret = mbedtls_pkcs12_pbe(&pbe_params, MBEDTLS_PKCS12_PBE_ENCRYPT, cipher_alg,
+                                    md_alg, pw->x, pw->len, data->x, data->len, my_out);
+        TEST_EQUAL(my_ret, ref_ret);
+    }
     if (ref_ret == 0) {
         ASSERT_COMPARE(my_out, ref_out->len,
                        ref_out->x, ref_out->len);
     }
 
+#if defined(MBEDTLS_CIPHER_PADDING_PKCS7)
+
+    pbe_params.tag = params_tag;
+    pbe_params.len = params_hex->len;
+    pbe_params.p = params_hex->x;
+
+    my_ret = mbedtls_pkcs12_pbe_ext(&pbe_params, MBEDTLS_PKCS12_PBE_ENCRYPT, cipher_alg,
+                                    md_alg, pw->x, pw->len, data->x, data->len, my_out,
+                                    outsize, &my_out_len);
+    TEST_EQUAL(my_ret, ref_ret);
+    if (ref_ret == 0) {
+        ASSERT_COMPARE(my_out, my_out_len,
+                       ref_out->x, ref_out->len);
+    }
+#endif
+
 exit:
     mbedtls_free(my_out);
     MD_PSA_DONE();
@@ -104,31 +123,53 @@
 /* END_CASE */
 
 /* BEGIN_CASE depends_on:MBEDTLS_ASN1_PARSE_C */
-void pkcs12_pbe_decrypt(int cipher, int md, data_t *params_hex, data_t *pw,
-                        data_t *data, int ref_ret, data_t *ref_out)
+void pkcs12_pbe_decrypt(int params_tag, int cipher, int md, data_t *params_hex, data_t *pw,
+                        data_t *data, int outsize, int ref_ret, data_t *ref_out)
 {
     int my_ret;
     mbedtls_asn1_buf pbe_params;
     unsigned char *my_out = NULL;
     mbedtls_cipher_type_t cipher_alg = (mbedtls_cipher_type_t) cipher;
     mbedtls_md_type_t md_alg = (mbedtls_md_type_t) md;
+#if defined(MBEDTLS_CIPHER_PADDING_PKCS7)
+    size_t my_out_len = 0;
+#endif
 
     MD_PSA_INIT();
 
-    TEST_CALLOC(my_out, data->len);
+    TEST_CALLOC(my_out, outsize);
 
-    pbe_params.tag = params_hex->x[0];
-    pbe_params.len = params_hex->x[1];
-    pbe_params.p = params_hex->x + 2;
+    pbe_params.tag = params_tag;
+    pbe_params.len = params_hex->len;
+    pbe_params.p = params_hex->x;
 
-    my_ret = mbedtls_pkcs12_pbe(&pbe_params, MBEDTLS_PKCS12_PBE_DECRYPT, cipher_alg,
-                                md_alg, pw->x, pw->len, data->x, data->len, my_out);
-    TEST_EQUAL(my_ret, ref_ret);
+    if (ref_ret != MBEDTLS_ERR_ASN1_BUF_TOO_SMALL) {
+        my_ret = mbedtls_pkcs12_pbe(&pbe_params, MBEDTLS_PKCS12_PBE_DECRYPT, cipher_alg,
+                                    md_alg, pw->x, pw->len, data->x, data->len, my_out);
+        TEST_EQUAL(my_ret, ref_ret);
+    }
+
     if (ref_ret == 0) {
         ASSERT_COMPARE(my_out, ref_out->len,
                        ref_out->x, ref_out->len);
     }
 
+#if defined(MBEDTLS_CIPHER_PADDING_PKCS7)
+
+    pbe_params.tag = params_tag;
+    pbe_params.len = params_hex->len;
+    pbe_params.p = params_hex->x;
+
+    my_ret = mbedtls_pkcs12_pbe_ext(&pbe_params, MBEDTLS_PKCS12_PBE_DECRYPT, cipher_alg,
+                                    md_alg, pw->x, pw->len, data->x, data->len, my_out,
+                                    outsize, &my_out_len);
+    TEST_EQUAL(my_ret, ref_ret);
+    if (ref_ret == 0) {
+        ASSERT_COMPARE(my_out, my_out_len,
+                       ref_out->x, ref_out->len);
+    }
+#endif
+
 exit:
     mbedtls_free(my_out);
     MD_PSA_DONE();