@@ -29,18 +29,16 @@ def is_single_gpu(self) -> bool:
2929 def used_devices (self ) -> list [str ]:
3030 return [self .first_device ] + self .other_devices
3131
32+ @property
33+ def has_cpu_device (self ) -> bool :
34+ devices = [torch .device (device ) for device in self .used_devices ]
35+ return any (device .type == "cpu" for device in devices )
36+
3237
3338def instantiate_model_with_devices (
3439 cfg : "Extract" , device_config : ModelDevices , is_verbose : bool , ** kwargs
3540) -> PreTrainedModel :
3641 first_device = device_config .first_device
37- if cfg .int8 :
38- # Required by `bitsandbytes`
39- torch_dtype = torch .float16
40- elif device_config == "cpu" :
41- torch_dtype = torch .float32
42- else :
43- torch_dtype = "auto"
4442
4543 # TODO: Maybe we should ensure the device map is the same
4644 # for all the extract processes? This is because the device map
@@ -51,8 +49,7 @@ def instantiate_model_with_devices(
5149 if device_config .is_single_gpu
5250 else create_device_map (
5351 model_str = cfg .model ,
54- use_8bit = cfg .int8 ,
55- torch_dtype = torch_dtype ,
52+ load_in_8bit = cfg .int8 ,
5653 model_devices = device_config ,
5754 verbose = is_verbose ,
5855 )
@@ -67,23 +64,24 @@ def instantiate_model_with_devices(
6764 cfg .model ,
6865 device_map = device_map ,
6966 load_in_8bit = cfg .int8 ,
70- torch_dtype = torch_dtype ,
67+ is_cpu = device_config . has_cpu_device ,
7168 ** kwargs ,
7269 )
7370 return model
7471
7572
7673def create_device_map (
7774 model_str : str ,
78- use_8bit : float ,
79- torch_dtype : dtype | str ,
75+ load_in_8bit : bool ,
8076 model_devices : ModelDevices ,
8177 verbose : bool ,
8278) -> dict [str , str ]:
8379 """Creates a device map for a model running on multiple GPUs."""
8480 with init_empty_weights ():
8581 # Need to first instantiate an empty model to get the layer class
86- model = instantiate_model (model_str = model_str , torch_dtype = torch_dtype )
82+ model = instantiate_model (
83+ model_str = model_str , load_in_8bit = load_in_8bit , is_cpu = False
84+ )
8785
8886 # e.g. {"cuda:0": 16000, "cuda:1": 16000}
8987 max_memory_all_devices : dict [str , int ] = get_available_memory_for_devices ()
@@ -97,7 +95,7 @@ def create_device_map(
9795 max_memory_used_devices [model_devices .first_device ] = (
9896 max_memory_used_devices [model_devices .first_device ] * 0.6
9997 )
100- if use_8bit :
98+ if load_in_8bit :
10199 print ("Using 8bit" )
102100 # If 8bit, multiply the memory by 2
103101 # This is because we instantiated our empty model in (probably) float16
@@ -107,7 +105,7 @@ def create_device_map(
107105 device : max_memory_used_devices [device ] * 2
108106 for device in max_memory_used_devices
109107 }
110- if use_8bit
108+ if load_in_8bit
111109 else max_memory_used_devices
112110 )
113111
0 commit comments