Merge branch 'development-restricted' into mac_buffer_protection

Signed-off-by: tom-daubney-arm <74920390+tom-daubney-arm@users.noreply.github.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 5b7a838..7afd7fc 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -2705,35 +2705,48 @@
 }
 
 psa_status_t psa_mac_update(psa_mac_operation_t *operation,
-                            const uint8_t *input,
+                            const uint8_t *input_external,
                             size_t input_length)
 {
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_INPUT_DECLARE(input_external, input);
+
     if (operation->id == 0) {
-        return PSA_ERROR_BAD_STATE;
+        status = PSA_ERROR_BAD_STATE;
+        return status;
     }
 
     /* Don't require hash implementations to behave correctly on a
      * zero-length input, which may have an invalid pointer. */
     if (input_length == 0) {
-        return PSA_SUCCESS;
+        status = PSA_SUCCESS;
+        return status;
     }
 
-    psa_status_t status = psa_driver_wrapper_mac_update(operation,
-                                                        input, input_length);
+    LOCAL_INPUT_ALLOC(input_external, input_length, input);
+    status = psa_driver_wrapper_mac_update(operation, input, input_length);
+
     if (status != PSA_SUCCESS) {
         psa_mac_abort(operation);
     }
 
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+    LOCAL_INPUT_FREE(input_external, input);
+
     return status;
 }
 
 psa_status_t psa_mac_sign_finish(psa_mac_operation_t *operation,
-                                 uint8_t *mac,
+                                 uint8_t *mac_external,
                                  size_t mac_size,
                                  size_t *mac_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_OUTPUT_DECLARE(mac_external, mac);
+    LOCAL_OUTPUT_ALLOC(mac_external, mac_size, mac);
 
     if (operation->id == 0) {
         status = PSA_ERROR_BAD_STATE;
@@ -2757,6 +2770,7 @@
         goto exit;
     }
 
+
     status = psa_driver_wrapper_mac_sign_finish(operation,
                                                 mac, operation->mac_size,
                                                 mac_length);
@@ -2773,19 +2787,23 @@
         operation->mac_size = 0;
     }
 
-    psa_wipe_tag_output_buffer(mac, status, mac_size, *mac_length);
+    if (mac != NULL) {
+        psa_wipe_tag_output_buffer(mac, status, mac_size, *mac_length);
+    }
 
     abort_status = psa_mac_abort(operation);
+    LOCAL_OUTPUT_FREE(mac_external, mac);
 
     return status == PSA_SUCCESS ? abort_status : status;
 }
 
 psa_status_t psa_mac_verify_finish(psa_mac_operation_t *operation,
-                                   const uint8_t *mac,
+                                   const uint8_t *mac_external,
                                    size_t mac_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_INPUT_DECLARE(mac_external, mac);
 
     if (operation->id == 0) {
         status = PSA_ERROR_BAD_STATE;
@@ -2802,11 +2820,13 @@
         goto exit;
     }
 
+    LOCAL_INPUT_ALLOC(mac_external, mac_length, mac);
     status = psa_driver_wrapper_mac_verify_finish(operation,
                                                   mac, mac_length);
 
 exit:
     abort_status = psa_mac_abort(operation);
+    LOCAL_INPUT_FREE(mac_external, mac);
 
     return status == PSA_SUCCESS ? abort_status : status;
 }
