Update hrr test cases generation code without change class

Change-Id: I38f620213bf5349d33ecad080538294633f85566
Signed-off-by: XiaokangQian <xiaokang.qian@arm.com>
diff --git a/tests/scripts/generate_tls13_compat_tests.py b/tests/scripts/generate_tls13_compat_tests.py
index eea8462..6399345 100755
--- a/tests/scripts/generate_tls13_compat_tests.py
+++ b/tests/scripts/generate_tls13_compat_tests.py
@@ -70,26 +70,17 @@
     'x448': 0x1e,
 }
 
-HRR_CIPHER_SUITE_VALUE = {
-    "TLS_AES_256_GCM_SHA384": 0x1302,
-}
-
-HRR_SIG_ALG_VALUE = {
-    "ecdsa_secp384r1_sha384": 0x0503,
-}
 
 class TLSProgram(metaclass=abc.ABCMeta):
     """
     Base class for generate server/client command.
     """
     # pylint: disable=too-many-arguments
-    def __init__(self, ciphersuite=None, signature_algorithm=None, named_group=None, peer_named_group=None,
-                 is_hrr=False, cert_sig_alg=None, compat_mode=True):
+    def __init__(self, ciphersuite=None, signature_algorithm=None, named_group=None,
+                 cert_sig_alg=None, compat_mode=True):
         self._ciphers = []
         self._sig_algs = []
         self._named_groups = []
-        self._peer_named_group = peer_named_group
-        self._is_hrr = is_hrr
         self._cert_sig_algs = []
         if ciphersuite:
             self.add_ciphersuites(ciphersuite)
@@ -155,16 +146,15 @@
             ret += ['-cert {cert} -key {key}'.format(cert=cert, key=key)]
         ret += ['-accept $SRV_PORT']
 
-        if not self._is_hrr:
-            if self._ciphers:
-                ciphersuites = ':'.join(self._ciphers)
-                ret += ["-ciphersuites {ciphersuites}".format(ciphersuites=ciphersuites)]
+        if self._ciphers:
+            ciphersuites = ':'.join(self._ciphers)
+            ret += ["-ciphersuites {ciphersuites}".format(ciphersuites=ciphersuites)]
 
