Merge pull request #3532 from AndrzejKurek/fi-hmac-drbg-fixes

Fi-related hmac_drbg fixes
diff --git a/configs/baremetal.h b/configs/baremetal.h
index 5294351..c93f53a 100644
--- a/configs/baremetal.h
+++ b/configs/baremetal.h
@@ -146,6 +146,9 @@
 
 #define MBEDTLS_DEPRECATED_REMOVED
 
+/* Fault Injection Countermeasures */
+#define MBEDTLS_FI_COUNTERMEASURES
+
 #if defined(MBEDTLS_USER_CONFIG_FILE)
 #include MBEDTLS_USER_CONFIG_FILE
 #endif
diff --git a/include/mbedtls/aes.h b/include/mbedtls/aes.h
index 6990be0..cb7d726 100644
--- a/include/mbedtls/aes.h
+++ b/include/mbedtls/aes.h
@@ -87,6 +87,9 @@
 {
     int nr;                     /*!< The number of rounds. */
     uint32_t *rk;               /*!< AES round keys. */
+#if defined(MBEDTLS_AES_SCA_COUNTERMEASURES)
+    uint32_t frk[8];            /*!< Fake AES round keys. */
+#endif
 #if defined(MBEDTLS_AES_ONLY_128_BIT_KEY_LENGTH) && !defined(MBEDTLS_PADLOCK_C)
     uint32_t buf[44];           /*!< Unaligned data buffer */
 #else /* MBEDTLS_AES_ONLY_128_BIT_KEY_LENGTH */
diff --git a/include/mbedtls/config.h b/include/mbedtls/config.h
index 4ee5920..9b88597 100644
--- a/include/mbedtls/config.h
+++ b/include/mbedtls/config.h
@@ -655,6 +655,16 @@
 //#define MBEDTLS_AES_SCA_COUNTERMEASURES
 
 /**
+ * \def MBEDTLS_FI_COUNTERMEASURES
+ *
+ * Add countermeasures against a possible FI attack.
+ *
+ * Uncommenting this macro increases code size and slows performance,
+ * it performs double calls and double result checks of some crypto functions
+ */
+//#define MBEDTLS_FI_COUNTERMEASURES
+
+/**
  * \def MBEDTLS_CAMELLIA_SMALL_MEMORY
  *
  * Use less ROM for the Camellia implementation (saves about 768 bytes).
diff --git a/include/mbedtls/sha256.h b/include/mbedtls/sha256.h
index 6ef2245..42aa988 100644
--- a/include/mbedtls/sha256.h
+++ b/include/mbedtls/sha256.h
@@ -118,7 +118,7 @@
  *                 and have a hash operation started.
  * \param input    The buffer holding the data. This must be a readable
  *                 buffer of length \p ilen Bytes.
- * \param ilen     The length of the input data in Bytes.
+ * \param ilen     The length of the input data in Bytes. At most UINT32_MAX.
  *
  * \return         \c 0 on success.
  * \return         A negative error code on failure.
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index e14f58f..ee231a5 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -1460,6 +1460,10 @@
      *  after an initial handshake. */
     unsigned char own_cid[ MBEDTLS_SSL_CID_IN_LEN_MAX ];
 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
+#if defined(MBEDTLS_FI_COUNTERMEASURES)
+    unsigned char *out_msg_dup;     /*!< out msg ptr duplication  */
+    size_t out_msglen_dup;          /*!< out msg size duplication */
+#endif
 };
 
 #if defined(MBEDTLS_SSL_HW_RECORD_ACCEL)
diff --git a/include/tinycrypt/ecc.h b/include/tinycrypt/ecc.h
index 57aa508..4c20729 100644
--- a/include/tinycrypt/ecc.h
+++ b/include/tinycrypt/ecc.h
@@ -85,7 +85,7 @@
 /* Return values for functions, chosen with large Hamming distances between
  * them (especially to SUCESS) to mitigate the impact of fault injection
  * attacks flipping a low number of bits. */
-#define UECC_SUCCESS            0
+#define UECC_SUCCESS            0x00FFAAAA
 #define UECC_FAILURE            0x75555555
 #define UECC_FAULT_DETECTED     0x7aaaaaaa
 
