Unify round two
diff --git a/library/ecjpake.c b/library/ecjpake.c
index 08d54d7..e09d742 100644
--- a/library/ecjpake.c
+++ b/library/ecjpake.c
@@ -512,9 +512,9 @@
 }
 
 /*
- * Read and process ServerECJPAKEParams (7.4.2.5)
+ * Read and process second round message (C: 7.4.2.5, S: 7.4.2.6)
  */
-int mbedtls_ecjpake_tls_read_server_params( mbedtls_ecjpake_context *ctx,
+int mbedtls_ecjpake_read_round_two( mbedtls_ecjpake_context *ctx,
                                             const unsigned char *buf,
                                             size_t len )
 {
@@ -522,28 +522,30 @@
     const unsigned char *p = buf;
     const unsigned char *end = buf + len;
     mbedtls_ecp_group grp;
-    mbedtls_ecp_point GB;
+    mbedtls_ecp_point G;
 
     mbedtls_ecp_group_init( &grp );
-    mbedtls_ecp_point_init( &GB );
+    mbedtls_ecp_point_init( &G );
 
     /*
-     * Client:  GB = X1  + X2  + X3     (7.4.2.5.1)
-     * Unified: GB = Xm1 + Xm2 + Xp1
+     * Server: GA = X3  + X4  + X1      (7.4.2.6.1)
+     * Client: GB = X1  + X2  + X3      (7.4.2.5.1)
+     * Unified: G = Xm1 + Xm2 + Xp1
      * We need that before parsing in order to check Xp as we read it
      */
-    MBEDTLS_MPI_CHK( ecjpake_ecp_add3( &ctx->grp, &GB,
+    MBEDTLS_MPI_CHK( ecjpake_ecp_add3( &ctx->grp, &G,
                                        &ctx->Xm1, &ctx->Xm2, &ctx->Xp1 ) );
 
     /*
      * struct {
-     *     ECParameters curve_params;
+     *     ECParameters curve_params;   // only client reading server msg
      *     ECJPAKEKeyKP ecjpake_key_kp;
-     * } ServerECJPAKEParams;
+     * } Client/ServerECJPAKEParams;
      */
-    MBEDTLS_MPI_CHK( mbedtls_ecp_tls_read_group( &grp, &p, len ) );
+    if( ctx->role == MBEDTLS_ECJPAKE_CLIENT )
+        MBEDTLS_MPI_CHK( mbedtls_ecp_tls_read_group( &grp, &p, len ) );
     MBEDTLS_MPI_CHK( ecjpake_kkp_read( ctx->md_info, &ctx->grp,
-                            &GB, &ctx->Xp, ID_PEER, &p, end ) );
+                            &G, &ctx->Xp, ID_PEER, &p, end ) );
 
     if( p != end )
     {
@@ -552,185 +554,92 @@
     }
 
     /*
-     * Xs already checked, only thing left to check is the group
+     * Xs already checked, only thing left to check is the group,
      */
-    if( grp.id != ctx->grp.id )
+    if( ctx->role == MBEDTLS_ECJPAKE_CLIENT && grp.id != ctx->grp.id )
     {
         ret = MBEDTLS_ERR_ECP_FEATURE_UNAVAILABLE;
         goto cleanup;
     }
-
 cleanup:
     mbedtls_ecp_group_free( &grp );
-    mbedtls_ecp_point_free( &GB );
+    mbedtls_ecp_point_free( &G );
 
     return( ret );
 }
 
 /*
- * Generate and write ServerECJPAKEParams (7.4.2.5)
+ * Generate and write the second round message (S: 7.4.2.5, C: 7.4.2.6)
  */
-int mbedtls_ecjpake_tls_write_server_params( mbedtls_ecjpake_context *ctx,
+int mbedtls_ecjpake_write_round_two( mbedtls_ecjpake_context *ctx,
                             unsigned char *buf, size_t len, size_t *olen,
                             int (*f_rng)(void *, unsigned char *, size_t),
                             void *p_rng )
 {
     int ret;
-    mbedtls_ecp_point GB, Xs;
-    mbedtls_mpi xs;
+    mbedtls_ecp_point G;    /* C: GA, S: GB */
+    mbedtls_ecp_point Xm;   /* C: Xc, S: Xs */
+    mbedtls_mpi xm;         /* C: xc, S: xs */
     unsigned char *p = buf;
     const unsigned char *end = buf + len;
     size_t ec_len;
 
-    if( end < p )
-        return( MBEDTLS_ERR_ECP_BUFFER_TOO_SMALL );
-
-    mbedtls_ecp_point_init( &GB );
-    mbedtls_ecp_point_init( &Xs );
-    mbedtls_mpi_init( &xs );
+    mbedtls_ecp_point_init( &G );
+    mbedtls_ecp_point_init( &Xm );
+    mbedtls_mpi_init( &xm );
 
     /*
-     * First generate private/public key pair (7.4.2.5.1)
+     * First generate private/public key pair (S: 7.4.2.5.1, C: 7.4.2.6.1)
      *
-     * Server:  GB = X1 + X2 + X3
-     * Unified:
-     * xs = x4 * s mod n
-     * Xs = xs * GB
+     * Client:  GA = X1  + X3  + X4  | xs = x2  * s | Xc = xc * GA
+     * Server:  GB = X3  + X1  + X2  | xs = x4  * s | Xs = xs * GB
+     * Unified: G  = Xm1 + Xp1 + Xp2 | xm = xm2 * s | Xm = xm * G
      */
-    MBEDTLS_MPI_CHK( ecjpake_ecp_add3( &ctx->grp, &GB,
+    MBEDTLS_MPI_CHK( ecjpake_ecp_add3( &ctx->grp, &G,
                                        &ctx->Xp1, &ctx->Xp2, &ctx->Xm1 ) );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &xs, &ctx->xm2, &ctx->s ) );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &xs, &xs, &ctx->grp.N ) );
-    MBEDTLS_MPI_CHK( mbedtls_ecp_mul( &ctx->grp, &Xs, &xs, &GB, f_rng, p_rng ) );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &xm, &ctx->xm2, &ctx->s ) );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &xm, &xm, &ctx->grp.N ) );
+    MBEDTLS_MPI_CHK( mbedtls_ecp_mul( &ctx->grp, &Xm, &xm, &G, f_rng, p_rng ) );
 
     /*
      * Now write things out
+     *
+     * struct {
+     *     ECParameters curve_params;   // only server writing its message
+     *     ECJPAKEKeyKP ecjpake_key_kp;
+     * } Client/ServerECJPAKEParams;
      */
