@@ -31,32 +31,29 @@ def generate_file(out_fname, array_name, array_type, array_contents, size):
3131 """Write an array of values to a CC or header file."""
3232 os .makedirs (os .path .dirname (out_fname ), exist_ok = True )
3333 if out_fname .endswith ('.cc' ):
34- out_cc_file = open (out_fname , 'w' )
35- out_cc_file .write ('#include <cstdint>\n \n ' )
36- out_cc_file . write ( '# include "{}" \n \n ' . format (
37- out_fname .split ('genfiles/' )[- 1 ].replace ('.cc' , '.h' )) )
38- out_cc_file .write ('alignas(16) const {} {}[] = {{ ' .format (
39- array_type , array_name ))
40- out_cc_file . write ( array_contents )
41- out_cc_file .write ('}; \n ' )
42- out_cc_file .close ( )
34+ with open (out_fname , 'w' ) as out_cc_file :
35+ out_cc_file .write ('#include <cstdint>\n \n ' )
36+ # Header include path logic, maintaining compatibility with genfiles/ structure.
37+ header_path = out_fname .split ('genfiles/' )[- 1 ].replace ('.cc' , '.h' )
38+ out_cc_file .write ('#include "{}" \n \n ' .format (header_path ))
39+ out_cc_file . write ( 'alignas(16) const {} {}[] = {{' . format (
40+ array_type , array_name ) )
41+ out_cc_file .write (array_contents )
42+ out_cc_file .write ( '}; \n ' )
4343 elif out_fname .endswith ('.h' ):
44- out_hdr_file = open (out_fname , 'w' )
45- out_hdr_file .write ('#include <cstdint>\n \n ' )
46- out_hdr_file .write ('constexpr unsigned int {}_size = {};\n ' .format (
47- array_name , str (size )))
48- out_hdr_file .write ('extern const {} {}[];\n ' .format (
49- array_type , array_name ))
50- out_hdr_file .close ()
44+ with open (out_fname , 'w' ) as out_hdr_file :
45+ out_hdr_file .write ('#include <cstdint>\n \n ' )
46+ out_hdr_file .write ('constexpr unsigned int {}_size = {};\n ' .format (
47+ array_name , str (size )))
48+ out_hdr_file .write ('extern const {} {}[];\n ' .format (
49+ array_type , array_name ))
5150 else :
5251 raise ValueError ('generated file must be end with .cc or .h' )
5352
5453
5554def bytes_to_hexstring (buffer ):
5655 """Convert a byte array to a hex string."""
57- hex_values = [hex (buffer [i ]) for i in range (len (buffer ))]
58- out_string = ',' .join (hex_values )
59- return out_string
56+ return ',' .join ([hex (b ) for b in buffer ])
6057
6158
6259def generate_array (input_fname ):
@@ -92,31 +89,31 @@ def generate_array(input_fname):
9289 data_1d = data .flatten ()
9390 out_string = ',' .join ([str (x ) for x in data_1d ])
9491 return [len (data_1d ), out_string ]
95-
9692 else :
9793 raise ValueError ('input file must be .tflite, .bmp, .wav or .csv' )
9894
9995
100- def get_array_name (input_fname ):
101- # Normalize potential relative path to remove additional dot.
102- abs_fname = os .path .abspath (input_fname )
103- base_array_name = 'g_' + abs_fname .split ('.' )[- 2 ].split ('/' )[- 1 ]
96+ def get_array_name_and_type (input_fname ):
97+ """Return the array name and type for a given input file."""
98+ # Use os.path.basename to correctly handle both Unix and Windows paths.
99+ base_fname = os .path .basename (input_fname )
100+ # Original logic extracted the filename part between the last two dots.
101+ name_parts = base_fname .split ('.' )
102+ name_no_ext = name_parts [- 2 ] if len (name_parts ) >= 2 else base_fname
103+ base_array_name = 'g_' + name_no_ext
104+
104105 if input_fname .endswith ('.tflite' ):
105106 return [base_array_name + '_model_data' , 'unsigned char' ]
106107 elif input_fname .endswith ('.bmp' ):
107108 return [base_array_name + '_image_data' , 'unsigned char' ]
108109 elif input_fname .endswith ('.wav' ):
109110 return [base_array_name + '_audio_data' , 'int16_t' ]
110- elif input_fname .endswith ('_int32.csv' ):
111- return [base_array_name + '_test_data' , 'int32_t' ]
112- elif input_fname .endswith ('_int16.csv' ):
113- return [base_array_name + '_test_data' , 'int16_t' ]
114- elif input_fname .endswith ('_int8.csv' ):
115- return [base_array_name + '_test_data' , 'int8_t' ]
116- elif input_fname .endswith ('_float.csv' ):
117- return [base_array_name + '_test_data' , 'float' ]
118- elif input_fname .endswith ('npy' ):
119- return [base_array_name + '_test_data' , 'float' ]
111+ elif input_fname .endswith (('_int32.csv' , '_int16.csv' , '_int8.csv' , '_float.csv' , '.csv' , '.npy' )):
112+ return [base_array_name + '_test_data' , 'int32_t' if '_int32.csv' in input_fname else
113+ 'int16_t' if '_int16.csv' in input_fname else
114+ 'int8_t' if '_int8.csv' in input_fname else 'float' ]
115+ else :
116+ return [base_array_name + '_data' , 'unsigned char' ]
120117
121118
122119def main ():
@@ -135,7 +132,7 @@ def main():
135132 if args .output .endswith ('.cc' ) or args .output .endswith ('.h' ):
136133 assert len (args .inputs ) == 1
137134 size , cc_array = generate_array (args .inputs [0 ])
138- generated_array_name , array_type = get_array_name (args .inputs [0 ])
135+ generated_array_name , array_type = get_array_name_and_type (args .inputs [0 ])
139136 generate_file (args .output , generated_array_name , array_type , cc_array ,
140137 size )
141138 else :
@@ -144,15 +141,13 @@ def main():
144141 output_base_fname = os .path .join (args .output ,
145142 os .path .splitext (input_file )[0 ])
146143 if input_file .endswith ('.tflite' ):
147- output_base_fname = output_base_fname + '_model_data'
144+ output_base_fname += '_model_data'
148145 elif input_file .endswith ('.bmp' ):
149- output_base_fname = output_base_fname + '_image_data'
146+ output_base_fname += '_image_data'
150147 elif input_file .endswith ('.wav' ):
151- output_base_fname = output_base_fname + '_audio_data'
152- elif input_file .endswith ('.csv' ):
153- output_base_fname = output_base_fname + '_test_data'
154- elif input_file .endswith ('.npy' ):
155- output_base_fname = output_base_fname + '_test_data'
148+ output_base_fname += '_audio_data'
149+ elif input_file .endswith (('.csv' , '.npy' )):
150+ output_base_fname += '_test_data'
156151 else :
157152 raise ValueError (
158153 'input file must be .tflite, .bmp, .wav , .npy or .csv' )
@@ -162,12 +157,12 @@ def main():
162157 print (output_cc_fname )
163158 output_hdr_fname = output_base_fname + '.h'
164159 size , cc_array = generate_array (input_file )
165- generated_array_name , array_type = get_array_name (input_file )
160+ generated_array_name , array_type = get_array_name_and_type (input_file )
166161 generate_file (output_cc_fname , generated_array_name , array_type ,
167162 cc_array , size )
168163 generate_file (output_hdr_fname , generated_array_name , array_type ,
169164 cc_array , size )
170165
171166
172167if __name__ == '__main__' :
173- main ()
168+ main ()
0 commit comments