Move is_sign and mac_size checking back to PSA core scope

It makes sense to do the length checking in the core rather than expect
each driver to deal with it themselves. This puts the onus on the core to
dictate which algorithm/key combinations are valid before calling a driver.

Additionally, this commit also updates the psa_mac_sign_finish function
to better deal with output buffer sanitation, as per the review comments
on #4247.

Signed-off-by: Steven Cooreman <steven.cooreman@silabs.com>
diff --git a/include/psa/crypto_builtin_composites.h b/include/psa/crypto_builtin_composites.h
index f968c16..1d11b00 100644
--- a/include/psa/crypto_builtin_composites.h
+++ b/include/psa/crypto_builtin_composites.h
@@ -62,8 +62,6 @@
 typedef struct
 {
     psa_algorithm_t alg;
-    unsigned int is_sign : 1;
-    uint8_t mac_size;
     union
     {
         unsigned dummy; /* Make the union non-empty even with no supported algorithms. */
@@ -76,7 +74,7 @@
     } ctx;
 } mbedtls_psa_mac_operation_t;
 
-#define MBEDTLS_PSA_MAC_OPERATION_INIT {0, 0, 0, 0, {0}}
+#define MBEDTLS_PSA_MAC_OPERATION_INIT {0, {0}}
 
 /*
  * BEYOND THIS POINT, TEST DRIVER DECLARATIONS ONLY.
diff --git a/include/psa/crypto_struct.h b/include/psa/crypto_struct.h
index fc7e778..47012fd 100644
--- a/include/psa/crypto_struct.h
+++ b/include/psa/crypto_struct.h
@@ -137,10 +137,12 @@
      * ID value zero means the context is not valid or not assigned to
      * any driver (i.e. none of the driver contexts are active). */
     unsigned int id;
+    uint8_t mac_size;
+    unsigned int is_sign : 1;
     psa_driver_mac_context_t ctx;
 };
 
-#define PSA_MAC_OPERATION_INIT {0, {0}}
+#define PSA_MAC_OPERATION_INIT {0, 0, 0, {0}}
 static inline struct psa_mac_operation_s psa_mac_operation_init( void )
 {
     const struct psa_mac_operation_s v = PSA_MAC_OPERATION_INIT;
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 1d33f6b..4b769e9 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -2240,6 +2240,8 @@
         return( PSA_SUCCESS );
 
     psa_status_t status = psa_driver_wrapper_mac_abort( operation );
+    operation->mac_size = 0;
+    operation->is_sign = 0;
     operation->id = 0;
 
     return( status );
@@ -2253,7 +2255,6 @@
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_slot_t *slot;
-    size_t mac_size;
 
     /* A context must be freshly initialized before it can be set up. */
     if( operation->id != 0 )
@@ -2279,12 +2280,15 @@
     if( status != PSA_SUCCESS )
         goto exit;
 
+    operation->is_sign = is_sign;
+
     /* Get the output length for the algorithm and key combination. None of the
      * currently supported algorithms have an output length dependent on actual
      * key size, so setting it to a bogus value is currently OK. */
-    mac_size = PSA_MAC_LENGTH( psa_get_key_type( &attributes ), 0, alg );
+    operation->mac_size = PSA_MAC_LENGTH(
+                            psa_get_key_type( &attributes ), 0, alg );
 
-    if( mac_size < 4 )
+    if( operation->mac_size < 4 )
     {
         /* A very short MAC is too short for security since it can be
          * brute-forced. Ancient protocols with 32-bit MACs do exist,
@@ -2294,8 +2298,9 @@
         goto exit;
     }
 
-    if( mac_size > PSA_MAC_LENGTH( psa_get_key_type( &attributes ), 0,
-                                   PSA_ALG_FULL_LENGTH_MAC( alg ) ) )
+    if( operation->mac_size > PSA_MAC_LENGTH( psa_get_key_type( &attributes ),
+                                              0,
+                                              PSA_ALG_FULL_LENGTH_MAC( alg ) ) )
     {
         /* It's impossible to "truncate" to a larger length than the full length
          * of the algorithm. */
@@ -2372,26 +2377,45 @@
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED;
 
+    /* Set the output length and content to a safe default, such that in
+     * case the caller misses an error check, the output would be an
+     * unachievable MAC. */
+    *mac_length = mac_size;
+
     if( operation->id == 0 )
         return( PSA_ERROR_BAD_STATE );
 
-    /* Fill the output buffer with something that isn't a valid mac
-     * (barring an attack on the mac and deliberately-crafted input),
-     * in case the caller doesn't check the return status properly. */
-    *mac_length = mac_size;
-    /* If mac_size is 0 then mac may be NULL and then the
-     * call to memset would have undefined behavior. */
-    if( mac_size != 0 )
-        memset( mac, '!', mac_size );
+    if( ! operation->is_sign )
+        return( PSA_ERROR_BAD_STATE );
+
+    /* Sanity checks on output buffer length. */
+    if( mac_size == 0 || mac_size < operation->mac_size )
+        return( PSA_ERROR_BUFFER_TOO_SMALL );
 
     status = psa_driver_wrapper_mac_sign_finish( operation,
-                                                 mac, mac_size, mac_length );
+                                                 mac, operation->mac_size,
+                                                 mac_length );
+
+    if( status == PSA_SUCCESS )
+    {
+        /* Set the excess room in the output buffer to an invalid value, to
+         * avoid potentially leaking a longer MAC. */
+        if( mac_size > operation->mac_size )
+            memset( &mac[operation->mac_size],
+                    '!',
+                    mac_size - operation->mac_size );
+    }
+    else
+    {
+        /* Set the output length and content to a safe default, such that in
+         * case the caller misses an error check, the output would be an
+         * unachievable MAC. */
+        *mac_length = mac_size;
+        memset( mac, '!', mac_size );
+    }
 
     abort_status = psa_mac_abort( operation );
 
-    if( status != PSA_SUCCESS && mac_size > 0 )
-        memset( mac, '!', mac_size );
-
     return( status == PSA_SUCCESS ? abort_status : status );
 }
 
