Merge pull request #7579 from daverodgman/safer-ct-asm

Arm assembly implementation of constant time primitives
diff --git a/library/constant_time.c b/library/constant_time.c
index 12aed13..d3c69cf 100644
--- a/library/constant_time.c
+++ b/library/constant_time.c
@@ -150,8 +150,13 @@
                           const unsigned char *src2,
                           size_t len)
 {
+#if defined(MBEDTLS_CT_SIZE_64)
+    const uint64_t mask     = (uint64_t) condition;
+    const uint64_t not_mask = (uint64_t) ~mbedtls_ct_compiler_opaque(condition);
+#else
     const uint32_t mask     = (uint32_t) condition;
     const uint32_t not_mask = (uint32_t) ~mbedtls_ct_compiler_opaque(condition);
+#endif
 
     /* If src2 is NULL, setup src2 so that we read from the destination address.
      *
@@ -165,11 +170,19 @@
     /* dest[i] = c1 == c2 ? src[i] : dest[i] */
     size_t i = 0;
 #if defined(MBEDTLS_EFFICIENT_UNALIGNED_ACCESS)
+#if defined(MBEDTLS_CT_SIZE_64)
+    for (; (i + 8) <= len; i += 8) {
+        uint64_t a = mbedtls_get_unaligned_uint64(src1 + i) & mask;
+        uint64_t b = mbedtls_get_unaligned_uint64(src2 + i) & not_mask;
+        mbedtls_put_unaligned_uint64(dest + i, a | b);
+    }
+#else
     for (; (i + 4) <= len; i += 4) {
         uint32_t a = mbedtls_get_unaligned_uint32(src1 + i) & mask;
         uint32_t b = mbedtls_get_unaligned_uint32(src2 + i) & not_mask;
         mbedtls_put_unaligned_uint32(dest + i, a | b);
     }
+#endif /* defined(MBEDTLS_CT_SIZE_64) */
 #endif /* MBEDTLS_EFFICIENT_UNALIGNED_ACCESS */
     for (; i < len; i++) {
         dest[i] = (src1[i] & mask) | (src2[i] & not_mask);
diff --git a/library/constant_time_impl.h b/library/constant_time_impl.h
index b251a66..8da15a8 100644
--- a/library/constant_time_impl.h
+++ b/library/constant_time_impl.h
@@ -48,8 +48,14 @@
     #pragma GCC diagnostic ignored "-Wredundant-decls"
 #endif
 
-/* Disable asm under Memsan because it confuses Memsan and generates false errors */
-#if defined(MBEDTLS_TEST_CONSTANT_FLOW_MEMSAN)
+/* Disable asm under Memsan because it confuses Memsan and generates false errors.
+ *
+ * We also disable under Valgrind by default, because it's more useful
+ * for Valgrind to test the plain C implementation. MBEDTLS_TEST_CONSTANT_FLOW_ASM //no-check-names
+ * may be set to permit building asm under Valgrind.
+ */
+#if defined(MBEDTLS_TEST_CONSTANT_FLOW_MEMSAN) || \
+    (defined(MBEDTLS_TEST_CONSTANT_FLOW_VALGRIND) && !defined(MBEDTLS_TEST_CONSTANT_FLOW_ASM)) //no-check-names
 #define MBEDTLS_CT_NO_ASM
 #elif defined(__has_feature)
 #if __has_feature(memory_sanitizer)
@@ -109,6 +115,28 @@
 #endif
 }
 
+/*
+ * Selecting unified syntax is needed for gcc, and harmless on clang.
+ *
+ * This is needed because on Thumb 1, condition flags are always set, so
+ * e.g. "negs" is supported but "neg" is not (on Thumb 2, both exist).
+ *
+ * Under Thumb 1 unified syntax, only the "negs" form is accepted, and
+ * under divided syntax, only the "neg" form is accepted. clang only
+ * supports unified syntax.
+ *
+ * On Thumb 2 and Arm, both compilers are happy with the "s" suffix,
+ * although we don't actually care about setting the flags.
+ *
+ * For gcc, restore divided syntax afterwards - otherwise old versions of gcc
+ * seem to apply unified syntax globally, which breaks other asm code.
+ */
+#if !defined(__clang__)
+#define RESTORE_ASM_SYNTAX  ".syntax divided             \n\t"
+#else
+#define RESTORE_ASM_SYNTAX
+#endif
+
 /* Convert a number into a condition in constant time. */
 static inline mbedtls_ct_condition_t mbedtls_ct_bool(mbedtls_ct_uint_t x)
 {
@@ -120,6 +148,34 @@
      * Otherwise, we define a plain C fallback which (in May 2023) does not get optimised into
      * conditional instructions or branches by trunk clang, gcc, or MSVC v19.
      */
+#if defined(MBEDTLS_CT_AARCH64_ASM) && (defined(MBEDTLS_CT_SIZE_32) || defined(MBEDTLS_CT_SIZE_64))
+    mbedtls_ct_uint_t s;
+    asm volatile ("neg %x[s], %x[x]                     \n\t"
+                  "orr %x[x], %x[s], %x[x]              \n\t"
+                  "asr %x[x], %x[x], 63"
+                  :
+                  [s] "=&r" (s),
+                  [x] "+&r" (x)
+                  :
+                  :
+                  );
+    return (mbedtls_ct_condition_t) x;
+#elif defined(MBEDTLS_CT_ARM_ASM) && defined(MBEDTLS_CT_SIZE_32)
+    uint32_t s;
+    asm volatile (".syntax unified                       \n\t"
+                  "negs %[s], %[x]                       \n\t"
+                  "orrs %[x], %[x], %[s]                 \n\t"
+                  "asrs %[x], %[x], #31                  \n\t"
+                  RESTORE_ASM_SYNTAX
+                  :
+                  [s] "=&l" (s),
+                  [x] "+&l" (x)
+                  :
+                  :
+                  "cc" /* clobbers flag bits */
+                  );
+    return (mbedtls_ct_condition_t) x;
+#else
     const mbedtls_ct_uint_t xo = mbedtls_ct_compiler_opaque(x);
 #if defined(_MSC_VER)
     /* MSVC has a warning about unary minus on unsigned, but this is
@@ -127,24 +183,98 @@
 #pragma warning( push )
 #pragma warning( disable : 4146 )
 #endif
-    return (mbedtls_ct_condition_t) (((mbedtls_ct_int_t) ((-xo) | -(xo >> 1))) >>
-                                     (MBEDTLS_CT_SIZE - 1));
+    // y is negative (i.e., top bit set) iff x is non-zero
+    mbedtls_ct_int_t y = (-xo) | -(xo >> 1);
+
+    // extract only the sign bit of y so that y == 1 (if x is non-zero) or 0 (if x is zero)
+    y = (((mbedtls_ct_uint_t) y) >> (MBEDTLS_CT_SIZE - 1));
+
+    // -y has all bits set (if x is non-zero), or all bits clear (if x is zero)
+    return (mbedtls_ct_condition_t) (-y);
 #if defined(_MSC_VER)
 #pragma warning( pop )
 #endif
+#endif
 }
 
 static inline mbedtls_ct_uint_t mbedtls_ct_if(mbedtls_ct_condition_t condition,
                                               mbedtls_ct_uint_t if1,
                                               mbedtls_ct_uint_t if0)
 {
+#if defined(MBEDTLS_CT_AARCH64_ASM) && (defined(MBEDTLS_CT_SIZE_32) || defined(MBEDTLS_CT_SIZE_64))
+    asm volatile ("and %x[if1], %x[if1], %x[condition]       \n\t"
+                  "mvn %x[condition], %x[condition]          \n\t"
+                  "and %x[condition], %x[condition], %x[if0] \n\t"
+                  "orr %x[condition], %x[if1], %x[condition]"
+                  :
+                  [condition] "+&r" (condition),
+                  [if1] "+&r" (if1)
+                  :
+                  [if0] "r" (if0)
+                  :
+                  );
+    return (mbedtls_ct_uint_t) condition;
+#elif defined(MBEDTLS_CT_ARM_ASM) && defined(MBEDTLS_CT_SIZE_32)
+    asm volatile (".syntax unified                           \n\t"
+                  "ands %[if1], %[if1], %[condition]         \n\t"
+                  "mvns %[condition], %[condition]           \n\t"
+                  "ands %[condition], %[condition], %[if0]   \n\t"
+                  "orrs %[condition], %[if1], %[condition]   \n\t"
+                  RESTORE_ASM_SYNTAX
+                  :
+                  [condition] "+&l" (condition),
+                  [if1] "+&l" (if1)
+                  :
+                  [if0] "l" (if0)
+                  :
+                  "cc"
+                  );
+    return (mbedtls_ct_uint_t) condition;
+#else
     mbedtls_ct_condition_t not_cond =
         (mbedtls_ct_condition_t) (~mbedtls_ct_compiler_opaque(condition));
     return (mbedtls_ct_uint_t) ((condition & if1) | (not_cond & if0));
+#endif
 }
 
 static inline mbedtls_ct_condition_t mbedtls_ct_uint_lt(mbedtls_ct_uint_t x, mbedtls_ct_uint_t y)
 {
+#if defined(MBEDTLS_CT_AARCH64_ASM) && (defined(MBEDTLS_CT_SIZE_32) || defined(MBEDTLS_CT_SIZE_64))
+    uint64_t s1;
+    asm volatile ("eor     %x[s1], %x[y], %x[x]          \n\t"
+                  "sub     %x[x], %x[x], %x[y]           \n\t"
+                  "bic     %x[x], %x[x], %x[s1]          \n\t"
+                  "and     %x[s1], %x[s1], %x[y]         \n\t"
+                  "orr     %x[s1], %x[x], %x[s1]         \n\t"
+                  "asr     %x[x], %x[s1], 63"
+                  : [s1] "=&r" (s1), [x] "+&r" (x)
+                  : [y] "r" (y)
+                  :
+                  );
+    return (mbedtls_ct_condition_t) x;
+#elif defined(MBEDTLS_CT_ARM_ASM) && defined(MBEDTLS_CT_SIZE_32)
+    uint32_t s1;
+    asm volatile (
+        ".syntax unified                    \n\t"
+#if defined(__thumb__) && !defined(__thumb2__)
+        "movs     %[s1], %[x]               \n\t"
+        "eors     %[s1], %[s1], %[y]        \n\t"
+#else
+        "eors     %[s1], %[x], %[y]         \n\t"
+#endif
+        "subs    %[x], %[x], %[y]           \n\t"
+        "bics    %[x], %[x], %[s1]          \n\t"
+        "ands    %[y], %[s1], %[y]          \n\t"
+        "orrs    %[x], %[x], %[y]           \n\t"
+        "asrs    %[x], %[x], #31            \n\t"
+        RESTORE_ASM_SYNTAX
+        : [s1] "=&l" (s1), [x] "+&l" (x),  [y] "+&l" (y)
+        :
+        :
+        "cc"
+        );
+    return (mbedtls_ct_condition_t) x;
+#else
     /* Ensure that the compiler cannot optimise the following operations over x and y,
      * even if it knows the value of x and y.
      */
@@ -173,6 +303,7 @@
 
     // Convert to a condition (i.e., all bits set iff non-zero)
     return mbedtls_ct_bool(ret);
+#endif
 }
 
 static inline mbedtls_ct_condition_t mbedtls_ct_uint_ne(mbedtls_ct_uint_t x, mbedtls_ct_uint_t y)
diff --git a/library/constant_time_internal.h b/library/constant_time_internal.h
index dabf720..44b74ae 100644
--- a/library/constant_time_internal.h
+++ b/library/constant_time_internal.h
@@ -85,12 +85,14 @@
 typedef uint64_t  mbedtls_ct_condition_t;
 typedef uint64_t  mbedtls_ct_uint_t;
 typedef int64_t   mbedtls_ct_int_t;
+#define MBEDTLS_CT_SIZE_64
 #define MBEDTLS_CT_TRUE  ((mbedtls_ct_condition_t) mbedtls_ct_compiler_opaque(UINT64_MAX))
 #else
 /* Pointer size <= 32-bit, and no 64-bit MPIs */
 typedef uint32_t  mbedtls_ct_condition_t;
 typedef uint32_t  mbedtls_ct_uint_t;
 typedef int32_t   mbedtls_ct_int_t;
+#define MBEDTLS_CT_SIZE_32
 #define MBEDTLS_CT_TRUE  ((mbedtls_ct_condition_t) mbedtls_ct_compiler_opaque(UINT32_MAX))
 #endif
 #define MBEDTLS_CT_FALSE ((mbedtls_ct_condition_t) mbedtls_ct_compiler_opaque(0))
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index 8e978ac..c3c1275 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -1872,6 +1872,16 @@
     export SKIP_TEST_SUITES
 }
 
