Merge pull request #8534 from paul-elliott-arm/fix_mutex_abstraction

Make mutex abstraction and tests thread safe
diff --git a/include/mbedtls/threading.h b/include/mbedtls/threading.h
index ed16a23..b504233 100644
--- a/include/mbedtls/threading.h
+++ b/include/mbedtls/threading.h
@@ -28,10 +28,14 @@
 #include <pthread.h>
 typedef struct mbedtls_threading_mutex_t {
     pthread_mutex_t MBEDTLS_PRIVATE(mutex);
-    /* is_valid is 0 after a failed init or a free, and nonzero after a
-     * successful init. This field is not considered part of the public
-     * API of Mbed TLS and may change without notice. */
-    char MBEDTLS_PRIVATE(is_valid);
+
+    /* WARNING - state should only be accessed when holding the mutex lock in
+     * tests/src/threading_helpers.c, otherwise corruption can occur.
+     * state will be 0 after a failed init or a free, and nonzero after a
+     * successful init. This field is for testing only and thus not considered
+     * part of the public API of Mbed TLS and may change without notice.*/
+    char MBEDTLS_PRIVATE(state);
+
 } mbedtls_threading_mutex_t;
 #endif
 
diff --git a/library/threading.c b/library/threading.c
index 52fe8fc..873b507 100644
--- a/library/threading.c
+++ b/library/threading.c
@@ -56,28 +56,27 @@
         return;
     }
 
