Skip to content

Commit b7f2418

Browse files
authored
Merge pull request #18 from Gleghorn-Lab/cleaning_up
Cleaning up
2 parents 453da5d + 95c20b8 commit b7f2418

5 files changed

Lines changed: 306 additions & 33 deletions

File tree

README.md

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,34 +42,83 @@ Another limitation of traditional pLM training lies in MLM itself, which results
4242

4343
### Quick Start
4444

45+
On many popular HPC platforms will be missing Python headers `Python.h` which break `torch.compile`. To fix this, run the following code:
46+
47+
**Debian/Ubuntu:**
48+
49+
```bash
50+
sudo apt-get update
51+
sudo apt-get install -y python3.12-dev build-essential
52+
```
53+
54+
If python3.12-dev is not found: `sudo apt-get install -y python3-dev build-essential`
55+
56+
**Fedora/RHEL:**
57+
```base
58+
sudo dnf groupinstall -y "Development Tools"
59+
sudo dnf install -y python3-devel
60+
```
61+
62+
**openSUSE:**
63+
```bash
64+
sudo zypper install -y python3-devel gcc gcc-c++ make
65+
```
66+
67+
**Arch:**
68+
```bash
69+
sudo pacman -Sy --noconfirm base-devel python
70+
```
71+
4572
```bash
4673
git clone https://github.com/Synthyra/SpeedrunningPLMs.git
4774
cd SpeedrunningPLMs
48-
pip install huggingface_hub
49-
python data/download_omgprot50.py # Add --num_chunks 100 to download less data for smaller runs
5075
```
5176

52-
### ARM64 Systems (GH200)
77+
We offer a `docker` or Python `venv` option for running the code.
78+
79+
#### Python venv
80+
81+
```bash
82+
chmod +x setup_plm.sh
83+
./setup_plm.sh
84+
source ~/plm_venv/bin/activate
85+
```
86+
87+
**Download data**
88+
89+
```bash
90+
python data/download_data.py --data_name uniref50 --num_chunks 10
91+
```
92+
93+
`--data_name` can be uniref50, omg_prot50, or og_prot90 which have varying amount of chunks. Each chunk is ~100 million ESM2 tokens.
94+
`--num_chunks 500` will download everything.
95+
96+
**Train model**
5397

5498
```bash
55-
pip install -r requirements.txt -U
56-
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128 -U
5799
torchrun --standalone --nproc_per_node=NUM_GPUS_ON_YOUR_SYSTEM train.py
58100
```
59101

60-
### Docker Installation (Non-ARM64 Systems)
102+
or
103+
104+
```bash
105+
python -m train
106+
```
107+
108+
#### Docker
61109

62110
```bash
63-
git clone https://github.com/Synthyra/SpeedrunningPLMs.git
64-
cd SpeedrunningPLMs
65111
sudo docker build -t speedrun_plm .
66112
sudo docker run --gpus all --shm-size=128g -v ${PWD}:/workspace speedrun_plm \
67113
torchrun --standalone --nproc_per_node=NUM_GPUS_ON_YOUR_SYSTEM train.py \
68-
--token YOUR_HUGGINGFACE_TOKEN \
69-
--wandb_token YOUR_WANDB_TOKEN
114+
--token YOUR_HUGGINGFACE_TOKEN \ # a write token is required to save to hugginface hub
115+
--wandb_token YOUR_WANDB_TOKEN # optional
116+
--yaml_path example_yamls/default.yaml # point to your experimental yaml
70117
```
71118

72-
> **Note for ARM64 (GH200) Systems**: The Docker image currently experiences compatibility issues on ARM64 systems due to Triton version conflicts that break `torch.compile`. If you have a solution for this issue, please open an issue or pull request.
119+
**Note for ARM64 (for example GH200) Systems**:
120+
121+
The Docker image currently experiences compatibility issues on ARM64 systems due to Triton version conflicts that break `torch.compile`. If you have a solution for this issue, please open an issue or pull request.
73122

74123
## Running Experiments
75124

@@ -161,18 +210,21 @@ Batch sizes of 8×64×1024 (524,288) or 4×64×1024 (262,144) tokens have demons
161210

162211
Our optimized trainer and dataloader incorporate prefetching and multiple workers per GPU to accelerate data handling, with masking performed at the data loading stage. This results in improved throughput, particularly beneficial for systems with slower disk I/O.
163212

164-
**Training Throughput** (Default model: 133M parameters, 24 blocks, UNet + Value embeddings, 768 hidden size):
213+
**Training Throughput**
214+
215+
(Default model: 133M parameters, 24 blocks, UNet + Value embeddings, 768 hidden size):
165216

