Update after feedback on #3492

* Updated wording
* Split out buffer allocation to a convenience function
* Moved variable declarations to beginning of their code block

Signed-off-by: Steven Cooreman <steven.cooreman@silabs.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 4a3877c..34d4895 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -605,8 +605,8 @@
     pk.pk_ctx = rsa;
 
     /* PSA Crypto API defines the format of an RSA key as a DER-encoded
-     * representation of respectively the non-encrypted PKCS#1 RSAPrivateKey
-     * or the RFC3279 RSAPublicKey for a private key or a public key. */
+     * representation of the non-encrypted PKCS#1 RSAPrivateKey for a
+     * private key and of the RFC3279 RSAPublicKey for a public key. */
     if( PSA_KEY_TYPE_IS_KEY_PAIR( type ) )
         ret = mbedtls_pk_write_key_der( &pk, data, data_size );
     else
@@ -670,8 +670,10 @@
     slot->attr.bits = (psa_key_bits_t) PSA_BYTES_TO_BITS(
         mbedtls_rsa_get_len( &rsa ) );
 
-    /* Re-export the data to PSA export format, which in case of RSA is the
-     * smallest representation we can parse. */
+    /* Re-export the data to PSA export format, such that we can store export
+     * representation in the key slot. Export representation in case of RSA is
+     * the smallest representation that's allowed as input, so a straight-up
+     * allocation of the same size as the input buffer will be large enough. */
     output = mbedtls_calloc( 1, data_length );
 
     if( output == NULL )
@@ -680,11 +682,6 @@
         goto exit;
     }
 
-    /* PSA Crypto API defines the format of an RSA key as a DER-encoded
-     * representation of respectively the non-encrypted PKCS#1 RSAPrivateKey
-     * or the RFC3279 RSAPublicKey for a private key or a public key. That
-     * means we have no other choice then to run an import to verify the key
-     * size. */
     status = psa_export_rsa_key( slot->attr.type,
                                  &rsa,
                                  output,
@@ -905,6 +902,32 @@
     return( slot->attr.bits );
 }
 
+/** Try to allocate a buffer to an empty key slot.
+ *
+ * \param[in,out] slot          Key slot to attach buffer to.
+ * \param[in] buffer_length     Requested size of the buffer.
+ *
+ * \retval #PSA_SUCCESS
+ *         The buffer has been successfully allocated.
+ * \retval #PSA_ERROR_INSUFFICIENT_MEMORY
+ *         Not enough memory was available for allocation.
+ * \retval #PSA_ERROR_ALREADY_EXISTS
+ *         Trying to allocate a buffer to a non-empty key slot.
+ */
+static psa_status_t psa_allocate_buffer_to_slot( psa_key_slot_t *slot,
+                                                 size_t buffer_length )
+{
+    if( slot->data.key.data != NULL )
+        return PSA_ERROR_ALREADY_EXISTS;
+
+    slot->data.key.data = mbedtls_calloc( 1, buffer_length );
+    if( slot->data.key.data == NULL )
+        return PSA_ERROR_INSUFFICIENT_MEMORY;
+
+    slot->data.key.bytes = buffer_length;
+    return PSA_SUCCESS;
+}
+
 /** Import key data into a slot. `slot->attr.type` must have been set
  * previously. This function assumes that the slot does not contain
  * any key material yet. On failure, the slot content is unchanged. */
@@ -918,14 +941,14 @@
     if( data_length == 0 )
         return( PSA_ERROR_NOT_SUPPORTED );
 