diff --git a/library/aes.c b/library/aes.c
index e9e7544..e7a888f 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -675,6 +675,18 @@
 }
 #endif /* MBEDTLS_CIPHER_MODE_XTS */
 
+#if defined(MBEDTLS_AES_SCA_COUNTERMEASURES)
+static void mbedtls_generate_fake_key( unsigned int keybits, mbedtls_aes_context *ctx )
+{
+    unsigned int qword;
+
+    for( qword = keybits >> 5; qword > 0; qword-- )
+    {
+        ctx->frk[ qword - 1 ] = mbedtls_platform_random_uint32();
+    }
+}
+#endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
+
 /*
  * AES key schedule (encryption)
  */
@@ -719,6 +731,9 @@
     else
 #endif
     ctx->rk = RK = ctx->buf;
+#if defined(MBEDTLS_AES_SCA_COUNTERMEASURES)
+    mbedtls_generate_fake_key( keybits, ctx );
+#endif
 
 #if defined(MBEDTLS_AESNI_C) && defined(MBEDTLS_HAVE_X86_64)
     if( mbedtls_aesni_has_support( MBEDTLS_AESNI_AES ) )
@@ -858,6 +873,9 @@
     else
 #endif
     ctx->rk = RK = ctx->buf;
+#if defined(MBEDTLS_AES_SCA_COUNTERMEASURES)
+    mbedtls_generate_fake_key( keybits, ctx );
+#endif
 
     /* Also checks keybits */
     if( ( ret = mbedtls_aes_setkey_enc( &cty, key, keybits ) ) != 0 )
@@ -1071,7 +1089,8 @@
     uint8_t round_ctrl_table[( 14 + AES_SCA_CM_ROUNDS + 2 )];
 
     aes_data_real.rk_ptr = ctx->rk;
-    aes_data_fake.rk_ptr = ctx->rk;
+    aes_data_fake.rk_ptr = ctx->frk;
+
     aes_data_table[0] = &aes_data_real;
     aes_data_table[1] = &aes_data_fake;
 
@@ -1351,7 +1370,8 @@
     uint8_t round_ctrl_table[( 14 + AES_SCA_CM_ROUNDS + 2 )];
 
     aes_data_real.rk_ptr = ctx->rk;
-    aes_data_fake.rk_ptr = ctx->rk;
+    aes_data_fake.rk_ptr = ctx->frk;
+
     aes_data_table[0] = &aes_data_real;
     aes_data_table[1] = &aes_data_fake;
 
diff --git a/library/pk.c b/library/pk.c
index b92eb14..fea7576 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -548,6 +548,7 @@
     return( (size_t) ( NUM_ECC_BYTES * 8 ) );
 }
 
