move tls_prf_sha256/384

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 16cd454..bb56583 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -392,250 +392,6 @@
                                    size_t len );
 #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
 
-#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-
-static psa_status_t setup_psa_key_derivation( psa_key_derivation_operation_t* derivation,
-                                              mbedtls_svc_key_id_t key,
-                                              psa_algorithm_t alg,
-                                              const unsigned char* seed, size_t seed_length,
-                                              const unsigned char* label, size_t label_length,
-                                              size_t capacity )
-{
-    psa_status_t status;
-
-    status = psa_key_derivation_setup( derivation, alg );
-    if( status != PSA_SUCCESS )
-        return( status );
-
-    if( PSA_ALG_IS_TLS12_PRF( alg ) || PSA_ALG_IS_TLS12_PSK_TO_MS( alg ) )
-    {
-        status = psa_key_derivation_input_bytes( derivation,
-                                                 PSA_KEY_DERIVATION_INPUT_SEED,
-                                                 seed, seed_length );
-        if( status != PSA_SUCCESS )
-            return( status );
-
-        if( mbedtls_svc_key_id_is_null( key ) )
-        {
-            status = psa_key_derivation_input_bytes(
-                derivation, PSA_KEY_DERIVATION_INPUT_SECRET,
-                NULL, 0 );
-        }
-        else
-        {
-            status = psa_key_derivation_input_key(
-                derivation, PSA_KEY_DERIVATION_INPUT_SECRET, key );
-        }
-        if( status != PSA_SUCCESS )
-            return( status );
-
-        status = psa_key_derivation_input_bytes( derivation,
-                                                 PSA_KEY_DERIVATION_INPUT_LABEL,
-                                                 label, label_length );
-        if( status != PSA_SUCCESS )
-            return( status );
-    }
-    else
-    {
-        return( PSA_ERROR_NOT_SUPPORTED );
-    }
-
-    status = psa_key_derivation_set_capacity( derivation, capacity );
-    if( status != PSA_SUCCESS )
-        return( status );
-
-    return( PSA_SUCCESS );
-}
-
-static int tls_prf_generic( mbedtls_md_type_t md_type,
-                            const unsigned char *secret, size_t slen,
-                            const char *label,
-                            const unsigned char *random, size_t rlen,
-                            unsigned char *dstbuf, size_t dlen )
-{
-    psa_status_t status;
-    psa_algorithm_t alg;
-    mbedtls_svc_key_id_t master_key = MBEDTLS_SVC_KEY_ID_INIT;
-    psa_key_derivation_operation_t derivation =
-        PSA_KEY_DERIVATION_OPERATION_INIT;
-
-    if( md_type == MBEDTLS_MD_SHA384 )
-        alg = PSA_ALG_TLS12_PRF(PSA_ALG_SHA_384);
-    else
-        alg = PSA_ALG_TLS12_PRF(PSA_ALG_SHA_256);
-
-    /* Normally a "secret" should be long enough to be impossible to
-     * find by brute force, and in particular should not be empty. But
-     * this PRF is also used to derive an IV, in particular in EAP-TLS,
-     * and for this use case it makes sense to have a 0-length "secret".
-     * Since the key API doesn't allow importing a key of length 0,
-     * keep master_key=0, which setup_psa_key_derivation() understands
-     * to mean a 0-length "secret" input. */
-    if( slen != 0 )
-    {
-        psa_key_attributes_t key_attributes = psa_key_attributes_init();
-        psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE );
-        psa_set_key_algorithm( &key_attributes, alg );
-        psa_set_key_type( &key_attributes, PSA_KEY_TYPE_DERIVE );
-
-        status = psa_import_key( &key_attributes, secret, slen, &master_key );
-        if( status != PSA_SUCCESS )
-            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
-    }
-
-    status = setup_psa_key_derivation( &derivation,
-                                       master_key, alg,
-                                       random, rlen,
-                                       (unsigned char const *) label,
-                                       (size_t) strlen( label ),
-                                       dlen );
-    if( status != PSA_SUCCESS )
-    {
-        psa_key_derivation_abort( &derivation );
-        psa_destroy_key( master_key );
-        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
-    }
-
-    status = psa_key_derivation_output_bytes( &derivation, dstbuf, dlen );
-    if( status != PSA_SUCCESS )
-    {
-        psa_key_derivation_abort( &derivation );
-        psa_destroy_key( master_key );
-        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
-    }
-
-    status = psa_key_derivation_abort( &derivation );
-    if( status != PSA_SUCCESS )
-    {
-        psa_destroy_key( master_key );
-        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
-    }
-
-    if( ! mbedtls_svc_key_id_is_null( master_key ) )
-        status = psa_destroy_key( master_key );
-    if( status != PSA_SUCCESS )
-        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
-
-    return( 0 );
-}
-
-#else /* MBEDTLS_USE_PSA_CRYPTO */
-
-static int tls_prf_generic( mbedtls_md_type_t md_type,
-                            const unsigned char *secret, size_t slen,
-                            const char *label,
-                            const unsigned char *random, size_t rlen,
-                            unsigned char *dstbuf, size_t dlen )
-{
-    size_t nb;
-    size_t i, j, k, md_len;
-    unsigned char *tmp;
-    size_t tmp_len = 0;
-    unsigned char h_i[MBEDTLS_MD_MAX_SIZE];
-    const mbedtls_md_info_t *md_info;
-    mbedtls_md_context_t md_ctx;
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-
-    mbedtls_md_init( &md_ctx );
-
-    if( ( md_info = mbedtls_md_info_from_type( md_type ) ) == NULL )
-        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
-
-    md_len = mbedtls_md_get_size( md_info );
-
-    tmp_len = md_len + strlen( label ) + rlen;
-    tmp = mbedtls_calloc( 1, tmp_len );
-    if( tmp == NULL )
-    {
-        ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
-        goto exit;
-    }
-
-    nb = strlen( label );
-    memcpy( tmp + md_len, label, nb );
-    memcpy( tmp + md_len + nb, random, rlen );
-    nb += rlen;
-
-    /*
-     * Compute P_<hash>(secret, label + random)[0..dlen]
-     */
-    if ( ( ret = mbedtls_md_setup( &md_ctx, md_info, 1 ) ) != 0 )
-        goto exit;
-
-    ret = mbedtls_md_hmac_starts( &md_ctx, secret, slen );
-    if( ret != 0 )
-        goto exit;
-    ret = mbedtls_md_hmac_update( &md_ctx, tmp + md_len, nb );
-    if( ret != 0 )
-        goto exit;
-    ret = mbedtls_md_hmac_finish( &md_ctx, tmp );
-    if( ret != 0 )
-        goto exit;
-
-    for( i = 0; i < dlen; i += md_len )
-    {
-        ret = mbedtls_md_hmac_reset ( &md_ctx );
-        if( ret != 0 )
-            goto exit;
-        ret = mbedtls_md_hmac_update( &md_ctx, tmp, md_len + nb );
-        if( ret != 0 )
-            goto exit;
-        ret = mbedtls_md_hmac_finish( &md_ctx, h_i );
-        if( ret != 0 )
-            goto exit;
-
-        ret = mbedtls_md_hmac_reset ( &md_ctx );
-        if( ret != 0 )
-            goto exit;
-        ret = mbedtls_md_hmac_update( &md_ctx, tmp, md_len );
-        if( ret != 0 )
-            goto exit;
-        ret = mbedtls_md_hmac_finish( &md_ctx, tmp );
-        if( ret != 0 )
-            goto exit;
-
-        k = ( i + md_len > dlen ) ? dlen % md_len : md_len;
-
-        for( j = 0; j < k; j++ )
-            dstbuf[i + j]  = h_i[j];
-    }
-
-exit:
-    mbedtls_md_free( &md_ctx );
-
-    mbedtls_platform_zeroize( tmp, tmp_len );
-    mbedtls_platform_zeroize( h_i, sizeof( h_i ) );
-
-    mbedtls_free( tmp );
-
-    return( ret );
-}
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
-#if defined(MBEDTLS_SHA256_C)
-static int tls_prf_sha256( const unsigned char *secret, size_t slen,
-                           const char *label,
-                           const unsigned char *random, size_t rlen,
-                           unsigned char *dstbuf, size_t dlen )
-{
-    return( tls_prf_generic( MBEDTLS_MD_SHA256, secret, slen,
-                             label, random, rlen, dstbuf, dlen ) );
-}
-#endif /* MBEDTLS_SHA256_C */
-
-#if defined(MBEDTLS_SHA384_C)
-static int tls_prf_sha384( const unsigned char *secret, size_t slen,
-                           const char *label,
-                           const unsigned char *random, size_t rlen,
-                           unsigned char *dstbuf, size_t dlen )
-{
-    return( tls_prf_generic( MBEDTLS_MD_SHA384, secret, slen,
-                             label, random, rlen, dstbuf, dlen ) );
-}
-#endif /* MBEDTLS_SHA384_C */
-
-#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
-
 static void ssl_update_checksum_start( mbedtls_ssl_context *, const unsigned char *, size_t );
 
 #if defined(MBEDTLS_SHA256_C)