166-
| Hardware | Tokens/Second |
167-
|----------|---------------|
168-
| 1×H100 | 275,900 |
169-
| 1×GH200 | 1,011,800 |
170-
| 4×A100 80GB PCIe Gen4 | 340,700 |
171-
| 8×H100 SXM5 | 2,149,500 |
217+
| Hardware | Vendor | Cost/Hour | Tokens/Second |
218+
|----------|--------|-----------|---------------|
219+
| 1 × H100 80GB SXM5, 26 vCPUs | Lambda Labs | $3.29 | 275,900 |
220+
| 1 x H200 142GB NVLink, 16 vCPUs | Nebius | $3.64 | 327,680 |
221+
| 1 × GH200 96GB ARM64, 64 vCPUs | Lambda Labs | $1.49 | 1,011,800 |
222+
| 4 × A100 80GB PCIe Gen4, 96 vCPUs | Azure | $18.36 | 340,700 |
223+
| 8 × H100 80GB SXM5, 208 vCPUs | Lambda Labs | $23.92 | 2,149,500 |
172224

173225
### Cost Analysis
174226

175-
Based on current performance metrics, training ESM2-150M equivalent (2M token batch size, 500K steps) would require approximately 129 hours at $3,091 using 8×H100 systems (Lambda pricing as of June 2025). This represents a significant improvement over the estimated $46,000 cost for ESM2-150M training via AWS in 2022.
227+
Based on current performance metrics, training ESM2-150M equivalent with the old optimizer / architecture (2M token batch size, 500K steps) would require approximately 129 hours at $3,091 using 8×H100 systems (Lambda pricing as of June 2025). This represents a significant improvement over the estimated $46,000 cost for ESM2-150M training via AWS in 2022. Obviously with better achitecture, data, and optimizers, etc. (our improvements) this is dramatically decreased even further.
176228

177229
Memory and disk I/O remain primary bottlenecks on some systems, as evidenced by the GH200's superior performance. Further optimizations to data loading and prefetching may yield additional improvements.
178230

data/download_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def get(fname, data_name):
88
local_dir = os.path.join(os.path.dirname(__file__), data_name)
99
if not os.path.exists(os.path.join(local_dir, fname)):
1010
try:
11+
print(f"Downloading {fname} from Synthyra/{data_name}_packed")
1112
hf_hub_download(repo_id=f"Synthyra/{data_name}_packed", filename=fname, repo_type="dataset", local_dir=local_dir)
1213
except Exception as e:
1314
print(f"Error downloading {fname}: {e}")
@@ -17,7 +18,7 @@ def get(fname, data_name):
1718

1819
if __name__ == "__main__":
1920
parser = argparse.ArgumentParser(description="Download data from huggingface")
20-
parser.add_argument("-d", "--data_name", type=str, default="uniref50", help="Name of the dataset")
21+
parser.add_argument("-d", "--data_name", type=str, default="uniref50", help="Name of the dataset, uniref50, omg_prot50, or og_prot90")
2122
parser.add_argument("-n", "--num_chunks", type=int, default=100, help="Number of chunks to download")
2223
# each chunk is 100M tokens
2324
args = parser.parse_args()

example_yamls/default.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ token_dropout: false
2828
bfloat16: true
2929

3030
# Data Configuration
31-
input_bin: "data/uniref50/uniref50_train_*.bin"
32-
input_valid_bin: "data/uniref50/uniref50_valid_*.bin"
33-
input_test_bin: "data/uniref50/uniref50_test_*.bin"
3431
mlm: false # Masked Language Modeling
3532
mask_rate: 0.2
3633
mask_rate_schedule: true

