Combine mbedtls_ssl_tls13_generate_and_write_ecdh/ffdh_key_exchange functions

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index 821a54c..42cabf5 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -1513,7 +1513,36 @@
     return 0;
 }
 
-int mbedtls_ssl_tls13_generate_and_write_ecdh_key_exchange(
+static psa_key_type_t mbedtls_psa_parse_tls_ffdh_group(
+    uint16_t tls_ecc_grp_reg_id, size_t *bits, psa_key_type_t *key_type)
+{
+    switch (tls_ecc_grp_reg_id) {
+        case MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE2048:
+            *bits = 2048;
+            *key_type = PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919);
+            return PSA_SUCCESS;
+        case MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE3072:
+            *bits = 3072;
+            *key_type =  PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919);
+            return PSA_SUCCESS;
+        case MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE4096:
+            *bits = 4096;
+            *key_type =  PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919);
+            return PSA_SUCCESS;
+        case MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE6144:
+            *bits = 6144;
+            *key_type =  PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919);
+            return PSA_SUCCESS;
+        case MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE8192:
+            *bits = 8192;
+            *key_type =  PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919);
+            return PSA_SUCCESS;
+        default:
+            return PSA_ERROR_NOT_SUPPORTED;
+    }
+}
+
+int mbedtls_ssl_tls13_generate_and_write_dh_key_exchange(
     mbedtls_ssl_context *ssl,
     uint16_t named_group,
     unsigned char *buf,
@@ -1525,26 +1554,57 @@
     psa_key_attributes_t key_attributes;
     size_t own_pubkey_len;
     mbedtls_ssl_handshake_params *handshake = ssl->handshake;
-    psa_ecc_family_t ec_psa_family = 0;
-    size_t ec_bits = 0;
+    size_t bits = 0;
+    psa_key_type_t key_type = 0;
+    size_t buf_size = (size_t) (end - buf);
 
-    MBEDTLS_SSL_DEBUG_MSG(1, ("Perform PSA-based ECDH computation."));
+
+    MBEDTLS_SSL_DEBUG_MSG(1, ("Perform PSA-based ECDH/FFDH computation."));
 
     /* Convert EC's TLS ID to PSA key type. */
+#if defined(PSA_WANT_ALG_ECDH)
+    psa_ecc_family_t ec_psa_family = 0;
     if (mbedtls_ssl_get_psa_curve_info_from_tls_id(
-            named_group, &ec_psa_family, &ec_bits) == PSA_ERROR_NOT_SUPPORTED) {
+            named_group, &ec_psa_family, &bits) == PSA_SUCCESS) {
+        key_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ec_psa_family);
+    }
+#endif
+#if defined(PSA_WANT_ALG_FFDH)
+    if (mbedtls_psa_parse_tls_ffdh_group(named_group, &bits, &key_type) == PSA_SUCCESS) {
+        if (PSA_KEY_TYPE_IS_DH(key_type)) {
+            if (buf_size < PSA_BITS_TO_BYTES(bits)) {
+
+                return MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL;
+            }
+            buf_size = PSA_BITS_TO_BYTES(bits);
+        }
+    }
+#endif
+
+    if (key_type == 0) {
         return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
     }
-    handshake->ecdh_psa_type = PSA_KEY_TYPE_ECC_KEY_PAIR(ec_psa_family);
-    ssl->handshake->ecdh_bits = ec_bits;
+
+    handshake->ecdh_psa_type = key_type;
+    ssl->handshake->ecdh_bits = bits;
 
     key_attributes = psa_key_attributes_init();
     psa_set_key_usage_flags(&key_attributes, PSA_KEY_USAGE_DERIVE);
-    psa_set_key_algorithm(&key_attributes, PSA_ALG_ECDH);
+
+    if (PSA_KEY_TYPE_IS_ECC(key_type)) {
+#if defined(PSA_WANT_ALG_ECDH)
+        psa_set_key_algorithm(&key_attributes, PSA_ALG_ECDH);
+#endif
+    } else {
+#if defined(PSA_WANT_ALG_FFDH)
+        psa_set_key_algorithm(&key_attributes, PSA_ALG_FFDH);
+#endif
+    }
+
     psa_set_key_type(&key_attributes, handshake->ecdh_psa_type);
     psa_set_key_bits(&key_attributes, handshake->ecdh_bits);
 