-            if self._sig_algs:
-                signature_algorithms = set(self._sig_algs + self._cert_sig_algs)
-                signature_algorithms = ':'.join(signature_algorithms)
-                ret += ["-sigalgs {signature_algorithms}".format(
-                    signature_algorithms=signature_algorithms)]
+        if self._sig_algs:
+            signature_algorithms = set(self._sig_algs + self._cert_sig_algs)
+            signature_algorithms = ':'.join(signature_algorithms)
+            ret += ["-sigalgs {signature_algorithms}".format(
+                signature_algorithms=signature_algorithms)]
 
         if self._named_groups:
             named_groups = ':'.join(
@@ -250,21 +240,18 @@
                     if i not in priority_string_list:
                         yield i
 
-        if self._is_hrr:
-            priority_string_list.extend(['CIPHER-ALL', 'SIGN-ALL', 'MAC-ALL'])
+        if self._ciphers:
+            priority_string_list.extend(update_priority_string_list(
+                self._ciphers, self.CIPHER_SUITE))
         else:
-            if self._ciphers:
-                priority_string_list.extend(update_priority_string_list(
-                    self._ciphers, self.CIPHER_SUITE))
-            else:
-                priority_string_list.append('CIPHER-ALL')
+            priority_string_list.append('CIPHER-ALL')
 
-            if self._sig_algs:
-                signature_algorithms = set(self._sig_algs + self._cert_sig_algs)
-                priority_string_list.extend(update_priority_string_list(
-                    signature_algorithms, self.SIGNATURE_ALGORITHM))
-            else:
-                priority_string_list.append('SIGN-ALL')
+        if self._sig_algs:
+            signature_algorithms = set(self._sig_algs + self._cert_sig_algs)
+            priority_string_list.extend(update_priority_string_list(
+                signature_algorithms, self.SIGNATURE_ALGORITHM))
+        else:
+            priority_string_list.extend(['SIGN-ALL','MAC-ALL'])
 
 
         if self._named_groups:
@@ -308,28 +295,16 @@
         ret += ['ca_file={cafile}'.format(
             cafile=CERTIFICATES[self._cert_sig_algs[0]].cafile)]
 
-        if not self._is_hrr:
-            if self._ciphers:
-                ciphers = ','.join(
-                    map(lambda cipher: self.CIPHER_SUITE[cipher], self._ciphers))
-                ret += ["force_ciphersuite={ciphers}".format(ciphers=ciphers)]
+        if self._ciphers:
+            ciphers = ','.join(
+                map(lambda cipher: self.CIPHER_SUITE[cipher], self._ciphers))
+            ret += ["force_ciphersuite={ciphers}".format(ciphers=ciphers)]
 
-            if self._sig_algs + self._cert_sig_algs:
-                ret += ['sig_algs={sig_algs}'.format(
-                    sig_algs=','.join(set(self._sig_algs + self._cert_sig_algs)))]
+        if self._sig_algs + self._cert_sig_algs:
+            ret += ['sig_algs={sig_algs}'.format(
+                sig_algs=','.join(set(self._sig_algs + self._cert_sig_algs)))]
 
         if self._named_groups:
-            if self._is_hrr:
-                # pylint: disable=pointless-string-statement
-                """
-                TODO: Use _cert_sig_algs to select EC groups in certificate verification.
-                """
-                for sig_alg in self._sig_algs:
-                    if sig_alg in ('ecdsa_secp256r1_sha256',
-                                   'ecdsa_secp384r1_sha384',
-                                   'ecdsa_secp521r1_sha512'):
-                        self.add_named_groups(sig_alg.split('_')[1])
-                self.add_named_groups(self._peer_named_group)
             named_groups = ','.join(self._named_groups)
             ret += ["curves={named_groups}".format(named_groups=named_groups)]
 
@@ -369,14 +344,6 @@
         check_strings.append("Verifying peer X.509 certificate... ok")
         return ['-c "{}"'.format(i) for i in check_strings]
 
-    # pylint: disable=C0330
-    def post_hrr_checks(self):
-        check_strings = ["NamedGroup: {group}".format(group=self._named_groups[0]),
-                         "NamedGroup: {group}".format(group=self._named_groups[-1]),
-                        "<= ssl_tls13_process_server_hello ( HelloRetryRequest )",
-                         "Verifying peer X.509 certificate... ok", ]
-        return ['-c "{}"'.format(i) for i in check_strings]
-
 
 SERVER_CLASSES = {'OpenSSL': OpenSSLServ, 'GnuTLS': GnuTLSServ}
 CLIENT_CLASSES = {'mbedTLS': MbedTLSCli}
@@ -407,27 +374,39 @@
     cmd = prefix.join(cmd)
     return '\n'.join(server_object.pre_checks() + client_object.pre_checks() + [cmd])
 
-# pylint: disable=too-many-arguments,C0330
-def generate_compat_hrr_test(server=None, client=None, cipher=None, sig_alg=None,
-                             client_named_group=None, server_named_group=None):
+
+def generate_hrr_compat_test(server=None, client=None, cipher=None, sig_alg=None,
+                             server_named_group=None):
     """
     Generate Hello Retry Request test case with `ssl-opt.sh` format.
     """
+    # Get a named_group for client side which does not equal input named_group
+    client_named_group = list(sorted(set(NAMED_GROUP_IANA_VALUE.keys() - {server_named_group})))[0]
     name = 'TLS 1.3 {client[0]}->{server[0]}: HRR {c_named_group} -> {s_named_group}'.format(
             client=client, server=server, c_named_group=client_named_group,
             s_named_group=server_named_group)
-    server_object = SERVER_CLASSES[server](cipher, sig_alg, server_named_group,
-                                           client_named_group, True, sig_alg)
-    client_object = CLIENT_CLASSES[client](cipher, sig_alg, client_named_group,
-                                           server_named_group, True, sig_alg)
+    server_object = SERVER_CLASSES[server](ciphersuite=cipher,
+                                           named_group=server_named_group,
+                                           cert_sig_alg=sig_alg)
+
+    client_object = CLIENT_CLASSES[client](ciphersuite=cipher,
+                                           named_group=client_named_group,
+                                           cert_sig_alg=sig_alg)
+    # after here, the named_group order is client_named_group, named_group.
+    client_object.add_named_groups(server_named_group)
 
     cmd = ['run_test "{}"'.format(name), '"{}"'.format(
         server_object.cmd()), '"{}"'.format(client_object.cmd()), '0']
     cmd += server_object.post_checks()
-    cmd += client_object.post_hrr_checks()
+    cmd += client_object.post_checks()
+    cmd += ['-c "received HelloRetryRequest message"']
+    cmd += ['-c "selected_group ( {:d} )"'.format(
+        NAMED_GROUP_IANA_VALUE[server_named_group])]
     prefix = ' \\\n' + (' '*9)
     cmd = prefix.join(cmd)
-    return '\n'.join(server_object.pre_checks() + client_object.pre_checks() + [cmd])
+    return '\n'.join(server_object.pre_checks() +
+                     client_object.pre_checks() +
+                     [cmd])
 
 
 SSL_OUTPUT_HEADER = '''#!/bin/sh
@@ -452,15 +431,19 @@
 # Purpose
 #
 # List TLS1.3 compat test cases. They are generated by
-# `generate_tls13_compat_tests.py {parameter} -o {filename}`.
+# `{cmd}`.
 #
 # PLEASE DO NOT EDIT THIS FILE. IF NEEDED, PLEASE MODIFY `generate_tls13_compat_tests.py`
 # AND REGENERATE THIS FILE.
 #
 '''
 
+def cycle_zip(*iters, max_len=None):
+    max_len = max_len or max([len(i) for i in iters])
+    cycle_iters = zip(*[itertools.cycle(i) for i in iters])
+    for _, c in zip(range(max_len), cycle_iters):
+        yield c
 
-# pylint: disable=too-many-branches
 def main():
     """
     Main function of this program
@@ -516,21 +499,23 @@
                               CLIENT_CLASSES.keys()):
             yield generate_compat_test(cipher=cipher, sig_alg=sig_alg, named_group=named_group,
                                        server=server, client=client)
-        for cipher, sig_alg, client_named_group, server_named_group, server, client in \
-            itertools.product(HRR_CIPHER_SUITE_VALUE.keys(), HRR_SIG_ALG_VALUE.keys(),
-                              NAMED_GROUP_IANA_VALUE.keys(), NAMED_GROUP_IANA_VALUE.keys(),
-                              SERVER_CLASSES.keys(), CLIENT_CLASSES.keys()):
-            if client_named_group != server_named_group:
-                yield generate_compat_hrr_test(cipher=cipher, sig_alg=sig_alg,
-                                               client_named_group=client_named_group,
-                                               server_named_group=server_named_group,
-                                               server=server, client=client)
+
+        # Generate Hello Retry Request  compat test cases
+        for combine_fields, server, client in \
+            itertools.product(list(cycle_zip(CIPHER_SUITE_IANA_VALUE.keys(),
+                                             SIG_ALG_IANA_VALUE.keys(),
+                                             NAMED_GROUP_IANA_VALUE.keys())),
+                              SERVER_CLASSES.keys(),
+                              CLIENT_CLASSES.keys()):
+            cipher, sig_alg, named_group = combine_fields
+            yield generate_hrr_compat_test( server_named_group=named_group, sig_alg=sig_alg,
+                                           server=server, client=client)
 
     if args.generate_all_tls13_compat_tests:
         if args.output:
             with open(args.output, 'w', encoding="utf-8") as f:
                 f.write(SSL_OUTPUT_HEADER.format(
-                    filename=os.path.basename(args.output), parameter='-a'))
+                    filename=os.path.basename(args.output), cmd=' '.join(sys.argv)))
                 f.write('\n\n'.join(get_all_test_cases()))
                 f.write('\n')
         else:
@@ -551,10 +536,8 @@
             print(*CLIENT_CLASSES.keys())
         return 0
 
-    if args.generate_all_tls13_compat_tests:
-        print(generate_compat_test(server=args.server, client=args.client, sig_alg=args.sig_alg,
-                                   cipher=args.cipher, named_group=args.named_group))
-
+    print(generate_compat_test(server=args.server, client=args.client, sig_alg=args.sig_alg,
+                               cipher=args.cipher, named_group=args.named_group))
     return 0