blob: c9e6ec337a9b6c77918fc5e60fe848ecb5a48429 [file] [log] [blame]
Gilles Peskinee7c44552021-01-25 21:40:45 +01001"""Collect macro definitions from header files.
2"""
3
4# Copyright The Mbed TLS Contributors
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the "License"); you may
8# not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18
Gilles Peskine22fcf1b2021-03-10 01:02:39 +010019import itertools
Gilles Peskinee7c44552021-01-25 21:40:45 +010020import re
Gilles Peskine22fcf1b2021-03-10 01:02:39 +010021from typing import Dict, Iterable, Iterator, List, Set
22
23
24class PSAMacroEnumerator:
25 """Information about constructors of various PSA Crypto types.
26
27 This includes macro names as well as information about their arguments
28 when applicable.
29
30 This class only provides ways to enumerate expressions that evaluate to
31 values of the covered types. Derived classes are expected to populate
32 the set of known constructors of each kind, as well as populate
33 `self.arguments_for` for arguments that are not of a kind that is
34 enumerated here.
35 """
36
37 def __init__(self) -> None:
38 """Set up an empty set of known constructor macros.
39 """
40 self.statuses = set() #type: Set[str]
41 self.algorithms = set() #type: Set[str]
42 self.ecc_curves = set() #type: Set[str]
43 self.dh_groups = set() #type: Set[str]
44 self.key_types = set() #type: Set[str]
45 self.key_usage_flags = set() #type: Set[str]
46 self.hash_algorithms = set() #type: Set[str]
47 self.mac_algorithms = set() #type: Set[str]
48 self.ka_algorithms = set() #type: Set[str]
49 self.kdf_algorithms = set() #type: Set[str]
50 self.aead_algorithms = set() #type: Set[str]
51 # macro name -> list of argument names
52 self.argspecs = {} #type: Dict[str, List[str]]
53 # argument name -> list of values
54 self.arguments_for = {
55 'mac_length': [],
56 'min_mac_length': [],
57 'tag_length': [],
58 'min_tag_length': [],
59 } #type: Dict[str, List[str]]
60
61 def gather_arguments(self) -> None:
62 """Populate the list of values for macro arguments.
63
64 Call this after parsing all the inputs.
65 """
66 self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
67 self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
68 self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
69 self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
70 self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
71 self.arguments_for['curve'] = sorted(self.ecc_curves)
72 self.arguments_for['group'] = sorted(self.dh_groups)
73
74 @staticmethod
75 def _format_arguments(name: str, arguments: Iterable[str]) -> str:
76 """Format a macro call with arguments.."""
77 return name + '(' + ', '.join(arguments) + ')'
78
79 _argument_split_re = re.compile(r' *, *')
80 @classmethod
81 def _argument_split(cls, arguments: str) -> List[str]:
82 return re.split(cls._argument_split_re, arguments)
83
84 def distribute_arguments(self, name: str) -> Iterator[str]:
85 """Generate macro calls with each tested argument set.
86
87 If name is a macro without arguments, just yield "name".
88 If name is a macro with arguments, yield a series of
89 "name(arg1,...,argN)" where each argument takes each possible
90 value at least once.
91 """
92 try:
93 if name not in self.argspecs:
94 yield name
95 return
96 argspec = self.argspecs[name]
97 if argspec == []:
98 yield name + '()'
99 return
100 argument_lists = [self.arguments_for[arg] for arg in argspec]
101 arguments = [values[0] for values in argument_lists]
102 yield self._format_arguments(name, arguments)
103 # Dear Pylint, enumerate won't work here since we're modifying
104 # the array.
105 # pylint: disable=consider-using-enumerate
106 for i in range(len(arguments)):
107 for value in argument_lists[i][1:]:
108 arguments[i] = value
109 yield self._format_arguments(name, arguments)
110 arguments[i] = argument_lists[0][0]
111 except BaseException as e:
112 raise Exception('distribute_arguments({})'.format(name)) from e
113
114 def generate_expressions(self, names: Iterable[str]) -> Iterator[str]:
115 """Generate expressions covering values constructed from the given names.
116
117 `names` can be any iterable collection of macro names.
118
119 For example:
120 * ``generate_expressions(['PSA_ALG_CMAC', 'PSA_ALG_HMAC'])``
121 generates ``'PSA_ALG_CMAC'`` as well as ``'PSA_ALG_HMAC(h)'`` for
122 every known hash algorithm ``h``.
123 * ``macros.generate_expressions(macros.key_types)`` generates all
124 key types.
125 """
126 return itertools.chain(*map(self.distribute_arguments, names))
127
Gilles Peskinee7c44552021-01-25 21:40:45 +0100128
129class PSAMacroCollector:
130 """Collect PSA crypto macro definitions from C header files.
131 """
132
Gilles Peskine10ab2672021-03-10 00:59:53 +0100133 def __init__(self, include_intermediate: bool = False) -> None:
Gilles Peskine13d60eb2021-01-25 22:42:14 +0100134 """Set up an object to collect PSA macro definitions.
135
136 Call the read_file method of the constructed object on each header file.
137
138 * include_intermediate: if true, include intermediate macros such as
139 PSA_XXX_BASE that do not designate semantic values.
140 """
141 self.include_intermediate = include_intermediate
Gilles Peskine10ab2672021-03-10 00:59:53 +0100142 self.statuses = set() #type: Set[str]
143 self.key_types = set() #type: Set[str]
144 self.key_types_from_curve = {} #type: Dict[str, str]
145 self.key_types_from_group = {} #type: Dict[str, str]
146 self.ecc_curves = set() #type: Set[str]
147 self.dh_groups = set() #type: Set[str]
148 self.algorithms = set() #type: Set[str]
149 self.hash_algorithms = set() #type: Set[str]
150 self.ka_algorithms = set() #type: Set[str]
151 self.algorithms_from_hash = {} #type: Dict[str, str]
152 self.key_usages = set() #type: Set[str]
Gilles Peskinee7c44552021-01-25 21:40:45 +0100153
Gilles Peskine10ab2672021-03-10 00:59:53 +0100154 def is_internal_name(self, name: str) -> bool:
Gilles Peskinef8deb752021-01-25 22:41:45 +0100155 """Whether this is an internal macro. Internal macros will be skipped."""
Gilles Peskine13d60eb2021-01-25 22:42:14 +0100156 if not self.include_intermediate:
157 if name.endswith('_BASE') or name.endswith('_NONE'):
158 return True
159 if '_CATEGORY_' in name:
160 return True
Gilles Peskine0655b4f2021-01-25 22:44:36 +0100161 return name.endswith('_FLAG') or name.endswith('_MASK')
Gilles Peskinef8deb752021-01-25 22:41:45 +0100162
Gilles Peskinee7c44552021-01-25 21:40:45 +0100163 # "#define" followed by a macro name with either no parameters
164 # or a single parameter and a non-empty expansion.
165 # Grab the macro name in group 1, the parameter name if any in group 2
166 # and the expansion in group 3.
167 _define_directive_re = re.compile(r'\s*#\s*define\s+(\w+)' +
168 r'(?:\s+|\((\w+)\)\s*)' +
169 r'(.+)')
170 _deprecated_definition_re = re.compile(r'\s*MBEDTLS_DEPRECATED')
171
172 def read_line(self, line):
173 """Parse a C header line and record the PSA identifier it defines if any.
174 This function analyzes lines that start with "#define PSA_"
175 (up to non-significant whitespace) and skips all non-matching lines.
176 """
177 # pylint: disable=too-many-branches
178 m = re.match(self._define_directive_re, line)
179 if not m:
180 return
181 name, parameter, expansion = m.groups()
182 expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion)
183 if re.match(self._deprecated_definition_re, expansion):
184 # Skip deprecated values, which are assumed to be
185 # backward compatibility aliases that share
186 # numerical values with non-deprecated values.
187 return
Gilles Peskinef8deb752021-01-25 22:41:45 +0100188 if self.is_internal_name(name):
Gilles Peskinee7c44552021-01-25 21:40:45 +0100189 # Macro only to build actual values
190 return
191 elif (name.startswith('PSA_ERROR_') or name == 'PSA_SUCCESS') \
192 and not parameter:
193 self.statuses.add(name)
194 elif name.startswith('PSA_KEY_TYPE_') and not parameter:
195 self.key_types.add(name)
196 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve':
197 self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:]
198 elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group':
199 self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:]
200 elif name.startswith('PSA_ECC_FAMILY_') and not parameter:
201 self.ecc_curves.add(name)
202 elif name.startswith('PSA_DH_FAMILY_') and not parameter:
203 self.dh_groups.add(name)
204 elif name.startswith('PSA_ALG_') and not parameter:
205 if name in ['PSA_ALG_ECDSA_BASE',
206 'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']:
207 # Ad hoc skipping of duplicate names for some numerical values
208 return
209 self.algorithms.add(name)
210 # Ad hoc detection of hash algorithms
211 if re.search(r'0x020000[0-9A-Fa-f]{2}', expansion):
212 self.hash_algorithms.add(name)
213 # Ad hoc detection of key agreement algorithms
214 if re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion):
215 self.ka_algorithms.add(name)
216 elif name.startswith('PSA_ALG_') and parameter == 'hash_alg':
217 if name in ['PSA_ALG_DSA', 'PSA_ALG_ECDSA']:
218 # A naming irregularity
219 tester = name[:8] + 'IS_RANDOMIZED_' + name[8:]
220 else:
221 tester = name[:8] + 'IS_' + name[8:]
222 self.algorithms_from_hash[name] = tester
223 elif name.startswith('PSA_KEY_USAGE_') and not parameter:
224 self.key_usages.add(name)
225 else:
226 # Other macro without parameter
227 return
228
229 _nonascii_re = re.compile(rb'[^\x00-\x7f]+')
230 _continued_line_re = re.compile(rb'\\\r?\n\Z')
231 def read_file(self, header_file):
232 for line in header_file:
233 m = re.search(self._continued_line_re, line)
234 while m:
235 cont = next(header_file)
236 line = line[:m.start(0)] + cont
237 m = re.search(self._continued_line_re, line)
238 line = re.sub(self._nonascii_re, rb'', line).decode('ascii')
239 self.read_line(line)