SSL asynchronous decryption (server side): tests

Test SSL asynchronous private operation for the case of a
decryption operation on a server.
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 2a4c833..2f3908d 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -108,6 +108,7 @@
 #define DFL_KEY_FILE            ""
 #define DFL_CRT_FILE2           ""
 #define DFL_KEY_FILE2           ""
+#define DFL_ASYNC_OPERATIONS    "-"
 #define DFL_ASYNC_PRIVATE_DELAY1 ( -1 )
 #define DFL_ASYNC_PRIVATE_DELAY2 ( -1 )
 #define DFL_ASYNC_PRIVATE_ERROR  ( 0 )
@@ -200,6 +201,7 @@
 
 #if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
 #define USAGE_SSL_ASYNC \
+    "    async_operations=%%c...   d=decrypt, s=sign (default: -=off)\n" \
     "    async_private_delay1=%%d  Asynchronous delay for key_file or preloaded key\n" \
     "    async_private_delay2=%%d  Asynchronous delay for key_file2\n" \
     "                              default: -1 (not asynchronous)\n" \
@@ -421,6 +423,7 @@
     const char *key_file;       /* the file with the server key             */
     const char *crt_file2;      /* the file with the 2nd server certificate */
     const char *key_file2;      /* the file with the 2nd server key         */
+    const char *async_operations; /* supported SSL asynchronous operations  */
     int async_private_delay1;   /* number of times f_async_resume needs to be called for key 1, or -1 for no async */
     int async_private_delay2;   /* number of times f_async_resume needs to be called for key 2, or -1 for no async */
     int async_private_error;    /* inject error in async private callback */
@@ -892,21 +895,23 @@
     ++ctx->slots_used;
 }
 
+#define SSL_ASYNC_INPUT_MAX_SIZE 512
 typedef struct
 {
     size_t slot;
     mbedtls_md_type_t md_alg;
-    unsigned char hash[MBEDTLS_MD_MAX_SIZE];
-    size_t hash_len;
+    unsigned char input[SSL_ASYNC_INPUT_MAX_SIZE];
+    size_t input_len;
     unsigned delay;
 } ssl_async_operation_context_t;
 
