Update hrr test cases generation code without change class
Change-Id: I38f620213bf5349d33ecad080538294633f85566
Signed-off-by: XiaokangQian <xiaokang.qian@arm.com>
diff --git a/tests/scripts/generate_tls13_compat_tests.py b/tests/scripts/generate_tls13_compat_tests.py
index eea8462..6399345 100755
--- a/tests/scripts/generate_tls13_compat_tests.py
+++ b/tests/scripts/generate_tls13_compat_tests.py
@@ -70,26 +70,17 @@
'x448': 0x1e,
}
-HRR_CIPHER_SUITE_VALUE = {
- "TLS_AES_256_GCM_SHA384": 0x1302,
-}
-
-HRR_SIG_ALG_VALUE = {
- "ecdsa_secp384r1_sha384": 0x0503,
-}
class TLSProgram(metaclass=abc.ABCMeta):
"""
Base class for generate server/client command.
"""
# pylint: disable=too-many-arguments
- def __init__(self, ciphersuite=None, signature_algorithm=None, named_group=None, peer_named_group=None,
- is_hrr=False, cert_sig_alg=None, compat_mode=True):
+ def __init__(self, ciphersuite=None, signature_algorithm=None, named_group=None,
+ cert_sig_alg=None, compat_mode=True):
self._ciphers = []
self._sig_algs = []
self._named_groups = []
- self._peer_named_group = peer_named_group
- self._is_hrr = is_hrr
self._cert_sig_algs = []
if ciphersuite:
self.add_ciphersuites(ciphersuite)
@@ -155,16 +146,15 @@
ret += ['-cert {cert} -key {key}'.format(cert=cert, key=key)]
ret += ['-accept $SRV_PORT']
- if not self._is_hrr:
- if self._ciphers:
- ciphersuites = ':'.join(self._ciphers)
- ret += ["-ciphersuites {ciphersuites}".format(ciphersuites=ciphersuites)]
+ if self._ciphers:
+ ciphersuites = ':'.join(self._ciphers)
+ ret += ["-ciphersuites {ciphersuites}".format(ciphersuites=ciphersuites)]
- if self._sig_algs:
- signature_algorithms = set(self._sig_algs + self._cert_sig_algs)
- signature_algorithms = ':'.join(signature_algorithms)
- ret += ["-sigalgs {signature_algorithms}".format(
- signature_algorithms=signature_algorithms)]
+ if self._sig_algs:
+ signature_algorithms = set(self._sig_algs + self._cert_sig_algs)
+ signature_algorithms = ':'.join(signature_algorithms)
+ ret += ["-sigalgs {signature_algorithms}".format(
+ signature_algorithms=signature_algorithms)]
if self._named_groups:
named_groups = ':'.join(
@@ -250,21 +240,18 @@
if i not in priority_string_list:
yield i
- if self._is_hrr:
- priority_string_list.extend(['CIPHER-ALL', 'SIGN-ALL', 'MAC-ALL'])
+ if self._ciphers:
+ priority_string_list.extend(update_priority_string_list(
+ self._ciphers, self.CIPHER_SUITE))
else:
- if self._ciphers:
- priority_string_list.extend(update_priority_string_list(
- self._ciphers, self.CIPHER_SUITE))
- else:
- priority_string_list.append('CIPHER-ALL')
+ priority_string_list.append('CIPHER-ALL')
- if self._sig_algs:
- signature_algorithms = set(self._sig_algs + self._cert_sig_algs)
- priority_string_list.extend(update_priority_string_list(
- signature_algorithms, self.SIGNATURE_ALGORITHM))
- else:
- priority_string_list.append('SIGN-ALL')
+ if self._sig_algs:
+ signature_algorithms = set(self._sig_algs + self._cert_sig_algs)
+ priority_string_list.extend(update_priority_string_list(
+ signature_algorithms, self.SIGNATURE_ALGORITHM))
+ else:
+ priority_string_list.extend(['SIGN-ALL','MAC-ALL'])
if self._named_groups:
@@ -308,28 +295,16 @@
ret += ['ca_file={cafile}'.format(
cafile=CERTIFICATES[self._cert_sig_algs[0]].cafile)]
- if not self._is_hrr:
- if self._ciphers:
- ciphers = ','.join(
- map(lambda cipher: self.CIPHER_SUITE[cipher], self._ciphers))
- ret += ["force_ciphersuite={ciphers}".format(ciphers=ciphers)]
+ if self._ciphers:
+ ciphers = ','.join(
+ map(lambda cipher: self.CIPHER_SUITE[cipher], self._ciphers))
+ ret += ["force_ciphersuite={ciphers}".format(ciphers=ciphers)]
- if self._sig_algs + self._cert_sig_algs:
- ret += ['sig_algs={sig_algs}'.format(
- sig_algs=','.join(set(self._sig_algs + self._cert_sig_algs)))]
+ if self._sig_algs + self._cert_sig_algs:
+ ret += ['sig_algs={sig_algs}'.format(
+ sig_algs=','.join(set(self._sig_algs + self._cert_sig_algs)))]
if self._named_groups:
- if self._is_hrr:
- # pylint: disable=pointless-string-statement
- """
- TODO: Use _cert_sig_algs to select EC groups in certificate verification.
- """
- for sig_alg in self._sig_algs:
- if sig_alg in ('ecdsa_secp256r1_sha256',
- 'ecdsa_secp384r1_sha384',
- 'ecdsa_secp521r1_sha512'):
- self.add_named_groups(sig_alg.split('_')[1])
- self.add_named_groups(self._peer_named_group)
named_groups = ','.join(self._named_groups)
ret += ["curves={named_groups}".format(named_groups=named_groups)]
@@ -369,14 +344,6 @@
check_strings.append("Verifying peer X.509 certificate... ok")
return ['-c "{}"'.format(i) for i in check_strings]
- # pylint: disable=C0330
- def post_hrr_checks(self):
- check_strings = ["NamedGroup: {group}".format(group=self._named_groups[0]),
- "NamedGroup: {group}".format(group=self._named_groups[-1]),
- "<= ssl_tls13_process_server_hello ( HelloRetryRequest )",
- "Verifying peer X.509 certificate... ok", ]
- return ['-c "{}"'.format(i) for i in check_strings]
-
SERVER_CLASSES = {'OpenSSL': OpenSSLServ, 'GnuTLS': GnuTLSServ}
CLIENT_CLASSES = {'mbedTLS': MbedTLSCli}
@@ -407,27 +374,39 @@
cmd = prefix.join(cmd)
return '\n'.join(server_object.pre_checks() + client_object.pre_checks() + [cmd])
-# pylint: disable=too-many-arguments,C0330
-def generate_compat_hrr_test(server=None, client=None, cipher=None, sig_alg=None,
- client_named_group=None, server_named_group=None):
+
+def generate_hrr_compat_test(server=None, client=None, cipher=None, sig_alg=None,
+ server_named_group=None):
"""
Generate Hello Retry Request test case with `ssl-opt.sh` format.
"""
+ # Get a named_group for client side which does not equal input named_group
+ client_named_group = list(sorted(set(NAMED_GROUP_IANA_VALUE.keys() - {server_named_group})))[0]
name = 'TLS 1.3 {client[0]}->{server[0]}: HRR {c_named_group} -> {s_named_group}'.format(
client=client, server=server, c_named_group=client_named_group,
s_named_group=server_named_group)
- server_object = SERVER_CLASSES[server](cipher, sig_alg, server_named_group,
- client_named_group, True, sig_alg)
- client_object = CLIENT_CLASSES[client](cipher, sig_alg, client_named_group,
- server_named_group, True, sig_alg)
+ server_object = SERVER_CLASSES[server](ciphersuite=cipher,
+ named_group=server_named_group,
+ cert_sig_alg=sig_alg)
+
+ client_object = CLIENT_CLASSES[client](ciphersuite=cipher,
+ named_group=client_named_group,
+ cert_sig_alg=sig_alg)
+ # after here, the named_group order is client_named_group, named_group.
+ client_object.add_named_groups(server_named_group)
cmd = ['run_test "{}"'.format(name), '"{}"'.format(
server_object.cmd()), '"{}"'.format(client_object.cmd()), '0']
cmd += server_object.post_checks()
- cmd += client_object.post_hrr_checks()
+ cmd += client_object.post_checks()
+ cmd += ['-c "received HelloRetryRequest message"']
+ cmd += ['-c "selected_group ( {:d} )"'.format(
+ NAMED_GROUP_IANA_VALUE[server_named_group])]
prefix = ' \\\n' + (' '*9)
cmd = prefix.join(cmd)
- return '\n'.join(server_object.pre_checks() + client_object.pre_checks() + [cmd])
+ return '\n'.join(server_object.pre_checks() +
+ client_object.pre_checks() +
+ [cmd])
SSL_OUTPUT_HEADER = '''#!/bin/sh
@@ -452,15 +431,19 @@
# Purpose
#
# List TLS1.3 compat test cases. They are generated by
-# `generate_tls13_compat_tests.py {parameter} -o {filename}`.
+# `{cmd}`.
#
# PLEASE DO NOT EDIT THIS FILE. IF NEEDED, PLEASE MODIFY `generate_tls13_compat_tests.py`
# AND REGENERATE THIS FILE.
#
'''
+def cycle_zip(*iters, max_len=None):
+ max_len = max_len or max([len(i) for i in iters])
+ cycle_iters = zip(*[itertools.cycle(i) for i in iters])
+ for _, c in zip(range(max_len), cycle_iters):
+ yield c
-# pylint: disable=too-many-branches
def main():
"""
Main function of this program
@@ -516,21 +499,23 @@
CLIENT_CLASSES.keys()):
yield generate_compat_test(cipher=cipher, sig_alg=sig_alg, named_group=named_group,
server=server, client=client)
- for cipher, sig_alg, client_named_group, server_named_group, server, client in \
- itertools.product(HRR_CIPHER_SUITE_VALUE.keys(), HRR_SIG_ALG_VALUE.keys(),
- NAMED_GROUP_IANA_VALUE.keys(), NAMED_GROUP_IANA_VALUE.keys(),
- SERVER_CLASSES.keys(), CLIENT_CLASSES.keys()):
- if client_named_group != server_named_group:
- yield generate_compat_hrr_test(cipher=cipher, sig_alg=sig_alg,
- client_named_group=client_named_group,
- server_named_group=server_named_group,
- server=server, client=client)
+
+ # Generate Hello Retry Request compat test cases
+ for combine_fields, server, client in \
+ itertools.product(list(cycle_zip(CIPHER_SUITE_IANA_VALUE.keys(),
+ SIG_ALG_IANA_VALUE.keys(),
+ NAMED_GROUP_IANA_VALUE.keys())),
+ SERVER_CLASSES.keys(),
+ CLIENT_CLASSES.keys()):
+ cipher, sig_alg, named_group = combine_fields
+ yield generate_hrr_compat_test( server_named_group=named_group, sig_alg=sig_alg,
+ server=server, client=client)
if args.generate_all_tls13_compat_tests:
if args.output:
with open(args.output, 'w', encoding="utf-8") as f:
f.write(SSL_OUTPUT_HEADER.format(
- filename=os.path.basename(args.output), parameter='-a'))
+ filename=os.path.basename(args.output), cmd=' '.join(sys.argv)))
f.write('\n\n'.join(get_all_test_cases()))
f.write('\n')
else:
@@ -551,10 +536,8 @@
print(*CLIENT_CLASSES.keys())
return 0
- if args.generate_all_tls13_compat_tests:
- print(generate_compat_test(server=args.server, client=args.client, sig_alg=args.sig_alg,
- cipher=args.cipher, named_group=args.named_group))
-
+ print(generate_compat_test(server=args.server, client=args.client, sig_alg=args.sig_alg,
+ cipher=args.cipher, named_group=args.named_group))
return 0