Address review comments

Signed-off-by: Przemyslaw Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index 45d940e..4c8143f 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -133,7 +133,7 @@
         return constructors
 
 
-def test_case_for_key_type_not_supported_invalid_arg(
+def test_case_for_key_type_not_supported(
         verb: str, key_type: str, bits: int,
         dependencies: List[str],
         *args: str,
@@ -148,20 +148,37 @@
     adverb = 'not' if dependencies else 'never'
     if param_descr:
         adverb = param_descr + ' ' + adverb
-    if (verb == "generate") and ("PUBLIC" in short_key_type):
-        tc.set_description('PSA {} {} {}-bit invalid argument'
-                           .format(verb, short_key_type, bits))
-        tc.set_function(verb + '_invalid_arg')
-    else:
-        tc.set_description('PSA {} {} {}-bit {} supported'
-                           .format(verb, short_key_type, bits, adverb))
-        tc.set_function(verb + '_not_supported')
+    tc.set_description('PSA {} {} {}-bit {} supported'
+                       .format(verb, short_key_type, bits, adverb))
+    tc.set_dependencies(dependencies)
+    tc.set_function(verb + '_not_supported')
+    tc.set_arguments([key_type] + list(args))
+    return tc
+
+def test_case_for_key_type_invalid_argument(
+        verb: str, key_type: str, bits: int,
+        dependencies: List[str],
+        *args: str,
+        param_descr: str = ''
+) -> test_case.TestCase:
+    """Return one test case exercising a key creation method
+    for an invalid argument when key is public.
+    """
+    hack_dependencies_not_implemented(dependencies)
+    tc = test_case.TestCase()
+    short_key_type = re.sub(r'PSA_(KEY_TYPE|ECC_FAMILY)_', r'', key_type)
+    adverb = 'not' if dependencies else 'never'
+    if param_descr:
+        adverb = param_descr + ' ' + adverb
+    tc.set_description('PSA {} {} {}-bit invalid argument'
+                       .format(verb, short_key_type, bits))
+    tc.set_function(verb + '_invalid_argument')
     tc.set_dependencies(dependencies)
     tc.set_arguments([key_type] + list(args))
     return tc
 
 class NotSupported:
-    """Generate test cases for when something is not supported."""
+    """Generate test cases for when something is not supported or argument is inavlid."""
 
     def __init__(self, info: Information) -> None:
         self.constructors = info.constructors
@@ -176,11 +193,13 @@
             param: Optional[int] = None,
             param_descr: str = '',
     ) -> Iterator[test_case.TestCase]:
-        """Return test cases exercising key creation when the given type is unsupported.
+        """Return test cases exercising key creation when the given type is unsupported
+        or argument is invalid.
 
         If param is present and not None, emit test cases conditioned on this
         parameter not being supported. If it is absent or None, emit test cases
-        conditioned on the base type not being supported.
+        conditioned on the base type not being supported. If key is public emit test
+        case for invalid argument.
         """
         if kt.name in self.ALWAYS_SUPPORTED:
             # Don't generate test cases for key types that are always supported.
@@ -197,7 +216,7 @@
         else:
             generate_dependencies = import_dependencies
         for bits in kt.sizes_to_test():
-            yield test_case_for_key_type_not_supported_invalid_arg(
+            yield test_case_for_key_type_not_supported(
                 'import', kt.expression, bits,
                 finish_family_dependencies(import_dependencies, bits),
                 test_case.hex_string(kt.key_material(bits)),
@@ -208,12 +227,20 @@
                 # supported or not depending on implementation capabilities,
                 # only generate the test case once.
                 continue
-            yield test_case_for_key_type_not_supported_invalid_arg(
-                'generate', kt.expression, bits,
-                finish_family_dependencies(generate_dependencies, bits),
-                str(bits),
-                param_descr=param_descr,
-            )
+            if kt.name.endswith('_PUBLIC_KEY'):
+                yield test_case_for_key_type_invalid_argument(
+                    'generate', kt.expression, bits,
+                    finish_family_dependencies(generate_dependencies, bits),
+                    str(bits),
+                    param_descr=param_descr,
+                )
+            else:
+                yield test_case_for_key_type_not_supported(
+                    'generate', kt.expression, bits,
+                    finish_family_dependencies(generate_dependencies, bits),
+                    str(bits),
+                    param_descr=param_descr,
+                )
             # To be added: derive
 
     ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
@@ -234,7 +261,6 @@
                 yield from self.test_cases_for_key_type_not_supported(
                     kt, 0, param_descr='curve')
 
-
 class StorageKey(psa_storage.Key):
     """Representation of a key for storage format testing."""