Factor tls_prf_sha{256,384} together
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 09c72a7..cddeb74 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -336,40 +336,42 @@
#endif /* POLARSSL_SSL_PROTO_TLS1) || POLARSSL_SSL_PROTO_TLS1_1 */
#if defined(POLARSSL_SSL_PROTO_TLS1_2)
-#if defined(POLARSSL_SHA256_C)
-static int tls_prf_sha256( const unsigned char *secret, size_t slen,
- const char *label,
- const unsigned char *random, size_t rlen,
- unsigned char *dstbuf, size_t dlen )
+static int tls_prf_generic( md_type_t md_type,
+ const unsigned char *secret, size_t slen,
+ const char *label,
+ const unsigned char *random, size_t rlen,
+ unsigned char *dstbuf, size_t dlen )
{
size_t nb;
- size_t i, j, k;
+ size_t i, j, k, md_len;
unsigned char tmp[128];
- unsigned char h_i[32];
+ unsigned char h_i[POLARSSL_MD_MAX_SIZE];
const md_info_t *md_info;
- if( sizeof( tmp ) < 32 + strlen( label ) + rlen )
+ if( ( md_info = md_info_from_type( md_type ) ) == NULL )
+ return( POLARSSL_ERR_SSL_INTERNAL_ERROR );
+
+ md_len = md_get_size( md_info );
+
+ if( sizeof( tmp ) < md_len + strlen( label ) + rlen )
return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
nb = strlen( label );
- memcpy( tmp + 32, label, nb );
- memcpy( tmp + 32 + nb, random, rlen );
+ memcpy( tmp + md_len, label, nb );
+ memcpy( tmp + md_len + nb, random, rlen );
nb += rlen;
/*
* Compute P_<hash>(secret, label + random)[0..dlen]
*/
- if( ( md_info = md_info_from_type( POLARSSL_MD_SHA256 ) ) == NULL )
- return( POLARSSL_ERR_SSL_INTERNAL_ERROR );
+ md_hmac( md_info, secret, slen, tmp + md_len, nb, tmp );
- md_hmac( md_info, secret, slen, tmp + 32, nb, tmp );
-
- for( i = 0; i < dlen; i += 32 )
+ for( i = 0; i < dlen; i += md_len )
{
- md_hmac( md_info, secret, slen, tmp, 32 + nb, h_i );
- md_hmac( md_info, secret, slen, tmp, 32, tmp );
+ md_hmac( md_info, secret, slen, tmp, md_len + nb, h_i );
+ md_hmac( md_info, secret, slen, tmp, md_len, tmp );
- k = ( i + 32 > dlen ) ? dlen % 32 : 32;
+ k = ( i + md_len > dlen ) ? dlen % md_len : md_len;
for( j = 0; j < k; j++ )
dstbuf[i + j] = h_i[j];
@@ -380,6 +382,16 @@
return( 0 );
}
+
+#if defined(POLARSSL_SHA256_C)
+static int tls_prf_sha256( const unsigned char *secret, size_t slen,
+ const char *label,
+ const unsigned char *random, size_t rlen,
+ unsigned char *dstbuf, size_t dlen )
+{
+ return( tls_prf_generic( POLARSSL_MD_SHA256, secret, slen,
+ label, random, rlen, dstbuf, dlen ) );
+}
#endif /* POLARSSL_SHA256_C */
#if defined(POLARSSL_SHA512_C)
@@ -388,43 +400,8 @@
const unsigned char *random, size_t rlen,
unsigned char *dstbuf, size_t dlen )
{
- size_t nb;
- size_t i, j, k;
- unsigned char tmp[128];
- unsigned char h_i[48];
- const md_info_t *md_info;
-
- if( sizeof( tmp ) < 48 + strlen( label ) + rlen )
- return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
-
- nb = strlen( label );
- memcpy( tmp + 48, label, nb );
- memcpy( tmp + 48 + nb, random, rlen );
- nb += rlen;
-
- /*
- * Compute P_<hash>(secret, label + random)[0..dlen]
- */
- if( ( md_info = md_info_from_type( POLARSSL_MD_SHA384 ) ) == NULL )
- return( POLARSSL_ERR_SSL_INTERNAL_ERROR );
-
- md_hmac( md_info, secret, slen, tmp + 48, nb, tmp );
-
- for( i = 0; i < dlen; i += 48 )
- {
- md_hmac( md_info, secret, slen, tmp, 48 + nb, h_i );
- md_hmac( md_info, secret, slen, tmp, 48, tmp );
-
- k = ( i + 48 > dlen ) ? dlen % 48 : 48;
-
- for( j = 0; j < k; j++ )
- dstbuf[i + j] = h_i[j];
- }
-
- polarssl_zeroize( tmp, sizeof( tmp ) );
- polarssl_zeroize( h_i, sizeof( h_i ) );
-
- return( 0 );
+ return( tls_prf_generic( POLARSSL_MD_SHA384, secret, slen,
+ label, random, rlen, dstbuf, dlen ) );
}
#endif /* POLARSSL_SHA512_C */
#endif /* POLARSSL_SSL_PROTO_TLS1_2 */