Skip to content

Commit a79c4fc

Browse files
committed
GPU: Add python script to convert C++ parameter header to csv file
1 parent dddf092 commit a79c4fc

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#!/usr/bin/env python3
2+
3+
import sys
4+
import os
5+
import re
6+
7+
EXCLUDE_PATTERNS = [
8+
r"^GPUCA_LB_GPUTPCCompressionGatherKernels_.*",
9+
r"^GPUCA_LB_GPUTPCNNClusterizerKernels_.*",
10+
r"^GPUCA_LB_GPUTPCDecompressionUtilKernels_.*",
11+
r"^GPUCA_LB_GPUTrackingRefitKernel_.*",
12+
r"^GPUCA_LB_GPUTRDTrackerKernels_.*",
13+
r"^GPUCA_LB_GPUMemClean16",
14+
r"^GPUCA_LB_GPUitoa",
15+
]
16+
17+
def fail(msg):
18+
print(f"Error: {msg}", file=sys.stderr)
19+
sys.exit(1)
20+
21+
def matches_any_pattern(text: str, patterns):
22+
for pattern in patterns:
23+
if re.search(pattern, text):
24+
return True
25+
return False
26+
27+
def parse_header(header_path):
28+
defines = {}
29+
define_re = re.compile(r"#define\s+(\w+)\s+(.+)")
30+
31+
with open(header_path, "r") as f:
32+
for line in f:
33+
line = line.strip()
34+
if not line.startswith("#define"):
35+
continue
36+
37+
m = define_re.match(line)
38+
if not m:
39+
continue
40+
41+
raw_name, value = m.groups()
42+
value = value.strip()
43+
44+
matched_excluded = matches_any_pattern(raw_name, EXCLUDE_PATTERNS)
45+
if matched_excluded:
46+
continue
47+
48+
# Section + stripped name
49+
if raw_name.startswith("GPUCA_LB_"):
50+
section = "LB"
51+
name = raw_name[len("GPUCA_LB_"):]
52+
elif raw_name.startswith("GPUCA_PAR_"):
53+
section = "PAR"
54+
name = raw_name[len("GPUCA_PAR_"):]
55+
else:
56+
continue
57+
58+
# Format value EXACTLY as requested
59+
if re.match(r"^\d+\s*,\s*\d+$", value):
60+
nums = [int(x.strip()) for x in value.split(",")]
61+
formatted = f'"[{nums[0]}, {nums[1]}]"' # exactly one pair of quotes
62+
elif re.match(r"^\d+$", value):
63+
formatted = value
64+
else:
65+
formatted = f'"""{value}"""' # exactly triple quotes
66+
67+
defines[(section, name)] = {
68+
"value": formatted,
69+
"matched": False,
70+
"raw": raw_name,
71+
}
72+
73+
return defines
74+
75+
def process_csv(csv_path, defines):
76+
output_lines = []
77+
current_section = None
78+
79+
with open(csv_path, "r") as f:
80+
lines = f.readlines()
81+
82+
if not lines:
83+
return []
84+
85+
# First row
86+
first = lines[0].rstrip("\n").split(",")
87+
output_lines.append(f"{first[0]},NEW")
88+
89+
for line in lines[1:]:
90+
stripped = line.rstrip("\n")
91+
92+
# Completely empty line → keep empty
93+
if stripped.strip() == "":
94+
output_lines.append("")
95+
continue
96+
97+
parts = stripped.split(",")
98+
key = parts[0].strip()
99+
100+
# Section handling
101+
if key.endswith(":"):
102+
section_name = key[:-1]
103+
if section_name in ("LB", "PAR"):
104+
current_section = section_name
105+
else:
106+
current_section = None
107+
108+
output_lines.append(f"{key},")
109+
continue
110+
111+
# Empty first column
112+
if key == "":
113+
output_lines.append("")
114+
continue
115+
116+
match_key = (current_section, key)
117+
118+
if match_key in defines:
119+
defines[match_key]["matched"] = True
120+
output_lines.append(f"{key},{defines[match_key]['value']}")
121+
else:
122+
# Ensure empty second column
123+
output_lines.append(f"{key},")
124+
125+
return output_lines
126+
127+
def validate_all_matched(defines):
128+
unmatched = [d["raw"] for d in defines.values() if not d["matched"]]
129+
if unmatched:
130+
fail("Unmatched defines: " + ", ".join(unmatched))
131+
132+
def main():
133+
if len(sys.argv) != 3:
134+
fail("Usage: script.py <input.csv> <header.h>")
135+
136+
csv_path = sys.argv[1]
137+
header_path = sys.argv[2]
138+
139+
if not os.path.isfile(csv_path):
140+
fail(f"CSV file does not exist: {csv_path}")
141+
142+
if not os.path.isfile(header_path):
143+
fail(f"Header file does not exist: {header_path}")
144+
145+
defines = parse_header(header_path)
146+
output_lines = process_csv(csv_path, defines)
147+
validate_all_matched(defines)
148+
149+
for line in output_lines:
150+
print(line)
151+
152+
if __name__ == "__main__":
153+
main()

0 commit comments

Comments
 (0)