-    /* Generate ECDH private key. */
+    /* Generate ECDH/FFDH private key. */
     status = psa_generate_key(&key_attributes,
                               &handshake->ecdh_psa_privkey);
     if (status != PSA_SUCCESS) {
@@ -1554,10 +1614,11 @@
 
     }
 
-    /* Export the public part of the ECDH private key from PSA. */
+    /* Export the public part of the ECDH/FFDH private key from PSA. */
     status = psa_export_public_key(handshake->ecdh_psa_privkey,
-                                   buf, (size_t) (end - buf),
+                                   buf, buf_size,
                                    &own_pubkey_len);
+
     if (status != PSA_SUCCESS) {
         ret = PSA_TO_MBEDTLS_ERR(status);
         MBEDTLS_SSL_DEBUG_RET(1, "psa_export_public_key", ret);
@@ -1571,92 +1632,6 @@
 }
 #endif /* PSA_WANT_ALG_ECDH || PSA_WANT_ALG_FFDH */
 
-#if defined(PSA_WANT_ALG_FFDH)
-static psa_key_type_t mbedtls_psa_parse_tls_ffdh_group(
-    uint16_t tls_ecc_grp_reg_id, size_t *bits)
-{
-    switch (tls_ecc_grp_reg_id) {
-        case MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE2048:
-            *bits = 2048;
-            return PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919);
-        case MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE3072:
-            *bits = 3072;
-            return PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919);
-        case MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE4096:
-            *bits = 4096;
-            return PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919);
-        case MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE6144:
-            *bits = 6144;
-            return PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919);
-        case MBEDTLS_SSL_IANA_TLS_GROUP_FFDHE8192:
-            *bits = 8192;
-            return PSA_KEY_TYPE_DH_KEY_PAIR(PSA_DH_FAMILY_RFC7919);
-        default:
-            return 0;
-    }
-}
-
-int mbedtls_ssl_tls13_generate_and_write_dhe_key_exchange(
-    mbedtls_ssl_context *ssl,
-    uint16_t named_group,
-    unsigned char *buf,
-    unsigned char *end,
-    size_t *out_len)
-{
-    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
-    int ret = MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE;
-    psa_key_attributes_t key_attributes;
-    size_t own_pubkey_len;
-    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
-    size_t ffdh_bits = 0;
-
-    MBEDTLS_SSL_DEBUG_MSG(1, ("Perform PSA-based DHE computation."));
-
-    /* Convert DHE group to PSA key type. */
-    if ((handshake->ecdh_psa_type =
-             mbedtls_psa_parse_tls_ffdh_group(named_group, &ffdh_bits)) == 0) {
-        return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
-    }
-
-    if ((size_t) (end - buf) < PSA_BITS_TO_BYTES(ffdh_bits)) {
-        ret = MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL;
-        return ret;
-    }
-
-    ssl->handshake->ecdh_bits = ffdh_bits;
-
-    key_attributes = psa_key_attributes_init();
-    psa_set_key_usage_flags(&key_attributes, PSA_KEY_USAGE_DERIVE);
-    psa_set_key_algorithm(&key_attributes, PSA_ALG_FFDH);
-    psa_set_key_type(&key_attributes, handshake->ecdh_psa_type);
-    psa_set_key_bits(&key_attributes, handshake->ecdh_bits);
-
-    /* Generate FFDH private key. */
-    status = psa_generate_key(&key_attributes,
-                              &handshake->ecdh_psa_privkey);
-    if (status != PSA_SUCCESS) {
-        ret = PSA_TO_MBEDTLS_ERR(status);
-        MBEDTLS_SSL_DEBUG_RET(1, "psa_generate_key", ret);
-        return ret;
-
-    }
-
-    /* Export the public part of the FFDH private key from PSA. */
-    status = psa_export_public_key(handshake->ecdh_psa_privkey,
-                                   buf, PSA_BITS_TO_BYTES(ffdh_bits),
-                                   &own_pubkey_len);
-    if (status != PSA_SUCCESS) {
-        ret = PSA_TO_MBEDTLS_ERR(status);
-        MBEDTLS_SSL_DEBUG_RET(1, "psa_export_public_key", ret);
-        return ret;
-    }
-
-    *out_len = own_pubkey_len;
-
-    return 0;
-}
-#endif /* PSA_WANT_ALG_FFDH */
-
 /* RFC 8446 section 4.2
  *
  * If an implementation receives an extension which it recognizes and which is