Add ssl_set_hs_own_cert()
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index 695b233..069c60e 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -1707,6 +1707,22 @@
 
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
 /**
+ * \brief          Set own certificate and key for the current handshake
+ *
+ * \note           Same as \c mbedtls_ssl_set_own_cert() but for use within
+ *                 the SNI callback.
+ *
+ * \param ssl      SSL context
+ * \param own_cert own public certificate chain
+ * \param pk_key   own private key
+ *
+ * \return         0 on success or MBEDTLS_ERR_SSL_MALLOC_FAILED
+ */
+int mbedtls_ssl_set_hs_own_cert( mbedtls_ssl_context *ssl,
+                                 mbedtls_x509_crt *own_cert,
+                                 mbedtls_pk_context *pk_key );
+
+/**
  * \brief          Set server side ServerName TLS extension callback
  *                 (optional, server-side only).
  *
@@ -1716,8 +1732,8 @@
  *                 following parameters: (void *parameter, mbedtls_ssl_context *ssl,
  *                 const unsigned char *hostname, size_t len). If a suitable
  *                 certificate is found, the callback should set the
- *                 certificate and key to use with mbedtls_ssl_set_own_cert() (and
- *                 possibly adjust the CA chain as well) and return 0. The
+ *                 certificate and key to use with mbedtls_ssl_set_hs_own_cert() (and
+ *                 possibly adjust the CA chain as well TODO: broken) and return 0. The
  *                 callback should return -1 to abort the handshake at this
  *                 point.
  *
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index 0ff3c18..ea015ae 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -389,25 +389,6 @@
 #endif /* MBEDTLS_SSL_DTLS_HELLO_VERIFY */
 
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
-/*
- * Wrapper around f_sni, allowing use of mbedtls_ssl_set_own_cert() but
- * making it act on ssl->handshake->sni_key_cert instead.
- */
-static int ssl_sni_wrapper( mbedtls_ssl_context *ssl,
-                            const unsigned char* name, size_t len )
-{
-    int ret;
-    mbedtls_ssl_key_cert *key_cert_ori = ssl->conf->key_cert;
-
-    ssl->conf->key_cert = NULL;
-    ret = ssl->conf->f_sni( ssl->conf->p_sni, ssl, name, len );
-    ssl->handshake->sni_key_cert = ssl->conf->key_cert;
-
-    ssl->conf->key_cert = key_cert_ori;
-
-    return( ret );
-}
-
 static int ssl_parse_servername_ext( mbedtls_ssl_context *ssl,
                                      const unsigned char *buf,
                                      size_t len )
@@ -437,7 +418,8 @@
 
         if( p[0] == MBEDTLS_TLS_EXT_SERVERNAME_HOSTNAME )
         {
-            ret = ssl_sni_wrapper( ssl, p + 3, hostname_len );
+            ret = ssl->conf->f_sni( ssl->conf->p_sni,
+                                    ssl, p + 3, hostname_len );
             if( ret != 0 )
             {
                 MBEDTLS_SSL_DEBUG_RET( 1, "ssl_sni_wrapper", ret );
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 1a75def..6f6e74e 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -5352,6 +5352,16 @@
 }
 #endif /* MBEDTLS_X509_CRT_PARSE_C */
 
+#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
+int mbedtls_ssl_set_hs_own_cert( mbedtls_ssl_context *ssl,
+                                 mbedtls_x509_crt *own_cert,
+                                 mbedtls_pk_context *pk_key )
+{
+    return( ssl_append_key_cert( &ssl->handshake->sni_key_cert,
+                                 own_cert, pk_key ) );
+}
+#endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION */
+
 #if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
 int mbedtls_ssl_set_psk( mbedtls_ssl_config *conf,
                 const unsigned char *psk, size_t psk_len,
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index f4c206e..8cc3ac1 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -547,7 +547,7 @@
         if( name_len == strlen( cur->name ) &&
             memcmp( name, cur->name, name_len ) == 0 )
         {
-            return( mbedtls_ssl_set_own_cert( ssl, cur->cert, cur->key ) );
+            return( mbedtls_ssl_set_hs_own_cert( ssl, cur->cert, cur->key ) );
         }
 
         cur = cur->next;