ECDH: Make the implementation use the new context

The functionality from public API functions are moved to
`xxx_internal()` functions. The public API functions are modified to do
basic parameter validation and dispatch the call to the right
implementation.

There is no intended change in behaviour when
`MBEDTLS_ECDH_LEGACY_CONTEXT` is enabled.
diff --git a/library/ecdh.c b/library/ecdh.c
index 702ba1a..d68db8a 100644
--- a/library/ecdh.c
+++ b/library/ecdh.c
@@ -38,6 +38,10 @@
 
 #include <string.h>
 
+#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
+typedef mbedtls_ecdh_context mbedtls_ecdh_context_mbed;
+#endif
+
 #if !defined(MBEDTLS_ECDH_GEN_PUBLIC_ALT)
 /*
  * Generate public key (restartable version)
@@ -124,38 +128,48 @@
 }
 #endif /* !MBEDTLS_ECDH_COMPUTE_SHARED_ALT */
 
-/*
- * Initialize context
- */
-void mbedtls_ecdh_init( mbedtls_ecdh_context *ctx )
+static void ecdh_init_internal( mbedtls_ecdh_context_mbed *ctx )
 {
     mbedtls_ecp_group_init( &ctx->grp );
     mbedtls_mpi_init( &ctx->d  );
     mbedtls_ecp_point_init( &ctx->Q   );
     mbedtls_ecp_point_init( &ctx->Qp  );
     mbedtls_mpi_init( &ctx->z  );
-    ctx->point_format = MBEDTLS_ECP_PF_UNCOMPRESSED;
-    mbedtls_ecp_point_init( &ctx->Vi  );
-    mbedtls_ecp_point_init( &ctx->Vf  );
-    mbedtls_mpi_init( &ctx->_d );
 
 #if defined(MBEDTLS_ECP_RESTARTABLE)
-    ctx->restart_enabled = 0;
     mbedtls_ecp_restart_init( &ctx->rs );
 #endif
 }
 
 /*
- * Setup context
+ * Initialize context
  */
-int mbedtls_ecdh_setup( mbedtls_ecdh_context *ctx, mbedtls_ecp_group_id grp_id )
+void mbedtls_ecdh_init( mbedtls_ecdh_context *ctx )
+{
+#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
+    ecdh_init_internal( ctx );
+    mbedtls_ecp_point_init( &ctx->Vi  );
+    mbedtls_ecp_point_init( &ctx->Vf  );
+    mbedtls_mpi_init( &ctx->_d );
+#else
+    memset( ctx, 0, sizeof( mbedtls_ecdh_context ) );
+
+    ctx->var = MBEDTLS_ECDH_VARIANT_NONE;
+#endif
+    ctx->point_format = MBEDTLS_ECP_PF_UNCOMPRESSED;
+#if defined(MBEDTLS_ECP_RESTARTABLE)
+    ctx->restart_enabled = 0;
+#endif
+}
+
+static int ecdh_setup_internal( mbedtls_ecdh_context_mbed *ctx,
+                                mbedtls_ecp_group_id grp_id )
 {
     int ret;
 
     ret = mbedtls_ecp_group_load( &ctx->grp, grp_id );
     if( ret != 0 )
     {
-        mbedtls_ecdh_free( ctx );
         return( MBEDTLS_ERR_ECP_FEATURE_UNAVAILABLE );
     }
 
@@ -163,21 +177,35 @@
 }
 
 /*
- * Free context
+ * Setup context
  */
-void mbedtls_ecdh_free( mbedtls_ecdh_context *ctx )
+int mbedtls_ecdh_setup( mbedtls_ecdh_context *ctx, mbedtls_ecp_group_id grp_id )
 {
     if( ctx == NULL )
-        return;
+        return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
 
+#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
+    return( ecdh_setup_internal( ctx, grp_id ) );
+#else
+    switch( grp_id )
+    {
+        default:
+            ctx->point_format = MBEDTLS_ECP_PF_UNCOMPRESSED;
+            ctx->var = MBEDTLS_ECDH_VARIANT_MBEDTLS_2_0;
+            ctx->grp_id = grp_id;
+            ecdh_init_internal( &ctx->ctx.mbed_ecdh );
+            return( ecdh_setup_internal( &ctx->ctx.mbed_ecdh, grp_id ) );
+    }
+#endif
+}
+
+static void ecdh_free_internal( mbedtls_ecdh_context_mbed *ctx )
+{
     mbedtls_ecp_group_free( &ctx->grp );
     mbedtls_mpi_free( &ctx->d  );
     mbedtls_ecp_point_free( &ctx->Q   );
     mbedtls_ecp_point_free( &ctx->Qp  );
     mbedtls_mpi_free( &ctx->z  );
-    mbedtls_ecp_point_free( &ctx->Vi  );
-    mbedtls_ecp_point_free( &ctx->Vf  );
-    mbedtls_mpi_free( &ctx->_d );
 
 #if defined(MBEDTLS_ECP_RESTARTABLE)
     mbedtls_ecp_restart_free( &ctx->rs );
@@ -190,21 +218,50 @@
  */
 void mbedtls_ecdh_enable_restart( mbedtls_ecdh_context *ctx )
 {
+    if( ctx == NULL )
+        return;
+
     ctx->restart_enabled = 1;
 }
 #endif
 
 /*
- * Setup and write the ServerKeyExhange parameters (RFC 4492)
- *      struct {
- *          ECParameters    curve_params;
- *          ECPoint         public;
- *      } ServerECDHParams;
+ * Free context
  */
-int mbedtls_ecdh_make_params( mbedtls_ecdh_context *ctx, size_t *olen,
-                      unsigned char *buf, size_t blen,
-                      int (*f_rng)(void *, unsigned char *, size_t),
-                      void *p_rng )
+void mbedtls_ecdh_free( mbedtls_ecdh_context *ctx )
+{
+    if( ctx == NULL )
+        return;
+
+#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
+    mbedtls_ecp_point_free( &ctx->Vi );
+    mbedtls_ecp_point_free( &ctx->Vf );
+    mbedtls_mpi_free( &ctx->_d );
+    ecdh_free_internal( ctx );
+#else
+    switch( ctx->var )
+    {
+        case MBEDTLS_ECDH_VARIANT_MBEDTLS_2_0:
+            ecdh_free_internal( &ctx->ctx.mbed_ecdh );
+            break;
+        default:
+            break;
+    }
+
+    ctx->point_format = MBEDTLS_ECP_PF_UNCOMPRESSED;
+    ctx->var = MBEDTLS_ECDH_VARIANT_NONE;
+    ctx->grp_id = MBEDTLS_ECP_DP_NONE;
+#endif
+}
+
+static int ecdh_make_params_internal( mbedtls_ecdh_context_mbed *ctx,
+                                      size_t *olen, int point_format,
+                                      unsigned char *buf, size_t blen,
+                                      int (*f_rng)(void *,
+                                                   unsigned char *,
+                                                   size_t),
+                                      void *p_rng,
+                                      int restart_enabled )
 {
     int ret;
     size_t grp_len, pt_len;
@@ -212,12 +269,14 @@
     mbedtls_ecp_restart_ctx *rs_ctx = NULL;
 #endif
 
-    if( ctx == NULL || ctx->grp.pbits == 0 )
+    if( ctx->grp.pbits == 0 )
         return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
 
 #if defined(MBEDTLS_ECP_RESTARTABLE)
-    if( ctx->restart_enabled )
+    if( restart_enabled )
         rs_ctx = &ctx->rs;
+#else
+    (void) restart_enabled;
 #endif
 
 
@@ -231,14 +290,14 @@
         return( ret );
 #endif /* MBEDTLS_ECP_RESTARTABLE */
 
-    if( ( ret = mbedtls_ecp_tls_write_group( &ctx->grp, &grp_len, buf, blen ) )
-                != 0 )
+    if( ( ret = mbedtls_ecp_tls_write_group( &ctx->grp, &grp_len, buf,
+                                             blen ) ) != 0 )
         return( ret );
 
     buf += grp_len;
     blen -= grp_len;
 
-    if( ( ret = mbedtls_ecp_tls_write_point( &ctx->grp, &ctx->Q, ctx->point_format,
+    if( ( ret = mbedtls_ecp_tls_write_point( &ctx->grp, &ctx->Q, point_format,
                                              &pt_len, buf, blen ) ) != 0 )
         return( ret );
 
@@ -247,6 +306,54 @@
 }
 
 /*
+ * Setup and write the ServerKeyExhange parameters (RFC 4492)
+ *      struct {
+ *          ECParameters    curve_params;
+ *          ECPoint         public;
+ *      } ServerECDHParams;
+ */
+int mbedtls_ecdh_make_params( mbedtls_ecdh_context *ctx, size_t *olen,
+                              unsigned char *buf, size_t blen,
+                              int (*f_rng)(void *, unsigned char *, size_t),
+                              void *p_rng )
+{
+    int restart_enabled = 0;
+
+    if( ctx == NULL )
+        return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
+
+#if defined(MBEDTLS_ECP_RESTARTABLE)
+    restart_enabled = ctx->restart_enabled;
+#else
+    (void) restart_enabled;
+#endif
+
+#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
+    return( ecdh_make_params_internal( ctx, olen, ctx->point_format, buf, blen,
+                                       f_rng, p_rng, restart_enabled ) );
+#else
+    switch( ctx->var )
+    {
+        case MBEDTLS_ECDH_VARIANT_MBEDTLS_2_0:
+            return( ecdh_make_params_internal( &ctx->ctx.mbed_ecdh, olen,
+                                               ctx->point_format, buf, blen,
+                                               f_rng, p_rng,
+                                               restart_enabled ) );
+        default:
+            return MBEDTLS_ERR_ECP_BAD_INPUT_DATA;
+    }
+#endif
+}
+
+static int ecdh_read_params_internal( mbedtls_ecdh_context_mbed *ctx,
+                                      const unsigned char **buf,
+                                      const unsigned char *end )
+{
+    return( mbedtls_ecp_tls_read_point( &ctx->grp, &ctx->Qp, buf,
+                                        end - *buf ) );
+}
+
+/*
  * Read the ServerKeyExhange parameters (RFC 4492)
  *      struct {
  *          ECParameters    curve_params;
@@ -254,11 +361,15 @@
  *      } ServerECDHParams;
  */
 int mbedtls_ecdh_read_params( mbedtls_ecdh_context *ctx,
-                      const unsigned char **buf, const unsigned char *end )
+                              const unsigned char **buf,
+                              const unsigned char *end )
 {
     int ret;
     mbedtls_ecp_group_id grp_id;
 
+    if( ctx == NULL )
+        return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
+
     if( ( ret = mbedtls_ecp_tls_read_group_id( &grp_id, buf, end - *buf ) )
             != 0 )
         return( ret );
@@ -266,24 +377,26 @@
     if( ( ret = mbedtls_ecdh_setup( ctx, grp_id ) ) != 0 )
         return( ret );
 
-    if( ( ret = mbedtls_ecp_tls_read_point( &ctx->grp, &ctx->Qp, buf,
-                                            end - *buf ) ) != 0 )
-        return( ret );
-
-    return( 0 );
+#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
+    return( ecdh_read_params_internal( ctx, buf, end ) );
+#else
+    switch( ctx->var )
+    {
+        case MBEDTLS_ECDH_VARIANT_MBEDTLS_2_0:
+            return( ecdh_read_params_internal( &ctx->ctx.mbed_ecdh,
+                                               buf, end ) );
+        default:
+            return MBEDTLS_ERR_ECP_BAD_INPUT_DATA;
+    }
+#endif
 }
 
-/*
- * Get parameters from a keypair
- */
-int mbedtls_ecdh_get_params( mbedtls_ecdh_context *ctx, const mbedtls_ecp_keypair *key,
-                     mbedtls_ecdh_side side )
+static int ecdh_get_params_internal( mbedtls_ecdh_context_mbed *ctx,
+                                     const mbedtls_ecp_keypair *key,
+                                     mbedtls_ecdh_side side )
 {
     int ret;
 
-    if( ( ret = mbedtls_ecdh_setup( ctx, key->grp.id ) ) != 0 )
-        return( ret );
-
     /* If it's not our key, just import the public part as Qp */
     if( side == MBEDTLS_ECDH_THEIRS )
         return( mbedtls_ecp_copy( &ctx->Qp, &key->Q ) );
@@ -300,29 +413,61 @@
 }
 
 /*
- * Setup and export the client public value
+ * Get parameters from a keypair
  */
-int mbedtls_ecdh_make_public( mbedtls_ecdh_context *ctx, size_t *olen,
-                      unsigned char *buf, size_t blen,
-                      int (*f_rng)(void *, unsigned char *, size_t),
-                      void *p_rng )
+int mbedtls_ecdh_get_params( mbedtls_ecdh_context *ctx,
+                             const mbedtls_ecp_keypair *key,
+                             mbedtls_ecdh_side side )
+{
+    int ret;
+
+    if( ctx == NULL )
+        return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
+
+    if( ( ret = mbedtls_ecdh_setup( ctx, key->grp.id ) ) != 0 )
+        return( ret );
+
+#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
+    return( ecdh_get_params_internal( ctx, key, side ) );
+#else
+    switch( ctx->var )
+    {
+        case MBEDTLS_ECDH_VARIANT_MBEDTLS_2_0:
+            return( ecdh_get_params_internal( &ctx->ctx.mbed_ecdh,
+                                              key, side ) );
+        default:
+            return MBEDTLS_ERR_ECP_BAD_INPUT_DATA;
+    }
+#endif
+}
+
+static int ecdh_make_public_internal( mbedtls_ecdh_context_mbed *ctx,
+                                      size_t *olen, int point_format,
+                                      unsigned char *buf, size_t blen,
+                                      int (*f_rng)(void *,
+                                                   unsigned char *,
+                                                   size_t),
+                                      void *p_rng,
+                                      int restart_enabled )
 {
     int ret;
 #if defined(MBEDTLS_ECP_RESTARTABLE)
     mbedtls_ecp_restart_ctx *rs_ctx = NULL;
 #endif
 
-    if( ctx == NULL || ctx->grp.pbits == 0 )
+    if( ctx->grp.pbits == 0 )
         return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
 
 #if defined(MBEDTLS_ECP_RESTARTABLE)
-    if( ctx->restart_enabled )
+    if( restart_enabled )
         rs_ctx = &ctx->rs;
+#else
+    (void) restart_enabled;
 #endif
 
 #if defined(MBEDTLS_ECP_RESTARTABLE)
     if( ( ret = ecdh_gen_public_restartable( &ctx->grp, &ctx->d, &ctx->Q,
-                    f_rng, p_rng, rs_ctx ) ) != 0 )
+                                             f_rng, p_rng, rs_ctx ) ) != 0 )
         return( ret );
 #else
     if( ( ret = mbedtls_ecdh_gen_public( &ctx->grp, &ctx->d, &ctx->Q,
@@ -330,23 +475,52 @@
         return( ret );
 #endif /* MBEDTLS_ECP_RESTARTABLE */
 
-    return mbedtls_ecp_tls_write_point( &ctx->grp, &ctx->Q, ctx->point_format,
-                                olen, buf, blen );
+    return mbedtls_ecp_tls_write_point( &ctx->grp, &ctx->Q, point_format, olen,
+                                        buf, blen );
 }
 
 /*
- * Parse and import the client's public value
+ * Setup and export the client public value
  */
-int mbedtls_ecdh_read_public( mbedtls_ecdh_context *ctx,
-                      const unsigned char *buf, size_t blen )
+int mbedtls_ecdh_make_public( mbedtls_ecdh_context *ctx, size_t *olen,
+                              unsigned char *buf, size_t blen,
+                              int (*f_rng)(void *, unsigned char *, size_t),
+                              void *p_rng )
 {
-    int ret;
-    const unsigned char *p = buf;
+    int restart_enabled = 0;
 
     if( ctx == NULL )
         return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
 
-    if( ( ret = mbedtls_ecp_tls_read_point( &ctx->grp, &ctx->Qp, &p, blen ) ) != 0 )
+#if defined(MBEDTLS_ECP_RESTARTABLE)
+    restart_enabled = ctx->restart_enabled;
+#endif
+
+#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
+    return( ecdh_make_public_internal( ctx, olen, ctx->point_format, buf, blen,
+                                       f_rng, p_rng, restart_enabled ) );
+#else
+    switch( ctx->var )
+    {
+        case MBEDTLS_ECDH_VARIANT_MBEDTLS_2_0:
+            return( ecdh_make_public_internal( &ctx->ctx.mbed_ecdh, olen,
+                                               ctx->point_format, buf, blen,
+                                               f_rng, p_rng,
+                                               restart_enabled ) );
+        default:
+            return MBEDTLS_ERR_ECP_BAD_INPUT_DATA;
+    }
+#endif
+}
+
+static int ecdh_read_public_internal( mbedtls_ecdh_context_mbed *ctx,
+                                      const unsigned char *buf, size_t blen )
+{
+    int ret;
+    const unsigned char *p = buf;
+
+    if( ( ret = mbedtls_ecp_tls_read_point( &ctx->grp, &ctx->Qp, &p,
+                                            blen ) ) != 0 )
         return( ret );
 
     if( (size_t)( p - buf ) != blen )
@@ -356,12 +530,36 @@
 }
 
 /*
- * Derive and export the shared secret
+ * Parse and import the client's public value
  */
-int mbedtls_ecdh_calc_secret( mbedtls_ecdh_context *ctx, size_t *olen,
-                      unsigned char *buf, size_t blen,
-                      int (*f_rng)(void *, unsigned char *, size_t),
-                      void *p_rng )
+int mbedtls_ecdh_read_public( mbedtls_ecdh_context *ctx,
+                              const unsigned char *buf, size_t blen )
+{
+    if( ctx == NULL )
+        return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
+
+#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
+    return( ecdh_read_public_internal( ctx, buf, blen ) );
+#else
+    switch( ctx->var )
+    {
+        case MBEDTLS_ECDH_VARIANT_MBEDTLS_2_0:
+            return( ecdh_read_public_internal( &ctx->ctx.mbed_ecdh,
+                                                       buf, blen ) );
+        default:
+            return MBEDTLS_ERR_ECP_BAD_INPUT_DATA;
+    }
+#endif
+}
+
+static int ecdh_calc_secret_internal( mbedtls_ecdh_context_mbed *ctx,
+                                      size_t *olen, unsigned char *buf,
+                                      size_t blen,
+                                      int (*f_rng)(void *,
+                                                   unsigned char *,
+                                                   size_t),
+                                      void *p_rng,
+                                      int restart_enabled )
 {
     int ret;
 #if defined(MBEDTLS_ECP_RESTARTABLE)
@@ -372,13 +570,16 @@
         return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
 
 #if defined(MBEDTLS_ECP_RESTARTABLE)
-    if( ctx->restart_enabled )
+    if( restart_enabled )
         rs_ctx = &ctx->rs;
+#else
+    (void) restart_enabled;
 #endif
 
 #if defined(MBEDTLS_ECP_RESTARTABLE)
-    if( ( ret = ecdh_compute_shared_restartable( &ctx->grp,
-                    &ctx->z, &ctx->Qp, &ctx->d, f_rng, p_rng, rs_ctx ) ) != 0 )
+    if( ( ret = ecdh_compute_shared_restartable( &ctx->grp, &ctx->z, &ctx->Qp,
+                                                 &ctx->d, f_rng, p_rng,
+                                                 rs_ctx ) ) != 0 )
     {
         return( ret );
     }
@@ -397,4 +598,37 @@
     return mbedtls_mpi_write_binary( &ctx->z, buf, *olen );
 }
 
+/*
+ * Derive and export the shared secret
+ */
+int mbedtls_ecdh_calc_secret( mbedtls_ecdh_context *ctx, size_t *olen,
+                              unsigned char *buf, size_t blen,
+                              int (*f_rng)(void *, unsigned char *, size_t),
+                              void *p_rng )
+{
+    int restart_enabled = 0;
+
+    if( ctx == NULL )
+        return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
+
+#if defined(MBEDTLS_ECP_RESTARTABLE)
+    restart_enabled = ctx->restart_enabled;
+#endif
+
+#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
+    return( ecdh_calc_secret_internal( ctx, olen, buf, blen, f_rng, p_rng,
+                                       restart_enabled ) );
+#else
+    switch( ctx->var )
+    {
+        case MBEDTLS_ECDH_VARIANT_MBEDTLS_2_0:
+            return( ecdh_calc_secret_internal( &ctx->ctx.mbed_ecdh, olen, buf,
+                                               blen, f_rng, p_rng,
+                                               restart_enabled ) );
+        default:
+            return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
+    }
+#endif
+}
+
 #endif /* MBEDTLS_ECDH_C */