Use core API in `ecp_mod_koblitz()`

Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/library/ecp_curves.c b/library/ecp_curves.c
index 1640107..029b515 100644
--- a/library/ecp_curves.c
+++ b/library/ecp_curves.c
@@ -25,6 +25,8 @@
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
 
+#include "mbedtls/platform.h"
+
 #include "bn_mul.h"
 #include "bignum_core.h"
 #include "ecp_invasive.h"
@@ -5526,60 +5528,69 @@
  */
 #define P_KOBLITZ_MAX   (256 / 8 / sizeof(mbedtls_mpi_uint))      // Max limbs in P
 #define P_KOBLITZ_R     (8 / sizeof(mbedtls_mpi_uint))            // Limbs in R
-static inline int ecp_mod_koblitz(mbedtls_mpi *N, mbedtls_mpi_uint *Rp, size_t p_limbs,
-                                  size_t adjust, size_t shift, mbedtls_mpi_uint mask)
-{
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    mbedtls_mpi M, R;
-    mbedtls_mpi_uint Mp[P_KOBLITZ_MAX + P_KOBLITZ_R + 1];
 
-    if (N->n < p_limbs) {
-        return 0;
+static inline int ecp_mod_koblitz(mbedtls_mpi_uint *X,
+                                  size_t X_limbs,
+                                  mbedtls_mpi_uint *R,
+                                  size_t P_limbs,
+                                  size_t adjust,
+                                  size_t shift,
+                                  mbedtls_mpi_uint mask)
+{
+    int ret = 0;
+
+    size_t R_limbs = P_KOBLITZ_R;
+    size_t M_limbs = X_limbs + R_limbs + adjust;
+    mbedtls_mpi_uint *M = mbedtls_calloc(M_limbs, ciL);
+    if (M == NULL) {
+        return MBEDTLS_ERR_ECP_ALLOC_FAILED;
     }
 
-    /* Init R */
-    R.s = 1;
-    R.p = Rp;
-    R.n = P_KOBLITZ_R;
-
-    /* Common setup for M */
-    M.s = 1;
-    M.p = Mp;
+    size_t A1_limbs = X_limbs - (P_limbs - adjust);
+    if (A1_limbs > P_limbs + adjust) {
+        A1_limbs = P_limbs + adjust;
+    }
+    mbedtls_mpi_uint *A1 = mbedtls_calloc(A1_limbs, ciL);
+    if (A1 == NULL) {
+        ret = MBEDTLS_ERR_ECP_ALLOC_FAILED;
+        goto cleanup;
+    }
 
     for (size_t pass = 0; pass < 2; pass++) {
-        /* M = A1 */
-        M.n = N->n - (p_limbs - adjust);
-        if (M.n > p_limbs + adjust) {
-            M.n = p_limbs + adjust;
-        }
-        memset(Mp, 0, sizeof(Mp));
-        memcpy(Mp, N->p + p_limbs - adjust, M.n * sizeof(mbedtls_mpi_uint));
+        /* Copy A1 */
+        memcpy(A1, X + P_limbs - adjust, A1_limbs * ciL);
         if (shift != 0) {
-            MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&M, shift));
+            mbedtls_mpi_core_shift_r(A1, A1_limbs, shift);
         }
-        M.n += R.n; /* Make room for multiplication by R */
 
-        /* N = A0 */
+        /* X = A0 */
         if (mask != 0) {
-            N->p[p_limbs - 1] &= mask;
-        }
-        for (size_t i = p_limbs; i < N->n; i++) {
-            N->p[i] = 0;
+            X[P_limbs - 1] &= mask;
         }
 
-        /* N = A0 + R * A1 */
-        MBEDTLS_MPI_CHK(mbedtls_mpi_mul_mpi(&M, &M, &R));
-        MBEDTLS_MPI_CHK(mbedtls_mpi_add_abs(N, N, &M));
+        /* Zeroize the A1 part of X to keep only the A0 part */
+        for (size_t i = P_limbs; i < X_limbs; i++) {
+            X[i] = 0;
+        }
+
+        /* X = A0 + R * A1 */
+        mbedtls_mpi_core_mul(M, A1, A1_limbs, R, R_limbs);
+        mbedtls_mpi_core_add(X, X, M, A1_limbs + R_limbs);
     }
 
 cleanup:
+    mbedtls_free(M);
+    mbedtls_free(A1);
+
     return ret;
 }
+
 #endif /* MBEDTLS_ECP_DP_SECP192K1_ENABLED) ||
           MBEDTLS_ECP_DP_SECP224K1_ENABLED) ||
           MBEDTLS_ECP_DP_SECP256K1_ENABLED) */
 
 #if defined(MBEDTLS_ECP_DP_SECP192K1_ENABLED)
+
 /*
  * Fast quasi-reduction modulo p192k1 = 2^192 - R,
  * with R = 2^32 + 2^12 + 2^8 + 2^7 + 2^6 + 2^3 + 1 = 0x0100001119
@@ -5597,9 +5608,10 @@
                                   0x00)
     };
 
-    return ecp_mod_koblitz(N, Rp, 192 / 8 / sizeof(mbedtls_mpi_uint), 0, 0,
-                           0);
+    return ecp_mod_koblitz(N->p, N->n, Rp,
+                           192 / 8 / sizeof(mbedtls_mpi_uint), 0, 0, 0);
 }
+
 #endif /* MBEDTLS_ECP_DP_SECP192K1_ENABLED */
 
 #if defined(MBEDTLS_ECP_DP_SECP224K1_ENABLED)
@@ -5622,10 +5634,10 @@
     };
 
 #if defined(MBEDTLS_HAVE_INT64)
-    return ecp_mod_koblitz(N, Rp, 4, 1, 32, 0xFFFFFFFF);
+    return ecp_mod_koblitz(N->p, N->n, Rp, 4, 1, 32, 0xFFFFFFFF);
 #else
-    return ecp_mod_koblitz(N, Rp, 224 / 8 / sizeof(mbedtls_mpi_uint), 0, 0,
-                           0);
+    return ecp_mod_koblitz(N->p, N->n, Rp,
+                           224 / 8 / sizeof(mbedtls_mpi_uint), 0, 0, 0);
 #endif
 }
 
@@ -5649,8 +5661,8 @@
         MBEDTLS_BYTES_TO_T_UINT_8(0xD1, 0x03, 0x00, 0x00, 0x01, 0x00, 0x00,
                                   0x00)
     };
-    return ecp_mod_koblitz(N, Rp, 256 / 8 / sizeof(mbedtls_mpi_uint), 0, 0,
-                           0);
+    return ecp_mod_koblitz(N->p, N->n, Rp,
+                           256 / 8 / sizeof(mbedtls_mpi_uint), 0, 0, 0);
 }
 #endif /* MBEDTLS_ECP_DP_SECP256K1_ENABLED */