+/* This function compares public keys of two keypairs */
 static int uecc_eckey_check_pair( const void *pub, const void *prv )
 {
     const mbedtls_uecc_keypair *uecc_pub =
@@ -621,13 +622,12 @@
 static int asn1_write_mpibuf( unsigned char **p, unsigned char *start,
                               size_t n_len )
 {
-    size_t len = 0;
+    size_t len = n_len;
     int ret = MBEDTLS_ERR_PLATFORM_FAULT_DETECTED;
 
-    if( (size_t)( *p - start ) < n_len )
+    if( (size_t)( *p - start ) < len )
         return( MBEDTLS_ERR_ASN1_BUF_TOO_SMALL );
 
-    len = n_len;
     *p -= len;
     ret = mbedtls_platform_memmove( *p, start, len );
     if( ret != 0 )
@@ -659,6 +659,10 @@
         len += 1;
     }
 
+    /* Ensure that there is still space for len and ASN1_INTEGER */
+    if( ( *p - start ) < 2 )
+        return( MBEDTLS_ERR_ASN1_BUF_TOO_SMALL );
+
     /* The ASN.1 length encoding is just a single Byte containing the length,
      * as we assume that the total buffer length is smaller than 128 Bytes. */
     *--(*p) = len;
@@ -674,7 +678,7 @@
  *
  * [in/out] sig: the signature pre- and post-transcoding
  * [in/out] sig_len: signature length pre- and post-transcoding
- * [int] buf_len: the available size the in/out buffer
+ * [in] buf_len: the available size the in/out buffer
  *
  * Warning: buf_len must be smaller than 128 Bytes.
  */
@@ -689,6 +693,9 @@
     MBEDTLS_ASN1_CHK_ADD( len, asn1_write_mpibuf( &p, sig + rs_len, rs_len ) );
     MBEDTLS_ASN1_CHK_ADD( len, asn1_write_mpibuf( &p, sig, rs_len ) );
 
+    if( p - sig < 2 )
+        return( MBEDTLS_ERR_ASN1_BUF_TOO_SMALL );
+
     /* The ASN.1 length encoding is just a single Byte containing the length,
      * as we assume that the total buffer length is smaller than 128 Bytes. */
     *--p = len;
diff --git a/library/pkparse.c b/library/pkparse.c
index 688082b..411fbaa 100644
--- a/library/pkparse.c
+++ b/library/pkparse.c
@@ -573,8 +573,10 @@
     mbedtls_uecc_keypair *uecc_keypair = (mbedtls_uecc_keypair *) pk_context;
     int ret;
 
-    ret = uecc_public_key_read_binary( uecc_keypair,
-                                       (const unsigned char *) *p, end - *p );
+    if( ( ret = uecc_public_key_read_binary( uecc_keypair,
+                                             (const unsigned char *) *p, end - *p ) )
+            != 0 )
+        return ret;
 
     /*
      * We know uecc_public_key_read_binary consumed all bytes or failed
@@ -1062,7 +1064,7 @@
                                   size_t keylen )
 {
     int ret;
-    int version, pubkey_done;
+    int version, pubkey_done = 0;
     size_t len;
     mbedtls_asn1_buf params;
     unsigned char *p = (unsigned char *) key;
@@ -1104,7 +1106,6 @@
 
     p += len;
 
-    pubkey_done = 0;
     if( p != end )
     {
         /*
diff --git a/library/sha256.c b/library/sha256.c
index 493e88e..5214591 100644
--- a/library/sha256.c
+++ b/library/sha256.c
@@ -35,6 +35,7 @@
 #include "mbedtls/sha256.h"
 #include "mbedtls/platform_util.h"
 #include "mbedtls/platform.h"
+#include <stdint.h>
 
 #include <string.h>
 
@@ -188,7 +189,7 @@
 {
     uint32_t temp1, temp2, W[64];
     uint32_t A[8];
-    uint32_t flow_ctrl = 0;
+    volatile uint32_t flow_ctrl = 0;
     unsigned int i;
 
     SHA256_VALIDATE_RET( ctx != NULL );
@@ -214,11 +215,6 @@
         }
     }
 
-    if( flow_ctrl != 16 )
-    {
-        return MBEDTLS_ERR_PLATFORM_FAULT_DETECTED;
-    }
-
     for( i = 0; i < 64; i++ )
     {
         if( i >= 16 )
@@ -317,19 +313,22 @@
     SHA256_VALIDATE_RET( ctx != NULL );
     SHA256_VALIDATE_RET( ilen == 0 || input != NULL );
 
-    if( ilen == 0 )
+    /* ilen_dup is used instead of ilen, to have it volatile for FI protection */
+    if( ilen_dup == 0 )
         return( 0 );
 
+    if( ilen_dup > UINT32_MAX )
+        return( MBEDTLS_ERR_SHA256_BAD_INPUT_DATA );
+
     left = ctx->total[0] & 0x3F;
     fill = 64 - left;
 
-    ctx->total[0] += (uint32_t) ilen;
-    ctx->total[0] &= 0xFFFFFFFF;
+    ctx->total[0] += (uint32_t) ilen_dup;
 
-    if( ctx->total[0] < (uint32_t) ilen )
+    if( ctx->total[0] < (uint32_t) ilen_dup )
         ctx->total[1]++;
 
-    if( left && ilen >= fill )
+    if( left && ilen_dup >= fill )
     {
         mbedtls_platform_memcpy( (void *) (ctx->buffer + left), input, fill );
 
@@ -337,27 +336,27 @@
             return( ret );
 
         input += fill;
-        ilen  -= fill;
+        ilen_dup  -= fill;
         left = 0;
     }
 
-    while( ilen >= 64 )
+    while( ilen_dup >= 64 )
     {
         if( ( ret = mbedtls_internal_sha256_process( ctx, input ) ) != 0 )
             return( ret );
 
         input += 64;
-        ilen  -= 64;
+        ilen_dup  -= 64;
     }
 
-    if( ilen > 0 )
-        mbedtls_platform_memcpy( (void *) (ctx->buffer + left), input, ilen );
+    if( ilen_dup > 0 )
+        mbedtls_platform_memcpy( (void *) (ctx->buffer + left), input, ilen_dup );
 
-    /* Re-check ilen to protect from a FI attack */
-    if( ilen < 64 )
+    /* Re-check ilen_dup to protect from a FI attack */
+    if( ilen_dup < 64 )
     {
         /* Re-check that the calculated offsets are correct */
-        ilen_change = ilen_dup - ilen;
+        ilen_change = ilen - ilen_dup;
         if( ( input_dup + ilen_change ) == input )
         {
             return( 0 );
@@ -387,7 +386,7 @@
     uint32_t used;
     uint32_t high, low;
     uint32_t offset = 0;
-    uint32_t flow_ctrl = 0;
+    volatile uint32_t flow_ctrl = 0;
 
     SHA256_VALIDATE_RET( ctx != NULL );
     SHA256_VALIDATE_RET( (unsigned char *)output != NULL );
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index bbe94cb..1ee7af0 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -2562,7 +2562,6 @@
     /* Not using more secure mbedtls_platform_memcpy as cid is public */
     memcpy( rec->cid, transform->out_cid, transform->out_cid_len );
     MBEDTLS_SSL_DEBUG_BUF( 3, "CID", rec->cid, rec->cid_len );
-
     if( rec->cid_len != 0 )
     {
         int ret = MBEDTLS_ERR_PLATFORM_FAULT_DETECTED;
@@ -11221,8 +11220,6 @@
 {
     int ret = mbedtls_ssl_get_max_out_record_payload( ssl );
     const size_t max_len = (size_t) ret;
-    volatile const unsigned char *buf_dup = buf;
-    volatile size_t len_dup = len;
 
     if( ret < 0 )
     {
@@ -11245,7 +11242,6 @@
 #if defined(MBEDTLS_SSL_PROTO_TLS)
         {
             len = max_len;
-            len_dup = len;
         }
 #endif
     }
@@ -11271,22 +11267,44 @@
          * copy the data into the internal buffers and setup the data structure
          * to keep track of partial writes
          */
-        ssl->out_msglen  = len;
+        ssl->out_msglen = len;
         ssl->out_msgtype = MBEDTLS_SSL_MSG_APPLICATION_DATA;
         mbedtls_platform_memcpy( ssl->out_msg, buf, len );
 
-        if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 )
-        {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret );
-            return( ret );
+#if defined(MBEDTLS_FI_COUNTERMEASURES) && !defined(MBEDTLS_SSL_CBC_RECORD_SPLITTING)
+        /*
+         * Buffer pointer and size duplication cannot be supported with MBEDTLS_SSL_CBC_RECORD_SPLITTING.
+         * After splitting pointers and data size will not be the same as initaly provides by user.
+         */
+        /* Secure against buffer substitution */
+        if( buf == ssl->out_msg_dup &&
+            ssl->out_msglen == ssl->out_msglen_dup &&
+            ssl->out_msg_dup[0] == ssl->out_msg[0] )
+        {/*write record only if data was copied from correct user pointer */
+#endif
+            if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 )
+            {
+                MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret );
+                return( ret );
+            }
+
+#if defined(MBEDTLS_FI_COUNTERMEASURES) && !defined(MBEDTLS_SSL_CBC_RECORD_SPLITTING)
         }
+        else
+        {
+            return( MBEDTLS_ERR_PLATFORM_FAULT_DETECTED );
+        }
+#endif
     }
-    /* Secure against buffer substitution */
-    if( buf_dup == buf && len_dup == len )
+    if ( ret == 0 )
     {
         return( (int) len );
     }
-    return( MBEDTLS_ERR_PLATFORM_FAULT_DETECTED );
+    else
+    {
+        return( MBEDTLS_ERR_PLATFORM_FAULT_DETECTED );
+    }
+
 }
 
 /*
@@ -11334,10 +11352,15 @@
  */
 int mbedtls_ssl_write( mbedtls_ssl_context *ssl, const unsigned char *buf, size_t len )
 {
-    int ret;
+    int ret = MBEDTLS_ERR_PLATFORM_FAULT_DETECTED;
+#if defined(MBEDTLS_FI_COUNTERMEASURES) && !defined(MBEDTLS_SSL_CBC_RECORD_SPLITTING)
+    /*
+     * Buffer pointer and size duplication cannot be supported with MBEDTLS_SSL_CBC_RECORD_SPLITTING.
+     * After splitting pointers and data size will not be the same as initaly provides by user.
+     */
     volatile const unsigned char *buf_dup = buf;
     volatile size_t len_dup = len;
-
+#endif
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write" ) );
 
     if( ssl == NULL || ssl->conf == NULL )
@@ -11363,17 +11386,19 @@
 #if defined(MBEDTLS_SSL_CBC_RECORD_SPLITTING)
     ret = ssl_write_split( ssl, buf, len );
 #else
+#if defined(MBEDTLS_FI_COUNTERMEASURES)
+    /* Add const user pointers to context. We will be able to check its validity before copy to context */
+    ssl->out_msg_dup = (unsigned char*)buf_dup;
+    ssl->out_msglen_dup = len_dup;
+#endif //MBEDTLS_FI_COUNTERMEASURES
     ret = ssl_write_real( ssl, buf, len );
 #endif
 
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= write" ) );
 
-    /* Secure against buffer substitution */
-    if( buf_dup == buf && len_dup == len )
-    {
-        return( ret );
-    }
-    return( MBEDTLS_ERR_PLATFORM_FAULT_DETECTED );
+
+    return( ret );
+
 }
 
 /*
diff --git a/library/version_features.c b/library/version_features.c
index d60758c..38a7cee 100644
--- a/library/version_features.c
+++ b/library/version_features.c
@@ -273,6 +273,9 @@
 #if defined(MBEDTLS_AES_SCA_COUNTERMEASURES)
     "MBEDTLS_AES_SCA_COUNTERMEASURES",
 #endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
+#if defined(MBEDTLS_FI_COUNTERMEASURES)
+    "MBEDTLS_FI_COUNTERMEASURES",
+#endif /* MBEDTLS_FI_COUNTERMEASURES */
 #if defined(MBEDTLS_CAMELLIA_SMALL_MEMORY)
     "MBEDTLS_CAMELLIA_SMALL_MEMORY",
 #endif /* MBEDTLS_CAMELLIA_SMALL_MEMORY */
