88""" 
99
1010import  argparse 
11+ from  collections  import  namedtuple 
1112import  itertools 
1213import  os 
1314import  platform 
@@ -60,12 +61,15 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
6061                from  exc_value 
6162
6263class  Inputs :
64+     # pylint: disable=too-many-instance-attributes 
6365    """Accumulate information about macros to test. 
66+ 
6467    This includes macro names as well as information about their arguments 
6568    when applicable. 
6669    """ 
6770
6871    def  __init__ (self ):
72+         self .all_declared  =  set ()
6973        # Sets of names per type 
7074        self .statuses  =  set (['PSA_SUCCESS' ])
7175        self .algorithms  =  set (['0xffffffff' ])
@@ -86,11 +90,30 @@ def __init__(self):
8690        self .table_by_prefix  =  {
8791            'ERROR' : self .statuses ,
8892            'ALG' : self .algorithms ,
89-             'CURVE ' : self .ecc_curves ,
90-             'GROUP ' : self .dh_groups ,
93+             'ECC_CURVE ' : self .ecc_curves ,
94+             'DH_GROUP ' : self .dh_groups ,
9195            'KEY_TYPE' : self .key_types ,
9296            'KEY_USAGE' : self .key_usage_flags ,
9397        }
98+         # Test functions 
99+         self .table_by_test_function  =  {
100+             # Any function ending in _algorithm also gets added to 
101+             # self.algorithms. 
102+             'key_type' : [self .key_types ],
103+             'ecc_key_types' : [self .ecc_curves ],
104+             'dh_key_types' : [self .dh_groups ],
105+             'hash_algorithm' : [self .hash_algorithms ],
106+             'mac_algorithm' : [self .mac_algorithms ],
107+             'cipher_algorithm' : [],
108+             'hmac_algorithm' : [self .mac_algorithms ],
109+             'aead_algorithm' : [self .aead_algorithms ],
110+             'key_derivation_algorithm' : [self .kdf_algorithms ],
111+             'key_agreement_algorithm' : [self .ka_algorithms ],
112+             'asymmetric_signature_algorithm' : [],
113+             'asymmetric_signature_wildcard' : [self .algorithms ],
114+             'asymmetric_encryption_algorithm' : [],
115+             'other_algorithm' : [],
116+         }
94117        # macro name -> list of argument names 
95118        self .argspecs  =  {}
96119        # argument name -> list of values 
@@ -99,8 +122,20 @@ def __init__(self):
99122            'tag_length' : ['1' , '63' ],
100123        }
101124
125+     def  get_names (self , type_word ):
126+         """Return the set of known names of values of the given type.""" 
127+         return  {
128+             'status' : self .statuses ,
129+             'algorithm' : self .algorithms ,
130+             'ecc_curve' : self .ecc_curves ,
131+             'dh_group' : self .dh_groups ,
132+             'key_type' : self .key_types ,
133+             'key_usage' : self .key_usage_flags ,
134+         }[type_word ]
135+ 
102136    def  gather_arguments (self ):
103137        """Populate the list of values for macro arguments. 
138+ 
104139        Call this after parsing all the inputs. 
105140        """ 
106141        self .arguments_for ['hash_alg' ] =  sorted (self .hash_algorithms )
@@ -118,6 +153,7 @@ def _format_arguments(name, arguments):
118153
119154    def  distribute_arguments (self , name ):
120155        """Generate macro calls with each tested argument set. 
156+ 
121157        If name is a macro without arguments, just yield "name". 
122158        If name is a macro with arguments, yield a series of 
123159        "name(arg1,...,argN)" where each argument takes each possible 
@@ -145,6 +181,9 @@ def distribute_arguments(self, name):
145181        except  BaseException  as  e :
146182            raise  Exception ('distribute_arguments({})' .format (name )) from  e 
147183
184+     def  generate_expressions (self , names ):
185+         return  itertools .chain (* map (self .distribute_arguments , names ))
186+ 
148187    _argument_split_re  =  re .compile (r' *, *' )
149188    @classmethod  
150189    def  _argument_split (cls , arguments ):
@@ -154,7 +193,7 @@ def _argument_split(cls, arguments):
154193    # Groups: 1=macro name, 2=type, 3=argument list (optional). 
155194    _header_line_re  =  \
156195        re .compile (r'#define +'  + 
157-                    r'(PSA_((?:KEY_ )?[A-Z]+)_\w+)'  + 
196+                    r'(PSA_((?:(?:DH|ECC|KEY)_ )?[A-Z]+)_\w+)'  + 
158197                   r'(?:\(([^\n()]*)\))?' )
159198    # Regex of macro names to exclude. 
160199    _excluded_name_re  =  re .compile (r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z' )
@@ -167,10 +206,6 @@ def _argument_split(cls, arguments):
167206        # Auxiliary macro whose name doesn't fit the usual patterns for 
168207        # auxiliary macros. 
169208        'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE' ,
170-         # PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script 
171-         # currently doesn't support them. 
172-         'PSA_ALG_ECDH' ,
173-         'PSA_ALG_FFDH' ,
174209        # Deprecated aliases. 
175210        'PSA_ERROR_UNKNOWN_ERROR' ,
176211        'PSA_ERROR_OCCUPIED_SLOT' ,
@@ -184,6 +219,7 @@ def parse_header_line(self, line):
184219        if  not  m :
185220            return 
186221        name  =  m .group (1 )
222+         self .all_declared .add (name )
187223        if  re .search (self ._excluded_name_re , name ) or  \
188224           name  in  self ._excluded_names :
189225            return 
@@ -200,26 +236,34 @@ def parse_header(self, filename):
200236            for  line  in  lines :
201237                self .parse_header_line (line )
202238
239+     _macro_identifier_re  =  r'[A-Z]\w+' 
240+     def  generate_undeclared_names (self , expr ):
241+         for  name  in  re .findall (self ._macro_identifier_re , expr ):
242+             if  name  not  in self .all_declared :
243+                 yield  name 
244+ 
245+     def  accept_test_case_line (self , function , argument ):
246+         #pylint: disable=unused-argument 
247+         undeclared  =  list (self .generate_undeclared_names (argument ))
248+         if  undeclared :
249+             raise  Exception ('Undeclared names in test case' , undeclared )
250+         return  True 
251+ 
203252    def  add_test_case_line (self , function , argument ):
204253        """Parse a test case data line, looking for algorithm metadata tests.""" 
254+         sets  =  []
205255        if  function .endswith ('_algorithm' ):
206-             # As above, ECDH and FFDH algorithms are excluded for now. 
207-             # Support for them will be added in the future. 
208-             if  'ECDH'  in  argument  or  'FFDH'  in  argument :
209-                 return 
210-             self .algorithms .add (argument )
211-             if  function  ==  'hash_algorithm' :
212-                 self .hash_algorithms .add (argument )
213-             elif  function  in  ['mac_algorithm' , 'hmac_algorithm' ]:
214-                 self .mac_algorithms .add (argument )
215-             elif  function  ==  'aead_algorithm' :
216-                 self .aead_algorithms .add (argument )
217-         elif  function  ==  'key_type' :
218-             self .key_types .add (argument )
219-         elif  function  ==  'ecc_key_types' :
220-             self .ecc_curves .add (argument )
221-         elif  function  ==  'dh_key_types' :
222-             self .dh_groups .add (argument )
256+             sets .append (self .algorithms )
257+             if  function  ==  'key_agreement_algorithm'  and  \
258+                argument .startswith ('PSA_ALG_KEY_AGREEMENT(' ):
259+                 # We only want *raw* key agreement algorithms as such, so 
260+                 # exclude ones that are already chained with a KDF. 
261+                 # Keep the expression as one to test as an algorithm. 
262+                 function  =  'other_algorithm' 
263+         sets  +=  self .table_by_test_function [function ]
264+         if  self .accept_test_case_line (function , argument ):
265+             for  s  in  sets :
266+                 s .add (argument )
223267
224268    # Regex matching a *.data line containing a test function call and 
225269    # its arguments. The actual definition is partly positional, but this 
@@ -233,9 +277,9 @@ def parse_test_cases(self, filename):
233277                if  m :
234278                    self .add_test_case_line (m .group (1 ), m .group (2 ))
235279
236- def  gather_inputs (headers , test_suites ):
280+ def  gather_inputs (headers , test_suites ,  inputs_class = Inputs ):
237281    """Read the list of inputs to test psa_constant_names with.""" 
238-     inputs  =  Inputs ()
282+     inputs  =  inputs_class ()
239283    for  header  in  headers :
240284        inputs .parse_header (header )
241285    for  test_cases  in  test_suites :
@@ -252,8 +296,10 @@ def remove_file_if_exists(filename):
252296    except  OSError :
253297        pass 
254298
255- def  run_c (options , type_word , names ):
256-     """Generate and run a program to print out numerical values for names.""" 
299+ def  run_c (type_word , expressions , include_path = None , keep_c = False ):
300+     """Generate and run a program to print out numerical values for expressions.""" 
301+     if  include_path  is  None :
302+         include_path  =  []
257303    if  type_word  ==  'status' :
258304        cast_to  =  'long' 
259305        printf_format  =  '%ld' 
@@ -278,18 +324,18 @@ def run_c(options, type_word, names):
278324int main(void) 
279325{ 
280326''' )
281-         for  name  in  names :
327+         for  expr  in  expressions :
282328            c_file .write ('    printf("{}\\ n", ({}) {});\n ' 
283-                          .format (printf_format , cast_to , name ))
329+                          .format (printf_format , cast_to , expr ))
284330        c_file .write ('''    return 0; 
285331} 
286332''' )
287333        c_file .close ()
288334        cc  =  os .getenv ('CC' , 'cc' )
289335        subprocess .check_call ([cc ] + 
290-                               ['-I'  +  dir  for  dir  in  options . include ] + 
336+                               ['-I'  +  dir  for  dir  in  include_path ] + 
291337                              ['-o' , exe_name , c_name ])
292-         if  options . keep_c :
338+         if  keep_c :
293339            sys .stderr .write ('List of {} tests kept at {}\n ' 
294340                             .format (type_word , c_name ))
295341        else :
@@ -302,76 +348,101 @@ def run_c(options, type_word, names):
302348NORMALIZE_STRIP_RE  =  re .compile (r'\s+' )
303349def  normalize (expr ):
304350    """Normalize the C expression so as not to care about trivial differences. 
351+ 
305352    Currently "trivial differences" means whitespace. 
306353    """ 
307-     expr  =  re .sub (NORMALIZE_STRIP_RE , '' , expr , len (expr ))
308-     return  expr .strip ().split ('\n ' )
309- 
310- def  do_test (options , inputs , type_word , names ):
311-     """Test psa_constant_names for the specified type. 
312-     Run program on names. 
313-     Use inputs to figure out what arguments to pass to macros that 
314-     take arguments. 
315-     """ 
316-     names  =  sorted (itertools .chain (* map (inputs .distribute_arguments , names )))
317-     values  =  run_c (options , type_word , names )
318-     output  =  subprocess .check_output ([options .program , type_word ] +  values )
319-     outputs  =  output .decode ('ascii' ).strip ().split ('\n ' )
320-     errors  =  [(type_word , name , value , output )
321-               for  (name , value , output ) in  zip (names , values , outputs )
322-               if  normalize (name ) !=  normalize (output )]
323-     return  len (names ), errors 
324- 
325- def  report_errors (errors ):
326-     """Describe each case where the output is not as expected.""" 
327-     for  type_word , name , value , output  in  errors :
328-         print ('For {} "{}", got "{}" (value: {})' 
329-               .format (type_word , name , output , value ))
330- 
331- def  run_tests (options , inputs ):
332-     """Run psa_constant_names on all the gathered inputs. 
333-     Return a tuple (count, errors) where count is the total number of inputs 
334-     that were tested and errors is the list of cases where the output was 
335-     not as expected. 
354+     return  re .sub (NORMALIZE_STRIP_RE , '' , expr )
355+ 
356+ def  collect_values (inputs , type_word , include_path = None , keep_c = False ):
357+     """Generate expressions using known macro names and calculate their values. 
358+ 
359+     Return a list of pairs of (expr, value) where expr is an expression and 
360+     value is a string representation of its integer value. 
336361    """ 
337-     count  =  0 
338-     errors  =  []
339-     for  type_word , names  in  [('status' , inputs .statuses ),
340-                              ('algorithm' , inputs .algorithms ),
341-                              ('ecc_curve' , inputs .ecc_curves ),
342-                              ('dh_group' , inputs .dh_groups ),
343-                              ('key_type' , inputs .key_types ),
344-                              ('key_usage' , inputs .key_usage_flags )]:
345-         c , e  =  do_test (options , inputs , type_word , names )
346-         count  +=  c 
347-         errors  +=  e 
348-     return  count , errors 
362+     names  =  inputs .get_names (type_word )
363+     expressions  =  sorted (inputs .generate_expressions (names ))
364+     values  =  run_c (type_word , expressions ,
365+                    include_path = include_path , keep_c = keep_c )
366+     return  expressions , values 
367+ 
368+ class  Tests :
369+     """An object representing tests and their results.""" 
370+ 
371+     Error  =  namedtuple ('Error' ,
372+                        ['type' , 'expression' , 'value' , 'output' ])
373+ 
374+     def  __init__ (self , options ):
375+         self .options  =  options 
376+         self .count  =  0 
377+         self .errors  =  []
378+ 
379+     def  run_one (self , inputs , type_word ):
380+         """Test psa_constant_names for the specified type. 
381+ 
382+         Run the program on the names for this type. 
383+         Use the inputs to figure out what arguments to pass to macros that 
384+         take arguments. 
385+         """ 
386+         expressions , values  =  collect_values (inputs , type_word ,
387+                                              include_path = self .options .include ,
388+                                              keep_c = self .options .keep_c )
389+         output  =  subprocess .check_output ([self .options .program , type_word ] + 
390+                                          values )
391+         outputs  =  output .decode ('ascii' ).strip ().split ('\n ' )
392+         self .count  +=  len (expressions )
393+         for  expr , value , output  in  zip (expressions , values , outputs ):
394+             if  normalize (expr ) !=  normalize (output ):
395+                 self .errors .append (self .Error (type = type_word ,
396+                                               expression = expr ,
397+                                               value = value ,
398+                                               output = output ))
399+ 
400+     def  run_all (self , inputs ):
401+         """Run psa_constant_names on all the gathered inputs.""" 
402+         for  type_word  in  ['status' , 'algorithm' , 'ecc_curve' , 'dh_group' ,
403+                           'key_type' , 'key_usage' ]:
404+             self .run_one (inputs , type_word )
405+ 
406+     def  report (self , out ):
407+         """Describe each case where the output is not as expected. 
408+ 
409+         Write the errors to ``out``. 
410+         Also write a total. 
411+         """ 
412+         for  error  in  self .errors :
413+             out .write ('For {} "{}", got "{}" (value: {})\n ' 
414+                       .format (error .type , error .expression ,
415+                               error .output , error .value ))
416+         out .write ('{} test cases' .format (self .count ))
417+         if  self .errors :
418+             out .write (', {} FAIL\n ' .format (len (self .errors )))
419+         else :
420+             out .write (' PASS\n ' )
421+ 
422+ HEADERS  =  ['psa/crypto.h' , 'psa/crypto_extra.h' , 'psa/crypto_values.h' ]
423+ TEST_SUITES  =  ['tests/suites/test_suite_psa_crypto_metadata.data' ]
349424
350425def  main ():
351426    parser  =  argparse .ArgumentParser (description = globals ()['__doc__' ])
352427    parser .add_argument ('--include' , '-I' ,
353428                        action = 'append' , default = ['include' ],
354429                        help = 'Directory for header files' )
355-     parser .add_argument ('--program' ,
356-                         default = 'programs/psa/psa_constant_names' ,
357-                         help = 'Program to test' )
358430    parser .add_argument ('--keep-c' ,
359431                        action = 'store_true' , dest = 'keep_c' , default = False ,
360432                        help = 'Keep the intermediate C file' )
361433    parser .add_argument ('--no-keep-c' ,
362434                        action = 'store_false' , dest = 'keep_c' ,
363435                        help = 'Don\' t keep the intermediate C file (default)' )
436+     parser .add_argument ('--program' ,
437+                         default = 'programs/psa/psa_constant_names' ,
438+                         help = 'Program to test' )
364439    options  =  parser .parse_args ()
365-     headers  =  [os .path .join (options .include [0 ], 'psa' , h )
366-                for  h  in  ['crypto.h' , 'crypto_extra.h' , 'crypto_values.h' ]]
367-     test_suites  =  ['tests/suites/test_suite_psa_crypto_metadata.data' ]
368-     inputs  =  gather_inputs (headers , test_suites )
369-     count , errors  =  run_tests (options , inputs )
370-     report_errors (errors )
371-     if  errors  ==  []:
372-         print ('{} test cases PASS' .format (count ))
373-     else :
374-         print ('{} test cases, {} FAIL' .format (count , len (errors )))
440+     headers  =  [os .path .join (options .include [0 ], h ) for  h  in  HEADERS ]
441+     inputs  =  gather_inputs (headers , TEST_SUITES )
442+     tests  =  Tests (options )
443+     tests .run_all (inputs )
444+     tests .report (sys .stdout )
445+     if  tests .errors :
375446        exit (1 )
376447
377448if  __name__  ==  '__main__' :
0 commit comments