mbedtls_base64_decode: insist on correct padding

Correct base64 input (excluding ignored characters such as spaces) consists
of exactly 4*k, 4*k-1 or 4*k-2 digits, followed by 0, 1 or 2 equal signs
respectively.

Previously, any number of trailing equal signs up to 2 was accepted, but if
there fewer than 4*k digits-or-equals, the last partial block was counted in
`*olen` in buffer-too-small mode, but was not output despite returning 0.

Now `mbedtls_base64_decode()` insists on correct padding. This is
backward-compatible since the only plausible useful inputs that used to be
accepted were inputs with 4*k-1 or 4*k-2 digits and no trailing equal signs,
and those led to invalid (truncated) output. Furthermore the function now
always reports the exact output size in buffer-too-small mode.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/ChangeLog.d/base64_decode.txt b/ChangeLog.d/base64_decode.txt
index 93dfef1..2cd2c59 100644
--- a/ChangeLog.d/base64_decode.txt
+++ b/ChangeLog.d/base64_decode.txt
@@ -1,3 +1,8 @@
 Bugfix
-   * Fix mbedtls_base64_decode() accepting invalid inputs with 4n+1 digits
-     (the last digit was ignored).
+   * Fix mbedtls_base64_decode() on inputs that did not have the correct
+     number of trailing equal signs, or had 4*k+1 digits. They were accepted
+     as long as they had at most two trailing equal signs. They are now
+     rejected. Furthermore, before, on inputs with too few equal signs, the
+     function reported the correct size in *olen when it returned
+     MBEDTLS_ERR_BASE64_BUFFER_TOO_SMALL, but truncated the output to the
+     last multiple of 3 bytes.
diff --git a/library/base64.c b/library/base64.c
index bff9123..cc6a73d 100644
--- a/library/base64.c
+++ b/library/base64.c
@@ -14,6 +14,7 @@
 #include "mbedtls/base64.h"
 #include "base64_internal.h"
 #include "constant_time_internal.h"
+#include "mbedtls/error.h"
 
 #include <stdint.h>
 
@@ -183,55 +184,57 @@
         n++;
     }
 
-    /* In valid base64, the number of digits is always of the form
-     * 4n, 4n+2 or 4n+3. */
+    /* In valid base64, the number of digits (n-equals) is always of the form
+     * 4*k, 4*k+2 or *4k+3. Also, the number n of digits plus the number of
+     * equal signs at the end is always a multiple of 4. */
     if ((n - equals) % 4 == 1) {
         return MBEDTLS_ERR_BASE64_INVALID_CHARACTER;
     }
