Implement md over PSA

When MBEDTLS_MD_xxx_VIA_PSA is enabled (by mbdetls/md.h), route calls to xxx
over PSA rather than through the built-in implementation.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/library/md.c b/library/md.c
index 7171057..20bfd23 100644
--- a/library/md.c
+++ b/library/md.c
@@ -52,6 +52,10 @@
 #include "mbedtls/sha256.h"
 #include "mbedtls/sha512.h"
 
+#if defined(MBEDTLS_MD_SOME_PSA)
+#include <psa/crypto.h>
+#endif
+
 #include "mbedtls/platform.h"
 
 #include <string.h>
@@ -159,6 +163,63 @@
     }
 }
 
+#if defined(MBEDTLS_MD_SOME_PSA)
+static psa_algorithm_t psa_alg_of_md(const mbedtls_md_info_t *info)
+{
+    switch (info->type) {
+#if defined(MBEDTLS_MD_MD5_VIA_PSA)
+        case MBEDTLS_MD_MD5:
+            return PSA_ALG_MD5;
+#endif
+#if defined(MBEDTLS_MD_RIPEMD160_VIA_PSA)
+        case MBEDTLS_MD_RIPEMD160:
+            return PSA_ALG_RIPEMD160;
+#endif
+#if defined(MBEDTLS_MD_SHA1_VIA_PSA)
+        case MBEDTLS_MD_SHA1:
+            return PSA_ALG_SHA_1;
+#endif
+#if defined(MBEDTLS_MD_SHA224_VIA_PSA)
+        case MBEDTLS_MD_SHA224:
+            return PSA_ALG_SHA_224;
+#endif
+#if defined(MBEDTLS_MD_SHA256_VIA_PSA)
+        case MBEDTLS_MD_SHA256:
+            return PSA_ALG_SHA_256;
+#endif
+#if defined(MBEDTLS_MD_SHA384_VIA_PSA)
+        case MBEDTLS_MD_SHA384:
+            return PSA_ALG_SHA_384;
+#endif
+#if defined(MBEDTLS_MD_SHA512_VIA_PSA)
+        case MBEDTLS_MD_SHA512:
+            return PSA_ALG_SHA_512;
+#endif
+        default:
+            return PSA_ALG_NONE;
+    }
+}
+
+static int md_uses_psa(const mbedtls_md_info_t *info)
+{
+    return psa_alg_of_md(info) != PSA_ALG_NONE;
+}
+
+static int mbedtls_md_error_from_psa(psa_status_t status)
+{
+    switch (status) {
+        case PSA_SUCCESS:
+            return 0;
+        case PSA_ERROR_NOT_SUPPORTED:
+            return MBEDTLS_ERR_MD_FEATURE_UNAVAILABLE;
+        case PSA_ERROR_INSUFFICIENT_MEMORY:
+            return MBEDTLS_ERR_MD_ALLOC_FAILED;
+        default:
+            return MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED;
+    }
+}
+#endif /* MBEDTLS_MD_SOME_PSA */
+
 void mbedtls_md_init(mbedtls_md_context_t *ctx)
 {
     memset(ctx, 0, sizeof(mbedtls_md_context_t));
@@ -171,6 +232,11 @@
     }
 
     if (ctx->md_ctx != NULL) {
+#if defined(MBEDTLS_MD_SOME_PSA)
+        if (md_uses_psa(ctx->md_info) && ctx->md_ctx != NULL) {
+            psa_hash_abort(ctx->md_ctx);
+        } else
+#endif
         switch (ctx->md_info->type) {
 #if defined(MBEDTLS_MD5_C)
             case MBEDTLS_MD_MD5:
@@ -232,6 +298,13 @@
         return MBEDTLS_ERR_MD_BAD_INPUT_DATA;
     }
 
+#if defined(MBEDTLS_MD_SOME_PSA)
+    if (md_uses_psa(src->md_info)) {
+        psa_status_t status = psa_hash_clone(src->md_ctx, dst->md_ctx);
+        return mbedtls_md_error_from_psa(status);
+    }
+#endif
+
     switch (src->md_info->type) {
 #if defined(MBEDTLS_MD5_C)
         case MBEDTLS_MD_MD5:
@@ -294,6 +367,14 @@
     ctx->md_ctx = NULL;
     ctx->hmac_ctx = NULL;
 
+#if defined(MBEDTLS_MD_SOME_PSA)
+    if (md_uses_psa(ctx->md_info)) {
+        ctx->md_ctx = mbedtls_calloc(1, sizeof(psa_hash_operation_t));
+        if (ctx->md_ctx == NULL) {
+            return MBEDTLS_ERR_MD_ALLOC_FAILED;
+        }
+    } else
+#endif
     switch (md_info->type) {
 #if defined(MBEDTLS_MD5_C)
         case MBEDTLS_MD_MD5:
@@ -352,6 +433,15 @@
         return MBEDTLS_ERR_MD_BAD_INPUT_DATA;
     }
 
+#if defined(MBEDTLS_MD_SOME_PSA)
+    psa_algorithm_t alg = psa_alg_of_md(ctx->md_info);
+    if (alg != PSA_ALG_NONE) {
+        psa_hash_abort(ctx->md_ctx);
+        psa_status_t status = psa_hash_setup(ctx->md_ctx, alg);
+        return mbedtls_md_error_from_psa(status);
+    }
+#endif
+
     switch (ctx->md_info->type) {
 #if defined(MBEDTLS_MD5_C)
         case MBEDTLS_MD_MD5:
@@ -392,6 +482,13 @@
         return MBEDTLS_ERR_MD_BAD_INPUT_DATA;
     }
 
+#if defined(MBEDTLS_MD_SOME_PSA)
+    if (md_uses_psa(ctx->md_info)) {
+        psa_status_t status = psa_hash_update(ctx->md_ctx, input, ilen);
+        return mbedtls_md_error_from_psa(status);
+    }
+#endif
+
     switch (ctx->md_info->type) {
 #if defined(MBEDTLS_MD5_C)
         case MBEDTLS_MD_MD5:
@@ -432,6 +529,15 @@
         return MBEDTLS_ERR_MD_BAD_INPUT_DATA;
     }
 
+#if defined(MBEDTLS_MD_SOME_PSA)
+    if (md_uses_psa(ctx->md_info)) {
+        size_t size = ctx->md_info->size;
+        psa_status_t status = psa_hash_finish(ctx->md_ctx,
+                                              output, size, &size);
+        return mbedtls_md_error_from_psa(status);
+    }
+#endif
+
     switch (ctx->md_info->type) {
 #if defined(MBEDTLS_MD5_C)
         case MBEDTLS_MD_MD5:
@@ -473,6 +579,17 @@
         return MBEDTLS_ERR_MD_BAD_INPUT_DATA;
     }
 
+#if defined(MBEDTLS_MD_SOME_PSA)
+    psa_algorithm_t alg = psa_alg_of_md(md_info);
+    if (alg != PSA_ALG_NONE) {
+        size_t size = md_info->size;
+        psa_status_t status = psa_hash_compute(alg,
+                                               input, ilen,
+                                               output, size, &size);
+        return mbedtls_md_error_from_psa(status);
+    }
+#endif
+
     switch (md_info->type) {
 #if defined(MBEDTLS_MD5_C)
         case MBEDTLS_MD_MD5: