ctr_drbg: add alternative PSA implementation when AES_C is not defined

Signed-off-by: Valerio Setti <valerio.setti@nordicsemi.no>
diff --git a/include/mbedtls/ctr_drbg.h b/include/mbedtls/ctr_drbg.h
index d1f19e6..c00756d 100644
--- a/include/mbedtls/ctr_drbg.h
+++ b/include/mbedtls/ctr_drbg.h
@@ -32,7 +32,14 @@
 
 #include "mbedtls/build_info.h"
 
+/* In case AES_C is defined then it is the primary option for backward
+ * compatibility purposes. If that's not available, PSA is used instead */
+#if defined(MBEDTLS_AES_C)
 #include "mbedtls/aes.h"
+#else
+#include "psa/crypto.h"
+#endif
+
 #include "entropy.h"
 
 #if defined(MBEDTLS_THREADING_C)
@@ -150,6 +157,13 @@
 #define MBEDTLS_CTR_DRBG_ENTROPY_NONCE_LEN (MBEDTLS_CTR_DRBG_ENTROPY_LEN + 1) / 2
 #endif
 
+#if !defined(MBEDTLS_AES_C)
+typedef struct mbedtls_ctr_drbg_psa_context {
+    mbedtls_svc_key_id_t key_id;
+    psa_cipher_operation_t operation;
+} mbedtls_ctr_drbg_psa_context;
+#endif
+
 /**
  * \brief          The CTR_DRBG context structure.
  */
@@ -175,7 +189,11 @@
                                                   * This is the maximum number of requests
                                                   * that can be made between reseedings. */
 
+#if defined(MBEDTLS_AES_C)
     mbedtls_aes_context MBEDTLS_PRIVATE(aes_ctx);        /*!< The AES context. */
+#else
+    mbedtls_ctr_drbg_psa_context MBEDTLS_PRIVATE(psa_ctx); /*!< The PSA context. */
+#endif
 
     /*
      * Callbacks (Entropy)
diff --git a/library/ctr_drbg.c b/library/ctr_drbg.c
index cf3816e..da34f95 100644
--- a/library/ctr_drbg.c
+++ b/library/ctr_drbg.c
@@ -24,15 +24,60 @@
 #include <stdio.h>
 #endif
 
+/* Using error translation functions from PSA to MbedTLS */
+#if !defined(MBEDTLS_AES_C)
+#include "psa_util_internal.h"
+#endif
+
 #include "mbedtls/platform.h"
 
+#if !defined(MBEDTLS_AES_C)
+static psa_status_t ctr_drbg_setup_psa_context(mbedtls_ctr_drbg_psa_context *psa_ctx,
+                                               unsigned char *key, size_t key_len)
+{
+    psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT;
+    psa_status_t status;
+
+    psa_set_key_usage_flags(&key_attr, PSA_KEY_USAGE_ENCRYPT);
+    psa_set_key_algorithm(&key_attr, PSA_ALG_ECB_NO_PADDING);
+    psa_set_key_type(&key_attr, PSA_KEY_TYPE_AES);
+    status = psa_import_key(&key_attr, key, key_len, &psa_ctx->key_id);
+    if (status != PSA_SUCCESS) {
+        goto exit;
+    }
+
+    status = psa_cipher_encrypt_setup(&psa_ctx->operation, psa_ctx->key_id, PSA_ALG_ECB_NO_PADDING);
+    if (status != PSA_SUCCESS) {
+        goto exit;
+    }
+
+exit:
+    psa_reset_key_attributes(&key_attr);
+    return status;
+}
+
+static void ctr_drbg_destroy_psa_contex(mbedtls_ctr_drbg_psa_context *psa_ctx)
+{
+    psa_cipher_abort(&psa_ctx->operation);
+    psa_destroy_key(psa_ctx->key_id);
+
+    psa_ctx->operation = psa_cipher_operation_init();
+    psa_ctx->key_id = MBEDTLS_SVC_KEY_ID_INIT;
+}
+#endif
+
 /*
  * CTR_DRBG context initialization
  */
 void mbedtls_ctr_drbg_init(mbedtls_ctr_drbg_context *ctx)
 {
     memset(ctx, 0, sizeof(mbedtls_ctr_drbg_context));
+#if defined(MBEDTLS_AES_C)
     mbedtls_aes_init(&ctx->aes_ctx);
+#else
+    ctx->psa_ctx.key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    ctx->psa_ctx.operation = psa_cipher_operation_init();
+#endif
     /* Indicate that the entropy nonce length is not set explicitly.
      * See mbedtls_ctr_drbg_set_nonce_len(). */
     ctx->reseed_counter = -1;
@@ -56,7 +101,11 @@
         mbedtls_mutex_free(&ctx->mutex);
     }
 #endif