setup_plm.sh

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
#!/bin/bash
2+
3+
# chmod +x setup_plm.sh
4+
# ./setup_plm.sh
5+
6+
# Strict mode for safer scripting
7+
set -euo pipefail
8+
9+
echo "Setting up Python virtual environment for PLM training..."
10+
11+
# Configurable variables (can be overridden via environment)
12+
: "${VENV_DIR:=$HOME/plm_venv}"
13+
: "${PYTORCH_CUDA_URL:=https://download.pytorch.org/whl/cu128}"
14+
15+
# Nuke any existing venv at target path to ensure a clean setup
16+
echo "Removing existing venv at $VENV_DIR (if any)..."
17+
if [ -n "${VIRTUAL_ENV:-}" ] && [ "${VIRTUAL_ENV}" = "${VENV_DIR}" ]; then
18+
deactivate || true
19+
fi
20+
rm -rf "$VENV_DIR"
21+
22+
# Create a fresh virtual environment
23+
python3 -m venv "$VENV_DIR"
24+
25+
# Activate virtual environment
26+
source "$VENV_DIR/bin/activate"
27+
28+
# Update pip and setuptools
29+
echo "Upgrading pip and setuptools..."
30+
pip install --upgrade pip setuptools wheel
31+
32+
# Install torch and torchvision (CUDA wheel index can be overridden)
33+
echo "Installing torch and torchvision from: $PYTORCH_CUDA_URL"
34+
pip install --force-reinstall torch torchvision --index-url "$PYTORCH_CUDA_URL"
35+
36+
# Install project requirements
37+
echo "Installing requirements..."
38+
pip install -r requirements.txt
39+
40+
# Ensure ninja is available for Triton/Inductor builds
41+
python - <<'PY' >/dev/null 2>&1 || true
42+
import importlib
43+
exit(0 if importlib.util.find_spec('ninja') else 1)
44+
PY
45+
if [ "$?" -ne 0 ]; then
46+
echo "Installing ninja..."
47+
pip install --upgrade ninja
48+
fi
49+
50+
# Check for system build deps (Python.h, gcc) and optionally install if permitted
51+
echo "Checking system build dependencies..."
52+
PY_VER=$(python - <<'PY'
53+
import sys
54+
print(f"{sys.version_info.major}.{sys.version_info.minor}")
55+
PY
56+
)
57+
PY_INCLUDE_DIR=$(python - <<'PY'
58+
import sysconfig
59+
print(sysconfig.get_paths()["include"])
60+
PY
61+
)
62+
63+
if ! command -v gcc >/dev/null 2>&1; then
64+
echo "Warning: gcc is not installed. torch.compile may fail to build extensions."
65+
echo "Install a compiler toolchain (e.g., build-essential on Debian/Ubuntu)."
66+
fi
67+
68+
if [ ! -f "$PY_INCLUDE_DIR/Python.h" ]; then
69+
echo "Warning: Python.h not found at: $PY_INCLUDE_DIR"
70+
echo "torch.compile may fail to build small helper extensions."
71+
if [ "${INSTALL_SYSTEM_DEPS:-0}" = "1" ]; then
72+
echo "Attempting to install Python development headers (requires sudo)..."
73+
if command -v apt-get >/dev/null 2>&1; then
74+
sudo -n apt-get update || true
75+
sudo -n apt-get install -y "python${PY_VER}-dev" python3-dev build-essential || true
76+
elif command -v dnf >/dev/null 2>&1; then
77+
sudo -n dnf groupinstall -y "Development Tools" || true
78+
sudo -n dnf install -y python3-devel || true
79+
elif command -v yum >/dev/null 2>&1; then
80+
sudo -n yum groupinstall -y "Development Tools" || true
81+
sudo -n yum install -y python3-devel || true
82+
elif command -v zypper >/dev/null 2>&1; then
83+
sudo -n zypper install -y python3-devel gcc gcc-c++ make || true
84+
elif command -v pacman >/dev/null 2>&1; then
85+
sudo -n pacman -Sy --noconfirm base-devel python || true
86+
fi
87+
else
88+
echo "To install headers:"
89+
echo "- Debian/Ubuntu: sudo apt-get install -y python3-dev python${PY_VER}-dev build-essential"
90+
echo "- Fedora/RHEL: sudo dnf install -y python3-devel @development-tools"
91+
echo "- CentOS: sudo yum install -y python3-devel 'Development Tools'"
92+
echo "- openSUSE: sudo zypper install -y python3-devel gcc gcc-c++ make"
93+
echo "- Arch: sudo pacman -Sy --noconfirm base-devel python"
94+
echo "Then re-run this script. You can also set INSTALL_SYSTEM_DEPS=1 to let the script attempt installation."
95+
fi
96+
fi
97+
98+
# Detect CUDA toolkit (if present) to help dynamic linker
99+
CUDA_HOME=""
100+
if [ -d "/usr/local/cuda" ]; then
101+
CUDA_HOME="/usr/local/cuda"
102+
else
103+
# Pick the highest versioned CUDA directory if multiple exist
104+
latest_cuda_dir=$(ls -d /usr/local/cuda-12* 2>/dev/null | sort -V | tail -n1 || true)
105+
if [ -n "${latest_cuda_dir}" ] && [ -d "${latest_cuda_dir}" ]; then
106+
CUDA_HOME="${latest_cuda_dir}"
107+
fi
108+
fi
109+
110+
# Locate torch's bundled shared libs directory
111+
TORCH_LIB_DIR=$(python - <<'PY'
112+
import os, torch
113+
print(os.path.join(os.path.dirname(torch.__file__), 'lib'))
114+
PY
115+
)
116+
117+
# Export runtime library paths for this session
118+
if [ -d "$TORCH_LIB_DIR" ]; then
119+
export LD_LIBRARY_PATH="$TORCH_LIB_DIR:${LD_LIBRARY_PATH:-}"
120+
fi
121+
if [ -n "$CUDA_HOME" ] && [ -d "$CUDA_HOME/lib64" ]; then
122+
export CUDA_HOME
123+
export LD_LIBRARY_PATH="$CUDA_HOME/lib64:${LD_LIBRARY_PATH:-}"
124+
export PATH="$CUDA_HOME/bin:$PATH"
125+
fi
126+
127+
# Persist environment exports inside the venv activate script (idempotent)
128+
ACTIVATE_FILE="$VENV_DIR/bin/activate"
129+
MARKER="# === PLM_SETUP CUDA/Torch dynamic libs ==="
130+
131+
# Remove any previously inserted PLM setup block (handles old end marker too)
132+
if grep -q "$MARKER" "$ACTIVATE_FILE"; then
133+
awk -v start="$MARKER" -v end1="# ============================================" -v end2="# === END PLM_SETUP ===" '
134+
BEGIN{skip=0}
135+
$0 ~ start {skip=1; next}
136+
skip==1 && ($0 ~ end1 || $0 ~ end2) {skip=0; next}
137+
skip==0 {print}
138+
' "$ACTIVATE_FILE" > "$ACTIVATE_FILE.tmp" && mv "$ACTIVATE_FILE.tmp" "$ACTIVATE_FILE"
139+
fi
140+
141+
# Append a safe, single-line Python invocation version of the block
142+
cat >> "$ACTIVATE_FILE" <<'EOF'
143+
# === PLM_SETUP CUDA/Torch dynamic libs ===
144+
# Add torch's bundled libs to runtime path for torch.compile/triton
145+
export LD_LIBRARY_PATH="$(python -c 'import os, torch, sys; sys.stdout.write(os.path.join(os.path.dirname(torch.__file__), "lib"))'):${LD_LIBRARY_PATH:-}"
146+
# Optionally add CUDA toolkit if present
147+
if [ -d /usr/local/cuda ]; then export CUDA_HOME=/usr/local/cuda; fi
148+
if [ -z "${CUDA_HOME:-}" ]; then latest=$(ls -d /usr/local/cuda-12* 2>/dev/null | sort -V | tail -n1 || true); if [ -n "$latest" ]; then export CUDA_HOME="$latest"; fi; fi
149+
if [ -n "${CUDA_HOME:-}" ] && [ -d "$CUDA_HOME/lib64" ]; then export LD_LIBRARY_PATH="$CUDA_HOME/lib64:$LD_LIBRARY_PATH"; export PATH="$CUDA_HOME/bin:$PATH"; fi
150+
# === END PLM_SETUP ===
151+
EOF
152+
153+
154+
# List installed packages for verification
155+
echo -e "\nInstalled packages:"
156+
pip list
157+
158+
# Quick diagnostics
159+
echo -e "\nDiagnostics:"
160+
python - <<'PY'
161+
import os, torch, sysconfig
162+
print('torch_version:', torch.__version__)
163+
print('torch_cuda_version:', torch.version.cuda)
164+
print('cuda_is_available:', torch.cuda.is_available())
165+
print('torch_lib_dir:', os.path.join(os.path.dirname(torch.__file__), 'lib'))
166+
inc = sysconfig.get_paths().get('include')
167+
print('python_include_dir:', inc)
168+
print('python_h_exists:', os.path.exists(os.path.join(inc or '', 'Python.h')))
169+
try:
170+
import triton # noqa: F401
171+
print('triton_import: ok')
172+
except Exception as e:
173+
print('triton_import: fail ->', e)
174+
if torch.cuda.is_available() and inc and os.path.exists(os.path.join(inc, 'Python.h')):
175+
try:
176+
f = torch.compile(lambda t: t + 1)
177+
x = torch.randn(16, device='cuda')
178+
y = f(x)
179+
print('torch.compile_smoke: ok (y_cuda:', y.is_cuda, ')')
180+
except Exception as e:
181+
print('torch.compile_smoke: fail ->', e)
182+
else:
183+
reason = []
184+
if not torch.cuda.is_available():
185+
reason.append('no CUDA device')
186+
if not (inc and os.path.exists(os.path.join(inc, 'Python.h'))):
187+
reason.append('no Python.h')
188+
print('torch.compile_smoke: skipped (' + ', '.join(reason) + ')')
189+
PY
190+
191+
# Instructions for future use
192+
echo -e "\n======================="
193+
echo "Setup complete!"
194+
echo "======================="
195+
echo "To activate this environment in the future, run:"
196+
echo " source \"$VENV_DIR/bin/activate\""
197+
echo ""
198+
echo "To deactivate the environment, simply run:"
199+
echo " deactivate"
200+
echo ""
201+
echo "Your virtual environment is located at: $VENV_DIR"
202+
echo "======================="
203+

0 commit comments

Comments
 (0)