-    MBEDTLS_MPI_CHK( mbedtls_ecp_tls_write_group( &ctx->grp, &ec_len,
-                                                  p, end - p ) );
-    p += ec_len;
+    if( ctx->role == MBEDTLS_ECJPAKE_SERVER )
+    {
+        if( end < p )
+        {
+            ret = MBEDTLS_ERR_ECP_BUFFER_TOO_SMALL;
+            goto cleanup;
+        }
+        MBEDTLS_MPI_CHK( mbedtls_ecp_tls_write_group( &ctx->grp, &ec_len,
+                                                      p, end - p ) );
+        p += ec_len;
+    }
 
     if( end < p )
     {
         ret = MBEDTLS_ERR_ECP_BUFFER_TOO_SMALL;
         goto cleanup;
     }
-    MBEDTLS_MPI_CHK( mbedtls_ecp_tls_write_point( &ctx->grp, &Xs,
+    MBEDTLS_MPI_CHK( mbedtls_ecp_tls_write_point( &ctx->grp, &Xm,
                      MBEDTLS_ECP_PF_UNCOMPRESSED, &ec_len, p, end - p ) );
     p += ec_len;
 
     MBEDTLS_MPI_CHK( ecjpake_zkp_write( ctx->md_info, &ctx->grp,
-                                        &GB, &xs, &Xs, ID_MINE,
+                                        &G, &xm, &Xm, ID_MINE,
                                         &p, end, f_rng, p_rng ) );
 
     *olen = p - buf;
 
 cleanup:
-    mbedtls_ecp_point_free( &GB );
-    mbedtls_ecp_point_free( &Xs );
-    mbedtls_mpi_free( &xs );
-
-    return( ret );
-}
-
-/*
- * Read and process ClientECJPAKEParams (7.4.2.6)
- */
-int mbedtls_ecjpake_tls_read_client_params( mbedtls_ecjpake_context *ctx,
-                                            const unsigned char *buf,
-                                            size_t len )
-{
-    int ret;
-    const unsigned char *p = buf;
-    const unsigned char *end = buf + len;
-    mbedtls_ecp_group grp;
-    mbedtls_ecp_point GA;
-
-    mbedtls_ecp_group_init( &grp );
-    mbedtls_ecp_point_init( &GA );
-
-    /*
-     * Server: GA = X1 + X3 + X4 (7.4.2.6.1)
-     * Unified: G = Xp1 + Xm1 + Xm2
-     * We need that before parsing in order to check Xc as we read it
-     */
-    MBEDTLS_MPI_CHK( ecjpake_ecp_add3( &ctx->grp, &GA,
-                                       &ctx->Xp1, &ctx->Xm1, &ctx->Xm2 ) );
-
-    /*
-     * struct {
-     *     ECJPAKEKeyKP ecjpake_key_kp;
-     * } CLientECJPAKEParams;
-     */
-    MBEDTLS_MPI_CHK( ecjpake_kkp_read( ctx->md_info, &ctx->grp,
-                            &GA, &ctx->Xp, ID_PEER, &p, end ) );
-
-    if( p != end )
-    {
-        ret = MBEDTLS_ERR_ECP_BAD_INPUT_DATA;
-        goto cleanup;
-    }
-
-cleanup:
-    mbedtls_ecp_group_free( &grp );
-    mbedtls_ecp_point_free( &GA );
-
-    return( ret );
-}
-
-/*
- * Generate and write ClientECJPAKEParams (7.4.2.6)
- */
-int mbedtls_ecjpake_tls_write_client_params( mbedtls_ecjpake_context *ctx,
-                            unsigned char *buf, size_t len, size_t *olen,
-                            int (*f_rng)(void *, unsigned char *, size_t),
-                            void *p_rng )
-{
-    int ret;
-    mbedtls_ecp_point GA, Xc;
-    mbedtls_mpi xc;
-    unsigned char *p = buf;
-    const unsigned char *end = buf + len;
-    size_t ec_len;
-
-    if( end < p )
-        return( MBEDTLS_ERR_ECP_BUFFER_TOO_SMALL );
-
-    mbedtls_ecp_point_init( &GA );
-    mbedtls_ecp_point_init( &Xc );
-    mbedtls_mpi_init( &xc );
-
-    /*
-     * First generate private/public key pair (7.4.2.6.1)
-     *
-     * Client:  GA = X1 + X3 + X4
-     * Unified: G  = Xm1 + Xp1 + Xp2
-     * xc = x2 * s mod n
-     * Xc = xc * GA
-     */
-    MBEDTLS_MPI_CHK( ecjpake_ecp_add3( &ctx->grp, &GA,
-                                       &ctx->Xm1, &ctx->Xp1, &ctx->Xp2 ) );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mpi( &xc, &ctx->xm2, &ctx->s ) );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_mod_mpi( &xc, &xc, &ctx->grp.N ) );
-    MBEDTLS_MPI_CHK( mbedtls_ecp_mul( &ctx->grp, &Xc, &xc, &GA, f_rng, p_rng ) );
-
-    /*
-     * Now write things out
-     */
-    MBEDTLS_MPI_CHK( mbedtls_ecp_tls_write_point( &ctx->grp, &Xc,
-                     MBEDTLS_ECP_PF_UNCOMPRESSED, &ec_len, p, end - p ) );
-    p += ec_len;
-
-    MBEDTLS_MPI_CHK( ecjpake_zkp_write( ctx->md_info, &ctx->grp,
-                                        &GA, &xc, &Xc, ID_MINE,
-                                        &p, end, f_rng, p_rng ) );
-
-    *olen = p - buf;
-
-cleanup:
-    mbedtls_ecp_point_free( &GA );
-    mbedtls_ecp_point_free( &Xc );
-    mbedtls_mpi_free( &xc );
+    mbedtls_ecp_point_free( &G );
+    mbedtls_ecp_point_free( &Xm );
+    mbedtls_mpi_free( &xm );
 
     return( ret );
 }
