Add cert_sig_algs for compat generate script

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/tests/scripts/generate_tls13_compat_tests.py b/tests/scripts/generate_tls13_compat_tests.py
index 2e6ff72..445c702 100755
--- a/tests/scripts/generate_tls13_compat_tests.py
+++ b/tests/scripts/generate_tls13_compat_tests.py
@@ -28,7 +28,6 @@
 import argparse
 import itertools
 from collections import namedtuple
-# pylint: disable=useless-super-delegation
 
 # define certificates configuration entry
 Certificate = namedtuple("Certificate", ['cafile', 'certfile', 'keyfile'])
@@ -71,18 +70,26 @@
     'x448': 0x1e,
 }
 
+
 class TLSProgram(metaclass=abc.ABCMeta):
     """
     Base class for generate server/client command.
     """
-
-    def __init__(self, ciphersuite, signature_algorithm, named_group, compat_mode=True):
+    # pylint: disable=too-many-arguments
+    def __init__(self, ciphersuite=None, signature_algorithm=None, named_group=None,
+                 cert_sig_alg=None, compat_mode=True):
         self._ciphers = []
         self._sig_algs = []
         self._named_groups = []
-        self.add_ciphersuites(ciphersuite)
-        self.add_named_groups(named_group)
-        self.add_signature_algorithms(signature_algorithm)
+        self._cert_sig_algs = []
+        if ciphersuite:
+            self.add_ciphersuites(ciphersuite)
+        if named_group:
+            self.add_named_groups(named_group)
+        if signature_algorithm:
+            self.add_signature_algorithms(signature_algorithm)
+        if cert_sig_alg:
+            self.add_cert_signature_algorithms(cert_sig_alg)
         self._compat_mode = compat_mode
 
     # add_ciphersuites should not override by sub class
@@ -95,18 +102,24 @@
         self._sig_algs.extend(
             [sig_alg for sig_alg in signature_algorithms if sig_alg not in self._sig_algs])
 
-    # add_signature_algorithms should not override by sub class
+    # add_named_groups should not override by sub class
     def add_named_groups(self, *named_groups):
         self._named_groups.extend(
             [named_group for named_group in named_groups if named_group not in self._named_groups])
 
+    # add_cert_signature_algorithms should not override by sub class
+    def add_cert_signature_algorithms(self, *signature_algorithms):
+        self._cert_sig_algs.extend(
+            [sig_alg for sig_alg in signature_algorithms if sig_alg not in self._cert_sig_algs])
+
     @abc.abstractmethod
     def pre_checks(self):
         return []
 
     @abc.abstractmethod
     def cmd(self):
-        pass
+        if not self._cert_sig_algs:
+            self._cert_sig_algs = list(CERTIFICATES.keys())
 
     @abc.abstractmethod
     def post_checks(self):
@@ -127,18 +140,26 @@
     }
 
     def cmd(self):
+        super().cmd()
         ret = ['$O_NEXT_SRV_NO_CERT']
-        for _, cert, key in map(lambda sig_alg: CERTIFICATES[sig_alg], self._sig_algs):
+        for _, cert, key in map(lambda sig_alg: CERTIFICATES[sig_alg], self._cert_sig_algs):
             ret += ['-cert {cert} -key {key}'.format(cert=cert, key=key)]
         ret += ['-accept $SRV_PORT']
-        ciphersuites = ','.join(self._ciphers)
-        signature_algorithms = ','.join(self._sig_algs)
-        named_groups = ','.join(
-            map(lambda named_group: self.NAMED_GROUP[named_group], self._named_groups))
-        ret += ["-ciphersuites {ciphersuites}".format(ciphersuites=ciphersuites),
-                "-sigalgs {signature_algorithms}".format(
-                    signature_algorithms=signature_algorithms),
-                "-groups {named_groups}".format(named_groups=named_groups)]
+
+        if self._ciphers:
+            ciphersuites = ':'.join(self._ciphers)
+            ret += ["-ciphersuites {ciphersuites}".format(ciphersuites=ciphersuites)]
+
+        if self._sig_algs:
+            signature_algorithms = ':'.join(self._sig_algs)
+            ret += ["-sigalgs {signature_algorithms}".format(
+                signature_algorithms=signature_algorithms)]
+
+        if self._named_groups:
+            named_groups = ':'.join(
+                map(lambda named_group: self.NAMED_GROUP[named_group], self._named_groups))
+            ret += ["-groups {named_groups}".format(named_groups=named_groups)]
+
         ret += ['-msg -tls1_3 -num_tickets 0 -no_resume_ephemeral -no_cache']
         if not self._compat_mode:
             ret += ['-no_middlebox']
