Merge enc/dec cipher contexts in ssl transforms

Store the raw encryption and decryption keys in transforms
to set them before each cipher operation. Add a config option
for this - MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS.
Signed-off-by: Andrzej Kurek <andrzej.kurek@arm.com>
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 4ebfb5c..2c363fd 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1563,7 +1563,21 @@
                                   iv_copy_len );
     }
 #endif
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+    if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx,
+                                 cipher_info ) ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup", ret );
+        return( ret );
+    }
+    transform->key_enc = mbedtls_calloc( 1, cipher_info->key_bitlen >> 3 );
+    transform->key_dec = mbedtls_calloc( 1, cipher_info->key_bitlen >> 3 );
 
+    memcpy( transform->key_enc, key1, cipher_info->key_bitlen >> 3 );
+    memcpy( transform->key_dec, key2, cipher_info->key_bitlen >> 3 );
+
+    transform->key_bitlen = cipher_info->key_bitlen;
+#else
     if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_enc,
                                  cipher_info ) ) != 0 )
     {
@@ -1593,10 +1607,18 @@
         MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
         return( ret );
     }
-
+#endif
 #if defined(MBEDTLS_CIPHER_MODE_CBC)
     if( cipher_info->mode == MBEDTLS_MODE_CBC )
     {
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+        if( ( ret = mbedtls_cipher_set_padding_mode( &transform->cipher_ctx,
+                                             MBEDTLS_PADDING_NONE ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_set_padding_mode", ret );
+            return( ret );
+        }
+#else
         if( ( ret = mbedtls_cipher_set_padding_mode( &transform->cipher_ctx_enc,
                                              MBEDTLS_PADDING_NONE ) ) != 0 )
         {
@@ -1610,6 +1632,7 @@
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_set_padding_mode", ret );
             return( ret );
         }
+#endif
     }
 #endif /* MBEDTLS_CIPHER_MODE_CBC */
 
@@ -2554,9 +2577,11 @@
     post_avail = rec->buf_len - ( rec->data_len + rec->data_offset );
     MBEDTLS_SSL_DEBUG_BUF( 4, "before encrypt: output payload",
                            data, rec->data_len );
-
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+    mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx );
+#else
     mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_enc );
-
+#endif
     if( rec->data_len > MBEDTLS_SSL_OUT_CONTENT_LEN )
     {
         MBEDTLS_SSL_DEBUG_MSG( 1, ( "Record content %u too large, maximum %d",
@@ -2671,7 +2696,25 @@
         MBEDTLS_SSL_DEBUG_MSG( 3, ( "before encrypt: msglen = %d, "
                                     "including %d bytes of padding",
                                     rec->data_len, 0 ) );
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
+                                           transform->key_enc,
+                                           transform->key_bitlen,
+                                           MBEDTLS_ENCRYPT ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            return( ret );
+        }
 
+        if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx,
+                                   transform->iv_enc, transform->ivlen,
+                                   data, rec->data_len,
+                                   data, &olen ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_crypt", ret );
+            return( ret );
+        }
+#else
         if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx_enc,
                                    transform->iv_enc, transform->ivlen,
                                    data, rec->data_len,
@@ -2680,7 +2723,7 @@
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_crypt", ret );
             return( ret );
         }
-
+#endif
         if( rec->data_len != olen )
         {
             MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) );
@@ -2754,7 +2797,27 @@
         /*
          * Encrypt and authenticate
          */
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
+                                           transform->key_enc,
+                                           transform->key_bitlen,
+                                           MBEDTLS_ENCRYPT ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            return( ret );
+        }
 
