masked-aes CI problems fixes

Signed-off-by: Shelly Liberman <shelly.liberman@arm.com>
diff --git a/library/aes.c b/library/aes.c
index ea6a69d..f01c738 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -805,7 +805,7 @@
                 RK[5]  = RK[1] ^ RK[4];
                 RK[6]  = RK[2] ^ RK[5];
                 RK[7]  = RK[3] ^ RK[6];
-            }                
+            }
             break;
 #if !defined(MBEDTLS_AES_ONLY_128_BIT_KEY_LENGTH)
         case 12:
@@ -1197,7 +1197,7 @@
   volatile int flow_control = 0;
   unsigned int i = 0;
 
-  mbedtls_platform_memcpy(rk_masked, rk, AES_128_EXPANDED_KEY_SIZE_IN_WORDS*4);
+  mbedtls_platform_memcpy(rk_masked, rk, MBEDTLS_AES_128_EXPANDED_KEY_SIZE_IN_WORDS*4);
 
 
   //Randomly generate the masks: m1 m2 m3 m4 m m'
@@ -1213,7 +1213,7 @@
 
   //Calculate the masked Sbox
   if (calcSboxMasked(mask, sbox_masked) == 0){
-      flow_control++;  
+      flow_control++;
   }
 
 #define MASK_INIT_CONTROL 19
@@ -1386,7 +1386,7 @@
     uint8_t round_ctrl_table[( 14 + AES_SCA_CM_ROUNDS + 2 )];
 
 #if defined MBEDTLS_AES_128_BIT_MASKED
-    uint32_t rk_masked[AES_128_EXPANDED_KEY_SIZE_IN_WORDS] = {0};
+    uint32_t rk_masked[MBEDTLS_AES_128_EXPANDED_KEY_SIZE_IN_WORDS] = {0};
     static uint8_t sbox_masked[256] = {0};
     uint32_t mask[10] = {0};
 #endif
@@ -1495,7 +1495,7 @@
             aes_data_ptr->xy_values[6 - offset],
             aes_data_ptr->xy_values[7 - offset] );
         flow_control++;
-#endif        
+#endif
         tindex++;
 
     } while( stop_mark == 0 );
@@ -1510,7 +1510,7 @@
                                      aes_data_ptr->rk_ptr, sbox_masked ) == 0)
             flow_control++;
         //Cleanup the masked key
-        mbedtls_platform_memset(rk_masked, 0, sizeof(rk_masked));              
+        mbedtls_platform_memset(rk_masked, 0, sizeof(rk_masked));
 #else
         aes_fround_final( aes_data_ptr->rk_ptr,
             &aes_data_ptr->xy_values[0],
diff --git a/library/ccm.c b/library/ccm.c
index e54a995..aa15af2 100644
--- a/library/ccm.c
+++ b/library/ccm.c
@@ -114,6 +114,41 @@
     mbedtls_platform_zeroize( ctx, sizeof( mbedtls_ccm_context ) );
 }
 
+/* Durstenfeld's version of Fisher-Yates shuffle */
+static void mbedtls_generate_permutation( unsigned char* table, size_t  size )
+{
+    size_t i, j;
+
+    for( i = 0; i < size; i++ )
+    {
+        table[i] = (unsigned char) i;
+    }
+
+    if( size < 2 )
+    {
+        return;
+    }
+
+    for( i = size - 1; i > 0; i-- )
+    {
+        unsigned char tmp;
+        j = mbedtls_platform_random_in_range( (uint32_t) i + 1 );
+        tmp = table[i];
+        table[i] = table[j];
+        table[j] = tmp;
+    }
+}
+
+static void mbedtls_generate_masks( unsigned char* table, size_t  size )
+{
+    size_t i;
+
+    for( i = 0; i < size; i++ )
+    {
+        table[i] = mbedtls_platform_random_in_range( 256 );
+    }
+}
+
 /*
  * Macros for common operations.
  * Results in smaller compiled code than static inline functions.
@@ -122,30 +157,55 @@
 /*
  * Update the CBC-MAC state in y using a block in b
  * (Always using b as the source helps the compiler optimise a bit better.)
+ * Initial b masking happens outside of this macro due to various sources of it.
  */
 #define UPDATE_CBC_MAC                                                      \
     for( i = 0; i < 16; i++ )                                               \