@@ -202,10 +223,11 @@
         return ['-c "HTTP/1.0 200 OK"']
 
     def cmd(self):
+        super().cmd()
         ret = ['$G_NEXT_SRV_NO_CERT', '--http',
                '--disable-client-cert', '--debug=4']
 
-        for _, cert, key in map(lambda sig_alg: CERTIFICATES[sig_alg], self._sig_algs):
+        for _, cert, key in map(lambda sig_alg: CERTIFICATES[sig_alg], self._cert_sig_algs):
             ret += ['--x509certfile {cert} --x509keyfile {key}'.format(
                 cert=cert, key=key)]
 
@@ -216,16 +238,32 @@
                 for i in map_table[item]:
                     if i not in priority_string_list:
                         yield i
-        priority_string_list.extend(update_priority_string_list(
-            self._sig_algs, self.SIGNATURE_ALGORITHM))
-        priority_string_list.extend(
-            update_priority_string_list(self._ciphers, self.CIPHER_SUITE))
-        priority_string_list.extend(update_priority_string_list(
-            self._named_groups, self.NAMED_GROUP))
-        priority_string_list = ['NONE'] + sorted(priority_string_list) + ['VERS-TLS1.3']
+
+        if self._ciphers:
+            priority_string_list.extend(update_priority_string_list(
+                self._ciphers, self.CIPHER_SUITE))
+        else:
+            priority_string_list.append('CIPHER-ALL')
+
+        if self._sig_algs:
+            priority_string_list.extend(update_priority_string_list(
+                self._sig_algs, self.SIGNATURE_ALGORITHM))
+        else:
+            priority_string_list.append('SIGN-ALL')
+
+
+        if self._named_groups:
+            priority_string_list.extend(update_priority_string_list(
+                self._named_groups, self.NAMED_GROUP))
+        else:
+            priority_string_list.append('GROUP-ALL')
+
+        priority_string_list = ['NONE'] + \
+            sorted(priority_string_list) + ['VERS-TLS1.3']
 
         priority_string = ':+'.join(priority_string_list)
         priority_string += ':%NO_TICKETS'
+
         if not self._compat_mode:
             priority_string += [':%DISABLE_TLS13_COMPAT_MODE']
 
@@ -248,11 +286,12 @@
         'TLS_AES_128_CCM_8_SHA256': 'TLS1-3-AES-128-CCM-8-SHA256'}
 
     def cmd(self):
