Skip to content

Commit 473f8f2

Browse files
committed
Merge branch 'merge_wheel' of https://github.com/leofang/cuda-python into merge_wheel
2 parents 20c5a99 + 4ba0090 commit 473f8f2

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

ci/tools/merge_cuda_core_wheels.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def run_command(cmd: List[str], cwd: Path = None, env: dict = os.environ) -> sub
4848

4949
def merge_wheels(wheels: List[Path], output_dir: Path) -> Path:
5050
"""Merge multiple wheels into a single wheel with version-specific binaries."""
51-
print("\n=== Merging wheels ===")
52-
print(f"Input wheels: {[w.name for w in wheels]}")
51+
print("\n=== Merging wheels ===", file=sys.stderr)
52+
print(f"Input wheels: {[w.name for w in wheels]}", file=sys.stderr)
5353

5454
if len(wheels) == 1:
5555
raise RuntimeError("only one wheel is provided, nothing to merge")
@@ -60,11 +60,11 @@ def merge_wheels(wheels: List[Path], output_dir: Path) -> Path:
6060
extracted_wheels = []
6161

6262
for i, wheel in enumerate(wheels):
63-
print(f"Extracting wheel {i + 1}/{len(wheels)}: {wheel.name}")
63+
print(f"Extracting wheel {i + 1}/{len(wheels)}: {wheel.name}", file=sys.stderr)
6464
# Extract wheel - wheel unpack creates the directory itself
6565
run_command(
6666
[
67-
"python",
67+
sys.executable,
6868
"-m",
6969
"wheel",
7070
"unpack",
@@ -100,18 +100,18 @@ def merge_wheels(wheels: List[Path], output_dir: Path) -> Path:
100100
cuda_version = wheels[i].name.split(".cu")[1].split(".")[0]
101101
base_dir = Path("cuda") / "core" / "experimental"
102102
# Copy from other wheels
103-
print(f" Copying {wheel_dir} to {base_wheel}")
103+
print(f" Copying {wheel_dir} to {base_wheel}", file=sys.stderr)
104104
shutil.copytree(wheel_dir / base_dir, base_wheel / base_dir / f"cu{cuda_version}")
105105

106106
# Overwrite the __init__.py in versioned dirs
107-
open(base_wheel / base_dir / f"cu{cuda_version}" / "__init__.py", "w").close()
107+
os.truncate(base_wheel / base_dir / f"cu{cuda_version}" / "__init__.py", 0)
108108

109109
# The base dir should only contain __init__.py, the include dir, and the versioned dirs
110-
files_to_remove = os.listdir(base_wheel / base_dir)
110+
files_to_remove = os.scandir(base_wheel / base_dir)
111111
for f in files_to_remove:
112-
f_abspath = base_wheel / base_dir / f
113-
if f not in ("__init__.py", "cu12", "cu13", "include"):
114-
if os.path.isdir(f_abspath):
112+
f_abspath = f.path
113+
if f.name not in ("__init__.py", "cu12", "cu13", "include"):
114+
if f.is_dir():
115115
shutil.rmtree(f_abspath)
116116
else:
117117
os.remove(f_abspath)
@@ -120,15 +120,12 @@ def merge_wheels(wheels: List[Path], output_dir: Path) -> Path:
120120
output_dir.mkdir(parents=True, exist_ok=True)
121121

122122
# Create a clean wheel name without CUDA version suffixes
123-
base_wheel_name = wheels[0].name
124-
# Remove any .cu* suffix from the wheel name
125-
if ".cu" in base_wheel_name:
126-
base_wheel_name = base_wheel_name.split(".cu")[0] + ".whl"
123+
base_wheel_name = wheels[0].with_suffix(".whl").name
127124

128-
print(f"Repacking merged wheel as: {base_wheel_name}")
125+
print(f"Repacking merged wheel as: {base_wheel_name}", file=sys.stderr)
129126
run_command(
130127
[
131-
"python",
128+
sys.executable,
132129
"-m",
133130
"wheel",
134131
"pack",
@@ -144,7 +141,7 @@ def merge_wheels(wheels: List[Path], output_dir: Path) -> Path:
144141
raise RuntimeError("Failed to create merged wheel")
145142

146143
merged_wheel = output_wheels[0]
147-
print(f"Successfully merged wheel: {merged_wheel}")
144+
print(f"Successfully merged wheel: {merged_wheel}", file=sys.stderr)
148145
return merged_wheel
149146

150147

@@ -156,32 +153,32 @@ def main():
156153

157154
args = parser.parse_args()
158155

159-
print("cuda.core Wheel Merger")
160-
print("======================")
156+
print("cuda.core Wheel Merger", file=sys.stderr)
157+
print("======================", file=sys.stderr)
161158

162159
# Convert wheel paths to Path objects and validate
163160
wheels = []
164161
for wheel_path in args.wheels:
165162
wheel = Path(wheel_path)
166163
if not wheel.exists():
167-
print(f"Error: Wheel not found: {wheel}")
164+
print(f"Error: Wheel not found: {wheel}", file=sys.stderr)
168165
sys.exit(1)
169166
if not wheel.name.endswith(".whl"):
170-
print(f"Error: Not a wheel file: {wheel}")
167+
print(f"Error: Not a wheel file: {wheel}", file=sys.stderr)
171168
sys.exit(1)
172169
wheels.append(wheel)
173170

174171
if not wheels:
175-
print("Error: No wheels provided")
172+
print("Error: No wheels provided", file=sys.stderr)
176173
sys.exit(1)
177174

178175
output_dir = Path(args.output_dir)
179176

180177
# Check that we have wheel tool available
181178
try:
182-
run_command(["python", "-m", "wheel", "--help"])
183-
except Exception:
184-
print("Error: wheel package not available. Install with: pip install wheel")
179+
import wheel
180+
except ImportError:
181+
print("Error: wheel package not available. Install with: pip install wheel", file=sys.stderr)
185182
sys.exit(1)
186183

187184
# Merge the wheels

0 commit comments

Comments
 (0)