Add per-function override for AES
diff --git a/ChangeLog b/ChangeLog
index 7ef5184..10b6820 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -4,9 +4,9 @@
Features
* Support for DTLS 1.0 and 1.2 (RFC 6347).
- * Ability to override xxx_process() function from a md/sha module with
- custom implementation (eg hardware accelerated), complementing the ability
- to override the whole module.
+ * Ability to override core functions from MDx, SHAx, AES and DES modules
+ with custom implementation (eg hardware accelerated), complementing the
+ ability to override the whole module.
API Changes
* All public identifiers moved to the mbedtls_* or MBEDTLS_* namespace.
diff --git a/include/mbedtls/aes.h b/include/mbedtls/aes.h
index ffd6d3b..1ddde82 100644
--- a/include/mbedtls/aes.h
+++ b/include/mbedtls/aes.h
@@ -246,6 +246,32 @@
unsigned char *output );
#endif /* MBEDTLS_CIPHER_MODE_CTR */
+/**
+ * \brief Internal AES block encryption function
+ * (Only exposed to allow overriding it,
+ * see MBEDTLS_AES_ENCRYPT_ALT)
+ *
+ * \param ctx AES context
+ * \param input Plaintext block
+ * \param output Output (ciphertext) block
+ */
+void mbedtls_aes_encrypt( mbedtls_aes_context *ctx,
+ const unsigned char input[16],
+ unsigned char output[16] );
+
+/**
+ * \brief Internal AES block decryption function
+ * (Only exposed to allow overriding it,
+ * see MBEDTLS_AES_DECRYPT_ALT)
+ *
+ * \param ctx AES context
+ * \param input Ciphertext block
+ * \param output Output (plaintext) block
+ */
+void mbedtls_aes_decrypt( mbedtls_aes_context *ctx,
+ const unsigned char input[16],
+ unsigned char output[16] );
+
#ifdef __cplusplus
}
#endif
diff --git a/include/mbedtls/config.h b/include/mbedtls/config.h
index db40dbf..fc86056 100644
--- a/include/mbedtls/config.h
+++ b/include/mbedtls/config.h
@@ -241,6 +241,10 @@
* of mbedtls_sha1_context, so your implementation of mbedtls_sha1_process must be compatible
* with this definition.
*
+ * Note: if you use the AES_xxx_ALT macros, then is is recommended to also set
+ * MBEDTLS_AES_ROM_TABLES in order to help the linker garbage-collect the AES
+ * tables.
+ *
* Uncomment a macro to enable alternate implementation of the corresponding
* function.
*/
@@ -254,6 +258,10 @@
//#define MBEDTLS_DES_SETKEY_ALT
//#define MBEDTLS_DES_CRYPT_ECB_ALT
//#define MBEDTLS_DES3_CRYPT_ECB_ALT
+//#define MBEDTLS_AES_SETKEY_ENC_ALT
+//#define MBEDTLS_AES_SETKEY_DEC_ALT
+//#define MBEDTLS_AES_ENCRYPT_ALT
+//#define MBEDTLS_AES_DECRYPT_ALT
/**
* \def MBEDTLS_AES_ROM_TABLES
diff --git a/library/aes.c b/library/aes.c
index 9780e8f..3983ae3 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -481,6 +481,7 @@
/*
* AES key schedule (encryption)
*/
+#if !defined(MBEDTLS_AES_SETKEY_ENC_ALT)
int mbedtls_aes_setkey_enc( mbedtls_aes_context *ctx, const unsigned char *key,
unsigned int keysize )
{
@@ -589,10 +590,12 @@
return( 0 );
}
+#endif /* !MBEDTLS_AES_SETKEY_ENC_ALT */
/*
* AES key schedule (decryption)
*/
+#if !defined(MBEDTLS_AES_SETKEY_DEC_ALT)
int mbedtls_aes_setkey_dec( mbedtls_aes_context *ctx, const unsigned char *key,
unsigned int keysize )
{
@@ -656,6 +659,7 @@
return( ret );
}
+#endif /* !MBEDTLS_AES_SETKEY_DEC_ALT */
#define AES_FROUND(X0,X1,X2,X3,Y0,Y1,Y2,Y3) \
{ \
@@ -704,6 +708,120 @@
}
/*
+ * AES-ECB block encryption
+ */
+#if !defined(MBEDTLS_AES_ENCRYPT_ALT)
+void mbedtls_aes_encrypt( mbedtls_aes_context *ctx,
+ const unsigned char input[16],
+ unsigned char output[16] )
+{
+ int i;
+ uint32_t *RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3;
+
+ RK = ctx->rk;
+
+ GET_UINT32_LE( X0, input, 0 ); X0 ^= *RK++;
+ GET_UINT32_LE( X1, input, 4 ); X1 ^= *RK++;
+ GET_UINT32_LE( X2, input, 8 ); X2 ^= *RK++;
+ GET_UINT32_LE( X3, input, 12 ); X3 ^= *RK++;
+
+ for( i = ( ctx->nr >> 1 ) - 1; i > 0; i-- )
+ {
+ AES_FROUND( Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
+ AES_FROUND( X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
+ }
+
+ AES_FROUND( Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
+
+ X0 = *RK++ ^ \
+ ( (uint32_t) FSb[ ( Y0 ) & 0xFF ] ) ^
+ ( (uint32_t) FSb[ ( Y1 >> 8 ) & 0xFF ] << 8 ) ^
+ ( (uint32_t) FSb[ ( Y2 >> 16 ) & 0xFF ] << 16 ) ^
+ ( (uint32_t) FSb[ ( Y3 >> 24 ) & 0xFF ] << 24 );
+
+ X1 = *RK++ ^ \
+ ( (uint32_t) FSb[ ( Y1 ) & 0xFF ] ) ^
+ ( (uint32_t) FSb[ ( Y2 >> 8 ) & 0xFF ] << 8 ) ^
+ ( (uint32_t) FSb[ ( Y3 >> 16 ) & 0xFF ] << 16 ) ^
+ ( (uint32_t) FSb[ ( Y0 >> 24 ) & 0xFF ] << 24 );
+
+ X2 = *RK++ ^ \
+ ( (uint32_t) FSb[ ( Y2 ) & 0xFF ] ) ^
+ ( (uint32_t) FSb[ ( Y3 >> 8 ) & 0xFF ] << 8 ) ^
+ ( (uint32_t) FSb[ ( Y0 >> 16 ) & 0xFF ] << 16 ) ^
+ ( (uint32_t) FSb[ ( Y1 >> 24 ) & 0xFF ] << 24 );
+
+ X3 = *RK++ ^ \
+ ( (uint32_t) FSb[ ( Y3 ) & 0xFF ] ) ^
+ ( (uint32_t) FSb[ ( Y0 >> 8 ) & 0xFF ] << 8 ) ^
+ ( (uint32_t) FSb[ ( Y1 >> 16 ) & 0xFF ] << 16 ) ^
+ ( (uint32_t) FSb[ ( Y2 >> 24 ) & 0xFF ] << 24 );
+
+ PUT_UINT32_LE( X0, output, 0 );
+ PUT_UINT32_LE( X1, output, 4 );
+ PUT_UINT32_LE( X2, output, 8 );
+ PUT_UINT32_LE( X3, output, 12 );
+}
+#endif /* !MBEDTLS_AES_ENCRYPT_ALT */
+
+/*
+ * AES-ECB block decryption
+ */
+#if !defined(MBEDTLS_AES_DECRYPT_ALT)
+void mbedtls_aes_decrypt( mbedtls_aes_context *ctx,
+ const unsigned char input[16],
+ unsigned char output[16] )
+{
+ int i;
+ uint32_t *RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3;
+
+ RK = ctx->rk;
+
+ GET_UINT32_LE( X0, input, 0 ); X0 ^= *RK++;
+ GET_UINT32_LE( X1, input, 4 ); X1 ^= *RK++;
+ GET_UINT32_LE( X2, input, 8 ); X2 ^= *RK++;
+ GET_UINT32_LE( X3, input, 12 ); X3 ^= *RK++;
+
+ for( i = ( ctx->nr >> 1 ) - 1; i > 0; i-- )
+ {
+ AES_RROUND( Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
+ AES_RROUND( X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
+ }
+
+ AES_RROUND( Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
+
+ X0 = *RK++ ^ \
+ ( (uint32_t) RSb[ ( Y0 ) & 0xFF ] ) ^
+ ( (uint32_t) RSb[ ( Y3 >> 8 ) & 0xFF ] << 8 ) ^
+ ( (uint32_t) RSb[ ( Y2 >> 16 ) & 0xFF ] << 16 ) ^
+ ( (uint32_t) RSb[ ( Y1 >> 24 ) & 0xFF ] << 24 );
+
+ X1 = *RK++ ^ \
+ ( (uint32_t) RSb[ ( Y1 ) & 0xFF ] ) ^
+ ( (uint32_t) RSb[ ( Y0 >> 8 ) & 0xFF ] << 8 ) ^
+ ( (uint32_t) RSb[ ( Y3 >> 16 ) & 0xFF ] << 16 ) ^
+ ( (uint32_t) RSb[ ( Y2 >> 24 ) & 0xFF ] << 24 );
+
+ X2 = *RK++ ^ \
+ ( (uint32_t) RSb[ ( Y2 ) & 0xFF ] ) ^
+ ( (uint32_t) RSb[ ( Y1 >> 8 ) & 0xFF ] << 8 ) ^
+ ( (uint32_t) RSb[ ( Y0 >> 16 ) & 0xFF ] << 16 ) ^
+ ( (uint32_t) RSb[ ( Y3 >> 24 ) & 0xFF ] << 24 );
+
+ X3 = *RK++ ^ \
+ ( (uint32_t) RSb[ ( Y3 ) & 0xFF ] ) ^
+ ( (uint32_t) RSb[ ( Y2 >> 8 ) & 0xFF ] << 8 ) ^
+ ( (uint32_t) RSb[ ( Y1 >> 16 ) & 0xFF ] << 16 ) ^
+ ( (uint32_t) RSb[ ( Y0 >> 24 ) & 0xFF ] << 24 );
+
+ PUT_UINT32_LE( X0, output, 0 );
+ PUT_UINT32_LE( X1, output, 4 );
+ PUT_UINT32_LE( X2, output, 8 );
+ PUT_UINT32_LE( X3, output, 12 );
+}
+#endif /* !MBEDTLS_AES_DECRYPT_ALT */
+
+/*
* AES-ECB block encryption/decryption
*/
int mbedtls_aes_crypt_ecb( mbedtls_aes_context *ctx,
@@ -711,9 +829,6 @@
const unsigned char input[16],
unsigned char output[16] )
{
- int i;
- uint32_t *RK, X0, X1, X2, X3, Y0, Y1, Y2, Y3;
-
#if defined(MBEDTLS_AESNI_C) && defined(MBEDTLS_HAVE_X86_64)
if( mbedtls_aesni_supports( MBEDTLS_AESNI_AES ) )
return( mbedtls_aesni_crypt_ecb( ctx, mode, input, output ) );
@@ -731,86 +846,10 @@
}
#endif
- RK = ctx->rk;
-
- GET_UINT32_LE( X0, input, 0 ); X0 ^= *RK++;
- GET_UINT32_LE( X1, input, 4 ); X1 ^= *RK++;
- GET_UINT32_LE( X2, input, 8 ); X2 ^= *RK++;
- GET_UINT32_LE( X3, input, 12 ); X3 ^= *RK++;
-
- if( mode == MBEDTLS_AES_DECRYPT )
- {
- for( i = ( ctx->nr >> 1 ) - 1; i > 0; i-- )
- {
- AES_RROUND( Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
- AES_RROUND( X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
- }
-
- AES_RROUND( Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
-
- X0 = *RK++ ^ \
- ( (uint32_t) RSb[ ( Y0 ) & 0xFF ] ) ^
- ( (uint32_t) RSb[ ( Y3 >> 8 ) & 0xFF ] << 8 ) ^
- ( (uint32_t) RSb[ ( Y2 >> 16 ) & 0xFF ] << 16 ) ^
- ( (uint32_t) RSb[ ( Y1 >> 24 ) & 0xFF ] << 24 );
-
- X1 = *RK++ ^ \
- ( (uint32_t) RSb[ ( Y1 ) & 0xFF ] ) ^
- ( (uint32_t) RSb[ ( Y0 >> 8 ) & 0xFF ] << 8 ) ^
- ( (uint32_t) RSb[ ( Y3 >> 16 ) & 0xFF ] << 16 ) ^
- ( (uint32_t) RSb[ ( Y2 >> 24 ) & 0xFF ] << 24 );
-
- X2 = *RK++ ^ \
- ( (uint32_t) RSb[ ( Y2 ) & 0xFF ] ) ^
- ( (uint32_t) RSb[ ( Y1 >> 8 ) & 0xFF ] << 8 ) ^
- ( (uint32_t) RSb[ ( Y0 >> 16 ) & 0xFF ] << 16 ) ^
- ( (uint32_t) RSb[ ( Y3 >> 24 ) & 0xFF ] << 24 );
-
- X3 = *RK++ ^ \
- ( (uint32_t) RSb[ ( Y3 ) & 0xFF ] ) ^
- ( (uint32_t) RSb[ ( Y2 >> 8 ) & 0xFF ] << 8 ) ^
- ( (uint32_t) RSb[ ( Y1 >> 16 ) & 0xFF ] << 16 ) ^
- ( (uint32_t) RSb[ ( Y0 >> 24 ) & 0xFF ] << 24 );
- }
- else /* MBEDTLS_AES_ENCRYPT */
- {
- for( i = ( ctx->nr >> 1 ) - 1; i > 0; i-- )
- {
- AES_FROUND( Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
- AES_FROUND( X0, X1, X2, X3, Y0, Y1, Y2, Y3 );
- }
-
- AES_FROUND( Y0, Y1, Y2, Y3, X0, X1, X2, X3 );
-
- X0 = *RK++ ^ \
- ( (uint32_t) FSb[ ( Y0 ) & 0xFF ] ) ^
- ( (uint32_t) FSb[ ( Y1 >> 8 ) & 0xFF ] << 8 ) ^
- ( (uint32_t) FSb[ ( Y2 >> 16 ) & 0xFF ] << 16 ) ^
- ( (uint32_t) FSb[ ( Y3 >> 24 ) & 0xFF ] << 24 );
-
- X1 = *RK++ ^ \
- ( (uint32_t) FSb[ ( Y1 ) & 0xFF ] ) ^
- ( (uint32_t) FSb[ ( Y2 >> 8 ) & 0xFF ] << 8 ) ^
- ( (uint32_t) FSb[ ( Y3 >> 16 ) & 0xFF ] << 16 ) ^
- ( (uint32_t) FSb[ ( Y0 >> 24 ) & 0xFF ] << 24 );
-
- X2 = *RK++ ^ \
- ( (uint32_t) FSb[ ( Y2 ) & 0xFF ] ) ^
- ( (uint32_t) FSb[ ( Y3 >> 8 ) & 0xFF ] << 8 ) ^
- ( (uint32_t) FSb[ ( Y0 >> 16 ) & 0xFF ] << 16 ) ^
- ( (uint32_t) FSb[ ( Y1 >> 24 ) & 0xFF ] << 24 );
-
- X3 = *RK++ ^ \
- ( (uint32_t) FSb[ ( Y3 ) & 0xFF ] ) ^
- ( (uint32_t) FSb[ ( Y0 >> 8 ) & 0xFF ] << 8 ) ^
- ( (uint32_t) FSb[ ( Y1 >> 16 ) & 0xFF ] << 16 ) ^
- ( (uint32_t) FSb[ ( Y2 >> 24 ) & 0xFF ] << 24 );
- }
-
- PUT_UINT32_LE( X0, output, 0 );
- PUT_UINT32_LE( X1, output, 4 );
- PUT_UINT32_LE( X2, output, 8 );
- PUT_UINT32_LE( X3, output, 12 );
+ if( mode == MBEDTLS_AES_ENCRYPT )
+ mbedtls_aes_encrypt( ctx, input, output );
+ else
+ mbedtls_aes_decrypt( ctx, input, output );
return( 0 );
}