+        if( ( ret = mbedtls_cipher_auth_encrypt( &transform->cipher_ctx,
+                   iv, transform->ivlen,
+                   add_data, add_data_len,       /* add data     */
+                   data, rec->data_len,          /* source       */
+                   data, &rec->data_len,         /* destination  */
+                   data + rec->data_len, transform->taglen ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_auth_encrypt", ret );
+            return( ret );
+        }
+#else
         if( ( ret = mbedtls_cipher_auth_encrypt( &transform->cipher_ctx_enc,
                    iv, transform->ivlen,
                    add_data, add_data_len,       /* add data     */
@@ -2765,7 +2828,7 @@
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_auth_encrypt", ret );
             return( ret );
         }
-
+#endif
         MBEDTLS_SSL_DEBUG_BUF( 4, "after encrypt: tag",
                                data + rec->data_len, transform->taglen );
 
@@ -2841,7 +2904,26 @@
                             "including %d bytes of IV and %d bytes of padding",
                             rec->data_len, transform->ivlen,
                             padlen + 1 ) );
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
+                                           transform->key_enc,
+                                           transform->key_bitlen,
+                                           MBEDTLS_ENCRYPT ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            return( ret );
+        }
 
+        if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx,
+                                   transform->iv_enc,
+                                   transform->ivlen,
+                                   data, rec->data_len,
+                                   data, &olen ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_crypt", ret );
+            return( ret );
+        }
+#else
         if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx_enc,
                                    transform->iv_enc,
                                    transform->ivlen,
@@ -2851,7 +2933,7 @@
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_crypt", ret );
             return( ret );
         }
-
+#endif
         if( rec->data_len != olen )
         {
             MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) );
@@ -2866,8 +2948,13 @@
             /*
              * Save IV in SSL3 and TLS1
              */
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+            mbedtls_platform_memcpy( transform->iv_enc, transform->cipher_ctx.iv,
+                    transform->ivlen );
+#else
             mbedtls_platform_memcpy( transform->iv_enc, transform->cipher_ctx_enc.iv,
                     transform->ivlen );
+#endif
         }
         else
 #endif
@@ -2968,8 +3055,11 @@
     }
 
     data = rec->buf + rec->data_offset;
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+    mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx );
+#else
     mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_dec );
-
+#endif
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
     /*
      * Match record's CID with incoming CID.
@@ -2985,6 +3075,25 @@
     if( mode == MBEDTLS_MODE_STREAM )
     {
         padlen = 0;
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
+                                           transform->key_dec,
+                                           transform->key_bitlen,
+                                           MBEDTLS_DECRYPT ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            return( ret );
+        }
+        if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx,
+                                   transform->iv_dec,
+                                   transform->ivlen,
+                                   data, rec->data_len,
+                                   data, &olen ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_crypt", ret );
+            return( ret );
+        }
+#else
         if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx_dec,
                                    transform->iv_dec,
                                    transform->ivlen,
@@ -2994,7 +3103,7 @@
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_crypt", ret );
             return( ret );
         }
-
+#endif
         if( rec->data_len != olen )
         {
             MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) );
@@ -3082,6 +3191,31 @@
         /*
          * Decrypt and authenticate
          */
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
+                                           transform->key_dec,
+                                           transform->key_bitlen,
+                                           MBEDTLS_DECRYPT ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            return( ret );
+        }
+        if( ( ret = mbedtls_cipher_auth_decrypt( &transform->cipher_ctx,
+                  iv, transform->ivlen,
+                  add_data, add_data_len,
+                  data, rec->data_len,
+                  data, &olen,
+                  data + rec->data_len,
+                  transform->taglen ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_auth_decrypt", ret );
+
+            if( ret == MBEDTLS_ERR_CIPHER_AUTH_FAILED )
+                return( MBEDTLS_ERR_SSL_INVALID_MAC );
+
+            return( ret );
+        }
+#else
         if( ( ret = mbedtls_cipher_auth_decrypt( &transform->cipher_ctx_dec,
                   iv, transform->ivlen,
                   add_data, add_data_len,
@@ -3097,6 +3231,8 @@
 
             return( ret );
         }
+#endif
+
         auth_done++;
 
         /* Double-check that AEAD decryption doesn't change content length. */
@@ -3239,7 +3375,23 @@
 #endif /* MBEDTLS_SSL_PROTO_TLS1_1 || MBEDTLS_SSL_PROTO_TLS1_2 */
 
         /* We still have data_len % ivlen == 0 and data_len >= ivlen here. */
-
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
+                                           transform->key_dec,
+                                           transform->key_bitlen,
+                                           MBEDTLS_DECRYPT ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            return( ret );
+        }
+        if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx,
+                                   transform->iv_dec, transform->ivlen,
+                                   data, rec->data_len, data, &olen ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_crypt", ret );
+            return( ret );
+        }
+#else
         if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx_dec,
                                    transform->iv_dec, transform->ivlen,
                                    data, rec->data_len, data, &olen ) ) != 0 )
