Skip to content

Commit df1c0ff

Browse files
committed
load in 8bit correctly
1 parent 0182a64 commit df1c0ff

2 files changed

Lines changed: 43 additions & 29 deletions

File tree

elk/utils/hf_utils.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,11 @@
2020
_AUTOREGRESSIVE_SUFFIXES = ["ConditionalGeneration"] + _DECODER_ONLY_SUFFIXES
2121

2222

23-
def instantiate_model(
23+
def determine_dtypes(
2424
model_str: str,
25-
device: str | torch.device = "cpu",
26-
**kwargs,
27-
) -> PreTrainedModel:
28-
"""Instantiate a model string with the appropriate `Auto` class."""
29-
device = torch.device(device)
30-
kwargs["device_map"] = {"": device}
31-
25+
is_cpu: bool,
26+
load_in_8bit: bool,
27+
) -> torch.dtype | str:
3228
with prevent_name_conflicts():
3329
model_cfg = AutoConfig.from_pretrained(model_str)
3430

@@ -37,27 +33,47 @@ def instantiate_model(
3733
fp32_weights = model_cfg.torch_dtype in (None, torch.float32)
3834

3935
# Required by `bitsandbytes` to load in 8-bit.
40-
if kwargs.get("load_in_8bit"):
36+
if load_in_8bit:
4137
# Sanity check: we probably shouldn't be loading in 8-bit if the checkpoint
4238
# is in fp32. `bitsandbytes` only supports mixed fp16/int8 inference, and
4339
# we can't guarantee that there won't be overflow if we downcast to fp16.
4440
if fp32_weights:
4541
raise ValueError("Cannot load in 8-bit if weights are fp32")
4642

47-
kwargs["torch_dtype"] = torch.float16
43+
torch_dtype = torch.float16
4844

4945
# CPUs generally don't support anything other than fp32.
50-
elif device.type == "cpu":
51-
kwargs["torch_dtype"] = torch.float32
46+
elif is_cpu:
47+
torch_dtype = torch.float32
5248

5349
# If the model is fp32 but bf16 is available, convert to bf16.
5450
# Usually models with fp32 weights were actually trained in bf16, and
5551
# converting them doesn't hurt performance.
5652
elif fp32_weights and torch.cuda.is_bf16_supported():
57-
kwargs["torch_dtype"] = torch.bfloat16
53+
torch_dtype = torch.bfloat16
5854
print("Weights seem to be fp32, but bf16 is available. Loading in bf16.")
5955
else:
60-
kwargs["torch_dtype"] = "auto"
56+
torch_dtype = "auto"
57+
return torch_dtype
58+
59+
60+
def instantiate_model(
61+
model_str: str,
62+
load_in_8bit: bool,
63+
is_cpu: bool,
64+
**kwargs,
65+
) -> PreTrainedModel:
66+
"""Instantiate a model string with the appropriate `Auto` class."""
67+
68+
with prevent_name_conflicts():
69+
model_cfg = AutoConfig.from_pretrained(model_str)
70+
# If a torch_dtype was not specified, try to infer it.
71+
if "torch_dtype" not in kwargs:
72+
kwargs["torch_dtype"] = determine_dtypes(
73+
model_str=model_str, is_cpu=is_cpu, load_in_8bit=load_in_8bit
74+
)
75+
# Add load_in_8bit to kwargs
76+
kwargs["load_in_8bit"] = load_in_8bit
6177

6278
archs = model_cfg.architectures
6379
if not isinstance(archs, list):

elk/utils/multi_gpu.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,16 @@ def is_single_gpu(self) -> bool:
2929
def used_devices(self) -> list[str]:
3030
return [self.first_device] + self.other_devices
3131

32+
@property
33+
def has_cpu_device(self) -> bool:
34+
devices = [torch.device(device) for device in self.used_devices]
35+
return any(device.type == "cpu" for device in devices)
36+
3237

3338
def instantiate_model_with_devices(
3439
cfg: "Extract", device_config: ModelDevices, is_verbose: bool, **kwargs
3540
) -> PreTrainedModel:
3641
first_device = device_config.first_device
37-
if cfg.int8:
38-
# Required by `bitsandbytes`
39-
torch_dtype = torch.float16
40-
elif device_config == "cpu":
41-
torch_dtype = torch.float32
42-
else:
43-
torch_dtype = "auto"
4442

4543
# TODO: Maybe we should ensure the device map is the same
4644
# for all the extract processes? This is because the device map
@@ -51,8 +49,7 @@ def instantiate_model_with_devices(
5149
if device_config.is_single_gpu
5250
else create_device_map(
5351
model_str=cfg.model,
54-
use_8bit=cfg.int8,
55-
torch_dtype=torch_dtype,
52+
load_in_8bit=cfg.int8,
5653
model_devices=device_config,
5754
verbose=is_verbose,
5855
)
@@ -67,23 +64,24 @@ def instantiate_model_with_devices(
6764
cfg.model,
6865
device_map=device_map,
6966
load_in_8bit=cfg.int8,
70-
torch_dtype=torch_dtype,
67+
is_cpu=device_config.has_cpu_device,
7168
**kwargs,
7269
)
7370
return model
7471

7572

7673
def create_device_map(
7774
model_str: str,
78-
use_8bit: float,
79-
torch_dtype: dtype | str,
75+
load_in_8bit: bool,
8076
model_devices: ModelDevices,
8177
verbose: bool,
8278
) -> dict[str, str]:
8379
"""Creates a device map for a model running on multiple GPUs."""
8480
with init_empty_weights():
8581
# Need to first instantiate an empty model to get the layer class
86-
model = instantiate_model(model_str=model_str, torch_dtype=torch_dtype)
82+
model = instantiate_model(
83+
model_str=model_str, load_in_8bit=load_in_8bit, is_cpu=False
84+
)
8785

8886
# e.g. {"cuda:0": 16000, "cuda:1": 16000}
8987
max_memory_all_devices: dict[str, int] = get_available_memory_for_devices()
@@ -97,7 +95,7 @@ def create_device_map(
9795
max_memory_used_devices[model_devices.first_device] = (
9896
max_memory_used_devices[model_devices.first_device] * 0.6
9997
)
100-
if use_8bit:
98+
if load_in_8bit:
10199
print("Using 8bit")
102100
# If 8bit, multiply the memory by 2
103101
# This is because we instantiated our empty model in (probably) float16
@@ -107,7 +105,7 @@ def create_device_map(
107105
device: max_memory_used_devices[device] * 2
108106
for device in max_memory_used_devices
109107
}
110-
if use_8bit
108+
if load_in_8bit
111109
else max_memory_used_devices
112110
)
113111

0 commit comments

Comments
 (0)