diff --git a/library/x509.c b/library/x509.c
index 093a315..65f2ec6 100644
--- a/library/x509.c
+++ b/library/x509.c
@@ -176,7 +176,7 @@
         return( MBEDTLS_ERR_X509_INVALID_ALG +
                 MBEDTLS_ERR_ASN1_UNEXPECTED_TAG );
 
-    p = (unsigned char *) alg->p;
+    p = alg->p;
     end = p + alg->len;
 
     if( p >= end )
diff --git a/programs/ssl/query_config.c b/programs/ssl/query_config.c
index 8093c0d..8db6d22 100644
--- a/programs/ssl/query_config.c
+++ b/programs/ssl/query_config.c
@@ -770,6 +770,14 @@
     }
 #endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
 
+#if defined(MBEDTLS_FI_COUNTERMEASURES)
+    if( strcmp( "MBEDTLS_FI_COUNTERMEASURES", config ) == 0 )
+    {
+        MACRO_EXPANSION_TO_STR( MBEDTLS_FI_COUNTERMEASURES );
+        return( 0 );
+    }
+#endif /* MBEDTLS_FI_COUNTERMEASURES */
+
 #if defined(MBEDTLS_CAMELLIA_SMALL_MEMORY)
     if( strcmp( "MBEDTLS_CAMELLIA_SMALL_MEMORY", config ) == 0 )
     {