-        y[i] ^= b[i];                                                       \
+    {                                                                       \
+        y[perm_table[i]] ^= b[perm_table[i]];                               \
+        y[perm_table[i]] ^= mask_table[perm_table[i]];                      \
+    }                                                                       \
                                                                             \
     if( ( ret = mbedtls_cipher_update( &ctx->cipher_ctx, y, 16, y, &olen ) ) != 0 ) \
         return( ret );
 
 /*
+ * Copy src to dst starting at a random offset, while masking the whole dst buffer.
+ */
+#define COPY_MASK( dst, src, mask, len_src, len_dst )                   \
+    do                                                                  \
+    {                                                                   \
+        unsigned j, offset = mbedtls_platform_random_in_range( 256 );   \
+        for( i = 0; i < len_src; i++ )                                  \
+        {                                                               \
+            j = (i + offset) % len_src;                                 \
+            (dst)[j] = (src)[j] ^ (mask)[j];                            \
+        }                                                               \
+        for( ; i < len_dst; i++ )                                       \
+            (dst)[i] ^= (mask)[i];                                      \
+    } while( 0 )
+/*
  * Encrypt or decrypt a partial block with CTR
  * Warning: using b for temporary storage! src and dst must not be b!
  * This avoids allocating one more 16 bytes buffer while allowing src == dst.
  */
-#define CTR_CRYPT( dst, src, len  )                                            \
+#define CTR_CRYPT( dst, src, len )                                      \
     do                                                                  \
     {                                                                   \
+        mbedtls_generate_permutation( perm_table, len );                \
+        mbedtls_generate_masks( mask_table, len );                      \
         if( ( ret = mbedtls_cipher_update( &ctx->cipher_ctx, ctr,       \
                                            16, b, &olen ) ) != 0 )      \
         {                                                               \
             return( ret );                                              \
         }                                                               \
                                                                         \
-        for( i = 0; i < (len); i++ )                                    \
-            (dst)[i] = (src)[i] ^ b[i];                                 \
+        for( i = 0; i < len; i++ )                                      \
+        {                                                               \
+            (dst)[perm_table[i]] = (src)[perm_table[i]] ^ mask_table[perm_table[i]];\
+            (dst)[perm_table[i]] ^= b[perm_table[i]];                   \
+            (dst)[perm_table[i]] ^= mask_table[perm_table[i]];          \
+        }                                                               \
     } while( 0 )
 
 /*
@@ -164,6 +224,8 @@
     unsigned char b[16];
     unsigned char y[16];
     unsigned char ctr[16];
+    unsigned char perm_table[16];
+    unsigned char mask_table[16];
     const unsigned char *src;
     unsigned char *dst;
 
@@ -184,6 +246,10 @@
     if( add_len > 0xFF00 )
         return( MBEDTLS_ERR_CCM_BAD_INPUT );
 
+    mbedtls_platform_zeroize( b, 16 );
+    mbedtls_platform_zeroize( y, 16 );
+    mbedtls_platform_zeroize( ctr, 16 );
+
     q = (uint_fast8_t) (16 - 1 - iv_len);
 
     /*
@@ -198,15 +264,16 @@
      * 5 .. 3   (t - 2) / 2
      * 2 .. 0   q - 1
      */
-    b[0] = 0;
-    b[0] |= ( add_len > 0 ) << 6;
-    b[0] |= ( ( tag_len - 2 ) / 2 ) << 3;
-    b[0] |= q - 1;
+    mbedtls_generate_masks( mask_table, 16 );
+    mbedtls_generate_permutation( perm_table, 16 );
+    b[0] = (unsigned char) ( ( ( add_len > 0 ) << 6 ) |
+                           ( ( ( tag_len - 2 ) / 2 ) << 3 ) |
+                           ( q - 1 ) ) ^ mask_table[0];
 
-    mbedtls_platform_memcpy( b + 1, iv, iv_len );
-
+    for( i = 0; i < iv_len; i++ )
+        b[i+1] = iv[i] ^ mask_table[i+1];
     for( i = 0, len_left = length; i < q; i++, len_left >>= 8 )
-        b[15-i] = (unsigned char)( len_left & 0xFF );
+        b[15-i] = (unsigned char)( ( len_left & 0xFF ) ) ^ mask_table[15-i];
 
     if( len_left > 0 )
         return( MBEDTLS_ERR_CCM_BAD_INPUT );
@@ -226,12 +293,16 @@
         len_left = add_len;
         src = add;
 
+        mbedtls_generate_masks( mask_table, 16 );
+        mbedtls_generate_permutation( perm_table, 16 );
         mbedtls_platform_memset( b, 0, 16 );
-        b[0] = (unsigned char)( ( add_len >> 8 ) & 0xFF );
-        b[1] = (unsigned char)( ( add_len      ) & 0xFF );
+        b[0] = (unsigned char)( ( ( add_len >> 8 ) & 0xFF ) ^ mask_table[0] );
+        b[1] = (unsigned char)( ( ( add_len      ) & 0xFF ) ^ mask_table[1] );
 
         use_len = len_left < 16 - 2 ? len_left : 16 - 2;
-        mbedtls_platform_memcpy( b + 2, src, use_len );
+
+        COPY_MASK( b+2, src, mask_table+2, use_len, 14 );
+
         len_left -= use_len;
         src += use_len;
 