-    /* A nonzero value of is_valid indicates a successfully initialized
-     * mutex. This is a workaround for not being able to return an error
-     * code for this function. The lock/unlock functions return an error
-     * if is_valid is nonzero. The Mbed TLS unit test code uses this field
-     * to distinguish more states of the mutex; see
-     * tests/src/threading_helpers for details. */
-    mutex->is_valid = pthread_mutex_init(&mutex->mutex, NULL) == 0;
+    /* One problem here is that calling lock on a pthread mutex without first
+     * having initialised it is undefined behaviour. Obviously we cannot check
+     * this here in a thread safe manner without a significant performance
+     * hit, so state transitions are checked in tests only via the state
+     * variable. Please make sure any new mutex that gets added is exercised in
+     * tests; see tests/src/threading_helpers.c for more details. */
+    (void) pthread_mutex_init(&mutex->mutex, NULL);
 }
 
 static void threading_mutex_free_pthread(mbedtls_threading_mutex_t *mutex)
 {
-    if (mutex == NULL || !mutex->is_valid) {
+    if (mutex == NULL) {
         return;
     }
 
     (void) pthread_mutex_destroy(&mutex->mutex);
-    mutex->is_valid = 0;
 }
 
 static int threading_mutex_lock_pthread(mbedtls_threading_mutex_t *mutex)
 {
-    if (mutex == NULL || !mutex->is_valid) {
+    if (mutex == NULL) {
         return MBEDTLS_ERR_THREADING_BAD_INPUT_DATA;
     }
 
@@ -90,7 +89,7 @@
 
 static int threading_mutex_unlock_pthread(mbedtls_threading_mutex_t *mutex)
 {
-    if (mutex == NULL || !mutex->is_valid) {
+    if (mutex == NULL) {
         return MBEDTLS_ERR_THREADING_BAD_INPUT_DATA;
     }
 
diff --git a/programs/ssl/ssl_test_lib.c b/programs/ssl/ssl_test_lib.c
index 6e0c615..b49dd67 100644
--- a/programs/ssl/ssl_test_lib.c
+++ b/programs/ssl/ssl_test_lib.c
@@ -435,6 +435,9 @@
 
 void test_hooks_free(void)
 {
+#if defined(MBEDTLS_TEST_MUTEX_USAGE)
+    mbedtls_test_mutex_usage_end();
+#endif
 }
 
 #endif /* MBEDTLS_TEST_HOOKS */
diff --git a/tests/include/test/helpers.h b/tests/include/test/helpers.h
index cf791e2..7c962a2 100644
--- a/tests/include/test/helpers.h
+++ b/tests/include/test/helpers.h
@@ -255,10 +255,18 @@
 #endif
 
 #if defined(MBEDTLS_TEST_MUTEX_USAGE)
-/** Permanently activate the mutex usage verification framework. See
- * threading_helpers.c for information. */
+/**
+ *  Activate the mutex usage verification framework. See threading_helpers.c for
+ *  information.
+ *  */
 void mbedtls_test_mutex_usage_init(void);
 
+/**
+ *  Deactivate the mutex usage verification framework. See threading_helpers.c
+ *  for information.
+ */
+void mbedtls_test_mutex_usage_end(void);
+
 /** Call this function after executing a test case to check for mutex usage
  * errors. */
 void mbedtls_test_mutex_usage_check(void);
diff --git a/tests/src/threading_helpers.c b/tests/src/threading_helpers.c
index 6f405b0..5fbf65b 100644
--- a/tests/src/threading_helpers.c
+++ b/tests/src/threading_helpers.c
@@ -58,15 +58,15 @@
  * indicate the exact location of the problematic call. To locate the error,
  * use a debugger and set a breakpoint on mbedtls_test_mutex_usage_error().
  */
-enum value_of_mutex_is_valid_field {
-    /* Potential values for the is_valid field of mbedtls_threading_mutex_t.
+enum value_of_mutex_state_field {
+    /* Potential values for the state field of mbedtls_threading_mutex_t.
      * Note that MUTEX_FREED must be 0 and MUTEX_IDLE must be 1 for
      * compatibility with threading_mutex_init_pthread() and
      * threading_mutex_free_pthread(). MUTEX_LOCKED could be any nonzero
      * value. */
-    MUTEX_FREED = 0, //!< Set by threading_mutex_free_pthread
-    MUTEX_IDLE = 1, //!< Set by threading_mutex_init_pthread and by our unlock
-    MUTEX_LOCKED = 2, //!< Set by our lock
+    MUTEX_FREED = 0, //! < Set by mbedtls_test_wrap_mutex_free
+    MUTEX_IDLE = 1, //! < Set by mbedtls_test_wrap_mutex_init and by mbedtls_test_wrap_mutex_unlock
+    MUTEX_LOCKED = 2, //! < Set by mbedtls_test_wrap_mutex_lock
 };
 
 typedef struct {
@@ -77,10 +77,30 @@
 } mutex_functions_t;
 static mutex_functions_t mutex_functions;
 
-/** The total number of calls to mbedtls_mutex_init(), minus the total number
- * of calls to mbedtls_mutex_free().
+/**
+ *  The mutex used to guard live_mutexes below and access to the status variable
+ *  in every mbedtls_threading_mutex_t.
+ *  Note that we are not reporting any errors when locking and unlocking this
+ *  mutex. This is for a couple of reasons:
  *
- * Reset to 0 after each test case.
+ *  1. We have no real way of reporting any errors with this mutex - we cannot
+ *  report it back to the caller, as the failure was not that of the mutex
+ *  passed in. We could fail the test, but again this would indicate a problem
+ *  with the test code that did not exist.
+ *
+ *  2. Any failure to lock is unlikely to be intermittent, and will thus not
+ *  give false test results - the overall result would be to turn off the
+ *  testing. This is not a situation that is likely to happen with normal
+ *  testing and we still have TSan to fall back on should this happen.
+ */
+mbedtls_threading_mutex_t mbedtls_test_mutex_mutex;
+
+/**
+ *  The total number of calls to mbedtls_mutex_init(), minus the total number
+ *  of calls to mbedtls_mutex_free().
+ *
+ *  Do not read or write without holding mbedtls_test_mutex_mutex (above). Reset
+ *  to 0 after each test case.
  */
 static int live_mutexes;
 
@@ -88,6 +108,7 @@
                                            const char *msg)
 {
     (void) mutex;
+
     if (mbedtls_test_info.mutex_usage_error == NULL) {
         mbedtls_test_info.mutex_usage_error = msg;
     }
@@ -101,76 +122,92 @@
 static void mbedtls_test_wrap_mutex_init(mbedtls_threading_mutex_t *mutex)
 {
     mutex_functions.init(mutex);
-    if (mutex->is_valid) {
+
+    if (mutex_functions.lock(&mbedtls_test_mutex_mutex) == 0) {
+        mutex->state = MUTEX_IDLE;
         ++live_mutexes;
+
+        mutex_functions.unlock(&mbedtls_test_mutex_mutex);
     }
 }
 
 static void mbedtls_test_wrap_mutex_free(mbedtls_threading_mutex_t *mutex)
 {
-    switch (mutex->is_valid) {
-        case MUTEX_FREED:
-            mbedtls_test_mutex_usage_error(mutex, "free without init or double free");
-            break;
-        case MUTEX_IDLE:
-            /* Do nothing. The underlying free function will reset is_valid
-             * to 0. */
-            break;
-        case MUTEX_LOCKED:
-            mbedtls_test_mutex_usage_error(mutex, "free without unlock");
-            break;
-        default:
-            mbedtls_test_mutex_usage_error(mutex, "corrupted state");
-            break;
-    }
-    if (mutex->is_valid) {
-        --live_mutexes;
+    if (mutex_functions.lock(&mbedtls_test_mutex_mutex) == 0) {
+
+        switch (mutex->state) {
+            case MUTEX_FREED:
+                mbedtls_test_mutex_usage_error(mutex, "free without init or double free");
+                break;
+            case MUTEX_IDLE:
+                mutex->state = MUTEX_FREED;
+                --live_mutexes;
+                break;
+            case MUTEX_LOCKED:
+                mbedtls_test_mutex_usage_error(mutex, "free without unlock");
+                break;
+            default:
+                mbedtls_test_mutex_usage_error(mutex, "corrupted state");
+                break;
+        }
+
+        mutex_functions.unlock(&mbedtls_test_mutex_mutex);
     }
     mutex_functions.free(mutex);
 }
 
 static int mbedtls_test_wrap_mutex_lock(mbedtls_threading_mutex_t *mutex)
 {
+    /* Lock the passed in mutex first, so that the only way to change the state
+     * is to hold the passed in and internal mutex - otherwise we create a race
+     * condition. */
     int ret = mutex_functions.lock(mutex);
-    switch (mutex->is_valid) {
-        case MUTEX_FREED:
-            mbedtls_test_mutex_usage_error(mutex, "lock without init");
-            break;
-        case MUTEX_IDLE:
-            if (ret == 0) {
-                mutex->is_valid = 2;
-            }
-            break;
-        case MUTEX_LOCKED:
-            mbedtls_test_mutex_usage_error(mutex, "double lock");
-            break;
-        default:
-            mbedtls_test_mutex_usage_error(mutex, "corrupted state");
-            break;
+    if (mutex_functions.lock(&mbedtls_test_mutex_mutex) == 0) {
+        switch (mutex->state) {
+            case MUTEX_FREED:
+                mbedtls_test_mutex_usage_error(mutex, "lock without init");
+                break;
+            case MUTEX_IDLE:
+                if (ret == 0) {
+                    mutex->state = MUTEX_LOCKED;
+                }
+                break;
+            case MUTEX_LOCKED:
+                mbedtls_test_mutex_usage_error(mutex, "double lock");
+                break;
+            default:
+                mbedtls_test_mutex_usage_error(mutex, "corrupted state");
+                break;
+        }
+
+        mutex_functions.unlock(&mbedtls_test_mutex_mutex);
     }
     return ret;
 }
 
 static int mbedtls_test_wrap_mutex_unlock(mbedtls_threading_mutex_t *mutex)
 {
-    int ret = mutex_functions.unlock(mutex);
-    switch (mutex->is_valid) {
-        case MUTEX_FREED:
-            mbedtls_test_mutex_usage_error(mutex, "unlock without init");
-            break;
-        case MUTEX_IDLE:
-            mbedtls_test_mutex_usage_error(mutex, "unlock without lock");
-            break;
-        case MUTEX_LOCKED:
-            if (ret == 0) {
-                mutex->is_valid = MUTEX_IDLE;
-            }
-            break;
-        default:
-            mbedtls_test_mutex_usage_error(mutex, "corrupted state");
-            break;
+    /* Lock the internal mutex first and change state, so that the only way to
+     * change the state is to hold the passed in and internal mutex - otherwise
+     * we create a race condition. */
+    if (mutex_functions.lock(&mbedtls_test_mutex_mutex) == 0) {
+        switch (mutex->state) {
+            case MUTEX_FREED:
+                mbedtls_test_mutex_usage_error(mutex, "unlock without init");
+                break;
+            case MUTEX_IDLE:
+                mbedtls_test_mutex_usage_error(mutex, "unlock without lock");
+                break;
+            case MUTEX_LOCKED:
+                mutex->state = MUTEX_IDLE;
+                break;
+            default:
+                mbedtls_test_mutex_usage_error(mutex, "corrupted state");
+                break;
+        }
+        mutex_functions.unlock(&mbedtls_test_mutex_mutex);
     }
-    return ret;
+    return mutex_functions.unlock(mutex);
 }
 
 void mbedtls_test_mutex_usage_init(void)
@@ -183,6 +220,8 @@
     mbedtls_mutex_free = &mbedtls_test_wrap_mutex_free;
     mbedtls_mutex_lock = &mbedtls_test_wrap_mutex_lock;
     mbedtls_mutex_unlock = &mbedtls_test_wrap_mutex_unlock;
+
+    mutex_functions.init(&mbedtls_test_mutex_mutex);
 }
 
 void mbedtls_test_mutex_usage_check(void)
@@ -207,4 +246,14 @@
     mbedtls_test_info.mutex_usage_error = NULL;
 }
 
+void mbedtls_test_mutex_usage_end(void)
+{
+    mbedtls_mutex_init = mutex_functions.init;
+    mbedtls_mutex_free = mutex_functions.free;
+    mbedtls_mutex_lock = mutex_functions.lock;
+    mbedtls_mutex_unlock = mutex_functions.unlock;
+
+    mutex_functions.free(&mbedtls_test_mutex_mutex);
+}
+
 #endif /* MBEDTLS_TEST_MUTEX_USAGE */
diff --git a/tests/suites/host_test.function b/tests/suites/host_test.function
index d8ff49e..cc28697 100644
--- a/tests/suites/host_test.function
+++ b/tests/suites/host_test.function
@@ -772,6 +772,10 @@
     mbedtls_fprintf(stdout, " (%u / %u tests (%u skipped))\n",
                     total_tests - total_errors, total_tests, total_skipped);
 
+#if defined(MBEDTLS_TEST_MUTEX_USAGE)
+    mbedtls_test_mutex_usage_end();
+#endif
+
 #if defined(MBEDTLS_MEMORY_BUFFER_ALLOC_C) && \
     !defined(TEST_SUITE_MEMORY_BUFFER_ALLOC)
 #if defined(MBEDTLS_MEMORY_DEBUG)