Skip to content

LongLM isn't compatible with gemma-2-27b-it or gemma-2b-it #46

@uebian

Description

@uebian

I found that the current version of LongLM can not load Gemma 1 or Gemma 2 model successfully. I wrote a minimum test to help reproduce the issue:

# transfromers version 4.38.2
# this example is tested with 4 RTX3090s, 24GB memory each
import warnings
warnings.filterwarnings("ignore")

import torch 
import json
import time
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

import SelfExtend 

window_size = 1024
group_size = 32

model_name = '/tmp/gemma-2b-it/'
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
SelfExtend.apply(model, group_size, window_size, enable_flash_attention=False)
prompt = "How are you?"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

start_time = time.time()
tokens = model.generate(input_ids, max_new_tokens=4096)
answer = tokenizer.decode(tokens[0].tolist()[input_ids.shape[1]:], skip_special_tokens=True)
print( answer )

While trying to load the model, it fails with the error message below:

$ python3 test.py 
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.07it/s]
Traceback (most recent call last):
  File "/var/lib/condor/execute/slot1/dir_2652801/test.py", line 22, in <module>
    SelfExtend.apply(model, group_size, window_size, enable_flash_attention=False)
  File "/var/lib/condor/execute/slot1/dir_2652801/SelfExtend.py", line 160, in apply
    raise Exception(f"Failed to modify the attention method of {arch_name}")
Exception: Failed to modify the attention method of GemmaForCausalLM

I found that it fails in the duplicate check in the L24 of SelfExtend.py. When it fails, instance = False.

Below is a conda env export dump including package details in my Python environment:

channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h5eee18b_6
  - ca-certificates=2024.7.2=h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_1
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - ncurses=6.4=h6a678d5_0
  - openssl=3.0.14=h5eee18b_0
  - pip=24.0=py310h06a4308_0
  - python=3.10.14=h955ad1f_1
  - readline=8.2=h5eee18b_0
  - setuptools=69.5.1=py310h06a4308_0
  - sqlite=3.45.3=h5eee18b_0
  - tk=8.6.14=h39e8969_0
  - wheel=0.43.0=py310h06a4308_0
  - xz=5.4.6=h5eee18b_1
  - zlib=1.2.13=h5eee18b_1
  - pip:
      - accelerate==0.33.0
      - aiohttp==3.9.5
      - aiosignal==1.3.1
      - annotated-types==0.7.0
      - anyio==4.4.0
      - async-timeout==4.0.3
      - attrs==23.2.0
      - certifi==2024.7.4
      - charset-normalizer==3.3.2
      - click==8.1.7
      - cloudpickle==3.0.0
      - cmake==3.30.1
      - datasets==2.20.0
      - dill==0.3.8
      - diskcache==5.6.3
      - distro==1.9.0
      - dnspython==2.6.1
      - einops==0.8.0
      - email-validator==2.2.0
      - exceptiongroup==1.2.2
      - fastapi==0.111.1
      - fastapi-cli==0.0.4
      - filelock==3.15.4
      - flash-attn==2.6.3
      - frozenlist==1.4.1
      - fsspec==2024.5.0
      - h11==0.14.0
      - httpcore==1.0.5
      - httptools==0.6.1
      - httpx==0.27.0
      - huggingface-hub==0.24.2
      - idna==3.7
      - interegular==0.3.3
      - jinja2==3.1.4
      - jsonschema==4.23.0
      - jsonschema-specifications==2023.12.1
      - lark==1.1.9
      - llvmlite==0.43.0
      - lm-format-enforcer==0.10.3
      - markdown-it-py==3.0.0
      - markupsafe==2.1.5
      - mdurl==0.1.2
      - mpmath==1.3.0
      - msgpack==1.0.8
      - multidict==6.0.5
      - multiprocess==0.70.16
      - nest-asyncio==1.6.0
      - networkx==3.3
      - ninja==1.11.1.1
      - numba==0.60.0
      - numpy==1.26.4
      - nvidia-cublas-cu12==12.1.3.1
      - nvidia-cuda-cupti-cu12==12.1.105
      - nvidia-cuda-nvrtc-cu12==12.1.105
      - nvidia-cuda-runtime-cu12==12.1.105
      - nvidia-cudnn-cu12==8.9.2.26
      - nvidia-cufft-cu12==11.0.2.54
      - nvidia-curand-cu12==10.3.2.106
      - nvidia-cusolver-cu12==11.4.5.107
      - nvidia-cusparse-cu12==12.1.0.106
      - nvidia-ml-py==12.555.43
      - nvidia-nccl-cu12==2.20.5
      - nvidia-nvjitlink-cu12==12.5.82
      - nvidia-nvtx-cu12==12.1.105
      - openai==1.37.1
      - outlines==0.0.46
      - packaging==24.1
      - pandas==2.2.2
      - pillow==10.4.0
      - prometheus-client==0.20.0
      - prometheus-fastapi-instrumentator==7.0.0
      - protobuf==5.27.2
      - psutil==6.0.0
      - py-cpuinfo==9.0.0
      - pyairports==2.1.1
      - pyarrow==17.0.0
      - pyarrow-hotfix==0.6
      - pycountry==24.6.1
      - pydantic==2.8.2
      - pydantic-core==2.20.1
      - pygments==2.18.0
      - python-dateutil==2.9.0.post0
      - python-dotenv==1.0.1
      - python-multipart==0.0.9
      - pytz==2024.1
      - pyyaml==6.0.1
      - pyzmq==26.0.3
      - ray==2.33.0
      - referencing==0.35.1
      - regex==2024.7.24
      - requests==2.32.3
      - rich==13.7.1
      - rpds-py==0.19.1
      - safetensors==0.4.3
      - sentencepiece==0.2.0
      - shellingham==1.5.4
      - six==1.16.0
      - sniffio==1.3.1
      - starlette==0.37.2
      - sympy==1.13.1
      - tiktoken==0.7.0
      - tokenizers==0.19.1
      - torch==2.3.1
      - torchvision==0.18.1
      - tqdm==4.66.4
      - transformers==4.43.3
      - triton==2.3.1
      - typer==0.12.3
      - typing-extensions==4.12.2
      - tzdata==2024.1
      - urllib3==2.2.2
      - uvicorn==0.30.3
      - uvloop==0.19.0
      - vllm==0.5.3.post1
      - vllm-flash-attn==2.5.9.post1
      - watchfiles==0.22.0
      - websockets==12.0
      - xformers==0.0.27
      - xxhash==3.4.1
      - yarl==1.9.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions