Skip to content

Commit dec1a10

Browse files
authored
[Fix] Ci windows 2 (#3330)
* Fix * Fix
1 parent 438169a commit dec1a10

2 files changed

Lines changed: 41 additions & 47 deletions

File tree

.bazelrc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ build --workspace_status_command=./tools/workspace_status.sh
2828
build --noenable_bzlmod
2929

3030
# Use the following C++ standard
31-
build --cxxopt -std=c++17
32-
build:windows --cxxopt=/std:c++17
33-
31+
build --cxxopt=-std=c++17
3432
build:windows --cxxopt=/std:c++17
3533

3634
# Common options for --config=ci
@@ -46,6 +44,7 @@ build:ci --verbose_failures
4644
build:ci --test_output=errors
4745

4846
# Windows CI options
47+
build:windows_ci --config=windows
4948
build:windows_ci --curses=no
5049
build:windows_ci --color=no
5150
build:windows_ci --noshow_progress

tensorflow/lite/micro/tools/generate_cc_arrays.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -31,32 +31,29 @@ def generate_file(out_fname, array_name, array_type, array_contents, size):
3131
"""Write an array of values to a CC or header file."""
3232
os.makedirs(os.path.dirname(out_fname), exist_ok=True)
3333
if out_fname.endswith('.cc'):
34-
out_cc_file = open(out_fname, 'w')
35-
out_cc_file.write('#include <cstdint>\n\n')
36-
out_cc_file.write('#include "{}"\n\n'.format(
37-
out_fname.split('genfiles/')[-1].replace('.cc', '.h')))
38-
out_cc_file.write('alignas(16) const {} {}[] = {{'.format(
39-
array_type, array_name))
40-
out_cc_file.write(array_contents)
41-
out_cc_file.write('};\n')
42-
out_cc_file.close()
34+
with open(out_fname, 'w') as out_cc_file:
35+
out_cc_file.write('#include <cstdint>\n\n')
36+
# Header include path logic, maintaining compatibility with genfiles/ structure.
37+
header_path = out_fname.split('genfiles/')[-1].replace('.cc', '.h')
38+
out_cc_file.write('#include "{}"\n\n'.format(header_path))
39+
out_cc_file.write('alignas(16) const {} {}[] = {{'.format(
40+
array_type, array_name))
41+
out_cc_file.write(array_contents)
42+
out_cc_file.write('};\n')
4343
elif out_fname.endswith('.h'):
44-
out_hdr_file = open(out_fname, 'w')
45-
out_hdr_file.write('#include <cstdint>\n\n')
46-
out_hdr_file.write('constexpr unsigned int {}_size = {};\n'.format(
47-
array_name, str(size)))
48-
out_hdr_file.write('extern const {} {}[];\n'.format(
49-
array_type, array_name))
50-
out_hdr_file.close()
44+
with open(out_fname, 'w') as out_hdr_file:
45+
out_hdr_file.write('#include <cstdint>\n\n')
46+
out_hdr_file.write('constexpr unsigned int {}_size = {};\n'.format(
47+
array_name, str(size)))
48+
out_hdr_file.write('extern const {} {}[];\n'.format(
49+
array_type, array_name))
5150
else:
5251
raise ValueError('generated file must be end with .cc or .h')
5352

5453

5554
def bytes_to_hexstring(buffer):
5655
"""Convert a byte array to a hex string."""
57-
hex_values = [hex(buffer[i]) for i in range(len(buffer))]
58-
out_string = ','.join(hex_values)
59-
return out_string
56+
return ','.join([hex(b) for b in buffer])
6057

6158

6259
def generate_array(input_fname):
@@ -92,31 +89,31 @@ def generate_array(input_fname):
9289
data_1d = data.flatten()
9390
out_string = ','.join([str(x) for x in data_1d])
9491
return [len(data_1d), out_string]
95-
9692
else:
9793
raise ValueError('input file must be .tflite, .bmp, .wav or .csv')
9894

9995

100-
def get_array_name(input_fname):
101-
# Normalize potential relative path to remove additional dot.
102-
abs_fname = os.path.abspath(input_fname)
103-
base_array_name = 'g_' + abs_fname.split('.')[-2].split('/')[-1]
96+
def get_array_name_and_type(input_fname):
97+
"""Return the array name and type for a given input file."""
98+
# Use os.path.basename to correctly handle both Unix and Windows paths.
99+
base_fname = os.path.basename(input_fname)
100+
# Original logic extracted the filename part between the last two dots.
101+
name_parts = base_fname.split('.')
102+
name_no_ext = name_parts[-2] if len(name_parts) >= 2 else base_fname
103+
base_array_name = 'g_' + name_no_ext
104+
104105
if input_fname.endswith('.tflite'):
105106
return [base_array_name + '_model_data', 'unsigned char']
106107
elif input_fname.endswith('.bmp'):
107108
return [base_array_name + '_image_data', 'unsigned char']
108109
elif input_fname.endswith('.wav'):
109110
return [base_array_name + '_audio_data', 'int16_t']
110-
elif input_fname.endswith('_int32.csv'):
111-
return [base_array_name + '_test_data', 'int32_t']
112-
elif input_fname.endswith('_int16.csv'):
113-
return [base_array_name + '_test_data', 'int16_t']
114-
elif input_fname.endswith('_int8.csv'):
115-
return [base_array_name + '_test_data', 'int8_t']
116-
elif input_fname.endswith('_float.csv'):
117-
return [base_array_name + '_test_data', 'float']
118-
elif input_fname.endswith('npy'):
119-
return [base_array_name + '_test_data', 'float']
111+
elif input_fname.endswith(('_int32.csv', '_int16.csv', '_int8.csv', '_float.csv', '.csv', '.npy')):
112+
return [base_array_name + '_test_data', 'int32_t' if '_int32.csv' in input_fname else
113+
'int16_t' if '_int16.csv' in input_fname else
114+
'int8_t' if '_int8.csv' in input_fname else 'float']
115+
else:
116+
return [base_array_name + '_data', 'unsigned char']
120117

121118

122119
def main():
@@ -135,7 +132,7 @@ def main():
135132
if args.output.endswith('.cc') or args.output.endswith('.h'):
136133
assert len(args.inputs) == 1
137134
size, cc_array = generate_array(args.inputs[0])
138-
generated_array_name, array_type = get_array_name(args.inputs[0])
135+
generated_array_name, array_type = get_array_name_and_type(args.inputs[0])
139136
generate_file(args.output, generated_array_name, array_type, cc_array,
140137
size)
141138
else:
@@ -144,15 +141,13 @@ def main():
144141
output_base_fname = os.path.join(args.output,
145142
os.path.splitext(input_file)[0])
146143
if input_file.endswith('.tflite'):
147-
output_base_fname = output_base_fname + '_model_data'
144+
output_base_fname += '_model_data'
148145
elif input_file.endswith('.bmp'):
149-
output_base_fname = output_base_fname + '_image_data'
146+
output_base_fname += '_image_data'
150147
elif input_file.endswith('.wav'):
151-
output_base_fname = output_base_fname + '_audio_data'
152-
elif input_file.endswith('.csv'):
153-
output_base_fname = output_base_fname + '_test_data'
154-
elif input_file.endswith('.npy'):
155-
output_base_fname = output_base_fname + '_test_data'
148+
output_base_fname += '_audio_data'
149+
elif input_file.endswith(('.csv', '.npy')):
150+
output_base_fname += '_test_data'
156151
else:
157152
raise ValueError(
158153
'input file must be .tflite, .bmp, .wav , .npy or .csv')
@@ -162,12 +157,12 @@ def main():
162157
print(output_cc_fname)
163158
output_hdr_fname = output_base_fname + '.h'
164159
size, cc_array = generate_array(input_file)
165-
generated_array_name, array_type = get_array_name(input_file)
160+
generated_array_name, array_type = get_array_name_and_type(input_file)
166161
generate_file(output_cc_fname, generated_array_name, array_type,
167162
cc_array, size)
168163
generate_file(output_hdr_fname, generated_array_name, array_type,
169164
cc_array, size)
170165

171166

172167
if __name__ == '__main__':
173-
main()
168+
main()

0 commit comments

Comments
 (0)