@@ -239,10 +310,12 @@
 
         while( len_left > 0 )
         {
+            mbedtls_generate_masks( mask_table, 16 );
+            mbedtls_generate_permutation( perm_table, 16 );
             use_len = len_left > 16 ? 16 : len_left;
 
             mbedtls_platform_memset( b, 0, 16 );
-            mbedtls_platform_memcpy( b, src, use_len );
+            COPY_MASK( b, src, mask_table, use_len, 16);
             UPDATE_CBC_MAC;
 
             len_left -= use_len;
@@ -281,8 +354,10 @@
 
         if( mode == CCM_ENCRYPT )
         {
+            mbedtls_generate_masks( mask_table, 16 );
+            mbedtls_generate_permutation( perm_table, 16 );
             mbedtls_platform_memset( b, 0, 16 );
-            mbedtls_platform_memcpy( b, src, use_len );
+            COPY_MASK( b, src, mask_table, use_len, 16 );
             UPDATE_CBC_MAC;
         }
 
@@ -290,8 +365,10 @@
 
         if( mode == CCM_DECRYPT )
         {
+            mbedtls_generate_masks( mask_table, 16 );
+            mbedtls_generate_permutation( perm_table, 16 );
             mbedtls_platform_memset( b, 0, 16 );
-            mbedtls_platform_memcpy( b, dst, use_len );
+            COPY_MASK( b, dst, mask_table, use_len, 16 );
             UPDATE_CBC_MAC;
         }
 
@@ -317,6 +394,10 @@
     CTR_CRYPT( y, y, 16 );
     mbedtls_platform_memcpy( tag, y, tag_len );
 
+    mbedtls_platform_zeroize( b, 16 );
+    mbedtls_platform_zeroize( y, 16 );
+    mbedtls_platform_zeroize( ctr, 16 );
+
     return( ret );
 }
 
diff --git a/library/platform_util.c b/library/platform_util.c
index 15309aa..ecfdb84 100644
--- a/library/platform_util.c
+++ b/library/platform_util.c
@@ -48,6 +48,12 @@
 #include "mbedtls/entropy_poll.h"
 #endif
 
+#if defined(MBEDTLS_PLATFORM_FAULT_CALLBACKS)
+#include "platform_fault.h"
+#else
+static void mbedtls_platform_fault(){}
+#endif
+
 #include <stddef.h>
 #include <string.h>
 
@@ -119,43 +125,45 @@
 
 void *mbedtls_platform_memset( void *ptr, int value, size_t num )
 {
-    size_t i, start_offset;
+    size_t i, start_offset = 0;
     volatile size_t flow_counter = 0;
     volatile char *b = ptr;
     char rnd_data;
-
-    start_offset = (size_t) mbedtls_platform_random_in_range( (uint32_t) num );
-    rnd_data = (char) mbedtls_platform_random_in_range( 256 );
-
-    /* Perform a memset operations with random data and start from a random
-     * location */
-    for( i = start_offset; i < num; ++i )
+    if( num > 0 )
     {
-        b[i] = rnd_data;
-        flow_counter++;
-    }
+        start_offset = (size_t) mbedtls_platform_random_in_range( (uint32_t) num );
 
-    /* Start from a random location with target data */
-    for( i = start_offset; i < num; ++i )
-    {
-        b[i] = value;
-        flow_counter++;
-    }
+        rnd_data = (char) mbedtls_platform_random_in_range( 256 );
 
-    /* Second memset operation with random data */
-    for( i = 0; i < start_offset; ++i )
-    {
-        b[i] = rnd_data;
-        flow_counter++;
-    }
+        /* Perform a memset operations with random data and start from a random
+         * location */
+        for( i = start_offset; i < num; ++i )
+        {
+            b[i] = rnd_data;
+            flow_counter++;
+        }
 
-    /* Finish memset operation with correct data */
-    for( i = 0; i < start_offset; ++i )
-    {
-        b[i] = value;
-        flow_counter++;
-    }
+        /* Start from a random location with target data */
+        for( i = start_offset; i < num; ++i )
+        {
+            b[i] = value;
+            flow_counter++;
+        }
 
+        /* Second memset operation with random data */
+        for( i = 0; i < start_offset; ++i )
+        {
+            b[i] = rnd_data;
+            flow_counter++;
+        }
+
+        /* Finish memset operation with correct data */
+        for( i = 0; i < start_offset; ++i )
+        {
+            b[i] = value;
+            flow_counter++;
+        }
+    }
     /* check the correct number of iterations */
     if( flow_counter == 2 * num )
     {
@@ -165,6 +173,7 @@
             return ptr;
         }
     }
+    mbedtls_platform_fault();
     return NULL;
 }
 
@@ -204,6 +213,7 @@
             return dst;
         }
     }