-
-    if (n == 0) {
-        *olen = 0;
-        return 0;
+    if (n % 4 != 0) {
+        return MBEDTLS_ERR_BASE64_INVALID_CHARACTER;
     }
 
-    /* The following expression is to calculate the following formula without
-     * risk of integer overflow in n:
-     *     n = ( ( n * 6 ) + 7 ) >> 3;
-     */
-    n = (6 * (n >> 3)) + ((6 * (n & 0x7) + 7) >> 3);
-    n -= equals;
+    /* We've determined that the input is valid, and that it contains
+     * n digits-plus-trailing-equal-signs, which means (n - equals) digits.
+     * Now set *olen to the exact length of the output. */
+    /* Each block of 4 digits in the input map to 3 bytes of output.
+     * The last block can have one or two equal signs, in which case
+     * there are that many fewer output bytes. */
+    *olen = (n / 4) * 3 - equals;
 
-    if (dst == NULL || dlen < n) {
-        *olen = n;
+    if ((*olen != 0 && dst == NULL) || dlen < *olen) {
         return MBEDTLS_ERR_BASE64_BUFFER_TOO_SMALL;
     }
 
-    equals = 0;
     for (x = 0, p = dst; i > 0; i--, src++) {
         if (*src == '\r' || *src == '\n' || *src == ' ') {
             continue;
         }
-
-        x = x << 6;
         if (*src == '=') {
-            ++equals;
-        } else {
-            x |= mbedtls_ct_base64_dec_value(*src);
+            /* We already know from the first loop that equal signs are
+             * only at the end. */
+            break;
         }
+        x = x << 6;
+        x |= mbedtls_ct_base64_dec_value(*src);
 
         if (++accumulated_digits == 4) {
             accumulated_digits = 0;
             *p++ = MBEDTLS_BYTE_2(x);
-            if (equals <= 1) {
-                *p++ = MBEDTLS_BYTE_1(x);
-            }
-            if (equals <= 0) {
-                *p++ = MBEDTLS_BYTE_0(x);
-            }
+            *p++ = MBEDTLS_BYTE_1(x);
+            *p++ = MBEDTLS_BYTE_0(x);
         }
     }
+    if (accumulated_digits == 3) {
+        *p++ = MBEDTLS_BYTE_2(x << 6);
+        *p++ = MBEDTLS_BYTE_1(x << 6);
+    } else if (accumulated_digits == 2) {
+        *p++ = MBEDTLS_BYTE_2(x << 12);
+    }
 
-    *olen = (size_t) (p - dst);
+    if (*olen != (size_t) (p - dst)) {
+        return MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    }
 
     return 0;
 }
diff --git a/tests/suites/test_suite_base64.data b/tests/suites/test_suite_base64.data
index 4d3b5b9..547b9fd 100644
--- a/tests/suites/test_suite_base64.data
+++ b/tests/suites/test_suite_base64.data
@@ -76,31 +76,25 @@
 Base64 decode (Space inside string)
 mbedtls_base64_decode:"zm masd":"":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
-# Validate the weird behavior of mbedtls_base64_decode() on some
-# invalid inputs (number of digts + equals not a multiple of 4).
-# In the reference output, "!" characters at the end are needed to
-# pad the output buffer, but the actual output omits those. E.g. if
-# dst_string is "ab!" then mbedtls_base64_decode() reports a 3-byte
-# output when dlen < 3, but actually outputs 2 bytes if given a
-# buffer of 3 bytes or more.
-
+# The next few test cases validate systematically for short inputs that
+# we require the correct number of trailing equal signs.
 Base64 decode: 1 digit, 0 equals (bad)
 mbedtls_base64_decode:"Y":"":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 1 digit, 1 equals (bad)
-mbedtls_base64_decode:"Y":"":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
+mbedtls_base64_decode:"Y=":"":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 1 digit, 2 equals (bad)
 mbedtls_base64_decode:"Y==":"":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 1 digit, 3 equals (bad)
-mbedtls_base64_decode:"Y===":"!":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
+mbedtls_base64_decode:"Y===":"":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
-Base64 decode: 2 digits, 0 equals (sloppily accepted)
-mbedtls_base64_decode:"Yw":"!!":0
+Base64 decode: 2 digits, 0 equals (bad)
+mbedtls_base64_decode:"Yw":"c":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
-Base64 decode: 2 digits, 1 equals (sloppily accepted)
-mbedtls_base64_decode:"Yw=":"!!":0
+Base64 decode: 2 digits, 1 equals (bad)
+mbedtls_base64_decode:"Yw=":"c":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 2 digits, 2 equals (good)
 mbedtls_base64_decode:"Yw==":"c":0
@@ -108,14 +102,14 @@
 Base64 decode: 2 digits, 3 equals (bad)
 mbedtls_base64_decode:"Yw===":"c":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
-Base64 decode: 3 digits, 0 equals (sloppily accepted)
-mbedtls_base64_decode:"Y28":"!!!":0
+Base64 decode: 3 digits, 0 equals (bad)
+mbedtls_base64_decode:"Y28":"co":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 3 digits, 1 equals (good)
 mbedtls_base64_decode:"Y28=":"co":0
 
-Base64 decode: 3 digits, 2 equals (sloppily accepted)
-mbedtls_base64_decode:"Y28==":"co":0
+Base64 decode: 3 digits, 2 equals (bad)
+mbedtls_base64_decode:"Y28==":"co":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 3 digits, 3 equals (bad)
 mbedtls_base64_decode:"Y28===":"co":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
@@ -123,11 +117,11 @@
 Base64 decode: 4 digits, 0 equals (good)
 mbedtls_base64_decode:"Y29t":"com":0
 
-Base64 decode: 4 digits, 1 equals (sloppily accepted)
-mbedtls_base64_decode:"Y29t=":"com":0
+Base64 decode: 4 digits, 1 equals (bad)
+mbedtls_base64_decode:"Y29t=":"com":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
-Base64 decode: 4 digits, 2 equals (sloppily accepted)
-mbedtls_base64_decode:"Y29t==":"com":0
+Base64 decode: 4 digits, 2 equals (bad)
+mbedtls_base64_decode:"Y29t==":"com":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 4 digits, 3 equals (bad)
 mbedtls_base64_decode:"Y29t===":"com":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
@@ -142,13 +136,13 @@
 mbedtls_base64_decode:"Y29tc==":"":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 5 digits, 3 equals (bad)
-mbedtls_base64_decode:"Y29tc===":"com!":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
+mbedtls_base64_decode:"Y29tc===":"com":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
-Base64 decode: 6 digits, 0 equals (sloppily accepted)
-mbedtls_base64_decode:"Y29tcA":"com!!":0
+Base64 decode: 6 digits, 0 equals (bad)
+mbedtls_base64_decode:"Y29tcA":"comp":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
-Base64 decode: 6 digits, 1 equals (sloppily accepted)
-mbedtls_base64_decode:"Y29tcA=":"com!!":0
+Base64 decode: 6 digits, 1 equals (bad)
+mbedtls_base64_decode:"Y29tcA=":"comp":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 6 digits, 2 equals (good)
 mbedtls_base64_decode:"Y29tcA==":"comp":0
@@ -156,14 +150,14 @@
 Base64 decode: 6 digits, 3 equals (bad)
 mbedtls_base64_decode:"Y29tcA===":"comp":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
-Base64 decode: 7 digits, 0 equals (sloppily accepted)
-mbedtls_base64_decode:"Y29tcG8":"com!!!":0
+Base64 decode: 7 digits, 0 equals (bad)
+mbedtls_base64_decode:"Y29tcG8":"compo":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 7 digits, 1 equals (good)
 mbedtls_base64_decode:"Y29tcG8=":"compo":0
 
-Base64 decode: 7 digits, 2 equals (sloppily accepted)
-mbedtls_base64_decode:"Y29tcG8==":"compo":0
+Base64 decode: 7 digits, 2 equals (bad)
+mbedtls_base64_decode:"Y29tcG8==":"compo":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 7 digits, 3 equals (bad)
 mbedtls_base64_decode:"Y29tcG8===":"compo":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
@@ -171,11 +165,11 @@
 Base64 decode: 8 digits, 0 equals (good)
 mbedtls_base64_decode:"Y29tcG9z":"compos":0
 
-Base64 decode: 8 digits, 1 equals (sloppily accepted)
-mbedtls_base64_decode:"Y29tcG9z=":"compos":0
+Base64 decode: 8 digits, 1 equals (bad)
+mbedtls_base64_decode:"Y29tcG9z=":"compos":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
-Base64 decode: 8 digits, 2 equals (sloppily accepted)
-mbedtls_base64_decode:"Y29tcG9z==":"compos":0
+Base64 decode: 8 digits, 2 equals (bad)
+mbedtls_base64_decode:"Y29tcG9z==":"compos":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
 
 Base64 decode: 8 digits, 3 equals (bad)
 mbedtls_base64_decode:"Y29tcG9z===":"compos":MBEDTLS_ERR_BASE64_INVALID_CHARACTER
diff --git a/tests/suites/test_suite_base64.function b/tests/suites/test_suite_base64.function
index 182be29..8c948b4 100644
--- a/tests/suites/test_suite_base64.function
+++ b/tests/suites/test_suite_base64.function
@@ -99,25 +99,12 @@
 
     TEST_CALLOC(dst, dst_size);
 
-    /* Validate broken behavior observed on Mbed TLS 3.6.3:
-     * some invalid inputs are accepted, and asking for the decoded length
-     * gives a figure that's longer than the decoded output.
-     * In the test data, trailing "!" characters in dst_string indicate
-     * padding that must be present in the output buffer length, but
-     * will not be present in the actual output when the output buffer
-     * is large enough.
-     */
-    size_t expected_dst_len = correct_dst_len;
-    while (expected_dst_len > 0 && dst_string[expected_dst_len - 1] == '!') {
-        --expected_dst_len;
-    }
-
     /* Test normal operation */
     TEST_EQUAL(mbedtls_base64_decode(dst, dst_size, &len,
                                      src, src_len),
                result);
     if (result == 0) {
-        TEST_MEMORY_COMPARE(dst_string, expected_dst_len, dst, len);
+        TEST_MEMORY_COMPARE(dst_string, correct_dst_len, dst, len);
     }
 
     /* Test an output buffer that's one byte too small */