+#if defined(MBEDTLS_AES_C)
     mbedtls_aes_free(&ctx->aes_ctx);
+#else
+    ctr_drbg_destroy_psa_contex(&ctx->psa_ctx);
+#endif
     mbedtls_platform_zeroize(ctx, sizeof(mbedtls_ctr_drbg_context));
     ctx->reseed_interval = MBEDTLS_CTR_DRBG_RESEED_INTERVAL;
     ctx->reseed_counter = -1;
@@ -117,8 +166,17 @@
     unsigned char key[MBEDTLS_CTR_DRBG_KEYSIZE];
     unsigned char chain[MBEDTLS_CTR_DRBG_BLOCKSIZE];
     unsigned char *p, *iv;
-    mbedtls_aes_context aes_ctx;
     int ret = 0;
+#if defined(MBEDTLS_AES_C)
+    mbedtls_aes_context aes_ctx;
+#else
+    psa_status_t status;
+    size_t tmp_len;
+    mbedtls_ctr_drbg_psa_context psa_ctx;
+
+    psa_ctx.key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    psa_ctx.operation = psa_cipher_operation_init();
+#endif
 
     int i, j;
     size_t buf_len, use_len;
@@ -129,7 +187,6 @@
 
     memset(buf, 0, MBEDTLS_CTR_DRBG_MAX_SEED_INPUT +
            MBEDTLS_CTR_DRBG_BLOCKSIZE + 16);
-    mbedtls_aes_init(&aes_ctx);
 
     /*
      * Construct IV (16 bytes) and S in buffer
@@ -151,10 +208,20 @@
         key[i] = i;
     }
 
+#if defined(MBEDTLS_AES_C)
+    mbedtls_aes_init(&aes_ctx);
+
     if ((ret = mbedtls_aes_setkey_enc(&aes_ctx, key,
                                       MBEDTLS_CTR_DRBG_KEYBITS)) != 0) {
         goto exit;
     }
+#else
+    status = ctr_drbg_setup_psa_context(&psa_ctx, key, sizeof(key));
+    if (status != PSA_SUCCESS) {
+        ret = psa_generic_status_to_mbedtls(status);
+        goto exit;
+    }
+#endif
 
     /*
      * Reduce data to MBEDTLS_CTR_DRBG_SEEDLEN bytes of data
@@ -170,10 +237,19 @@
             use_len -= (use_len >= MBEDTLS_CTR_DRBG_BLOCKSIZE) ?
                        MBEDTLS_CTR_DRBG_BLOCKSIZE : use_len;
 
+#if defined(MBEDTLS_AES_C)
             if ((ret = mbedtls_aes_crypt_ecb(&aes_ctx, MBEDTLS_AES_ENCRYPT,
                                              chain, chain)) != 0) {
                 goto exit;
             }
+#else
+            status = psa_cipher_update(&psa_ctx.operation, chain, MBEDTLS_CTR_DRBG_BLOCKSIZE,
+                                       chain, MBEDTLS_CTR_DRBG_BLOCKSIZE, &tmp_len);
+            if (status != PSA_SUCCESS) {
+                ret = psa_generic_status_to_mbedtls(status);
+                goto exit;
+            }
+#endif
         }
 
         memcpy(tmp + j, chain, MBEDTLS_CTR_DRBG_BLOCKSIZE);
@@ -187,23 +263,46 @@
     /*
      * Do final encryption with reduced data
      */
+#if defined(MBEDTLS_AES_C)
     if ((ret = mbedtls_aes_setkey_enc(&aes_ctx, tmp,
                                       MBEDTLS_CTR_DRBG_KEYBITS)) != 0) {
         goto exit;
     }
+#else
+    ctr_drbg_destroy_psa_contex(&psa_ctx);
+
+    status = ctr_drbg_setup_psa_context(&psa_ctx, tmp, MBEDTLS_CTR_DRBG_KEYSIZE);
+    if (status != PSA_SUCCESS) {
+        ret = psa_generic_status_to_mbedtls(status);
+        goto exit;
+    }
+#endif
     iv = tmp + MBEDTLS_CTR_DRBG_KEYSIZE;
     p = output;
 
     for (j = 0; j < MBEDTLS_CTR_DRBG_SEEDLEN; j += MBEDTLS_CTR_DRBG_BLOCKSIZE) {
+#if defined(MBEDTLS_AES_C)
         if ((ret = mbedtls_aes_crypt_ecb(&aes_ctx, MBEDTLS_AES_ENCRYPT,
                                          iv, iv)) != 0) {
             goto exit;
         }
+#else
+        status = psa_cipher_update(&psa_ctx.operation, iv, MBEDTLS_CTR_DRBG_BLOCKSIZE,
+                                   iv, MBEDTLS_CTR_DRBG_BLOCKSIZE, &tmp_len);
+        if (status != PSA_SUCCESS) {
+            ret = psa_generic_status_to_mbedtls(status);
+            goto exit;
+        }
+#endif
         memcpy(p, iv, MBEDTLS_CTR_DRBG_BLOCKSIZE);
         p += MBEDTLS_CTR_DRBG_BLOCKSIZE;
     }
 exit:
