hi, I got OOM error while fine tuning with qwen-14b-chat and the default model.
using
accelerate launch --config_file configs/deepspeed_zero3.yaml --multi_gpu --num_processes=8 --main_process_port 29501 spin/run_spin.py configs/config.yaml --num_train_epochs=3 --output_dir="xxx/spin_outputs/iter0-ckpt"

system info
absl-py 2.1.0
accelerate 0.23.0
aiohttp 3.9.5
aioprometheus 23.12.0
aiosignal 1.3.1
annotated-types 0.7.0
anyio 4.4.0
async-timeout 4.0.3
attrs 23.2.0
bitsandbytes 0.41.2.post2
certifi 2024.6.2
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
cmake 3.29.6
contourpy 1.2.1
cycler 0.12.1
datasets 2.14.6
deepspeed 0.12.2
dill 0.3.7
diskcache 5.6.3
dnspython 2.6.1
docstring_parser 0.16
einops 0.8.0
email_validator 2.1.1
evaluate 0.4.0
exceptiongroup 1.2.1
fastapi 0.111.0
fastapi-cli 0.0.4
filelock 3.15.1
flash_attn 2.5.9.post1
fonttools 4.53.0
frozenlist 1.4.1
fsspec 2023.10.0
grpcio 1.64.1
h11 0.14.0
hjson 3.1.0
httpcore 1.0.5
httptools 0.6.1
httpx 0.27.0
huggingface-hub 0.23.3
idna 3.7
interegular 0.3.3
Jinja2 3.1.4
joblib 1.4.2
jsonlines 4.0.0
jsonschema 4.22.0
jsonschema-specifications 2023.12.1
kiwisolver 1.4.5
lark 1.1.9
llvmlite 0.43.0
Markdown 3.6
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.9.0
mdurl 0.1.2
mpmath 1.3.0
msgpack 1.0.8
multidict 6.0.5
multiprocess 0.70.15
nest-asyncio 1.6.0
networkx 3.3
ninja 1.11.1.1
numba 0.60.0
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.5.40
nvidia-nvtx-cu12 12.1.105
opencv-python 4.10.0.84
orjson 3.10.5
outlines 0.0.34
packaging 24.1
pandas 2.2.2
peft 0.6.1
pillow 10.4.0
pip 24.0
prometheus_client 0.20.0
protobuf 3.20.2
psutil 5.9.8
py-cpuinfo 9.0.0
py4j 0.10.9.7
pyarrow 16.1.0
pydantic 2.7.4
pydantic_core 2.18.4
Pygments 2.18.0
pynvml 11.5.0
pyparsing 3.1.2
pyspark 3.5.1
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-multipart 0.0.9
pytz 2024.1
PyYAML 6.0.1
quantile-python 1.1
ray 2.24.0
referencing 0.35.1
regex 2024.5.15
requests 2.32.3
responses 0.18.0
rich 13.7.1
rpds-py 0.18.1
safetensors 0.4.3
scipy 1.13.1
seaborn 0.13.2
sentencepiece 0.2.0
setuptools 69.5.1
shellingham 1.5.4
shtab 1.7.1
six 1.16.0
sniffio 1.3.1
spin 0.1.0.dev0
starlette 0.37.2
sympy 1.12.1
tensorboard 2.17.0
tensorboard-data-server 0.7.2
tiktoken 0.6.0
tokenizers 0.15.2
torch 2.1.0
torchvision 0.18.1
tqdm 4.66.4
transformers 4.36.2
transformers-stream-generator 0.0.5
triton 2.1.0
trl 0.7.4
typer 0.12.3
typing_extensions 4.12.2
tyro 0.8.4
tzdata 2024.1
ujson 5.10.0
ultralytics-thop 2.0.0
urllib3 2.2.1
uvicorn 0.30.1
uvloop 0.19.0
vllm 0.3.0
watchfiles 0.22.0
websockets 12.0
Werkzeug 3.0.3
wheel 0.43.0
xformers 0.0.23.post1
xxhash 3.4.1
yarl 1.9.4
Thanks for your help in advance!
hi, I got OOM error while fine tuning with qwen-14b-chat and the default model.
using
accelerate launch --config_file configs/deepspeed_zero3.yaml --multi_gpu --num_processes=8 --main_process_port 29501 spin/run_spin.py configs/config.yaml --num_train_epochs=3 --output_dir="xxx/spin_outputs/iter0-ckpt"system info
absl-py 2.1.0
accelerate 0.23.0
aiohttp 3.9.5
aioprometheus 23.12.0
aiosignal 1.3.1
annotated-types 0.7.0
anyio 4.4.0
async-timeout 4.0.3
attrs 23.2.0
bitsandbytes 0.41.2.post2
certifi 2024.6.2
charset-normalizer 3.3.2
click 8.1.7
cloudpickle 3.0.0
cmake 3.29.6
contourpy 1.2.1
cycler 0.12.1
datasets 2.14.6
deepspeed 0.12.2
dill 0.3.7
diskcache 5.6.3
dnspython 2.6.1
docstring_parser 0.16
einops 0.8.0
email_validator 2.1.1
evaluate 0.4.0
exceptiongroup 1.2.1
fastapi 0.111.0
fastapi-cli 0.0.4
filelock 3.15.1
flash_attn 2.5.9.post1
fonttools 4.53.0
frozenlist 1.4.1
fsspec 2023.10.0
grpcio 1.64.1
h11 0.14.0
hjson 3.1.0
httpcore 1.0.5
httptools 0.6.1
httpx 0.27.0
huggingface-hub 0.23.3
idna 3.7
interegular 0.3.3
Jinja2 3.1.4
joblib 1.4.2
jsonlines 4.0.0
jsonschema 4.22.0
jsonschema-specifications 2023.12.1
kiwisolver 1.4.5
lark 1.1.9
llvmlite 0.43.0
Markdown 3.6
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.9.0
mdurl 0.1.2
mpmath 1.3.0
msgpack 1.0.8
multidict 6.0.5
multiprocess 0.70.15
nest-asyncio 1.6.0
networkx 3.3
ninja 1.11.1.1
numba 0.60.0
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.5.40
nvidia-nvtx-cu12 12.1.105
opencv-python 4.10.0.84
orjson 3.10.5
outlines 0.0.34
packaging 24.1
pandas 2.2.2
peft 0.6.1
pillow 10.4.0
pip 24.0
prometheus_client 0.20.0
protobuf 3.20.2
psutil 5.9.8
py-cpuinfo 9.0.0
py4j 0.10.9.7
pyarrow 16.1.0
pydantic 2.7.4
pydantic_core 2.18.4
Pygments 2.18.0
pynvml 11.5.0
pyparsing 3.1.2
pyspark 3.5.1
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-multipart 0.0.9
pytz 2024.1
PyYAML 6.0.1
quantile-python 1.1
ray 2.24.0
referencing 0.35.1
regex 2024.5.15
requests 2.32.3
responses 0.18.0
rich 13.7.1
rpds-py 0.18.1
safetensors 0.4.3
scipy 1.13.1
seaborn 0.13.2
sentencepiece 0.2.0
setuptools 69.5.1
shellingham 1.5.4
shtab 1.7.1
six 1.16.0
sniffio 1.3.1
spin 0.1.0.dev0
starlette 0.37.2
sympy 1.12.1
tensorboard 2.17.0
tensorboard-data-server 0.7.2
tiktoken 0.6.0
tokenizers 0.15.2
torch 2.1.0
torchvision 0.18.1
tqdm 4.66.4
transformers 4.36.2
transformers-stream-generator 0.0.5
triton 2.1.0
trl 0.7.4
typer 0.12.3
typing_extensions 4.12.2
tyro 0.8.4
tzdata 2024.1
ujson 5.10.0
ultralytics-thop 2.0.0
urllib3 2.2.1
uvicorn 0.30.1
uvloop 0.19.0
vllm 0.3.0
watchfiles 0.22.0
websockets 12.0
Werkzeug 3.0.3
wheel 0.43.0
xformers 0.0.23.post1
xxhash 3.4.1
yarl 1.9.4
Thanks for your help in advance!