@@ -3247,7 +3399,7 @@
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_crypt", ret );
             return( ret );
         }
-
+#endif
         /* Double-check that length hasn't changed during decryption. */
         if( rec->data_len != olen )
         {
@@ -3266,8 +3418,13 @@
              * of the records; in other words, IVs are maintained across
              * record decryptions.
              */
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+            mbedtls_platform_memcpy( transform->iv_dec, transform->cipher_ctx.iv,
+                    transform->ivlen );
+#else
             mbedtls_platform_memcpy( transform->iv_dec, transform->cipher_ctx_dec.iv,
                     transform->ivlen );
+#endif
         }
 #endif
 
@@ -8495,9 +8652,12 @@
 void mbedtls_ssl_transform_init( mbedtls_ssl_transform *transform )
 {
     memset( transform, 0, sizeof(mbedtls_ssl_transform) );
-
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+    mbedtls_cipher_init( &transform->cipher_ctx );
+#else
     mbedtls_cipher_init( &transform->cipher_ctx_enc );
     mbedtls_cipher_init( &transform->cipher_ctx_dec );
+#endif
 
 #if defined(MBEDTLS_SSL_SOME_MODES_USE_MAC)
     mbedtls_md_init( &transform->md_ctx_enc );
@@ -9866,8 +10026,11 @@
     if( ssl->session_out->compression != MBEDTLS_SSL_COMPRESS_NULL )
         return( MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE );
 #endif
-
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+    switch( mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx ) )
+#else
     switch( mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_enc ) )
+#endif
     {
 #if defined(MBEDTLS_GCM_C)        || \
     defined(MBEDTLS_CCM_C)        || \
@@ -9898,10 +10061,13 @@
         case MBEDTLS_MODE_CBC:
         {
             size_t block_size;
-
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+            block_size = mbedtls_cipher_get_block_size(
+                &transform->cipher_ctx );
+#else
             block_size = mbedtls_cipher_get_block_size(
                 &transform->cipher_ctx_enc );
-
+#endif
             /* Expansion due to the addition of the MAC. */
             transform_expansion += transform->maclen;
 
@@ -11371,8 +11537,13 @@
         mbedtls_ssl_ver_gt(
             mbedtls_ssl_get_minor_ver( ssl ),
             MBEDTLS_SSL_MINOR_VERSION_1 ) ||
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+        mbedtls_cipher_get_cipher_mode( &ssl->transform_out->cipher_ctx )
+                                != MBEDTLS_MODE_CBC )
+#else
         mbedtls_cipher_get_cipher_mode( &ssl->transform_out->cipher_ctx_enc )
                                 != MBEDTLS_MODE_CBC )
+#endif
     {
         return( ssl_write_real( ssl, buf, len ) );
     }
@@ -11486,10 +11657,16 @@
     deflateEnd( &transform->ctx_deflate );
     inflateEnd( &transform->ctx_inflate );
 #endif
-
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+    mbedtls_cipher_free( &transform->cipher_ctx );
+    if( transform->key_dec != NULL )
+        mbedtls_free( transform->key_dec );
+    if( transform->key_enc != NULL )
+        mbedtls_free( transform->key_enc );
+#else
     mbedtls_cipher_free( &transform->cipher_ctx_enc );
     mbedtls_cipher_free( &transform->cipher_ctx_dec );
-
+#endif
 #if defined(MBEDTLS_SSL_SOME_MODES_USE_MAC)
     mbedtls_md_free( &transform->md_ctx_enc );
     mbedtls_md_free( &transform->md_ctx_dec );