Implement PSA server-side ECDHE-PSK

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index 486632e..1a4571c 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -3068,7 +3068,8 @@
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
         if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_RSA ||
-            ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA )
+            ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA ||
+            ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK )
         {
             psa_status_t status = PSA_ERROR_GENERIC_ERROR;
             psa_key_attributes_t key_attributes;
@@ -4037,6 +4038,96 @@
     }
     else
 #endif /* MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED */
+#if defined(MBEDTLS_USE_PSA_CRYPTO) &&                           \
+        defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED)
+    if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK )
+    {
+        psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+        psa_status_t destruction_status = PSA_ERROR_CORRUPTION_DETECTED;
+        uint8_t ecpoint_len;
+
+        mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+
+        if( ( ret = ssl_parse_client_psk_identity( ssl, &p, end ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, ( "ssl_parse_client_psk_identity" ), ret );
+            psa_destroy_key( handshake->ecdh_psa_privkey );
+            handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+            return( ret );
+        }
+
+        /* Keep a copy of the peer's public key */
+        ecpoint_len = *(p++);
+        if( (size_t)( end - *p ) < ecpoint_len ) {
+            psa_destroy_key( handshake->ecdh_psa_privkey );
+            handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+            return( MBEDTLS_ERR_SSL_DECODE_ERROR );
+        }
+
+        if( ecpoint_len > sizeof( handshake->ecdh_psa_peerkey ) ) {
+            psa_destroy_key( handshake->ecdh_psa_privkey );
+            handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+            return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+        }
+
+        memcpy( handshake->ecdh_psa_peerkey, p, ecpoint_len );
+        handshake->ecdh_psa_peerkey_len = ecpoint_len;
+        p += ecpoint_len;
+
+        /* The ECDH secret is the premaster secret used for key derivation. */
+        unsigned char *psm = ssl->handshake->premaster;
+        unsigned char *psm_end = psm + sizeof( ssl->handshake->premaster );
+        size_t zlen;
+
+        /* Compute ECDH shared secret. */
+        status = psa_raw_key_agreement( PSA_ALG_ECDH,
+                                        handshake->ecdh_psa_privkey,
+                                        handshake->ecdh_psa_peerkey,
+                                        handshake->ecdh_psa_peerkey_len,
+                                        psm + 2,
+                                        psm_end - ( psm + 2 ),
+                                        &zlen );
+
+        destruction_status = psa_destroy_key( handshake->ecdh_psa_privkey );
+        handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+
+        if( status != PSA_SUCCESS || destruction_status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+        MBEDTLS_PUT_UINT16_BE( zlen, psm, 0 );
+        psm += 2 + zlen;
+
+        /* opaque psk<0..2^16-1>; */
+        if( psm_end - psm < 2 )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+        const unsigned char *psk = NULL;
+        size_t psk_len = 0;
+
+        if( mbedtls_ssl_get_psk( ssl, &psk, &psk_len )
+                == MBEDTLS_ERR_SSL_PRIVATE_KEY_REQUIRED )
+        {
+            /*
+             * This should never happen because the existence of a PSK is always
+             * checked before calling this function
+             */
+            return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+        }
+
+        MBEDTLS_PUT_UINT16_BE( psk_len, psm, 0 );
+        psm += 2;
+
+        if( psm_end < psm || (size_t)( psm_end - psm ) < psk_len )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+        memcpy( psm, psk, psk_len );
+        psm += psk_len;
+
+        ssl->handshake->pmslen = psm - ssl->handshake->premaster;
+    }
+    else
+#endif /* MBEDTLS_USE_PSA_CRYPTO &&
+            MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED */
 #if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED)
     if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK )
     {