@@ -2878,28 +2898,45 @@
 
 psa_status_t psa_mac_compute(mbedtls_svc_key_id_t key,
                              psa_algorithm_t alg,
-                             const uint8_t *input,
+                             const uint8_t *input_external,
                              size_t input_length,
-                             uint8_t *mac,
+                             uint8_t *mac_external,
                              size_t mac_size,
                              size_t *mac_length)
 {
-    return psa_mac_compute_internal(key, alg,
-                                    input, input_length,
-                                    mac, mac_size, mac_length, 1);
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_INPUT_DECLARE(input_external, input);
+    LOCAL_OUTPUT_DECLARE(mac_external, mac);
+
+    LOCAL_INPUT_ALLOC(input_external, input_length, input);
+    LOCAL_OUTPUT_ALLOC(mac_external, mac_size, mac);
+    status = psa_mac_compute_internal(key, alg,
+                                      input, input_length,
+                                      mac, mac_size, mac_length, 1);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+    LOCAL_INPUT_FREE(input_external, input);
+    LOCAL_OUTPUT_FREE(mac_external, mac);
+
+    return status;
 }
 
 psa_status_t psa_mac_verify(mbedtls_svc_key_id_t key,
                             psa_algorithm_t alg,
-                            const uint8_t *input,
+                            const uint8_t *input_external,
                             size_t input_length,
-                            const uint8_t *mac,
+                            const uint8_t *mac_external,
                             size_t mac_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     uint8_t actual_mac[PSA_MAC_MAX_SIZE];
     size_t actual_mac_length;
+    LOCAL_INPUT_DECLARE(input_external, input);
+    LOCAL_INPUT_DECLARE(mac_external, mac);
 
+    LOCAL_INPUT_ALLOC(input_external, input_length, input);
     status = psa_mac_compute_internal(key, alg,
                                       input, input_length,
                                       actual_mac, sizeof(actual_mac),
@@ -2912,6 +2949,8 @@
         status = PSA_ERROR_INVALID_SIGNATURE;
         goto exit;
     }
+
+    LOCAL_INPUT_ALLOC(mac_external, mac_length, mac);
     if (mbedtls_ct_memcmp(mac, actual_mac, actual_mac_length) != 0) {
         status = PSA_ERROR_INVALID_SIGNATURE;
         goto exit;
@@ -2919,6 +2958,8 @@
 
 exit:
     mbedtls_platform_zeroize(actual_mac, sizeof(actual_mac));
+    LOCAL_INPUT_FREE(input_external, input);
+    LOCAL_INPUT_FREE(mac_external, mac);
 
     return status;
 }
diff --git a/tests/scripts/generate_psa_wrappers.py b/tests/scripts/generate_psa_wrappers.py
index cdc798d..ef28cbf 100755
--- a/tests/scripts/generate_psa_wrappers.py
+++ b/tests/scripts/generate_psa_wrappers.py
@@ -164,6 +164,12 @@
                              'psa_hash_compute',
                              'psa_hash_compare'):
             return True
+        if function_name in ('psa_mac_update',
+                             'psa_mac_sign_finish',
+                             'psa_mac_verify_finish',
+                             'psa_mac_compute',
+                             'psa_mac_verify'):
+            return True
         return False
 
     def _write_function_call(self, out: typing_util.Writable,
diff --git a/tests/src/psa_test_wrappers.c b/tests/src/psa_test_wrappers.c
index 5f0a3dd..4824e38 100644
--- a/tests/src/psa_test_wrappers.c
+++ b/tests/src/psa_test_wrappers.c
@@ -816,7 +816,15 @@
     size_t arg5_mac_size,
     size_t *arg6_mac_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg4_mac, arg5_mac_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_mac_compute)(arg0_key, arg1_alg, arg2_input, arg3_input_length, arg4_mac, arg5_mac_size, arg6_mac_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg4_mac, arg5_mac_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -827,7 +835,13 @@
     size_t arg2_mac_size,
     size_t *arg3_mac_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg1_mac, arg2_mac_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_mac_sign_finish)(arg0_operation, arg1_mac, arg2_mac_size, arg3_mac_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg1_mac, arg2_mac_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -847,7 +861,13 @@
     const uint8_t *arg1_input,
     size_t arg2_input_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg1_input, arg2_input_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_mac_update)(arg0_operation, arg1_input, arg2_input_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg1_input, arg2_input_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -860,7 +880,15 @@
     const uint8_t *arg4_mac,
     size_t arg5_mac_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg4_mac, arg5_mac_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_mac_verify)(arg0_key, arg1_alg, arg2_input, arg3_input_length, arg4_mac, arg5_mac_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg4_mac, arg5_mac_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -870,7 +898,13 @@
     const uint8_t *arg1_mac,
     size_t arg2_mac_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg1_mac, arg2_mac_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_mac_verify_finish)(arg0_operation, arg1_mac, arg2_mac_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg1_mac, arg2_mac_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }