Merge pull request #8901 from paul-elliott-arm/make_psa_global_data_safe

Make PSA global_data thread safe
diff --git a/include/mbedtls/threading.h b/include/mbedtls/threading.h
index b4e0502..d50d04e 100644
--- a/include/mbedtls/threading.h
+++ b/include/mbedtls/threading.h
@@ -112,6 +112,20 @@
  * psa_key_slot_state_transition(), psa_register_read(), psa_unregister_read(),
  * psa_key_slot_has_readers() and psa_wipe_key_slot(). */
 extern mbedtls_threading_mutex_t mbedtls_threading_key_slot_mutex;
+
+/*
+ * A mutex used to make the non-rng PSA global_data struct members thread safe.
+ *
+ * This mutex must be held when reading or writing to any of the PSA global_data
+ * structure members, other than the rng_state or rng struct. */
+extern mbedtls_threading_mutex_t mbedtls_threading_psa_globaldata_mutex;
+
+/*
+ * A mutex used to make the PSA global_data rng data thread safe.
+ *
+ * This mutex must be held when reading or writing to the PSA
+ * global_data rng_state or rng struct members. */
+extern mbedtls_threading_mutex_t mbedtls_threading_psa_rngdata_mutex;
 #endif
 
 #endif /* MBEDTLS_THREADING_C */
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index ec9d115..a0a002a 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -71,6 +71,7 @@
 #include "mbedtls/sha256.h"
 #include "mbedtls/sha512.h"
 #include "mbedtls/psa_util.h"
+#include "mbedtls/threading.h"
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_HKDF) ||          \
     defined(MBEDTLS_PSA_BUILTIN_ALG_HKDF_EXTRACT) ||  \
@@ -92,30 +93,93 @@
 #define RNG_INITIALIZED 1
 #define RNG_SEEDED 2
 
+/* IDs for PSA crypto subsystems. Starts at 1 to catch potential uninitialized
+ * variables as arguments. */
+typedef enum {
+    PSA_CRYPTO_SUBSYSTEM_DRIVER_WRAPPERS = 1,
+    PSA_CRYPTO_SUBSYSTEM_KEY_SLOTS,
+    PSA_CRYPTO_SUBSYSTEM_RNG,
+    PSA_CRYPTO_SUBSYSTEM_TRANSACTION,
+} mbedtls_psa_crypto_subsystem;
+
+/* Initialization flags for global_data::initialized */
+#define PSA_CRYPTO_SUBSYSTEM_DRIVER_WRAPPERS_INITIALIZED    0x01
+#define PSA_CRYPTO_SUBSYSTEM_KEY_SLOTS_INITIALIZED          0x02
+#define PSA_CRYPTO_SUBSYSTEM_TRANSACTION_INITIALIZED        0x04
+
+#define PSA_CRYPTO_SUBSYSTEM_ALL_INITIALISED                ( \
+        PSA_CRYPTO_SUBSYSTEM_DRIVER_WRAPPERS_INITIALIZED | \
+        PSA_CRYPTO_SUBSYSTEM_KEY_SLOTS_INITIALIZED | \
+        PSA_CRYPTO_SUBSYSTEM_TRANSACTION_INITIALIZED)
+
 typedef struct {
     uint8_t initialized;
     uint8_t rng_state;
-    uint8_t drivers_initialized;
     mbedtls_psa_random_context_t rng;
 } psa_global_data_t;
 
 static psa_global_data_t global_data;
 
