Split mbedtls_mpi_mod_inv() into separate functions for mont/non-mont form

Signed-off-by: Tom Cosgrove <tom.cosgrove@arm.com>
diff --git a/library/bignum_mod.c b/library/bignum_mod.c
index 216b20f..7c89b57 100644
--- a/library/bignum_mod.c
+++ b/library/bignum_mod.c
@@ -192,6 +192,54 @@
     return( 0 );
 }
 
+static int mbedtls_mpi_mod_inv_mont( mbedtls_mpi_mod_residue *X,
+                                     const mbedtls_mpi_mod_residue *A,
+                                     const mbedtls_mpi_mod_modulus *N,
+                                     mbedtls_mpi_uint *working_memory )
+{
+    /* Input already in Montgomery form, so there's little to do */
+    mbedtls_mpi_mod_raw_inv_prime( X->p, A->p,
+                                   N->p, N->limbs,
+                                   N->rep.mont.rr,
+                                   working_memory );
+    return( 0 );
+}
+
+static int mbedtls_mpi_mod_inv_non_mont( mbedtls_mpi_mod_residue *X,
+                                         const mbedtls_mpi_mod_residue *A,
+                                         const mbedtls_mpi_mod_modulus *N,
+                                         mbedtls_mpi_uint *working_memory )
+{
+    /* Need to convert input into Montgomery form */
+
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+
+    mbedtls_mpi_mod_modulus Nmont;
+    mbedtls_mpi_mod_modulus_init( &Nmont );
+
+    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_modulus_setup( &Nmont, N->p, N->limbs,
+                                         MBEDTLS_MPI_MOD_REP_MONTGOMERY ) );
+
+    /* We'll use X->p to hold the Montgomery form of the input A->p */
+    mbedtls_mpi_core_to_mont_rep( X->p, A->p, Nmont.p, Nmont.limbs,
+                                  Nmont.rep.mont.mm, Nmont.rep.mont.rr,
+                                  working_memory );
+
+    mbedtls_mpi_mod_raw_inv_prime( X->p, X->p,
+                                   Nmont.p, Nmont.limbs,
+                                   Nmont.rep.mont.rr,
+                                   working_memory );
+
+    /* And convert back from Montgomery form */
+
+    mbedtls_mpi_core_from_mont_rep( X->p, X->p, Nmont.p, Nmont.limbs,
+                                    Nmont.rep.mont.mm, working_memory );
+
+cleanup:
+    mbedtls_mpi_mod_modulus_free( &Nmont );
+    return( ret );
+}
+
 int mbedtls_mpi_mod_inv( mbedtls_mpi_mod_residue *X,
                          const mbedtls_mpi_mod_residue *A,
                          const mbedtls_mpi_mod_modulus *N )
@@ -203,94 +251,34 @@
     if( mbedtls_mpi_core_check_zero_ct( A->p, A->limbs ) == 0 )
         return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
 
-    /* Will we need to do Montgomery conversion? */
-    int mont_conv_needed;
-    switch( N->int_rep )
-    {
-        case MBEDTLS_MPI_MOD_REP_MONTGOMERY:
-            mont_conv_needed = 0;
-            break;
-        case MBEDTLS_MPI_MOD_REP_OPT_RED:
-            mont_conv_needed = 1;
-            break;
-        default:
-            return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
-    }
-
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-
-    /* If the input is already in Montgomery form, we have little to do but
-     * allocate working memory and call mbedtls_mpi_mod_raw_inv_prime().
-     *
-     * If it's not, we need to
-     * 1. Create a Montgomery version of the modulus;
-     * 2. Convert the input into Mont. form, using X->p to hold it;
-     * 3. (allocate and convert, same as if already in Mont. form);
-     * 4. Convert the inverted output back from Mont. form.
-     *
-     * Since the Montgomery conversion functions are in-place, we'll need to
-     * copy A into X before we start working on it (which could be avoided if
-     * there was a not-in-place function to convert to Montgomery form.
-     */
-
-    /* Montgomery version of modulus (if not already in Mont. form).
-     * We will only call setup if the input is not already in Montgomery form.
-     * We will re-use N->p from input modulus, and make use of the fact that
-     * mbedtls_mpi_mod_raw_to_mont_rep() won't free it. */
-    mbedtls_mpi_mod_modulus Nmont;
-    mbedtls_mpi_mod_modulus_init( &Nmont );
-
     size_t working_limbs =
                     mbedtls_mpi_mod_raw_inv_prime_working_limbs( N->limbs );
 
     mbedtls_mpi_uint *working_memory = mbedtls_calloc( working_limbs,
                                                      sizeof(mbedtls_mpi_uint) );
     if( working_memory == NULL )
+        return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
+
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+
+    switch( N->int_rep )
     {
-        ret = MBEDTLS_ERR_MPI_ALLOC_FAILED;
-        goto cleanup;
+        case MBEDTLS_MPI_MOD_REP_MONTGOMERY:
+            ret = mbedtls_mpi_mod_inv_mont( X, A, N, working_memory );
+            break;
+        case MBEDTLS_MPI_MOD_REP_OPT_RED:
+            ret = mbedtls_mpi_mod_inv_non_mont( X, A, N, working_memory );
+            break;
+        default:
+            ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
+            break;
     }
 
-    const mbedtls_mpi_uint *to_invert;   /* Will alias A->p or X->p */
-    const mbedtls_mpi_mod_modulus *Nuse; /* Which of N and Nmont to use */
+    mbedtls_platform_zeroize( working_memory,
+                              working_limbs * sizeof(mbedtls_mpi_uint) );
+    free( working_memory );
 
-    if( mont_conv_needed )
-    {
-        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_modulus_setup( &Nmont, N->p, N->limbs,
-                                             MBEDTLS_MPI_MOD_REP_MONTGOMERY ) );
-
-        mbedtls_mpi_core_to_mont_rep( X->p, A->p, Nmont.p, Nmont.limbs,
-                                      Nmont.rep.mont.mm, Nmont.rep.mont.rr,
-                                      working_memory );
-        to_invert = X->p;
-        Nuse = &Nmont;
-    }
-    else
-    {
-        to_invert = A->p;
-        Nuse = N;
-    }
-
-    mbedtls_mpi_mod_raw_inv_prime( X->p, to_invert,
-                                   Nuse->p, Nuse->limbs,
-                                   Nuse->rep.mont.rr,
-                                   working_memory );
-
-    if( mont_conv_needed )
-        mbedtls_mpi_core_from_mont_rep( X->p, X->p, Nmont.p, Nmont.limbs,
-                                        Nmont.rep.mont.mm, working_memory );
-
-cleanup:
-    mbedtls_mpi_mod_modulus_free( &Nmont );
-
-    if (working_memory != NULL )
-    {
-        mbedtls_platform_zeroize( working_memory,
-                                  working_limbs * sizeof(mbedtls_mpi_uint) );
-        mbedtls_free( working_memory );
-    }
-
-    return( ret );
+    return ret;
 }
 /* END MERGE SLOT 3 */