+    mbedtls_platform_fault();
     return NULL;
 }
 
@@ -245,22 +255,25 @@
 
     /* Start from a random location and check the correct number of iterations */
     size_t i, flow_counter = 0;
-    size_t start_offset = (size_t) mbedtls_platform_random_in_range( (uint32_t) num );
-
-    for( i = start_offset; i < num; i++ )
+    size_t start_offset = 0;
+    if( num > 0 )
     {
-        unsigned char x = A[i], y = B[i];
-        flow_counter++;
-        diff |= x ^ y;
-    }
+        start_offset = (size_t) mbedtls_platform_random_in_range( (uint32_t) num );
 
-    for( i = 0; i < start_offset; i++ )
-    {
-        unsigned char x = A[i], y = B[i];
-        flow_counter++;
-        diff |= x ^ y;
-    }
+        for( i = start_offset; i < num; i++ )
+        {
+            unsigned char x = A[i], y = B[i];
+            flow_counter++;
+            diff |= x ^ y;
+        }
 
+        for( i = 0; i < start_offset; i++ )
+        {
+            unsigned char x = A[i], y = B[i];
+            flow_counter++;
+            diff |= x ^ y;
+        }
+    }
     /* Return 0 only when diff is 0 and flow_counter is equal to num */
     return( (int) diff | (int) ( flow_counter ^ num ) );
 }
@@ -340,18 +353,7 @@
 
 uint32_t mbedtls_platform_random_in_range( uint32_t num )
 {
-    uint32_t result;
-
-    if( num <= 1 )
-    {
-        result = 0;
-    }
-    else
-    {
-        result = mbedtls_platform_random_uint32() % num;
-    }
-
-    return( result );
+    return mbedtls_platform_random_uint32() % num;
 }
 
 void mbedtls_platform_random_delay( void )
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index b0dabf2..a33760f 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -933,8 +933,8 @@
 {
     size_t nb;
     size_t i, j, k, md_len;
-    unsigned char tmp[128];
-    unsigned char h_i[MBEDTLS_MD_MAX_SIZE];
+    unsigned char tmp[128] = {0};
+    unsigned char h_i[MBEDTLS_MD_MAX_SIZE] = {0};
     mbedtls_md_handle_t md_info;
     mbedtls_md_context_t md_ctx;
     int ret;
@@ -12709,6 +12709,7 @@
 #endif
         mbedtls_platform_zeroize( ssl->out_buf, out_buf_len );
         mbedtls_free( ssl->out_buf );
+        ssl->out_buf = NULL;
     }
 
     if( ssl->in_buf != NULL )
@@ -12720,6 +12721,7 @@
 #endif
         mbedtls_platform_zeroize( ssl->in_buf, in_buf_len );
         mbedtls_free( ssl->in_buf );
+        ssl->in_buf = NULL;
     }
 
 #if defined(MBEDTLS_ZLIB_SUPPORT)
diff --git a/library/version_features.c b/library/version_features.c
index b8f1d26..beccd3f 100644
--- a/library/version_features.c
+++ b/library/version_features.c
@@ -273,12 +273,12 @@
 #if defined(MBEDTLS_AES_SCA_COUNTERMEASURES)
     "MBEDTLS_AES_SCA_COUNTERMEASURES",
 #endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
-#if defined(MBEDTLS_FI_COUNTERMEASURES)
-    "MBEDTLS_FI_COUNTERMEASURES",
-#endif /* MBEDTLS_FI_COUNTERMEASURES */
 #if defined(MBEDTLS_AES_128_BIT_MASKED)
     "MBEDTLS_AES_128_BIT_MASKED",
 #endif /* MBEDTLS_AES_128_BIT_MASKED */
+#if defined(MBEDTLS_FI_COUNTERMEASURES)
+    "MBEDTLS_FI_COUNTERMEASURES",
+#endif /* MBEDTLS_FI_COUNTERMEASURES */
 #if defined(MBEDTLS_CAMELLIA_SMALL_MEMORY)
     "MBEDTLS_CAMELLIA_SMALL_MEMORY",
 #endif /* MBEDTLS_CAMELLIA_SMALL_MEMORY */
@@ -732,6 +732,9 @@
 #if defined(MBEDTLS_MEMORY_BUFFER_ALLOC_C)
     "MBEDTLS_MEMORY_BUFFER_ALLOC_C",
 #endif /* MBEDTLS_MEMORY_BUFFER_ALLOC_C */
+#if defined(MBEDTLS_PLATFORM_FAULT_CALLBACKS)
+    "MBEDTLS_PLATFORM_FAULT_CALLBACKS",
+#endif /* MBEDTLS_PLATFORM_FAULT_CALLBACKS */
 #if defined(MBEDTLS_NET_C)
     "MBEDTLS_NET_C",
 #endif /* MBEDTLS_NET_C */