@@ -2405,8 +2429,19 @@
     if( operation->id == 0 )
         return( PSA_ERROR_BAD_STATE );
 
+    if( operation->is_sign )
+        return( PSA_ERROR_BAD_STATE );
+
+    if( operation->mac_size != mac_length )
+    {
+        status = PSA_ERROR_INVALID_SIGNATURE;
+        goto cleanup;
+    }
+
     status = psa_driver_wrapper_mac_verify_finish( operation,
                                                    mac, mac_length );
+
+cleanup:
     abort_status = psa_mac_abort( operation );
 
     return( status == PSA_SUCCESS ? abort_status : status );
@@ -3199,6 +3234,9 @@
     psa_set_key_bits( &attributes, PSA_BYTES_TO_BITS( hmac_key_length ) );
     psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_SIGN_HASH );
 
+    operation->is_sign = 1;
+    operation->mac_size = PSA_HASH_LENGTH( hash_alg );
+
     status = psa_driver_wrapper_mac_sign_setup( operation,
                                                 &attributes,
                                                 hmac_key, hmac_key_length,
diff --git a/library/psa_crypto_mac.c b/library/psa_crypto_mac.c
index 3d7f70b..6753ded 100644
--- a/library/psa_crypto_mac.c
+++ b/library/psa_crypto_mac.c
@@ -229,11 +229,10 @@
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
-    operation->alg = PSA_ALG_FULL_LENGTH_MAC( alg );
-    operation->is_sign = 0;
+    operation->alg = alg;
 
 #if defined(BUILTIN_ALG_CMAC)
-    if( operation->alg == PSA_ALG_CMAC )
+    if( PSA_ALG_FULL_LENGTH_MAC( operation->alg ) == PSA_ALG_CMAC )
     {
         mbedtls_cipher_init( &operation->ctx.cmac );
         status = PSA_SUCCESS;
@@ -269,7 +268,7 @@
     }
     else
 #if defined(BUILTIN_ALG_CMAC)
-    if( operation->alg == PSA_ALG_CMAC )
+    if( PSA_ALG_FULL_LENGTH_MAC( operation->alg ) == PSA_ALG_CMAC )
     {
         mbedtls_cipher_free( &operation->ctx.cmac );
     }
@@ -289,7 +288,6 @@
     }
 
     operation->alg = 0;