+static uint8_t psa_get_initialized(void)
+{
+    uint8_t initialized;
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_lock(&mbedtls_threading_psa_rngdata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+    initialized = global_data.rng_state == RNG_SEEDED;
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_unlock(&mbedtls_threading_psa_rngdata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_lock(&mbedtls_threading_psa_globaldata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+    initialized =
+        (initialized && (global_data.initialized == PSA_CRYPTO_SUBSYSTEM_ALL_INITIALISED));
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_unlock(&mbedtls_threading_psa_globaldata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+    return initialized;
+}
+
+static uint8_t psa_get_drivers_initialized(void)
+{
+    uint8_t initialized;
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_lock(&mbedtls_threading_psa_globaldata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+    initialized = (global_data.initialized & PSA_CRYPTO_SUBSYSTEM_DRIVER_WRAPPERS_INITIALIZED) != 0;
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_unlock(&mbedtls_threading_psa_globaldata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+    return initialized;
+}
+
 #define GUARD_MODULE_INITIALIZED        \
-    if (global_data.initialized == 0)  \
+    if (psa_get_initialized() == 0)     \
     return PSA_ERROR_BAD_STATE;
 
 int psa_can_do_hash(psa_algorithm_t hash_alg)
 {
     (void) hash_alg;
-    return global_data.drivers_initialized;
+    return psa_get_drivers_initialized();
 }
 
 int psa_can_do_cipher(psa_key_type_t key_type, psa_algorithm_t cipher_alg)
 {
     (void) key_type;
     (void) cipher_alg;
-    return global_data.drivers_initialized;
+    return psa_get_drivers_initialized();
 }
 
 
@@ -7082,6 +7146,9 @@
 #endif
 
 /** Initialize the PSA random generator.
+ *
+ *  Note: the mbedtls_threading_psa_rngdata_mutex should be held when calling
+ *  this function if mutexes are enabled.
  */
 static void mbedtls_psa_random_init(mbedtls_psa_random_context_t *rng)
 {
@@ -7114,6 +7181,9 @@
 }
 
 /** Deinitialize the PSA random generator.
+ *
+ *  Note: the mbedtls_threading_psa_rngdata_mutex should be held when calling
+ *  this function if mutexes are enabled.
  */
 static void mbedtls_psa_random_free(mbedtls_psa_random_context_t *rng)
 {
@@ -7189,7 +7259,7 @@
 psa_status_t mbedtls_psa_inject_entropy(const uint8_t *seed,
                                         size_t seed_size)
 {
-    if (global_data.initialized) {
+    if (psa_get_initialized()) {
         return PSA_ERROR_NOT_PERMITTED;
     }
 
@@ -7431,28 +7501,77 @@
     void (* entropy_init)(mbedtls_entropy_context *ctx),
     void (* entropy_free)(mbedtls_entropy_context *ctx))
 {
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_lock(&mbedtls_threading_psa_rngdata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
     if (global_data.rng_state != RNG_NOT_INITIALIZED) {
-        return PSA_ERROR_BAD_STATE;
+        status = PSA_ERROR_BAD_STATE;
+    } else {
+        global_data.rng.entropy_init = entropy_init;
+        global_data.rng.entropy_free = entropy_free;
+        status = PSA_SUCCESS;
     }
-    global_data.rng.entropy_init = entropy_init;
-    global_data.rng.entropy_free = entropy_free;
-    return PSA_SUCCESS;
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_unlock(&mbedtls_threading_psa_rngdata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+    return status;
 }
 #endif /* !defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG) */
 
 void mbedtls_psa_crypto_free(void)
 {
-    psa_wipe_all_key_slots();
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_lock(&mbedtls_threading_psa_globaldata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+    /* Nothing to do to free transaction. */
+    if (global_data.initialized & PSA_CRYPTO_SUBSYSTEM_TRANSACTION_INITIALIZED) {
+        global_data.initialized &= ~PSA_CRYPTO_SUBSYSTEM_TRANSACTION_INITIALIZED;
+    }
+
+    if (global_data.initialized & PSA_CRYPTO_SUBSYSTEM_KEY_SLOTS_INITIALIZED) {
+        psa_wipe_all_key_slots();
+        global_data.initialized &= ~PSA_CRYPTO_SUBSYSTEM_KEY_SLOTS_INITIALIZED;
+    }
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_unlock(&mbedtls_threading_psa_globaldata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_lock(&mbedtls_threading_psa_rngdata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
     if (global_data.rng_state != RNG_NOT_INITIALIZED) {
         mbedtls_psa_random_free(&global_data.rng);
     }
-    /* Wipe all remaining data, including configuration.
-     * In particular, this sets all state indicator to the value
-     * indicating "uninitialized". */
-    mbedtls_platform_zeroize(&global_data, sizeof(global_data));
+    global_data.rng_state = RNG_NOT_INITIALIZED;
+    mbedtls_platform_zeroize(&global_data.rng, sizeof(global_data.rng));
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_unlock(&mbedtls_threading_psa_rngdata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_lock(&mbedtls_threading_psa_globaldata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
 
     /* Terminate drivers */
-    psa_driver_wrapper_free();
+    if (global_data.initialized & PSA_CRYPTO_SUBSYSTEM_DRIVER_WRAPPERS_INITIALIZED) {
+        psa_driver_wrapper_free();
+        global_data.initialized &= ~PSA_CRYPTO_SUBSYSTEM_DRIVER_WRAPPERS_INITIALIZED;
+    }
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_unlock(&mbedtls_threading_psa_globaldata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
 }
 
 #if defined(PSA_CRYPTO_STORAGE_HAS_TRANSACTIONS)
@@ -7480,57 +7599,171 @@
 }
 #endif /* PSA_CRYPTO_STORAGE_HAS_TRANSACTIONS */
 
+static psa_status_t mbedtls_psa_crypto_init_subsystem(mbedtls_psa_crypto_subsystem subsystem)
+{
+    psa_status_t status = PSA_SUCCESS;
+    uint8_t driver_wrappers_initialized = 0;
+
+    switch (subsystem) {
+        case PSA_CRYPTO_SUBSYSTEM_DRIVER_WRAPPERS:
+
+#if defined(MBEDTLS_THREADING_C)
+            PSA_THREADING_CHK_GOTO_EXIT(mbedtls_mutex_lock(&mbedtls_threading_psa_globaldata_mutex));
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+            if (!(global_data.initialized & PSA_CRYPTO_SUBSYSTEM_DRIVER_WRAPPERS_INITIALIZED)) {
+                /* Init drivers */
+                status = psa_driver_wrapper_init();
+
+                /* Drivers need shutdown regardless of startup errors. */
+                global_data.initialized |= PSA_CRYPTO_SUBSYSTEM_DRIVER_WRAPPERS_INITIALIZED;
+
+
+            }
+#if defined(MBEDTLS_THREADING_C)
+            PSA_THREADING_CHK_GOTO_EXIT(mbedtls_mutex_unlock(
+                                            &mbedtls_threading_psa_globaldata_mutex));
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+            break;
+
+        case PSA_CRYPTO_SUBSYSTEM_KEY_SLOTS:
+
+#if defined(MBEDTLS_THREADING_C)
+            PSA_THREADING_CHK_GOTO_EXIT(mbedtls_mutex_lock(&mbedtls_threading_psa_globaldata_mutex));
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+            if (!(global_data.initialized & PSA_CRYPTO_SUBSYSTEM_KEY_SLOTS_INITIALIZED)) {
+                status = psa_initialize_key_slots();
+
+                /* Need to wipe keys even if initialization fails. */
+                global_data.initialized |= PSA_CRYPTO_SUBSYSTEM_KEY_SLOTS_INITIALIZED;
+
+            }
+#if defined(MBEDTLS_THREADING_C)
+            PSA_THREADING_CHK_GOTO_EXIT(mbedtls_mutex_unlock(
+                                            &mbedtls_threading_psa_globaldata_mutex));
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+            break;
+
+        case PSA_CRYPTO_SUBSYSTEM_RNG:
+
+#if defined(MBEDTLS_THREADING_C)
+            PSA_THREADING_CHK_GOTO_EXIT(mbedtls_mutex_lock(&mbedtls_threading_psa_globaldata_mutex));
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+            driver_wrappers_initialized =
+                (global_data.initialized & PSA_CRYPTO_SUBSYSTEM_DRIVER_WRAPPERS_INITIALIZED);
+
+#if defined(MBEDTLS_THREADING_C)
+            PSA_THREADING_CHK_GOTO_EXIT(mbedtls_mutex_unlock(
+                                            &mbedtls_threading_psa_globaldata_mutex));
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+            /* Need to use separate mutex here, as initialisation can require
+             * testing of init flags, which requires locking the global data
+             * mutex. */
+#if defined(MBEDTLS_THREADING_C)
+            PSA_THREADING_CHK_GOTO_EXIT(mbedtls_mutex_lock(&mbedtls_threading_psa_rngdata_mutex));
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+            /* Initialize and seed the random generator. */
+            if (global_data.rng_state == RNG_NOT_INITIALIZED && driver_wrappers_initialized) {
+                mbedtls_psa_random_init(&global_data.rng);
+                global_data.rng_state = RNG_INITIALIZED;
+
+                status = mbedtls_psa_random_seed(&global_data.rng);
+                if (status == PSA_SUCCESS) {
+                    global_data.rng_state = RNG_SEEDED;
+                }
+            }
+
+#if defined(MBEDTLS_THREADING_C)
+            PSA_THREADING_CHK_GOTO_EXIT(mbedtls_mutex_unlock(
+                                            &mbedtls_threading_psa_rngdata_mutex));
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+            break;
+
+        case PSA_CRYPTO_SUBSYSTEM_TRANSACTION:
+
+#if defined(MBEDTLS_THREADING_C)
+            PSA_THREADING_CHK_GOTO_EXIT(mbedtls_mutex_lock(&mbedtls_threading_psa_globaldata_mutex));
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+            if (!(global_data.initialized & PSA_CRYPTO_SUBSYSTEM_TRANSACTION_INITIALIZED)) {
+#if defined(PSA_CRYPTO_STORAGE_HAS_TRANSACTIONS)
+                status = psa_crypto_load_transaction();
+                if (status == PSA_SUCCESS) {
+                    status = psa_crypto_recover_transaction(&psa_crypto_transaction);
+                    if (status == PSA_SUCCESS) {
+                        global_data.initialized |= PSA_CRYPTO_SUBSYSTEM_TRANSACTION_INITIALIZED;
+                    }
+                    status = psa_crypto_stop_transaction();
+                } else if (status == PSA_ERROR_DOES_NOT_EXIST) {
+                    /* There's no transaction to complete. It's all good. */
+                    global_data.initialized |= PSA_CRYPTO_SUBSYSTEM_TRANSACTION_INITIALIZED;
+                    status = PSA_SUCCESS;
+                }
+#else /* defined(PSA_CRYPTO_STORAGE_HAS_TRANSACTIONS) */
+                global_data.initialized |= PSA_CRYPTO_SUBSYSTEM_TRANSACTION_INITIALIZED;
+                status = PSA_SUCCESS;
+#endif /* defined(PSA_CRYPTO_STORAGE_HAS_TRANSACTIONS) */
+            }
+
+#if defined(MBEDTLS_THREADING_C)
+            PSA_THREADING_CHK_GOTO_EXIT(mbedtls_mutex_unlock(
+                                            &mbedtls_threading_psa_globaldata_mutex));
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+            break;
+
+        default:
+            status = PSA_ERROR_CORRUPTION_DETECTED;
+    }
+
+    /* Exit label only required when using threading macros. */
+#if defined(MBEDTLS_THREADING_C)
+exit:
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+    return status;
+}
+
 psa_status_t psa_crypto_init(void)
 {
     psa_status_t status;
 
-    /* Double initialization is explicitly allowed. */
-    if (global_data.initialized != 0) {
+    /* Double initialization is explicitly allowed. Early out if everything is
+     * done. */
+    if (psa_get_initialized()) {
         return PSA_SUCCESS;
     }
 
-    /* Init drivers */
-    status = psa_driver_wrapper_init();
-    if (status != PSA_SUCCESS) {
-        goto exit;
-    }
-    global_data.drivers_initialized = 1;
-
-    status = psa_initialize_key_slots();
+    status = mbedtls_psa_crypto_init_subsystem(PSA_CRYPTO_SUBSYSTEM_DRIVER_WRAPPERS);
     if (status != PSA_SUCCESS) {
         goto exit;
     }
 
-    /* Initialize and seed the random generator. */
-    mbedtls_psa_random_init(&global_data.rng);
-    global_data.rng_state = RNG_INITIALIZED;
-    status = mbedtls_psa_random_seed(&global_data.rng);
+    status = mbedtls_psa_crypto_init_subsystem(PSA_CRYPTO_SUBSYSTEM_KEY_SLOTS);
     if (status != PSA_SUCCESS) {
         goto exit;
     }
-    global_data.rng_state = RNG_SEEDED;
 
-#if defined(PSA_CRYPTO_STORAGE_HAS_TRANSACTIONS)
-    status = psa_crypto_load_transaction();
-    if (status == PSA_SUCCESS) {
-        status = psa_crypto_recover_transaction(&psa_crypto_transaction);
-        if (status != PSA_SUCCESS) {
-            goto exit;
-        }
-        status = psa_crypto_stop_transaction();
-    } else if (status == PSA_ERROR_DOES_NOT_EXIST) {
-        /* There's no transaction to complete. It's all good. */
-        status = PSA_SUCCESS;
+    status = mbedtls_psa_crypto_init_subsystem(PSA_CRYPTO_SUBSYSTEM_RNG);
+    if (status != PSA_SUCCESS) {
+        goto exit;
     }
-#endif /* PSA_CRYPTO_STORAGE_HAS_TRANSACTIONS */
 
-    /* All done. */
-    global_data.initialized = 1;
+    status = mbedtls_psa_crypto_init_subsystem(PSA_CRYPTO_SUBSYSTEM_TRANSACTION);
 
 exit:
+
     if (status != PSA_SUCCESS) {
         mbedtls_psa_crypto_free();
     }
+
     return status;
 }
 
diff --git a/library/psa_crypto_slot_management.c b/library/psa_crypto_slot_management.c
index 5dee32f..b184ed0 100644
--- a/library/psa_crypto_slot_management.c
+++ b/library/psa_crypto_slot_management.c
@@ -34,6 +34,23 @@
 
 static psa_global_data_t global_data;
 
+static uint8_t psa_get_key_slots_initialized(void)
+{
+    uint8_t initialized;
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_lock(&mbedtls_threading_psa_globaldata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+    initialized = global_data.key_slots_initialized;
+
+#if defined(MBEDTLS_THREADING_C)
+    mbedtls_mutex_unlock(&mbedtls_threading_psa_globaldata_mutex);
+#endif /* defined(MBEDTLS_THREADING_C) */
+
+    return initialized;
+}
+
 int psa_is_valid_key_id(mbedtls_svc_key_id_t key, int vendor_ok)
 {
     psa_key_id_t key_id = MBEDTLS_SVC_KEY_ID_GET_KEY_ID(key);
@@ -136,7 +153,9 @@
 {
     /* Nothing to do: program startup and psa_wipe_all_key_slots() both
      * guarantee that the key slots are initialized to all-zero, which
-     * means that all the key slots are in a valid, empty state. */
+     * means that all the key slots are in a valid, empty state. The global
+     * data mutex is already held when calling this function, so no need to
+     * lock it here, to set the flag. */
     global_data.key_slots_initialized = 1;
     return PSA_SUCCESS;
 }
@@ -151,6 +170,7 @@
         slot->state = PSA_SLOT_PENDING_DELETION;
         (void) psa_wipe_key_slot(slot);
     }
+    /* The global data mutex is already held when calling this function. */
     global_data.key_slots_initialized = 0;
 }
 
@@ -161,7 +181,7 @@
     size_t slot_idx;
     psa_key_slot_t *selected_slot, *unused_persistent_key_slot;
 
-    if (!global_data.key_slots_initialized) {
+    if (!psa_get_key_slots_initialized()) {
         status = PSA_ERROR_BAD_STATE;
         goto error;
     }
@@ -344,7 +364,7 @@
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
     *p_slot = NULL;
-    if (!global_data.key_slots_initialized) {
+    if (!psa_get_key_slots_initialized()) {
         return PSA_ERROR_BAD_STATE;
     }
 
diff --git a/library/threading.c b/library/threading.c
index c28290f..85db243 100644
--- a/library/threading.c
+++ b/library/threading.c
@@ -150,6 +150,8 @@
 #endif
 #if defined(MBEDTLS_PSA_CRYPTO_C)
     mbedtls_mutex_init(&mbedtls_threading_key_slot_mutex);
+    mbedtls_mutex_init(&mbedtls_threading_psa_globaldata_mutex);
+    mbedtls_mutex_init(&mbedtls_threading_psa_rngdata_mutex);
 #endif
 }
 
@@ -166,6 +168,8 @@
 #endif
 #if defined(MBEDTLS_PSA_CRYPTO_C)
     mbedtls_mutex_free(&mbedtls_threading_key_slot_mutex);
+    mbedtls_mutex_free(&mbedtls_threading_psa_globaldata_mutex);
+    mbedtls_mutex_free(&mbedtls_threading_psa_rngdata_mutex);
 #endif
 }
 #endif /* MBEDTLS_THREADING_ALT */
@@ -184,6 +188,8 @@
 #endif
 #if defined(MBEDTLS_PSA_CRYPTO_C)
 mbedtls_threading_mutex_t mbedtls_threading_key_slot_mutex MUTEX_INIT;
+mbedtls_threading_mutex_t mbedtls_threading_psa_globaldata_mutex MUTEX_INIT;
+mbedtls_threading_mutex_t mbedtls_threading_psa_rngdata_mutex MUTEX_INIT;
 #endif
 
 #endif /* MBEDTLS_THREADING_C */
diff --git a/tests/suites/test_suite_psa_crypto_init.data b/tests/suites/test_suite_psa_crypto_init.data
index 8c5b41d..147d03f 100644
--- a/tests/suites/test_suite_psa_crypto_init.data
+++ b/tests/suites/test_suite_psa_crypto_init.data
@@ -10,6 +10,9 @@
 PSA deinit twice
 deinit_without_init:1
 
+PSA threaded init checks
+psa_threaded_init:100
+
 No random without init
 validate_module_init_generate_random:0
 
diff --git a/tests/suites/test_suite_psa_crypto_init.function b/tests/suites/test_suite_psa_crypto_init.function
index 7a43432..9ff33a6 100644
--- a/tests/suites/test_suite_psa_crypto_init.function
+++ b/tests/suites/test_suite_psa_crypto_init.function
@@ -1,6 +1,7 @@
 /* BEGIN_HEADER */
 #include <stdint.h>
 
+#include "psa_crypto_core.h"
 /* Some tests in this module configure entropy sources. */
 #include "psa_crypto_invasive.h"
 
@@ -112,6 +113,59 @@
 
 #endif /* !defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG) */
 
+#if defined MBEDTLS_THREADING_PTHREAD
+
+typedef struct {
+    int do_init;
+} thread_psa_init_ctx_t;
+
+static void *thread_psa_init_function(void *ctx)
+{
+    thread_psa_init_ctx_t *init_context = (thread_psa_init_ctx_t *) ctx;
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    uint8_t random[10] = { 0 };
+
+    if (init_context->do_init) {
+        PSA_ASSERT(psa_crypto_init());
+    }
+
+    /* If this is a test only thread, then we can assume PSA is being started
+     * up on another thread and thus we cannot know whether the following tests
+     * will be successful or not. These checks are still useful, however even
+     * without checking the return codes as they may show up race conditions on
+     * the flags they check under TSAN.*/
+
+    /* Test getting if drivers are initialised. */
+    int can_do = psa_can_do_hash(PSA_ALG_NONE);
+
+    if (init_context->do_init) {
+        TEST_ASSERT(can_do == 1);
+    }
+
+#if !defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG)
+
+    /* Test getting global_data.rng_state. */
+    status = mbedtls_psa_crypto_configure_entropy_sources(NULL, NULL);
+
+    if (init_context->do_init) {
+        /* Bad state due to entropy sources already being setup in
+         * psa_crypto_init() */
+        TEST_EQUAL(status, PSA_ERROR_BAD_STATE);
+    }
+#endif /* !defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG) */
+
+    /* Test using the PSA RNG ony if we know PSA is up and running. */
+    if (init_context->do_init) {
+        status = psa_generate_random(random, sizeof(random));
+
+        TEST_EQUAL(status, PSA_SUCCESS);
+    }
+
+exit:
+    return NULL;
+}
+#endif /* defined MBEDTLS_THREADING_PTHREAD */
+
 /* END_HEADER */
 
 /* BEGIN_DEPENDENCIES
@@ -154,6 +208,67 @@
 }
 /* END_CASE */
 
+/* BEGIN_CASE depends_on:MBEDTLS_THREADING_PTHREAD */
+void psa_threaded_init(int arg_thread_count)
+{
+    thread_psa_init_ctx_t init_context;
+    thread_psa_init_ctx_t init_context_2;
+
+    size_t thread_count = (size_t) arg_thread_count;
+    mbedtls_test_thread_t *threads = NULL;
+
+    TEST_CALLOC(threads, sizeof(mbedtls_test_thread_t) * thread_count);
+
+    init_context.do_init = 1;
+
+    /* Test initialising PSA and testing certain protected globals on multiple
+     * threads. */
+    for (size_t i = 0; i < thread_count; i++) {
+        TEST_EQUAL(
+            mbedtls_test_thread_create(&threads[i],
+                                       thread_psa_init_function,
+                                       (void *) &init_context),
+            0);
+    }
+
+    for (size_t i = 0; i < thread_count; i++) {
+        TEST_EQUAL(mbedtls_test_thread_join(&threads[i]), 0);
+    }
+
+    PSA_DONE();
+
+    init_context_2.do_init = 0;
+
+    /* Test initialising PSA whilst also testing flags on other threads. */
+    for (size_t i = 0; i < thread_count; i++) {
+
+        if (i & 1) {
+
+            TEST_EQUAL(
+                mbedtls_test_thread_create(&threads[i],
+                                           thread_psa_init_function,
+                                           (void *) &init_context),
+                0);
+        } else {
+            TEST_EQUAL(
+                mbedtls_test_thread_create(&threads[i],
+                                           thread_psa_init_function,
+                                           (void *) &init_context_2),
+                0);
+        }
+    }
+
+    for (size_t i = 0; i < thread_count; i++) {
+        TEST_EQUAL(mbedtls_test_thread_join(&threads[i]), 0);
+    }
+exit:
+
+    PSA_DONE();
+
+    mbedtls_free(threads);
+}
+/* END_CASE */
+
 /* BEGIN_CASE */
 void validate_module_init_generate_random(int count)
 {