@@ -7965,4 +7721,251 @@
 }
 #endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION */
 
+
+#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+
+static psa_status_t setup_psa_key_derivation( psa_key_derivation_operation_t* derivation,
+                                              mbedtls_svc_key_id_t key,
+                                              psa_algorithm_t alg,
+                                              const unsigned char* seed, size_t seed_length,
+                                              const unsigned char* label, size_t label_length,
+                                              size_t capacity )
+{
+    psa_status_t status;
+
+    status = psa_key_derivation_setup( derivation, alg );
+    if( status != PSA_SUCCESS )
+        return( status );
+
+    if( PSA_ALG_IS_TLS12_PRF( alg ) || PSA_ALG_IS_TLS12_PSK_TO_MS( alg ) )
+    {
+        status = psa_key_derivation_input_bytes( derivation,
+                                                 PSA_KEY_DERIVATION_INPUT_SEED,
+                                                 seed, seed_length );
+        if( status != PSA_SUCCESS )
+            return( status );
+
+        if( mbedtls_svc_key_id_is_null( key ) )
+        {
+            status = psa_key_derivation_input_bytes(
+                derivation, PSA_KEY_DERIVATION_INPUT_SECRET,
+                NULL, 0 );
+        }
+        else
+        {
+            status = psa_key_derivation_input_key(
+                derivation, PSA_KEY_DERIVATION_INPUT_SECRET, key );
+        }
+        if( status != PSA_SUCCESS )
+            return( status );
+
+        status = psa_key_derivation_input_bytes( derivation,
+                                                 PSA_KEY_DERIVATION_INPUT_LABEL,
+                                                 label, label_length );
+        if( status != PSA_SUCCESS )
+            return( status );
+    }
+    else
+    {
+        return( PSA_ERROR_NOT_SUPPORTED );
+    }
+
+    status = psa_key_derivation_set_capacity( derivation, capacity );
+    if( status != PSA_SUCCESS )
+        return( status );
+
+    return( PSA_SUCCESS );
+}
+
+static int tls_prf_generic( mbedtls_md_type_t md_type,
+                            const unsigned char *secret, size_t slen,
+                            const char *label,
+                            const unsigned char *random, size_t rlen,
+                            unsigned char *dstbuf, size_t dlen )
+{
+    psa_status_t status;
+    psa_algorithm_t alg;
+    mbedtls_svc_key_id_t master_key = MBEDTLS_SVC_KEY_ID_INIT;
+    psa_key_derivation_operation_t derivation =
+        PSA_KEY_DERIVATION_OPERATION_INIT;
+
+    if( md_type == MBEDTLS_MD_SHA384 )
+        alg = PSA_ALG_TLS12_PRF(PSA_ALG_SHA_384);
+    else
+        alg = PSA_ALG_TLS12_PRF(PSA_ALG_SHA_256);
+
+    /* Normally a "secret" should be long enough to be impossible to
+     * find by brute force, and in particular should not be empty. But
+     * this PRF is also used to derive an IV, in particular in EAP-TLS,
+     * and for this use case it makes sense to have a 0-length "secret".
+     * Since the key API doesn't allow importing a key of length 0,
+     * keep master_key=0, which setup_psa_key_derivation() understands
+     * to mean a 0-length "secret" input. */
+    if( slen != 0 )
+    {
+        psa_key_attributes_t key_attributes = psa_key_attributes_init();
+        psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE );
+        psa_set_key_algorithm( &key_attributes, alg );
+        psa_set_key_type( &key_attributes, PSA_KEY_TYPE_DERIVE );
+
+        status = psa_import_key( &key_attributes, secret, slen, &master_key );
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+    }
+
+    status = setup_psa_key_derivation( &derivation,
+                                       master_key, alg,
+                                       random, rlen,
+                                       (unsigned char const *) label,
+                                       (size_t) strlen( label ),
+                                       dlen );
+    if( status != PSA_SUCCESS )
+    {
+        psa_key_derivation_abort( &derivation );
+        psa_destroy_key( master_key );
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+    }
+
+    status = psa_key_derivation_output_bytes( &derivation, dstbuf, dlen );
+    if( status != PSA_SUCCESS )
+    {
+        psa_key_derivation_abort( &derivation );
+        psa_destroy_key( master_key );
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+    }
+
+    status = psa_key_derivation_abort( &derivation );
+    if( status != PSA_SUCCESS )
+    {
+        psa_destroy_key( master_key );
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+    }
+
+    if( ! mbedtls_svc_key_id_is_null( master_key ) )
+        status = psa_destroy_key( master_key );
+    if( status != PSA_SUCCESS )
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+    return( 0 );
+}
+
+#else /* MBEDTLS_USE_PSA_CRYPTO */
+
+static int tls_prf_generic( mbedtls_md_type_t md_type,
+                            const unsigned char *secret, size_t slen,
+                            const char *label,
+                            const unsigned char *random, size_t rlen,
+                            unsigned char *dstbuf, size_t dlen )
+{
+    size_t nb;
+    size_t i, j, k, md_len;
+    unsigned char *tmp;
+    size_t tmp_len = 0;
+    unsigned char h_i[MBEDTLS_MD_MAX_SIZE];
+    const mbedtls_md_info_t *md_info;
+    mbedtls_md_context_t md_ctx;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+
+    mbedtls_md_init( &md_ctx );
+
+    if( ( md_info = mbedtls_md_info_from_type( md_type ) ) == NULL )
+        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+
+    md_len = mbedtls_md_get_size( md_info );
+
+    tmp_len = md_len + strlen( label ) + rlen;
+    tmp = mbedtls_calloc( 1, tmp_len );
+    if( tmp == NULL )
+    {
+        ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
+        goto exit;
+    }
+
+    nb = strlen( label );
+    memcpy( tmp + md_len, label, nb );
+    memcpy( tmp + md_len + nb, random, rlen );
+    nb += rlen;
+
+    /*
+     * Compute P_<hash>(secret, label + random)[0..dlen]
+     */
+    if ( ( ret = mbedtls_md_setup( &md_ctx, md_info, 1 ) ) != 0 )
+        goto exit;
+
+    ret = mbedtls_md_hmac_starts( &md_ctx, secret, slen );
+    if( ret != 0 )
+        goto exit;
+    ret = mbedtls_md_hmac_update( &md_ctx, tmp + md_len, nb );
+    if( ret != 0 )
+        goto exit;
+    ret = mbedtls_md_hmac_finish( &md_ctx, tmp );
+    if( ret != 0 )
+        goto exit;
+
+    for( i = 0; i < dlen; i += md_len )
+    {
+        ret = mbedtls_md_hmac_reset ( &md_ctx );
+        if( ret != 0 )
+            goto exit;
+        ret = mbedtls_md_hmac_update( &md_ctx, tmp, md_len + nb );
+        if( ret != 0 )
+            goto exit;
+        ret = mbedtls_md_hmac_finish( &md_ctx, h_i );
+        if( ret != 0 )
+            goto exit;
+
+        ret = mbedtls_md_hmac_reset ( &md_ctx );
+        if( ret != 0 )
+            goto exit;
+        ret = mbedtls_md_hmac_update( &md_ctx, tmp, md_len );
+        if( ret != 0 )
+            goto exit;
+        ret = mbedtls_md_hmac_finish( &md_ctx, tmp );
+        if( ret != 0 )
+            goto exit;
+
+        k = ( i + md_len > dlen ) ? dlen % md_len : md_len;
+
+        for( j = 0; j < k; j++ )
+            dstbuf[i + j]  = h_i[j];
+    }
+
+exit:
+    mbedtls_md_free( &md_ctx );
+
+    mbedtls_platform_zeroize( tmp, tmp_len );
+    mbedtls_platform_zeroize( h_i, sizeof( h_i ) );
+
+    mbedtls_free( tmp );
+
+    return( ret );
+}
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
+#if defined(MBEDTLS_SHA256_C)
+static int tls_prf_sha256( const unsigned char *secret, size_t slen,
+                           const char *label,
+                           const unsigned char *random, size_t rlen,
+                           unsigned char *dstbuf, size_t dlen )
+{
+    return( tls_prf_generic( MBEDTLS_MD_SHA256, secret, slen,
+                             label, random, rlen, dstbuf, dlen ) );
+}
+#endif /* MBEDTLS_SHA256_C */
+
+#if defined(MBEDTLS_SHA384_C)
+static int tls_prf_sha384( const unsigned char *secret, size_t slen,
+                           const char *label,
+                           const unsigned char *random, size_t rlen,
+                           unsigned char *dstbuf, size_t dlen )
+{
+    return( tls_prf_generic( MBEDTLS_MD_SHA384, secret, slen,
+                             label, random, rlen, dstbuf, dlen ) );
+}
+#endif /* MBEDTLS_SHA384_C */
+
+#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
+
 #endif /* MBEDTLS_SSL_TLS_C */