+skip_all_except_given_suite () {
+    # Skip all but the given test suite
+    SKIP_TEST_SUITES=$(
+        ls -1 tests/suites/test_suite_*.function |
+        grep -v $1.function |
+         sed 's/tests.suites.test_suite_//; s/\.function$//' |
+        tr '\n' ,)
+    export SKIP_TEST_SUITES
+}
+
 component_test_memsan_constant_flow () {
     # This tests both (1) accesses to undefined memory, and (2) branches or
     # memory access depending on secret values. To distinguish between those:
@@ -1931,6 +1941,16 @@
     # details are left in Testing/<date>/DynamicAnalysis.xml
     msg "test: some suites (full minus MBEDTLS_USE_PSA_CRYPTO, valgrind + constant flow)"
     make memcheck
+
+    # Test asm path in constant time module - by default, it will test the plain C
+    # path under Valgrind or Memsan. Running only the constant_time tests is fast (<1s)
+    msg "test: valgrind asm constant_time"
+    scripts/config.py --force set MBEDTLS_TEST_CONSTANT_FLOW_ASM
+    skip_all_except_given_suite test_suite_constant_time
+    cmake -D CMAKE_BUILD_TYPE:String=Release .
+    make clean
+    make
+    make memcheck
 }
 
 component_test_valgrind_constant_flow_psa () {