-    operation->is_sign = 0;
 
     return( PSA_SUCCESS );
 
@@ -306,8 +304,7 @@
                                const psa_key_attributes_t *attributes,
                                const uint8_t *key_buffer,
                                size_t key_buffer_size,
-                               psa_algorithm_t alg,
-                               int is_sign )
+                               psa_algorithm_t alg )
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
@@ -318,13 +315,6 @@
     status = mac_init( operation, alg );
     if( status != PSA_SUCCESS )
         return( status );
-    operation->is_sign = is_sign;
-
-    /* Get the output length for the algorithm and key combination. None of the
-     * currently supported algorithms have an output length dependent on actual
-     * key size, so setting it to a bogus value is currently OK. */
-    operation->mac_size =
-        PSA_MAC_LENGTH( psa_get_key_type( attributes ), 0, alg );
 
 #if defined(BUILTIN_ALG_CMAC)
     if( PSA_ALG_FULL_LENGTH_MAC( alg ) == PSA_ALG_CMAC )
@@ -340,7 +330,8 @@
     if( PSA_ALG_IS_HMAC( alg ) )
     {
         /* Sanity check. This shouldn't fail on a valid configuration. */
-        if( operation->mac_size > sizeof( operation->ctx.hmac.opad ) )
+        if( PSA_MAC_LENGTH( psa_get_key_type( attributes ), 0, alg ) >
+            sizeof( operation->ctx.hmac.opad ) )
         {
             status = PSA_ERROR_NOT_SUPPORTED;
             goto exit;
@@ -363,7 +354,6 @@
         status = PSA_ERROR_NOT_SUPPORTED;
     }
 
-exit:
     if( status != PSA_SUCCESS )
         mac_abort( operation );
 
@@ -401,8 +391,8 @@
     size_t key_buffer_size,
     psa_algorithm_t alg )
 {
-    return( mac_setup( operation, attributes, key_buffer, key_buffer_size, alg,
-                       1 ) );
+    return( mac_setup( operation,
+                       attributes, key_buffer, key_buffer_size, alg ) );
 }
 
 static psa_status_t mac_verify_setup(
@@ -412,8 +402,8 @@
     size_t key_buffer_size,
     psa_algorithm_t alg )
 {
-    return( mac_setup( operation, attributes, key_buffer, key_buffer_size, alg,
-                        0 ) );
+    return( mac_setup( operation,
+                       attributes, key_buffer, key_buffer_size, alg ) );
 }
 
 static psa_status_t mac_update(
@@ -425,7 +415,7 @@
         return( PSA_ERROR_BAD_STATE );
 
 #if defined(BUILTIN_ALG_CMAC)
-    if( operation->alg == PSA_ALG_CMAC )
+    if( PSA_ALG_FULL_LENGTH_MAC( operation->alg ) == PSA_ALG_CMAC )
     {
         return( mbedtls_to_psa_error(
                     mbedtls_cipher_cmac_update( &operation->ctx.cmac,
@@ -452,16 +442,13 @@
                                          uint8_t *mac,
                                          size_t mac_size )
 {
-    if( mac_size < operation->mac_size )
-        return( PSA_ERROR_BUFFER_TOO_SMALL );
-
 #if defined(BUILTIN_ALG_CMAC)
-    if( operation->alg == PSA_ALG_CMAC )
+    if( PSA_ALG_FULL_LENGTH_MAC( operation->alg ) == PSA_ALG_CMAC )
     {
         uint8_t tmp[PSA_BLOCK_CIPHER_BLOCK_MAX_SIZE];
         int ret = mbedtls_cipher_cmac_finish( &operation->ctx.cmac, tmp );
         if( ret == 0 )
-            memcpy( mac, tmp, operation->mac_size );
+            memcpy( mac, tmp, mac_size );
         mbedtls_platform_zeroize( tmp, sizeof( tmp ) );
         return( mbedtls_to_psa_error( ret ) );
     }
@@ -471,13 +458,16 @@
     if( PSA_ALG_IS_HMAC( operation->alg ) )
     {
         return( psa_hmac_finish_internal( &operation->ctx.hmac,
-                                          mac, operation->mac_size ) );
+                                          mac, mac_size ) );
     }
     else
 #endif /* BUILTIN_ALG_HMAC */
     {
         /* This shouldn't happen if `operation` was initialized by
          * a setup function. */
+        (void) operation;
+        (void) mac;
+        (void) mac_size;
         return( PSA_ERROR_BAD_STATE );
     }
 }
@@ -493,13 +483,10 @@
     if( operation->alg == 0 )
         return( PSA_ERROR_BAD_STATE );
 
-    if( ! operation->is_sign )
-        return( PSA_ERROR_BAD_STATE );
-
     status = mac_finish_internal( operation, mac, mac_size );
 
     if( status == PSA_SUCCESS )
-        *mac_length = operation->mac_size;
+        *mac_length = mac_size;
 
     return( status );
 }
@@ -515,16 +502,11 @@
     if( operation->alg == 0 )
         return( PSA_ERROR_BAD_STATE );
 
-    if( operation->is_sign )
-        return( PSA_ERROR_BAD_STATE );
+    /* Consistency check: requested MAC length fits our local buffer */
+    if( mac_length > sizeof( actual_mac ) )
+        return( PSA_ERROR_INVALID_ARGUMENT );
 
-    if( operation->mac_size != mac_length )
-    {
-        status = PSA_ERROR_INVALID_SIGNATURE;
-        goto cleanup;
-    }
-
-    status = mac_finish_internal( operation, actual_mac, sizeof( actual_mac ) );
+    status = mac_finish_internal( operation, actual_mac, mac_length );
     if( status != PSA_SUCCESS )
         goto cleanup;
 
diff --git a/library/psa_crypto_mac.h b/library/psa_crypto_mac.h
index 4635fe1..9b81e73 100644
--- a/library/psa_crypto_mac.h
+++ b/library/psa_crypto_mac.h
@@ -182,13 +182,15 @@
  *
  * \param[in,out] operation Active MAC operation.
  * \param[out] mac          Buffer where the MAC value is to be written.
- * \param mac_size          Size of the \p mac buffer in bytes.
- * \param[out] mac_length   On success, the number of bytes
- *                          that make up the MAC value. This is always
- *                          #PSA_MAC_LENGTH(\c key_type, \c key_bits, \c alg)
- *                          where \c key_type and \c key_bits are the type and
- *                          bit-size respectively of the key and \c alg is the
- *                          MAC algorithm that is calculated.
+ * \param mac_size          Output size requested for the MAC algorithm. The PSA
+ *                          core guarantees this is a valid MAC length for the
+ *                          algorithm and key combination passed to
+ *                          mbedtls_psa_mac_sign_setup(). It also guarantees the
+ *                          \p mac buffer is large enough to contain the
+ *                          requested output size.
+ * \param[out] mac_length   On success, the number of bytes output to buffer
+ *                          \p mac, which will be equal to the requested length
+ *                          \p mac_size.
  *
  * \retval #PSA_SUCCESS
  *         Success.
@@ -226,7 +228,10 @@
  *
  * \param[in,out] operation Active MAC operation.
  * \param[in] mac           Buffer containing the expected MAC value.
- * \param mac_length        Size of the \p mac buffer in bytes.
+ * \param mac_length        Length in bytes of the expected MAC value. The PSA
+ *                          core guarantees that this length is a valid MAC
+ *                          length for the algorithm and key combination passed
+ *                          to mbedtls_psa_mac_verify_setup().
  *
  * \retval #PSA_SUCCESS
  *         The expected MAC is identical to the actual MAC of the message.