Add param validation for mbedtls_aes_crypt_xts()
diff --git a/include/mbedtls/aes.h b/include/mbedtls/aes.h
index 0f8934f..1bfa434 100644
--- a/include/mbedtls/aes.h
+++ b/include/mbedtls/aes.h
@@ -325,6 +325,7 @@
* returns #MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH.
*
* \param ctx The AES XTS context to use for AES XTS operations.
+ * It must be initialized and bound to a key.
* \param mode The AES operation: #MBEDTLS_AES_ENCRYPT or
* #MBEDTLS_AES_DECRYPT.
* \param length The length of a data unit in bytes. This can be any
diff --git a/library/aes.c b/library/aes.c
index 2da86c7..c15022b 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -1182,6 +1182,12 @@
unsigned char prev_tweak[16];
unsigned char tmp[16];
+ AES_VALIDATE_RET( ctx != NULL );
+ AES_VALIDATE_RET( mode == MBEDTLS_AES_ENCRYPT ||
+ mode == MBEDTLS_AES_DECRYPT );
+ AES_VALIDATE_RET( input != NULL );
+ AES_VALIDATE_RET( output != NULL );
+
/* Data units must be at least 16 bytes long. */
if( length < 16 )
return MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH;
diff --git a/tests/suites/test_suite_aes.function b/tests/suites/test_suite_aes.function
index d21a41d..bcffe37 100644
--- a/tests/suites/test_suite_aes.function
+++ b/tests/suites/test_suite_aes.function
@@ -194,8 +194,8 @@
void aes_crypt_xts_size( int size, int retval )
{
mbedtls_aes_xts_context ctx;
- const unsigned char *src = NULL;
- unsigned char *output = NULL;
+ const unsigned char src[16] = { 0 };
+ unsigned char output[16];
unsigned char data_unit[16];
size_t length = size;
@@ -203,10 +203,8 @@
memset( data_unit, 0x00, sizeof( data_unit ) );
- /* Note that this function will most likely crash on failure, as NULL
- * parameters will be used. In the passing case, the length check in
- * mbedtls_aes_crypt_xts() will prevent any accesses to parameters by
- * exiting the function early. */
+ /* Valid pointers are passed for builds with MBEDTLS_CHECK_PARAMS, as
+ * otherwise we wouldn't get to the size check we're interested in. */
TEST_ASSERT( mbedtls_aes_crypt_xts( &ctx, MBEDTLS_AES_ENCRYPT, length, data_unit, src, output ) == retval );
}
/* END_CASE */
@@ -445,6 +443,29 @@
MBEDTLS_AES_ENCRYPT, 16,
out, in, NULL ) );
#endif /* MBEDTLS_CIPHER_MODE_CBC */
+
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+ TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+ mbedtls_aes_crypt_xts( NULL,
+ MBEDTLS_AES_ENCRYPT, 16,
+ in, in, out ) );
+ TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+ mbedtls_aes_crypt_xts( &xts_ctx,
+ 42, 16,
+ in, in, out ) );
+ TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+ mbedtls_aes_crypt_xts( &xts_ctx,
+ MBEDTLS_AES_ENCRYPT, 16,
+ NULL, in, out ) );
+ TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+ mbedtls_aes_crypt_xts( &xts_ctx,
+ MBEDTLS_AES_ENCRYPT, 16,
+ in, NULL, out ) );
+ TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+ mbedtls_aes_crypt_xts( &xts_ctx,
+ MBEDTLS_AES_ENCRYPT, 16,
+ in, in, NULL ) );
+#endif /* MBEDTLS_CIPHER_MODE_XTS */
}
/* END_CASE */
@@ -452,6 +473,9 @@
void aes_misc_params( )
{
mbedtls_aes_context aes_ctx;
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+ mbedtls_aes_xts_context xts_ctx;
+#endif
const unsigned char in[16] = { 0 };
unsigned char out[16];
@@ -463,13 +487,25 @@
#if defined(MBEDTLS_CIPHER_MODE_CBC)
TEST_ASSERT( mbedtls_aes_crypt_cbc( &aes_ctx, MBEDTLS_AES_ENCRYPT,
- 15, out, in, out )
+ 15,
+ out, in, out )
== MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH );
TEST_ASSERT( mbedtls_aes_crypt_cbc( &aes_ctx, MBEDTLS_AES_ENCRYPT,
- 17, out, in, out )
+ 17,
+ out, in, out )
== MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH );
#endif
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+ TEST_ASSERT( mbedtls_aes_crypt_xts( &xts_ctx, MBEDTLS_AES_ENCRYPT,
+ 15,
+ in, in, out )
+ == MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH );
+ TEST_ASSERT( mbedtls_aes_crypt_xts( &xts_ctx, MBEDTLS_AES_ENCRYPT,
+ (1 << 24) + 1,
+ in, in, out )
+ == MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH );
+#endif
}
/* END_CASE */