55import re
66
77from function import Function
8- from function import read_functions
8+ from function import read
99
1010select_re = re .compile ('LAPACK_(\w)_SELECT(\d)' )
1111
1212
1313def is_scalar (name , cty , f ):
14- return (
14+ return ( \
1515 'c_char' in cty or
1616 name in [
1717 'abnrm' ,
@@ -20,7 +20,6 @@ def is_scalar(name, cty, f):
2020 'anorm' ,
2121 'bbnrm' ,
2222 'colcnd' ,
23- 'dif' ,
2423 'ihi' ,
2524 'il' ,
2625 'ilo' ,
@@ -45,29 +44,80 @@ def is_scalar(name, cty, f):
4544 'tryrac' ,
4645 'vu' ,
4746 ] or
48- name == 'q' and 'lapack_int' in cty or
49- not (
47+ name in [
48+ 'alpha' ,
49+ ] and (
50+ 'larfg' in f .name
51+ ) or
52+ name in [
53+ 'dif' ,
54+ ] and not (
55+ 'tgsen' in f .name or
56+ 'tgsna' in f .name
57+ ) or
58+ name in [
59+ 'p' ,
60+ ] and not (
61+ 'tgevc' in f .name
62+ ) or
63+ name in [
64+ 'q'
65+ ] and (
66+ 'lapack_int' in cty
67+ ) or
68+ name in [
69+ 'vl' ,
70+ 'vr' ,
71+ ] and not (
5072 'geev' in f .name or
73+ 'ggev' in f .name or
74+ 'hsein' in f .name or
75+ 'tgevc' in f .name or
5176 'tgsna' in f .name or
77+ 'trevc' in f .name or
5278 'trsna' in f .name
53- ) and name in [
54- 'vl' ,
55- 'vr' ,
56- ] or
57- not ('tgevc' in f .name ) and name in [
58- 'p' ,
59- ] or
60- name .startswith ('alpha' ) or
61- name .startswith ('beta' ) or
79+ ) or
80+ name .startswith ('k' ) and not (
81+ 'lapmr' in f .name or
82+ 'lapmt' in f .name
83+ ) or
6284 name .startswith ('inc' ) or
63- name .startswith ('k' ) or
6485 name .startswith ('ld' ) or
6586 name .startswith ('tol' ) or
6687 name .startswith ('vers' )
6788 )
6889
6990
70- def translate_argument (name , cty , f ):
91+ def translate_name (name ):
92+ return name .lower ()
93+
94+
95+ def translate_base_type (cty ):
96+ cty = cty .replace ('__BindgenComplex<f32>' , 'lapack_complex_float' )
97+ cty = cty .replace ('__BindgenComplex<f64>' , 'lapack_complex_double' )
98+ cty = cty .replace ('lapack_float_return' , 'c_float' )
99+ cty = cty .replace ('f32' , 'c_float' )
100+ cty = cty .replace ('f64' , 'c_double' )
101+
102+ if 'c_char' in cty :
103+ return 'u8'
104+ elif 'c_int' in cty :
105+ return 'i32'
106+ elif 'c_float' in cty :
107+ return 'f32'
108+ elif 'c_double' in cty :
109+ return 'f64'
110+ elif 'lapack_complex_float' in cty :
111+ return 'c32'
112+ elif 'lapack_complex_double' in cty :
113+ return 'c64'
114+ elif 'size_t' in cty :
115+ return 'size_t'
116+
117+ assert False , 'cannot translate `{}`' .format (cty )
118+
119+
120+ def translate_signature_type (name , cty , f ):
71121 m = select_re .match (cty )
72122 if m is not None :
73123 if m .group (1 ) == 'S' :
@@ -79,7 +129,7 @@ def translate_argument(name, cty, f):
79129 elif m .group (1 ) == 'Z' :
80130 return 'Select{}C64' .format (m .group (2 ))
81131
82- base = translate_type_base (cty )
132+ base = translate_base_type (cty )
83133 if '*const' in cty :
84134 if is_scalar (name , cty , f ):
85135 return base
@@ -94,30 +144,6 @@ def translate_argument(name, cty, f):
94144 return base
95145
96146
97- def translate_type_base (cty ):
98- cty = cty .replace ('__BindgenComplex<f32>' , 'lapack_complex_float' )
99- cty = cty .replace ('__BindgenComplex<f64>' , 'lapack_complex_double' )
100- cty = cty .replace ('f32' , 'c_float' )
101- cty = cty .replace ('f64' , 'c_double' )
102-
103- if 'c_char' in cty :
104- return 'u8'
105- elif 'c_int' in cty :
106- return 'i32'
107- elif 'c_float' in cty :
108- return 'f32'
109- elif 'c_double' in cty :
110- return 'f64'
111- elif 'lapack_complex_float' in cty :
112- return 'c32'
113- elif 'lapack_complex_double' in cty :
114- return 'c64'
115- elif 'size_t' in cty :
116- return 'libc::c_ulong'
117-
118- assert False , 'cannot translate `{}`' .format (cty )
119-
120-
121147def translate_body_argument (name , rty ):
122148 if rty .startswith ('Select' ):
123149 return 'transmute({})' .format (name )
@@ -154,66 +180,56 @@ def translate_body_argument(name, rty):
154180 elif rty .startswith ('&mut [c' ):
155181 return '{}.as_mut_ptr() as *mut _' .format (name )
156182
157- elif rty . startswith ( 'libc::' ) :
158- return '&{}' . format ( name )
183+ elif rty == 'size_t' :
184+ return name
159185
160186 assert False , 'cannot translate `{}: {}`' .format (name , rty )
161187
162188
163- def translate_return_type (cty ):
164- cty = cty .replace ('lapack_float_return' , 'c_float' )
165- cty = cty .replace ('f64' , 'c_double' )
166-
167- if cty == 'c_int' :
168- return 'i32'
169- elif cty == 'c_float' :
170- return 'f32'
171- elif cty == 'c_double' :
172- return 'f64'
173-
174- assert False , 'cannot translate `{}`' .format (cty )
175-
176-
177- def format_header (f ):
178- args = format_header_arguments (f )
189+ def format_signature (f ):
190+ args = format_signature_arguments (f )
179191 if f .ret is None :
180192 return 'pub unsafe fn {}({})' .format (f .name , args )
181193 else :
182194 return 'pub unsafe fn {}({}) -> {}' .format (f .name , args ,
183- translate_return_type (f .ret ))
184-
185-
186- def format_body (f ):
187- return 'ffi::{}_({})' .format (f .name , format_body_arguments (f ))
195+ translate_base_type (f .ret ))
188196
189197
190- def format_header_arguments (f ):
198+ def format_signature_arguments (f ):
191199 s = []
192- for arg in f .args :
193- s .append ('{}: {}' .format (arg [0 ], translate_argument (* arg , f = f )))
200+ for name , cty in f .args :
201+ name = translate_name (name )
202+ s .append ('{}: {}' .format (name , translate_signature_type (name , cty , f )))
194203 return ', ' .join (s )
195204
196205
206+ def format_body (f ):
207+ return 'ffi::{}_({})' .format (f .name , format_body_arguments (f ))
208+
209+
197210def format_body_arguments (f ):
198211 s = []
199- for arg in f .args :
200- rty = translate_argument (* arg , f = f )
201- s .append (translate_body_argument (arg [0 ], rty ))
212+ for name , cty in f .args :
213+ name = translate_name (name )
214+ rty = translate_signature_type (name , cty , f )
215+ s .append (translate_body_argument (name , rty ))
202216 return ', ' .join (s )
203217
204218
205- def prepare (code ):
219+ def process (code ):
206220 lines = filter (lambda line : not re .match (r'^\s*//.*' , line ),
207221 code .split ('\n ' ))
208222 lines = re .sub (r'\s+' , ' ' , '' .join (lines )).strip ().split (';' )
209223 lines = filter (lambda line : not re .match (r'^\s*$' , line ), lines )
210224 return [Function .parse (line ) for line in lines ]
211225
212226
213- def do (functions ):
227+ def write (functions ):
214228 for f in functions :
229+ if f .name in ['lsame' ]:
230+ continue
215231 print ('\n #[inline]' )
216- print (format_header (f ) + ' {' )
232+ print (format_signature (f ) + ' {' )
217233 print (' ' + format_body (f ) + '\n }' )
218234
219235
@@ -222,4 +238,4 @@ def do(functions):
222238 parser .add_argument ('--sys' , default = 'lapack-sys' )
223239 arguments = parser .parse_args ()
224240 path = os .path .join (arguments .sys , 'src' , 'lapack.rs' )
225- do ( prepare ( read_functions (path )))
241+ write ( process ( read (path )))
0 commit comments