+        super().cmd()
         ret = ['$P_CLI']
         ret += ['server_addr=127.0.0.1', 'server_port=$SRV_PORT',
                 'debug_level=4', 'force_version=tls13']
         ret += ['ca_file={cafile}'.format(
-            cafile=CERTIFICATES[self._sig_algs[0]].cafile)]
+            cafile=CERTIFICATES[self._cert_sig_algs[0]].cafile)]
 
         if self._ciphers:
             ciphers = ','.join(
@@ -262,11 +301,6 @@
         if self._sig_algs:
             ret += ['sig_algs={sig_algs}'.format(
                 sig_algs=','.join(self._sig_algs))]
-            for sig_alg in self._sig_algs:
-                if sig_alg in ('ecdsa_secp256r1_sha256',
-                               'ecdsa_secp384r1_sha384',
-                               'ecdsa_secp521r1_sha512'):
-                    self.add_named_groups(sig_alg.split('_')[1])
 
         if self._named_groups:
             named_groups = ','.join(self._named_groups)
@@ -283,19 +317,29 @@
         if self._compat_mode:
             ret += ['requires_config_enabled MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE']
 
-        if 'rsa_pss_rsae_sha256' in self._sig_algs:
+        if 'rsa_pss_rsae_sha256' in self._sig_algs + self._cert_sig_algs:
             ret.append(
                 'requires_config_enabled MBEDTLS_X509_RSASSA_PSS_SUPPORT')
         return ret
 
     def post_checks(self):
-        check_strings = ["ECDH curve: {group}".format(group=self._named_groups[0]),
-                         "server hello, chosen ciphersuite: ( {:04x} ) - {}".format(
-                             CIPHER_SUITE_IANA_VALUE[self._ciphers[0]],
-                             self.CIPHER_SUITE[self._ciphers[0]]),
-                         "Certificate Verify: Signature algorithm ( {:04x} )".format(
-                             SIG_ALG_IANA_VALUE[self._sig_algs[0]]),
-                         "Verifying peer X.509 certificate... ok", ]
+        check_strings = []
+        if self._ciphers:
+            check_strings.append(
+                "server hello, chosen ciphersuite: ( {:04x} ) - {}".format(
+                    CIPHER_SUITE_IANA_VALUE[self._ciphers[0]],
+                    self.CIPHER_SUITE[self._ciphers[0]]))
+        if self._sig_algs:
+            check_strings.append(
+                "Certificate Verify: Signature algorithm ( {:04x} )".format(
+                    SIG_ALG_IANA_VALUE[self._sig_algs[0]]))
+
+        for named_group in self._named_groups:
+            check_strings += ['NamedGroup: {named_group} ( {iana_value:x} )'.format(
+                                named_group=named_group,
+                                iana_value=NAMED_GROUP_IANA_VALUE[named_group])]
+
+        check_strings.append("Verifying peer X.509 certificate... ok")
         return ['-c "{}"'.format(i) for i in check_strings]
 
 
@@ -309,13 +353,21 @@
     """
     name = 'TLS 1.3 {client[0]}->{server[0]}: {cipher},{named_group},{sig_alg}'.format(
         client=client, server=server, cipher=cipher, sig_alg=sig_alg, named_group=named_group)
-    server_object = SERVER_CLASSES[server](cipher, sig_alg, named_group)
-    client_object = CLIENT_CLASSES[client](cipher, sig_alg, named_group)
+
+    server_object = SERVER_CLASSES[server](ciphersuite=cipher,
+                                           named_group=named_group,
+                                           signature_algorithm=sig_alg,
+                                           cert_sig_alg=sig_alg)
+    client_object = CLIENT_CLASSES[client](ciphersuite=cipher,
+                                           named_group=named_group,
+                                           signature_algorithm=sig_alg,
+                                           cert_sig_alg=sig_alg)
 
     cmd = ['run_test "{}"'.format(name), '"{}"'.format(
         server_object.cmd()), '"{}"'.format(client_object.cmd()), '0']
     cmd += server_object.post_checks()
     cmd += client_object.post_checks()
+    cmd += ['-C "received HelloRetryRequest message"']
     prefix = ' \\\n' + (' '*9)
     cmd = prefix.join(cmd)
     return '\n'.join(server_object.pre_checks() + client_object.pre_checks() + [cmd])
@@ -343,7 +395,7 @@
 # Purpose
 #
 # List TLS1.3 compat test cases. They are generated by
-# `generate_tls13_compat_tests.py -a`.
+# `{cmd}`.
 #
 # PLEASE DO NOT EDIT THIS FILE. IF NEEDED, PLEASE MODIFY `generate_tls13_compat_tests.py`
 # AND REGENERATE THIS FILE.
@@ -397,22 +449,26 @@
     args = parser.parse_args()
 
     def get_all_test_cases():
+        # Generate normal compat test cases
         for cipher, sig_alg, named_group, server, client in \
-            itertools.product(CIPHER_SUITE_IANA_VALUE.keys(), SIG_ALG_IANA_VALUE.keys(),
-                              NAMED_GROUP_IANA_VALUE.keys(), SERVER_CLASSES.keys(),
+            itertools.product(CIPHER_SUITE_IANA_VALUE.keys(),
+                              SIG_ALG_IANA_VALUE.keys(),
+                              NAMED_GROUP_IANA_VALUE.keys(),
+                              SERVER_CLASSES.keys(),
                               CLIENT_CLASSES.keys()):
             yield generate_compat_test(cipher=cipher, sig_alg=sig_alg, named_group=named_group,
                                        server=server, client=client)
 
+
     if args.generate_all_tls13_compat_tests:
         if args.output:
             with open(args.output, 'w', encoding="utf-8") as f:
                 f.write(SSL_OUTPUT_HEADER.format(
-                    filename=os.path.basename(args.output)))
+                    filename=os.path.basename(args.output), cmd=' '.join(sys.argv)))
                 f.write('\n\n'.join(get_all_test_cases()))
                 f.write('\n')
         else:
-            print('\n'.join(get_all_test_cases()))
+            print('\n\n'.join(get_all_test_cases()))
         return 0
 
     if args.list_ciphers or args.list_sig_algs or args.list_named_groups \