@@ -1032,18 +941,18 @@
 
     TEST_ASSERT( mbedtls_ecjpake_read_round_one( &cli, buf, len ) == 0 );
 
-    TEST_ASSERT( mbedtls_ecjpake_tls_write_server_params( &srv,
+    TEST_ASSERT( mbedtls_ecjpake_write_round_two( &srv,
                  buf, sizeof( buf ), &len, ecjpake_lgc, NULL ) == 0 );
 
-    TEST_ASSERT( mbedtls_ecjpake_tls_read_server_params( &cli, buf, len ) == 0 );
+    TEST_ASSERT( mbedtls_ecjpake_read_round_two( &cli, buf, len ) == 0 );
 
     TEST_ASSERT( mbedtls_ecjpake_tls_derive_pms( &cli,
                  pms, sizeof( pms ), &pmslen, ecjpake_lgc, NULL ) == 0 );
 
-    TEST_ASSERT( mbedtls_ecjpake_tls_write_client_params( &cli,
+    TEST_ASSERT( mbedtls_ecjpake_write_round_two( &cli,
                  buf, sizeof( buf ), &len, ecjpake_lgc, NULL ) == 0 );
 
-    TEST_ASSERT( mbedtls_ecjpake_tls_read_client_params( &srv, buf, len ) == 0 );
+    TEST_ASSERT( mbedtls_ecjpake_read_round_two( &srv, buf, len ) == 0 );
 
     TEST_ASSERT( mbedtls_ecjpake_tls_derive_pms( &srv,
                  buf, sizeof( buf ), &len, ecjpake_lgc, NULL ) == 0 );
@@ -1077,12 +986,12 @@
                                     ecjpake_test_srv_ext,
                             sizeof( ecjpake_test_srv_ext ) ) == 0 );
 
-    TEST_ASSERT( mbedtls_ecjpake_tls_read_server_params( &cli,
+    TEST_ASSERT( mbedtls_ecjpake_read_round_two( &cli,
                                     ecjpake_test_srv_kx,
                             sizeof( ecjpake_test_srv_kx ) ) == 0 );
 
     /* Server reads client key exchange */
-    TEST_ASSERT( mbedtls_ecjpake_tls_read_client_params( &srv,
+    TEST_ASSERT( mbedtls_ecjpake_read_round_two( &srv,
                                     ecjpake_test_cli_kx,
                             sizeof( ecjpake_test_cli_kx ) ) == 0 );