+#if defined(MBEDTLS_AES_C)
     mbedtls_aes_free(&aes_ctx);
+#else
+    ctr_drbg_destroy_psa_contex(&psa_ctx);
+#endif
     /*
      * tidy up the stack
      */
@@ -236,6 +335,10 @@
     unsigned char *p = tmp;
     int i, j;
     int ret = 0;
+#if !defined(MBEDTLS_AES_C)
+    psa_status_t status;
+    size_t tmp_len;
+#endif
 
     memset(tmp, 0, MBEDTLS_CTR_DRBG_SEEDLEN);
 
@@ -252,10 +355,19 @@
         /*
          * Crypt counter block
          */
+#if defined(MBEDTLS_AES_C)
         if ((ret = mbedtls_aes_crypt_ecb(&ctx->aes_ctx, MBEDTLS_AES_ENCRYPT,
                                          ctx->counter, p)) != 0) {
             goto exit;
         }
+#else
+        status = psa_cipher_update(&ctx->psa_ctx.operation, ctx->counter, sizeof(ctx->counter),
+                                   p, MBEDTLS_CTR_DRBG_BLOCKSIZE, &tmp_len);
+        if (status != PSA_SUCCESS) {
+            ret = psa_generic_status_to_mbedtls(status);
+            goto exit;
+        }
+#endif
 
         p += MBEDTLS_CTR_DRBG_BLOCKSIZE;
     }
@@ -267,10 +379,20 @@
     /*
      * Update key and counter
      */
+#if defined(MBEDTLS_AES_C)
     if ((ret = mbedtls_aes_setkey_enc(&ctx->aes_ctx, tmp,
                                       MBEDTLS_CTR_DRBG_KEYBITS)) != 0) {
         goto exit;
     }
+#else
+    ctr_drbg_destroy_psa_contex(&ctx->psa_ctx);
+
+    status = ctr_drbg_setup_psa_context(&ctx->psa_ctx, tmp, MBEDTLS_CTR_DRBG_KEYSIZE);
+    if (status != PSA_SUCCESS) {
+        ret = psa_generic_status_to_mbedtls(status);
+        goto exit;
+    }
+#endif
     memcpy(ctx->counter, tmp + MBEDTLS_CTR_DRBG_KEYSIZE,
            MBEDTLS_CTR_DRBG_BLOCKSIZE);
 
@@ -447,10 +569,20 @@
                  good_nonce_len(ctx->entropy_len));
 
     /* Initialize with an empty key. */
+#if defined(MBEDTLS_AES_C)
     if ((ret = mbedtls_aes_setkey_enc(&ctx->aes_ctx, key,
                                       MBEDTLS_CTR_DRBG_KEYBITS)) != 0) {
         return ret;
     }
+#else
+    psa_status_t status;
+
+    status = ctr_drbg_setup_psa_context(&ctx->psa_ctx, key, MBEDTLS_CTR_DRBG_KEYSIZE);
+    if (status != PSA_SUCCESS) {
+        ret = psa_generic_status_to_mbedtls(status);
+        return status;
+    }
+#endif
 
     /* Do the initial seeding. */
     if ((ret = mbedtls_ctr_drbg_reseed_internal(ctx, custom, len,
@@ -531,10 +663,22 @@
         /*
          * Crypt counter block
          */
+#if defined(MBEDTLS_AES_C)
         if ((ret = mbedtls_aes_crypt_ecb(&ctx->aes_ctx, MBEDTLS_AES_ENCRYPT,
                                          ctx->counter, tmp)) != 0) {
             goto exit;
         }
+#else
+        psa_status_t status;
+        size_t tmp_len;
+
+        status = psa_cipher_update(&ctx->psa_ctx.operation, ctx->counter, sizeof(ctx->counter),
+                                   tmp, MBEDTLS_CTR_DRBG_BLOCKSIZE, &tmp_len);
+        if (status != PSA_SUCCESS) {
+            ret = psa_generic_status_to_mbedtls(status);
+            goto exit;
+        }
+#endif
 
         use_len = (output_len > MBEDTLS_CTR_DRBG_BLOCKSIZE)
             ? MBEDTLS_CTR_DRBG_BLOCKSIZE : output_len;