Refactor compat scripts
Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/tests/scripts/generate_tls13_compat_tests.py b/tests/scripts/generate_tls13_compat_tests.py
index c3bfe1b..7819a87 100755
--- a/tests/scripts/generate_tls13_compat_tests.py
+++ b/tests/scripts/generate_tls13_compat_tests.py
@@ -24,7 +24,6 @@
import sys
import os
-import abc
import argparse
import itertools
from collections import namedtuple
@@ -71,10 +70,11 @@
}
-class TLSProgram(metaclass=abc.ABCMeta):
+class TLSProgram:
"""
Base class for generate server/client command.
"""
+
# pylint: disable=too-many-arguments
def __init__(self, ciphersuite=None, signature_algorithm=None, named_group=None,
cert_sig_alg=None, compat_mode=True):
@@ -112,24 +112,25 @@
self._cert_sig_algs.extend(
[sig_alg for sig_alg in signature_algorithms if sig_alg not in self._cert_sig_algs])
- @abc.abstractmethod
+ # pylint: disable=no-self-use
def pre_checks(self):
return []
- @abc.abstractmethod
+ # pylint: disable=no-self-use
def cmd(self):
if not self._cert_sig_algs:
self._cert_sig_algs = list(CERTIFICATES.keys())
+ return self.pre_cmd()
- @abc.abstractmethod
+ # pylint: disable=no-self-use
def post_checks(self):
return []
- @abc.abstractmethod
+ # pylint: disable=no-self-use
def pre_cmd(self):
- return []
+ return ['false']
- @abc.abstractmethod
+ # pylint: disable=unused-argument,no-self-use
def hrr_post_checks(self, named_group):
return []
@@ -148,10 +149,7 @@
}
def cmd(self):
- super().cmd()
- ret = []
- for _, cert, key in map(lambda sig_alg: CERTIFICATES[sig_alg], self._cert_sig_algs):
- ret += ['-cert {cert} -key {key}'.format(cert=cert, key=key)]
+ ret = super().cmd()
if self._ciphers:
ciphersuites = ':'.join(self._ciphers)
@@ -177,15 +175,6 @@
def pre_checks(self):
return ["requires_openssl_tls1_3"]
- def post_checks(self):
- return []
-
- def pre_cmd(self):
- return []
-
- def hrr_post_checks(self, named_group):
- return []
-
class OpenSSLServ(OpenSSLBase):
"""
@@ -193,18 +182,28 @@
"""
def cmd(self):
- ret = self.pre_cmd() + super().cmd()
- ret += ['-accept $SRV_PORT']
-
+ ret = super().cmd()
ret += ['-num_tickets 0 -no_resume_ephemeral -no_cache']
-
- return ' '.join(ret)
+ return ret
def post_checks(self):
return ['-c "HTTP/1.0 200 ok"']
def pre_cmd(self):
- return ['$O_NEXT_SRV_NO_CERT']
+ ret = ['$O_NEXT_SRV_NO_CERT']
+ for _, cert, key in map(lambda sig_alg: CERTIFICATES[sig_alg], self._cert_sig_algs):
+ ret += ['-cert {cert} -key {key}'.format(cert=cert, key=key)]
+ return ret
+
+
+class OpenSSLCli(OpenSSLBase):
+ """
+ Generate test commands for OpenSSL client.
+ """
+
+ def pre_cmd(self):
+ return ['$O_NEXT_CLI_NO_CERT',
+ '-CAfile {cafile}'.format(cafile=CERTIFICATES[self._cert_sig_algs[0]].cafile)]
class GnuTLSBase(TLSProgram):
@@ -253,22 +252,8 @@
"requires_gnutls_next_no_ticket",
"requires_gnutls_next_disable_tls13_compat", ]
- def post_checks(self):
- return ['-c "HTTP/1.0 200 OK"']
-
- def hrr_post_checks(self, named_group):
- return []
-
- def pre_cmd(self):
- return []
-
def cmd(self):
- super().cmd()
- ret = []
-
- for _, cert, key in map(lambda sig_alg: CERTIFICATES[sig_alg], self._cert_sig_algs):
- ret += ['--x509certfile {cert} --x509keyfile {key}'.format(
- cert=cert, key=key)]
+ ret = super().cmd()
priority_string_list = []
@@ -316,14 +301,26 @@
Generate test commands for GnuTLS server.
"""
- def cmd(self):
- ret = self.pre_cmd() + super().cmd()
+ def pre_cmd(self):
+ ret = ['$G_NEXT_SRV_NO_CERT', '--http', '--disable-client-cert', '--debug=4']
- ret = ' '.join(ret)
+ for _, cert, key in map(lambda sig_alg: CERTIFICATES[sig_alg], self._cert_sig_algs):
+ ret += ['--x509certfile {cert} --x509keyfile {key}'.format(
+ cert=cert, key=key)]
return ret
+ def post_checks(self):
+ return ['-c "HTTP/1.0 200 OK"']
+
+
+class GnuTLSCli(GnuTLSBase):
+ """
+ Generate test commands for GnuTLS client.
+ """
+
def pre_cmd(self):
- return ['$G_NEXT_SRV_NO_CERT'] + ['--http', '--disable-client-cert', '--debug=4']
+ return ['$G_NEXT_CLI_NO_CERT', '--debug=4', '--single-key-share',
+ '--x509cafile {cafile}'.format(cafile=CERTIFICATES[self._cert_sig_algs[0]].cafile)]
class MbedTLSBase(TLSProgram):
@@ -339,10 +336,9 @@
'TLS_AES_128_CCM_8_SHA256': 'TLS1-3-AES-128-CCM-8-SHA256'}
def cmd(self):
- super().cmd()
- ret = ['server_addr=127.0.0.1', 'server_port=$SRV_PORT', 'debug_level=4']
- ret += ['ca_file={cafile}'.format(
- cafile=CERTIFICATES[self._cert_sig_algs[0]].cafile)]
+ ret = super().cmd()
+ ret += ['debug_level=4']
+
if self._ciphers:
ciphers = ','.join(
@@ -356,7 +352,7 @@
if self._named_groups:
named_groups = ','.join(self._named_groups)
ret += ["curves={named_groups}".format(named_groups=named_groups)]
-
+ ret += ['force_version=tls13']
return ret
def pre_checks(self):
@@ -371,15 +367,6 @@
'requires_config_enabled MBEDTLS_X509_RSASSA_PSS_SUPPORT')
return ret
- def post_checks(self):
- return []
-
- def pre_cmd(self):
- return []
-
- def hrr_post_checks(self, named_group):
- return []
-
class MbedTLSServ(MbedTLSBase):
"""
@@ -387,13 +374,8 @@
"""
def cmd(self):
- ret = self.pre_cmd() + super().cmd()
- ret += ['force_version=tls13']
- for _, cert, key in map(lambda sig_alg: CERTIFICATES[sig_alg], self._cert_sig_algs):
- ret += ['crt_file={cert} key_file={key}'.format(cert=cert, key=key)]
-
+ ret = super().cmd()
ret += ['tls13_kex_modes=ephemeral cookies=0 tickets=0']
- ret = ' '.join(ret)
return ret
def pre_checks(self):
@@ -420,64 +402,23 @@
return ['-s "{}"'.format(i) for i in check_strings]
def pre_cmd(self):
- return ['$P_SRV_NO_CERT']
+ ret = ['$P_SRV']
+ for _, cert, key in map(lambda sig_alg: CERTIFICATES[sig_alg], self._cert_sig_algs):
+ ret += ['crt_file={cert} key_file={key}'.format(cert=cert, key=key)]
+ return ret
def hrr_post_checks(self, named_group):
return ['-s "HRR selected_group: {:s}"'.format(named_group)]
-class OpenSSLCli(OpenSSLBase):
- """
- Generate test commands for OpenSSL client.
- """
-
- def cmd(self):
- ret = self.pre_cmd() + super().cmd()
-
- ret += ['-CAfile {cafile}'.format(
- cafile=CERTIFICATES[self._cert_sig_algs[0]].cafile)]
-
- return ' '.join(ret)
-
- def post_checks(self):
- return ['-s "HTTP/1.0 200 OK"']
-
- def pre_cmd(self):
- return ['$O_NEXT_CLI_NO_CERT']
-
-
-class GnuTLSCli(GnuTLSBase):
- """
- Generate test commands for GnuTLS client.
- """
-
- def cmd(self):
- ret = self.pre_cmd() + super().cmd()
- ret += ['--x509cafile {cafile}'.format(
- cafile=CERTIFICATES[self._cert_sig_algs[0]].cafile)]
-
- ret = ' '.join(ret)
- return ret
-
- def pre_cmd(self):
- ret = ['$G_NEXT_CLI_NO_CERT']
- ret += ['--debug=4', 'localhost', '-p $SRV_PORT', '--single-key-share']
- return ret
-
-
class MbedTLSCli(MbedTLSBase):
"""
Generate test commands for mbedTLS client.
"""
- def cmd(self):
- ret = self.pre_cmd() + super().cmd()
-
- ret = ' '.join(ret)
- return ret
-
def pre_cmd(self):
- return ['$P_CLI']
+ return ['$P_CLI',
+ 'ca_file={cafile}'.format(cafile=CERTIFICATES[self._cert_sig_algs[0]].cafile)]
def pre_checks(self):
return ['requires_config_enabled MBEDTLS_SSL_CLI_C'] + super().pre_checks()
@@ -528,8 +469,10 @@
signature_algorithm=sig_alg,
cert_sig_alg=sig_alg)
- cmd = ['run_test "{}"'.format(name), '"{}"'.format(
- server_object.cmd()), '"{}"'.format(client_object.cmd()), '0']
+ cmd = ['run_test "{}"'.format(name),
+ '"{}"'.format(' '.join(server_object.cmd())),
+ '"{}"'.format(' '.join(client_object.cmd())),
+ '0']
cmd += server_object.post_checks()
cmd += client_object.post_checks()
cmd += ['-C "received HelloRetryRequest message"']
@@ -554,8 +497,10 @@
cert_sig_alg=cert_sig_alg)
client_object.add_named_groups(server_named_group)
- cmd = ['run_test "{}"'.format(name), '"{}"'.format(
- server_object.cmd()), '"{}"'.format(client_object.cmd()), '0']
+ cmd = ['run_test "{}"'.format(name),
+ '"{}"'.format(' '.join(server_object.cmd())),
+ '"{}"'.format(' '.join(client_object.cmd())),
+ '0']
cmd += server_object.post_checks()
cmd += client_object.post_checks()
cmd += server_object.hrr_post_checks(server_named_group)
@@ -660,6 +605,7 @@
SERVER_CLASSES.keys(),
NAMED_GROUP_IANA_VALUE.keys(),
NAMED_GROUP_IANA_VALUE.keys()):
+
if (client == 'mbedTLS' or server == 'mbedTLS') and \
client_named_group != server_named_group:
yield generate_hrr_compat_test(client=client, server=server,