Add a few check for context validity.
diff --git a/library/ecdh.c b/library/ecdh.c
index c04d34c..e30ae3a 100644
--- a/library/ecdh.c
+++ b/library/ecdh.c
@@ -118,6 +118,9 @@
int ret;
size_t grp_len, pt_len;
+ if( ctx == NULL || ctx->grp.pbits == 0 )
+ return( POLARSSL_ERR_ECP_BAD_INPUT_DATA );
+
if( ( ret = ecdh_gen_public( &ctx->grp, &ctx->d, &ctx->Q, f_rng, p_rng ) )
!= 0 )
return( ret );
@@ -149,6 +152,8 @@
{
int ret;
+ ecdh_init( ctx );
+
if( ( ret = ecp_tls_read_group( &ctx->grp, buf, end - *buf ) ) != 0 )
return( ret );
@@ -169,6 +174,9 @@
{
int ret;
+ if( ctx == NULL || ctx->grp.pbits == 0 )
+ return( POLARSSL_ERR_ECP_BAD_INPUT_DATA );
+
if( ( ret = ecdh_gen_public( &ctx->grp, &ctx->d, &ctx->Q, f_rng, p_rng ) )
!= 0 )
return( ret );
@@ -183,6 +191,9 @@
int ecdh_read_public( ecdh_context *ctx,
const unsigned char *buf, size_t blen )
{
+ if( ctx == NULL )
+ return( POLARSSL_ERR_ECP_BAD_INPUT_DATA );
+
return ecp_tls_read_point( &ctx->grp, &ctx->Qp, &buf, blen );
}
@@ -194,6 +205,9 @@
{
int ret;
+ if( ctx == NULL )
+ return( POLARSSL_ERR_ECP_BAD_INPUT_DATA );
+
if( ( ret = ecdh_compute_shared( &ctx->grp, &ctx->z, &ctx->Qp, &ctx->d ) )
!= 0 )
return( ret );