-
Notifications
You must be signed in to change notification settings - Fork 60
Closed
Description
I found that the current version of LongLM can not load Gemma 1 or Gemma 2 model successfully. I wrote a minimum test to help reproduce the issue:
# transfromers version 4.38.2
# this example is tested with 4 RTX3090s, 24GB memory each
import warnings
warnings.filterwarnings("ignore")
import torch
import json
import time
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import SelfExtend
window_size = 1024
group_size = 32
model_name = '/tmp/gemma-2b-it/'
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
SelfExtend.apply(model, group_size, window_size, enable_flash_attention=False)
prompt = "How are you?"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
start_time = time.time()
tokens = model.generate(input_ids, max_new_tokens=4096)
answer = tokenizer.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True)
print( answer )While trying to load the model, it fails with the error message below:
$ python3 test.py
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.07it/s]
Traceback (most recent call last):
File "/var/lib/condor/execute/slot1/dir_2652801/test.py", line 22, in <module>
SelfExtend.apply(model, group_size, window_size, enable_flash_attention=False)
File "/var/lib/condor/execute/slot1/dir_2652801/SelfExtend.py", line 160, in apply
raise Exception(f"Failed to modify the attention method of {arch_name}")
Exception: Failed to modify the attention method of GemmaForCausalLM
I found that it fails in the duplicate check in the L24 of SelfExtend.py. When it fails, instance = False.
Below is a conda env export dump including package details in my Python environment:
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.7.2=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- ncurses=6.4=h6a678d5_0
- openssl=3.0.14=h5eee18b_0
- pip=24.0=py310h06a4308_0
- python=3.10.14=h955ad1f_1
- readline=8.2=h5eee18b_0
- setuptools=69.5.1=py310h06a4308_0
- sqlite=3.45.3=h5eee18b_0
- tk=8.6.14=h39e8969_0
- wheel=0.43.0=py310h06a4308_0
- xz=5.4.6=h5eee18b_1
- zlib=1.2.13=h5eee18b_1
- pip:
- accelerate==0.33.0
- aiohttp==3.9.5
- aiosignal==1.3.1
- annotated-types==0.7.0
- anyio==4.4.0
- async-timeout==4.0.3
- attrs==23.2.0
- certifi==2024.7.4
- charset-normalizer==3.3.2
- click==8.1.7
- cloudpickle==3.0.0
- cmake==3.30.1
- datasets==2.20.0
- dill==0.3.8
- diskcache==5.6.3
- distro==1.9.0
- dnspython==2.6.1
- einops==0.8.0
- email-validator==2.2.0
- exceptiongroup==1.2.2
- fastapi==0.111.1
- fastapi-cli==0.0.4
- filelock==3.15.4
- flash-attn==2.6.3
- frozenlist==1.4.1
- fsspec==2024.5.0
- h11==0.14.0
- httpcore==1.0.5
- httptools==0.6.1
- httpx==0.27.0
- huggingface-hub==0.24.2
- idna==3.7
- interegular==0.3.3
- jinja2==3.1.4
- jsonschema==4.23.0
- jsonschema-specifications==2023.12.1
- lark==1.1.9
- llvmlite==0.43.0
- lm-format-enforcer==0.10.3
- markdown-it-py==3.0.0
- markupsafe==2.1.5
- mdurl==0.1.2
- mpmath==1.3.0
- msgpack==1.0.8
- multidict==6.0.5
- multiprocess==0.70.16
- nest-asyncio==1.6.0
- networkx==3.3
- ninja==1.11.1.1
- numba==0.60.0
- numpy==1.26.4
- nvidia-cublas-cu12==12.1.3.1
- nvidia-cuda-cupti-cu12==12.1.105
- nvidia-cuda-nvrtc-cu12==12.1.105
- nvidia-cuda-runtime-cu12==12.1.105
- nvidia-cudnn-cu12==8.9.2.26
- nvidia-cufft-cu12==11.0.2.54
- nvidia-curand-cu12==10.3.2.106
- nvidia-cusolver-cu12==11.4.5.107
- nvidia-cusparse-cu12==12.1.0.106
- nvidia-ml-py==12.555.43
- nvidia-nccl-cu12==2.20.5
- nvidia-nvjitlink-cu12==12.5.82
- nvidia-nvtx-cu12==12.1.105
- openai==1.37.1
- outlines==0.0.46
- packaging==24.1
- pandas==2.2.2
- pillow==10.4.0
- prometheus-client==0.20.0
- prometheus-fastapi-instrumentator==7.0.0
- protobuf==5.27.2
- psutil==6.0.0
- py-cpuinfo==9.0.0
- pyairports==2.1.1
- pyarrow==17.0.0
- pyarrow-hotfix==0.6
- pycountry==24.6.1
- pydantic==2.8.2
- pydantic-core==2.20.1
- pygments==2.18.0
- python-dateutil==2.9.0.post0
- python-dotenv==1.0.1
- python-multipart==0.0.9
- pytz==2024.1
- pyyaml==6.0.1
- pyzmq==26.0.3
- ray==2.33.0
- referencing==0.35.1
- regex==2024.7.24
- requests==2.32.3
- rich==13.7.1
- rpds-py==0.19.1
- safetensors==0.4.3
- sentencepiece==0.2.0
- shellingham==1.5.4
- six==1.16.0
- sniffio==1.3.1
- starlette==0.37.2
- sympy==1.13.1
- tiktoken==0.7.0
- tokenizers==0.19.1
- torch==2.3.1
- torchvision==0.18.1
- tqdm==4.66.4
- transformers==4.43.3
- triton==2.3.1
- typer==0.12.3
- typing-extensions==4.12.2
- tzdata==2024.1
- urllib3==2.2.2
- uvicorn==0.30.3
- uvloop==0.19.0
- vllm==0.5.3.post1
- vllm-flash-attn==2.5.9.post1
- watchfiles==0.22.0
- websockets==12.0
- xformers==0.0.27
- xxhash==3.4.1
- yarl==1.9.4
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels