Implement md over PSA
When MBEDTLS_MD_xxx_VIA_PSA is enabled (by mbdetls/md.h), route calls to xxx
over PSA rather than through the built-in implementation.
Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/library/md.c b/library/md.c
index 7171057..20bfd23 100644
--- a/library/md.c
+++ b/library/md.c
@@ -52,6 +52,10 @@
#include "mbedtls/sha256.h"
#include "mbedtls/sha512.h"
+#if defined(MBEDTLS_MD_SOME_PSA)
+#include <psa/crypto.h>
+#endif
+
#include "mbedtls/platform.h"
#include <string.h>
@@ -159,6 +163,63 @@
}
}
+#if defined(MBEDTLS_MD_SOME_PSA)
+static psa_algorithm_t psa_alg_of_md(const mbedtls_md_info_t *info)
+{
+ switch (info->type) {
+#if defined(MBEDTLS_MD_MD5_VIA_PSA)
+ case MBEDTLS_MD_MD5:
+ return PSA_ALG_MD5;
+#endif
+#if defined(MBEDTLS_MD_RIPEMD160_VIA_PSA)
+ case MBEDTLS_MD_RIPEMD160:
+ return PSA_ALG_RIPEMD160;
+#endif
+#if defined(MBEDTLS_MD_SHA1_VIA_PSA)
+ case MBEDTLS_MD_SHA1:
+ return PSA_ALG_SHA_1;
+#endif
+#if defined(MBEDTLS_MD_SHA224_VIA_PSA)
+ case MBEDTLS_MD_SHA224:
+ return PSA_ALG_SHA_224;
+#endif
+#if defined(MBEDTLS_MD_SHA256_VIA_PSA)
+ case MBEDTLS_MD_SHA256:
+ return PSA_ALG_SHA_256;
+#endif
+#if defined(MBEDTLS_MD_SHA384_VIA_PSA)
+ case MBEDTLS_MD_SHA384:
+ return PSA_ALG_SHA_384;
+#endif
+#if defined(MBEDTLS_MD_SHA512_VIA_PSA)
+ case MBEDTLS_MD_SHA512:
+ return PSA_ALG_SHA_512;
+#endif
+ default:
+ return PSA_ALG_NONE;
+ }
+}
+
+static int md_uses_psa(const mbedtls_md_info_t *info)
+{
+ return psa_alg_of_md(info) != PSA_ALG_NONE;
+}
+
+static int mbedtls_md_error_from_psa(psa_status_t status)
+{
+ switch (status) {
+ case PSA_SUCCESS:
+ return 0;
+ case PSA_ERROR_NOT_SUPPORTED:
+ return MBEDTLS_ERR_MD_FEATURE_UNAVAILABLE;
+ case PSA_ERROR_INSUFFICIENT_MEMORY:
+ return MBEDTLS_ERR_MD_ALLOC_FAILED;
+ default:
+ return MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED;
+ }
+}
+#endif /* MBEDTLS_MD_SOME_PSA */
+
void mbedtls_md_init(mbedtls_md_context_t *ctx)
{
memset(ctx, 0, sizeof(mbedtls_md_context_t));
@@ -171,6 +232,11 @@
}
if (ctx->md_ctx != NULL) {
+#if defined(MBEDTLS_MD_SOME_PSA)
+ if (md_uses_psa(ctx->md_info) && ctx->md_ctx != NULL) {
+ psa_hash_abort(ctx->md_ctx);
+ } else
+#endif
switch (ctx->md_info->type) {
#if defined(MBEDTLS_MD5_C)
case MBEDTLS_MD_MD5:
@@ -232,6 +298,13 @@
return MBEDTLS_ERR_MD_BAD_INPUT_DATA;
}
+#if defined(MBEDTLS_MD_SOME_PSA)
+ if (md_uses_psa(src->md_info)) {
+ psa_status_t status = psa_hash_clone(src->md_ctx, dst->md_ctx);
+ return mbedtls_md_error_from_psa(status);
+ }
+#endif
+
switch (src->md_info->type) {
#if defined(MBEDTLS_MD5_C)
case MBEDTLS_MD_MD5:
@@ -294,6 +367,14 @@
ctx->md_ctx = NULL;
ctx->hmac_ctx = NULL;
+#if defined(MBEDTLS_MD_SOME_PSA)
+ if (md_uses_psa(ctx->md_info)) {
+ ctx->md_ctx = mbedtls_calloc(1, sizeof(psa_hash_operation_t));
+ if (ctx->md_ctx == NULL) {
+ return MBEDTLS_ERR_MD_ALLOC_FAILED;
+ }
+ } else
+#endif
switch (md_info->type) {
#if defined(MBEDTLS_MD5_C)
case MBEDTLS_MD_MD5:
@@ -352,6 +433,15 @@
return MBEDTLS_ERR_MD_BAD_INPUT_DATA;
}
+#if defined(MBEDTLS_MD_SOME_PSA)
+ psa_algorithm_t alg = psa_alg_of_md(ctx->md_info);
+ if (alg != PSA_ALG_NONE) {
+ psa_hash_abort(ctx->md_ctx);
+ psa_status_t status = psa_hash_setup(ctx->md_ctx, alg);
+ return mbedtls_md_error_from_psa(status);
+ }
+#endif
+
switch (ctx->md_info->type) {
#if defined(MBEDTLS_MD5_C)
case MBEDTLS_MD_MD5:
@@ -392,6 +482,13 @@
return MBEDTLS_ERR_MD_BAD_INPUT_DATA;
}
+#if defined(MBEDTLS_MD_SOME_PSA)
+ if (md_uses_psa(ctx->md_info)) {
+ psa_status_t status = psa_hash_update(ctx->md_ctx, input, ilen);
+ return mbedtls_md_error_from_psa(status);
+ }
+#endif
+
switch (ctx->md_info->type) {
#if defined(MBEDTLS_MD5_C)
case MBEDTLS_MD_MD5:
@@ -432,6 +529,15 @@
return MBEDTLS_ERR_MD_BAD_INPUT_DATA;
}
+#if defined(MBEDTLS_MD_SOME_PSA)
+ if (md_uses_psa(ctx->md_info)) {
+ size_t size = ctx->md_info->size;
+ psa_status_t status = psa_hash_finish(ctx->md_ctx,
+ output, size, &size);
+ return mbedtls_md_error_from_psa(status);
+ }
+#endif
+
switch (ctx->md_info->type) {
#if defined(MBEDTLS_MD5_C)
case MBEDTLS_MD_MD5:
@@ -473,6 +579,17 @@
return MBEDTLS_ERR_MD_BAD_INPUT_DATA;
}
+#if defined(MBEDTLS_MD_SOME_PSA)
+ psa_algorithm_t alg = psa_alg_of_md(md_info);
+ if (alg != PSA_ALG_NONE) {
+ size_t size = md_info->size;
+ psa_status_t status = psa_hash_compute(alg,
+ input, ilen,
+ output, size, &size);
+ return mbedtls_md_error_from_psa(status);
+ }
+#endif
+
switch (md_info->type) {
#if defined(MBEDTLS_MD5_C)
case MBEDTLS_MD_MD5: