Add param validation for mbedtls_aes_crypt_xts()
diff --git a/include/mbedtls/aes.h b/include/mbedtls/aes.h
index 0f8934f..1bfa434 100644
--- a/include/mbedtls/aes.h
+++ b/include/mbedtls/aes.h
@@ -325,6 +325,7 @@
  *             returns #MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH.
  *
  * \param ctx          The AES XTS context to use for AES XTS operations.
+ *                     It must be initialized and bound to a key.
  * \param mode         The AES operation: #MBEDTLS_AES_ENCRYPT or
  *                     #MBEDTLS_AES_DECRYPT.
  * \param length       The length of a data unit in bytes. This can be any
diff --git a/library/aes.c b/library/aes.c
index 2da86c7..c15022b 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -1182,6 +1182,12 @@
     unsigned char prev_tweak[16];
     unsigned char tmp[16];
 
+    AES_VALIDATE_RET( ctx != NULL );
+    AES_VALIDATE_RET( mode == MBEDTLS_AES_ENCRYPT ||
+                      mode == MBEDTLS_AES_DECRYPT );
+    AES_VALIDATE_RET( input != NULL );
+    AES_VALIDATE_RET( output != NULL );
+
     /* Data units must be at least 16 bytes long. */
     if( length < 16 )
         return MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH;
diff --git a/tests/suites/test_suite_aes.function b/tests/suites/test_suite_aes.function
index d21a41d..bcffe37 100644
--- a/tests/suites/test_suite_aes.function
+++ b/tests/suites/test_suite_aes.function
@@ -194,8 +194,8 @@
 void aes_crypt_xts_size( int size, int retval )
 {
     mbedtls_aes_xts_context ctx;
-    const unsigned char *src = NULL;
-    unsigned char *output = NULL;
+    const unsigned char src[16] = { 0 };
+    unsigned char output[16];
     unsigned char data_unit[16];
     size_t length = size;
 
@@ -203,10 +203,8 @@
     memset( data_unit, 0x00, sizeof( data_unit ) );
 
 
-    /* Note that this function will most likely crash on failure, as NULL
-     * parameters will be used. In the passing case, the length check in
-     * mbedtls_aes_crypt_xts() will prevent any accesses to parameters by
-     * exiting the function early. */
+    /* Valid pointers are passed for builds with MBEDTLS_CHECK_PARAMS, as
+     * otherwise we wouldn't get to the size check we're interested in. */
     TEST_ASSERT( mbedtls_aes_crypt_xts( &ctx, MBEDTLS_AES_ENCRYPT, length, data_unit, src, output ) == retval );
 }
 /* END_CASE */
@@ -445,6 +443,29 @@
                                                    MBEDTLS_AES_ENCRYPT, 16,
                                                    out, in, NULL ) );
 #endif /* MBEDTLS_CIPHER_MODE_CBC */
+
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+    TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+                            mbedtls_aes_crypt_xts( NULL,
+                                                   MBEDTLS_AES_ENCRYPT, 16,
+                                                   in, in, out ) );
+    TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+                            mbedtls_aes_crypt_xts( &xts_ctx,
+                                                   42, 16,
+                                                   in, in, out ) );
+    TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+                            mbedtls_aes_crypt_xts( &xts_ctx,
+                                                   MBEDTLS_AES_ENCRYPT, 16,
+                                                   NULL, in, out ) );
+    TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+                            mbedtls_aes_crypt_xts( &xts_ctx,
+                                                   MBEDTLS_AES_ENCRYPT, 16,
+                                                   in, NULL, out ) );
+    TEST_INVALID_PARAM_RET( MBEDTLS_ERR_AES_BAD_INPUT_DATA,
+                            mbedtls_aes_crypt_xts( &xts_ctx,
+                                                   MBEDTLS_AES_ENCRYPT, 16,
+                                                   in, in, NULL ) );
+#endif /* MBEDTLS_CIPHER_MODE_XTS */
 }
 /* END_CASE */
 
@@ -452,6 +473,9 @@
 void aes_misc_params( )
 {
     mbedtls_aes_context aes_ctx;
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+    mbedtls_aes_xts_context xts_ctx;
+#endif
     const unsigned char in[16] = { 0 };
     unsigned char out[16];
 
@@ -463,13 +487,25 @@
 
 #if defined(MBEDTLS_CIPHER_MODE_CBC)
     TEST_ASSERT( mbedtls_aes_crypt_cbc( &aes_ctx, MBEDTLS_AES_ENCRYPT,
-                                        15, out, in, out )
+                                        15,
+                                        out, in, out )
                  == MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH );
     TEST_ASSERT( mbedtls_aes_crypt_cbc( &aes_ctx, MBEDTLS_AES_ENCRYPT,
-                                        17, out, in, out )
+                                        17,
+                                        out, in, out )
                  == MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH );
 #endif
 
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+    TEST_ASSERT( mbedtls_aes_crypt_xts( &xts_ctx, MBEDTLS_AES_ENCRYPT,
+                                        15,
+                                        in, in, out )
+                 == MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH );
+    TEST_ASSERT( mbedtls_aes_crypt_xts( &xts_ctx, MBEDTLS_AES_ENCRYPT,
+                                        (1 << 24) + 1,
+                                        in, in, out )
+                 == MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH );
+#endif
 }
 /* END_CASE */