-    /* Ensure that the bytes-to-bit conversion never overflows. */
-    if( data_length > SIZE_MAX / 8 )
-        return( PSA_ERROR_NOT_SUPPORTED );
-
     if( key_type_is_raw_bytes( slot->attr.type ) )
     {
         size_t bit_size = PSA_BYTES_TO_BITS( data_length );
 
+        /* Ensure that the bytes-to-bits conversion hasn't overflown. */
+        if( data_length > SIZE_MAX / 8 )
+            return( PSA_ERROR_NOT_SUPPORTED );
+
         /* Enforce a size limit, and in particular ensure that the bit
          * size fits in its representation type. */
         if( bit_size > PSA_MAX_KEY_BITS )
@@ -936,12 +959,9 @@
             return status;
 
         /* Allocate memory for the key */
-        slot->data.key.data = mbedtls_calloc( 1, data_length );
-        if( slot->data.key.data == NULL )
-        {
-            return( PSA_ERROR_INSUFFICIENT_MEMORY );
-        }
-        slot->data.key.bytes = data_length;
+        status = psa_allocate_buffer_to_slot( slot, data_length );
+        if( status != PSA_SUCCESS )
+            return status;
 
         /* copy key into allocated buffer */
         memcpy(slot->data.key.data, data, data_length);
@@ -1135,6 +1155,10 @@
 /** Wipe key data from a slot. Preserve metadata such as the policy. */
 static psa_status_t psa_remove_key_data_from_memory( psa_key_slot_t *slot )
 {
+    /* Check whether key is already clean */
+    if( slot->data.key.data == NULL )
+        return PSA_SUCCESS;
+
 #if defined(MBEDTLS_PSA_CRYPTO_SE_C)
     if( psa_key_slot_is_external( slot ) )
     {
@@ -1958,11 +1982,12 @@
         {
             mbedtls_rsa_context rsa;
             mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+            mbedtls_mpi actual, required;
 
             psa_status_t status = psa_load_rsa_representation( slot, &rsa );
             if( status != PSA_SUCCESS )
                 return status;
-            mbedtls_mpi actual, required;
+
             int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
             mbedtls_mpi_init( &actual );
             mbedtls_mpi_init( &required );
@@ -3808,11 +3833,11 @@
     {
         mbedtls_rsa_context rsa;
         mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
         status = psa_load_rsa_representation( slot, &rsa );
         if( status != PSA_SUCCESS )
             return status;
-        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
         if( output_size < mbedtls_rsa_get_len( &rsa ) )
         {
             mbedtls_rsa_free( &rsa );
@@ -3898,11 +3923,11 @@
     {
         mbedtls_rsa_context rsa;
         mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
         status = psa_load_rsa_representation( slot, &rsa );
         if( status != PSA_SUCCESS )
             return status;
-        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
         if( input_length != mbedtls_rsa_get_len( &rsa ) )
         {
@@ -5773,13 +5798,9 @@
             return( status );
 
         /* Allocate memory for the key */
-        slot->data.key.bytes = PSA_BITS_TO_BYTES( bits );
-        slot->data.key.data = mbedtls_calloc( 1, slot->data.key.bytes );
-        if( slot->data.key.data == NULL )
-        {
-            slot->data.key.bytes = 0;
-            return( PSA_ERROR_INSUFFICIENT_MEMORY );
-        }
+        status = psa_allocate_buffer_to_slot( slot, PSA_BITS_TO_BYTES( bits ) );
+        if( status != PSA_SUCCESS )
+            return status;
 
         status = psa_generate_random( slot->data.key.data,
                                       slot->data.key.bytes );
@@ -5825,11 +5846,11 @@
         /* Make sure to always have an export representation available */
         size_t bytes = PSA_KEY_EXPORT_RSA_KEY_PAIR_MAX_SIZE( bits );
 
-        slot->data.key.data = mbedtls_calloc( 1, bytes );
-        if( slot->data.key.data == NULL )
+        status = psa_allocate_buffer_to_slot( slot, bytes );
+        if( status != PSA_SUCCESS )
         {
             mbedtls_rsa_free( &rsa );
-            return( PSA_ERROR_INSUFFICIENT_MEMORY );
+            return status;
         }
 
         status = psa_export_rsa_key( type,
@@ -5874,14 +5895,14 @@
 
         /* Make sure to always have an export representation available */
         size_t bytes = PSA_BITS_TO_BYTES( bits );
-        slot->data.key.data = mbedtls_calloc( 1, bytes );
-        if( slot->data.key.data == NULL )
+        psa_status_t status = psa_allocate_buffer_to_slot( slot, bytes );
+        if( status != PSA_SUCCESS )
         {
             mbedtls_ecp_keypair_free( &ecp );
-            return( PSA_ERROR_INSUFFICIENT_MEMORY );
+            return status;
         }
-        slot->data.key.bytes = bytes;
-        psa_status_t status = mbedtls_to_psa_error(
+
+        status = mbedtls_to_psa_error(
             mbedtls_ecp_write_key( &ecp, slot->data.key.data, bytes ) );
 
         mbedtls_ecp_keypair_free( &ecp );