Merge pull request #3638 from ARMmbed/better-cf-padding-checks

  Better constant-flow idioms for TLS-CBC padding checks
diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index 2ea3580..981d94e 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -1045,21 +1045,86 @@
 
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_TLS_CBC)
 /*
- * Constant-flow conditional memcpy:
- *  - if c1 == c2, equivalent to memcpy(dst, src, len),
- *  - otherwise, a no-op,
- * but with execution flow independent of the values of c1 and c2.
+ * Turn a bit into a mask:
+ * - if bit == 1, return the all-bits 1 mask, aka (size_t) -1
+ * - if bit == 0, return the all-bits 0 mask, aka 0
  *
- * Use only bit operations to avoid branches that could be used by some
- * compilers on some platforms to translate comparison operators.
+ * This function can be used to write constant-time code by replacing branches
+ * with bit operations using masks.
+ *
+ * This function is implemented without using comparison operators, as those
+ * might be translated to branches by some compilers on some platforms.
  */
-static void mbedtls_ssl_cf_memcpy_if_eq( unsigned char *dst,
-                                         const unsigned char *src,
-                                         size_t len,
-                                         size_t c1, size_t c2 )
+static size_t mbedtls_ssl_cf_mask_from_bit( size_t bit )
 {
-    /* diff = 0 if c1 == c2, non-zero otherwise */
-    const size_t diff = c1 ^ c2;
+    /* MSVC has a warning about unary minus on unsigned integer types,
+     * but this is well-defined and precisely what we want to do here. */
+#if defined(_MSC_VER)
+#pragma warning( push )
+#pragma warning( disable : 4146 )
+#endif
+    return -bit;
+#if defined(_MSC_VER)
+#pragma warning( pop )
+#endif
+}
+
+/*
+ * Constant-flow mask generation for "less than" comparison:
+ * - if x < y,  return all bits 1, that is (size_t) -1
+ * - otherwise, return all bits 0, that is 0
+ *
+ * This function can be used to write constant-time code by replacing branches
+ * with bit operations using masks.
+ *
+ * This function is implemented without using comparison operators, as those
+ * might be translated to branches by some compilers on some platforms.
+ */
+static size_t mbedtls_ssl_cf_mask_lt( size_t x, size_t y )
+{
+    /* This has the most significant bit set if and only if x < y */
+    const size_t sub = x - y;
+
+    /* sub1 = (x < y) ? 1 : 0 */
+    const size_t sub1 = sub >> ( sizeof( sub ) * 8 - 1 );
+
+    /* mask = (x < y) ? 0xff... : 0x00... */
+    const size_t mask = mbedtls_ssl_cf_mask_from_bit( sub1 );
+
+    return( mask );
+}
+
+/*
+ * Constant-flow mask generation for "greater or equal" comparison:
+ * - if x >= y, return all bits 1, that is (size_t) -1
+ * - otherwise, return all bits 0, that is 0
+ *
+ * This function can be used to write constant-time code by replacing branches
+ * with bit operations using masks.
+ *
+ * This function is implemented without using comparison operators, as those
+ * might be translated to branches by some compilers on some platforms.
+ */
+static size_t mbedtls_ssl_cf_mask_ge( size_t x, size_t y )
+{
+    return( ~mbedtls_ssl_cf_mask_lt( x, y ) );
+}
+
+/*
+ * Constant-flow boolean "equal" comparison:
+ * return x == y
+ *
+ * This function can be used to write constant-time code by replacing branches
+ * with bit operations - it can be used in conjunction with
+ * mbedtls_ssl_cf_mask_from_bit().
+ *
+ * This function is implemented without using comparison operators, as those
+ * might be translated to branches by some compilers on some platforms.
+ */
+static size_t mbedtls_ssl_cf_bool_eq( size_t x, size_t y )
+{
+    /* diff = 0 if x == y, non-zero otherwise */
+    const size_t diff = x ^ y;
 
     /* MSVC has a warning about unary minus on unsigned integer types,
      * but this is well-defined and precisely what we want to do here. */
@@ -1068,22 +1133,40 @@
 #pragma warning( disable : 4146 )
 #endif
 
-    /* diff_msb's most significant bit is equal to c1 != c2 */
+    /* diff_msb's most significant bit is equal to x != y */
     const size_t diff_msb = ( diff | -diff );
 
-    /* diff1 = c1 != c2 */
-    const size_t diff1 = diff_msb >> ( sizeof( diff_msb ) * 8 - 1 );
-
-    /* mask = c1 != c2 ? 0xff : 0x00 */
-    const unsigned char mask = (unsigned char) -diff1;
-
 #if defined(_MSC_VER)
 #pragma warning( pop )
 #endif
 
-    /* dst[i] = c1 != c2 ? dst[i] : src[i] */
+    /* diff1 = (x != y) ? 1 : 0 */
+    const size_t diff1 = diff_msb >> ( sizeof( diff_msb ) * 8 - 1 );
+
+    return( 1 ^ diff1 );
+}
+
+/*
+ * Constant-flow conditional memcpy:
+ *  - if c1 == c2, equivalent to memcpy(dst, src, len),
+ *  - otherwise, a no-op,
+ * but with execution flow independent of the values of c1 and c2.
+ *
+ * This function is implemented without using comparison operators, as those
+ * might be translated to branches by some compilers on some platforms.
+ */
+static void mbedtls_ssl_cf_memcpy_if_eq( unsigned char *dst,
+                                         const unsigned char *src,
+                                         size_t len,
+                                         size_t c1, size_t c2 )
+{
+    /* mask = c1 == c2 ? 0xff : 0x00 */
+    const size_t equal = mbedtls_ssl_cf_bool_eq( c1, c2 );
+    const unsigned char mask = (unsigned char) mbedtls_ssl_cf_mask_from_bit( equal );
+
+    /* dst[i] = c1 == c2 ? src[i] : dst[i] */
     for( size_t i = 0; i < len; i++ )
-        dst[i] = ( dst[i] & mask ) | ( src[i] & ~mask );
+        dst[i] = ( src[i] & mask ) | ( dst[i] & ~mask );
 }
 
 /*
@@ -1528,8 +1611,11 @@
 
         if( auth_done == 1 )
         {
-            correct *= ( rec->data_len >= padlen + 1 );
-            padlen  *= ( rec->data_len >= padlen + 1 );
+            const size_t mask = mbedtls_ssl_cf_mask_ge(
+                                rec->data_len,
+                                padlen + 1 );
+            correct &= mask;
+            padlen  &= mask;
         }
         else
         {
@@ -1543,8 +1629,11 @@
             }
 #endif
 
-            correct *= ( rec->data_len >= transform->maclen + padlen + 1 );
-            padlen  *= ( rec->data_len >= transform->maclen + padlen + 1 );
+            const size_t mask = mbedtls_ssl_cf_mask_ge(
+                                rec->data_len,
+                                transform->maclen + padlen + 1 );
+            correct &= mask;
+            padlen  &= mask;
         }
 
         padlen++;
@@ -1555,6 +1644,10 @@
 #if defined(MBEDTLS_SSL_PROTO_SSL3)
         if( transform->minor_ver == MBEDTLS_SSL_MINOR_VERSION_0 )
         {
+            /* This is the SSL 3.0 path, we don't have to worry about Lucky
+             * 13, because there's a strictly worse padding attack built in
+             * the protocol (known as part of POODLE), so we don't care if the
+             * code is not constant-time, in particular branches are OK. */
             if( padlen > transform->ivlen )
             {
 #if defined(MBEDTLS_SSL_DEBUG_ALL)
@@ -1578,7 +1671,6 @@
              * `min(256,plaintext_len)` reads (but take into account
              * only the last `padlen` bytes for the padding check). */
             size_t pad_count = 0;
-            size_t real_count = 0;
             volatile unsigned char* const check = data;
 
             /* Index of first padding byte; it has been ensured above
@@ -1590,16 +1682,21 @@
 
             for( idx = start_idx; idx < rec->data_len; idx++ )
             {
-                real_count |= ( idx >= padding_idx );
-                pad_count += real_count * ( check[idx] == padlen - 1 );
+                /* pad_count += (idx >= padding_idx) &&
+                 *              (check[idx] == padlen - 1);
+                 */
+                const size_t mask = mbedtls_ssl_cf_mask_ge( idx, padding_idx );
+                const size_t equal = mbedtls_ssl_cf_bool_eq( check[idx],
+                                                             padlen - 1 );
+                pad_count += mask & equal;
             }
-            correct &= ( pad_count == padlen );
+            correct &= mbedtls_ssl_cf_bool_eq( pad_count, padlen );
 
 #if defined(MBEDTLS_SSL_DEBUG_ALL)
             if( padlen > 0 && correct == 0 )
                 MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad padding byte detected" ) );
 #endif
-            padlen &= correct * 0x1FF;
+            padlen &= mbedtls_ssl_cf_mask_from_bit( correct );
         }
         else
 #endif /* MBEDTLS_SSL_PROTO_TLS1 || MBEDTLS_SSL_PROTO_TLS1_1 || \