-int ssl_async_sign( void *connection_ctx_arg,
-                    void **p_operation_ctx,
-                    mbedtls_x509_crt *cert,
-                    mbedtls_md_type_t md_alg,
-                    const unsigned char *hash,
-                    size_t hash_len )
+static int ssl_async_start( void *connection_ctx_arg,
+                            void **p_operation_ctx,
+                            mbedtls_x509_crt *cert,
+                            const char *op_name,
+                            mbedtls_md_type_t md_alg,
+                            const unsigned char *input,
+                            size_t input_len )
 {
     ssl_async_key_context_t *key_ctx = connection_ctx_arg;
     size_t slot;
@@ -914,7 +919,7 @@
     {
         char dn[100];
         mbedtls_x509_dn_gets( dn, sizeof( dn ), &cert->subject );
-        mbedtls_printf( "Async sign callback: looking for DN=%s\n", dn );
+        mbedtls_printf( "Async %s callback: looking for DN=%s\n", op_name, dn );
     }
     for( slot = 0; slot < key_ctx->slots_used; slot++ )
     {
@@ -923,25 +928,26 @@
     }
     if( slot == key_ctx->slots_used )
     {
-        mbedtls_printf( "Async sign callback: no key matches this certificate.\n" );
+        mbedtls_printf( "Async %s callback: no key matches this certificate.\n",
+                        op_name );
         return( MBEDTLS_ERR_SSL_HW_ACCEL_FALLTHROUGH );
     }
-    mbedtls_printf( "Async sign callback: using key slot %zd, delay=%u.\n",
-                    slot, key_ctx->slots[slot].delay );
+    mbedtls_printf( "Async %s callback: using key slot %zd, delay=%u.\n",
+                    op_name, slot, key_ctx->slots[slot].delay );
     if( key_ctx->inject_error == SSL_ASYNC_INJECT_ERROR_START )
     {
-        mbedtls_printf( "Async sign callback: injected error\n" );
+        mbedtls_printf( "Async %s callback: injected error\n", op_name );
         return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
     }
-    if( hash_len > MBEDTLS_MD_MAX_SIZE )
+    if( input_len > SSL_ASYNC_INPUT_MAX_SIZE )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
     ctx = mbedtls_calloc( 1, sizeof( *ctx ) );
     if( ctx == NULL )
         return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
     ctx->slot = slot;
     ctx->md_alg = md_alg;
-    memcpy( ctx->hash, hash, hash_len );
-    ctx->hash_len = hash_len;
+    memcpy( ctx->input, input, input_len );
+    ctx->input_len = input_len;
     ctx->delay = key_ctx->slots[slot].delay;
     *p_operation_ctx = ctx;
     if( ctx->delay == 0 )
@@ -950,16 +956,40 @@
         return( MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
 }
 
-int ssl_async_resume( void *connection_ctx_arg,
-                      void *operation_ctx_arg,
-                      unsigned char *output,
-                      size_t *output_len,
-                      size_t output_size )
+static int ssl_async_sign( void *connection_ctx_arg,
+                           void **p_operation_ctx,
+                           mbedtls_x509_crt *cert,
+                           mbedtls_md_type_t md_alg,
+                           const unsigned char *hash,
+                           size_t hash_len )
+{
+    return( ssl_async_start( connection_ctx_arg, p_operation_ctx, cert,
+                             "sign", md_alg,
+                             hash, hash_len ) );
+}
+
+static int ssl_async_decrypt( void *connection_ctx_arg,
+                              void **p_operation_ctx,
+                              mbedtls_x509_crt *cert,
+                              const unsigned char *input,
+                              size_t input_len )
+{
+    return( ssl_async_start( connection_ctx_arg, p_operation_ctx, cert,
+                             "decrypt", MBEDTLS_MD_NONE,
+                             input, input_len ) );
+}
+
+static int ssl_async_resume( void *connection_ctx_arg,
+                             void *operation_ctx_arg,
+                             unsigned char *output,
+                             size_t *output_len,
+                             size_t output_size )
 {
     ssl_async_operation_context_t *ctx = operation_ctx_arg;
     ssl_async_key_context_t *connection_ctx = connection_ctx_arg;
     ssl_async_key_slot_t *key_slot = &connection_ctx->slots[ctx->slot];
     int ret;
+    const char *op_name;
     if( connection_ctx->inject_error == SSL_ASYNC_INJECT_ERROR_RESUME )
     {
         mbedtls_printf( "Async resume callback: injected error\n" );
@@ -972,25 +1002,37 @@
                         ctx->slot, ctx->delay );
         return( MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
     }
-    (void) output_size; /* mbedtls_pk_size lacks this parameter */
-    ret = mbedtls_pk_sign( key_slot->pk,
-                           ctx->md_alg,
-                           ctx->hash, ctx->hash_len,
-                           output, output_len,
-                           connection_ctx->f_rng, connection_ctx->p_rng );
+    if( ctx->md_alg == MBEDTLS_MD_NONE )
+    {
+        op_name = "decrypt";
+        ret = mbedtls_pk_decrypt( key_slot->pk,
+                                  ctx->input, ctx->input_len,
+                                  output, output_len, output_size,
+                                  connection_ctx->f_rng, connection_ctx->p_rng );
+    }
+    else
+    {
+        op_name = "sign";
+        ret = mbedtls_pk_sign( key_slot->pk,
+                               ctx->md_alg,
+                               ctx->input, ctx->input_len,
+                               output, output_len,
+                               connection_ctx->f_rng, connection_ctx->p_rng );
+    }
     if( connection_ctx->inject_error == SSL_ASYNC_INJECT_ERROR_PK )
     {
-        mbedtls_printf( "Async resume callback: done but injected error\n" );
+        mbedtls_printf( "Async resume callback: %s done but injected error\n",
+                        op_name );
         return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
     }
-    mbedtls_printf( "Async resume (slot %zd): done, status=%d.\n",
-                    ctx->slot, ret );
+    mbedtls_printf( "Async resume (slot %zd): %s done, status=%d.\n",
+                    ctx->slot, op_name, ret );
     mbedtls_free( ctx );
     return( ret );
 }
 
-void ssl_async_cancel( void *connection_ctx_arg,
-                       void *operation_ctx_arg )
+static void ssl_async_cancel( void *connection_ctx_arg,
+                              void *operation_ctx_arg )
 {
     ssl_async_operation_context_t *ctx = operation_ctx_arg;
     (void) connection_ctx_arg;
@@ -1142,6 +1184,7 @@
     opt.key_file            = DFL_KEY_FILE;
     opt.crt_file2           = DFL_CRT_FILE2;
     opt.key_file2           = DFL_KEY_FILE2;
+    opt.async_operations    = DFL_ASYNC_OPERATIONS;
     opt.async_private_delay1 = DFL_ASYNC_PRIVATE_DELAY1;
     opt.async_private_delay2 = DFL_ASYNC_PRIVATE_DELAY2;
     opt.async_private_error = DFL_ASYNC_PRIVATE_ERROR;
@@ -1232,6 +1275,8 @@
         else if( strcmp( p, "dhm_file" ) == 0 )
             opt.dhm_file = q;
 #if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+        else if( strcmp( p, "async_operations" ) == 0 )
+            opt.async_operations = q;
         else if( strcmp( p, "async_private_delay1" ) == 0 )
             opt.async_private_delay1 = atoi( q );
         else if( strcmp( p, "async_private_delay2" ) == 0 )
@@ -2152,16 +2197,31 @@
     }
 
 #if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
-    if( opt.async_private_delay1 >= 0 || opt.async_private_delay2 >= 0 )
+    if( opt.async_operations[0] != '-' )
     {
+        mbedtls_ssl_async_sign_t *sign = NULL;
+        mbedtls_ssl_async_decrypt_t *decrypt = NULL;
+        const char *p;
+        for( p = opt.async_operations; *p; p++ )
+        {
+            switch( *p )
+            {
+            case 'd':
+                decrypt = ssl_async_decrypt;
+                break;
+            case 's':
+                sign = ssl_async_sign;
+                break;
+            }
+        }
         ssl_async_keys.inject_error = ( opt.async_private_error < 0 ?
                                         - opt.async_private_error :
                                         opt.async_private_error );
         ssl_async_keys.f_rng = mbedtls_ctr_drbg_random;
         ssl_async_keys.p_rng = &ctr_drbg;
         mbedtls_ssl_conf_async_private_cb( &conf,
-                                           ssl_async_sign,
-                                           NULL,
+                                           sign,
+                                           decrypt,
                                            ssl_async_resume,
                                            ssl_async_cancel,
                                            &ssl_async_keys );