|
19 | 19 | from fms_acceleration_odm import OnlineMixingDataset |
20 | 20 | from fms_acceleration_odm.odm.reward import Reward |
21 | 21 |
|
22 | | -model_name = "ibm-granite/granite-4.0-h-1b" |
| 22 | +model_name = "ibm-granite/granite-4.0-350m" |
23 | 23 | output_dir = "./odm_custom_use" |
24 | 24 | max_steps = 125 |
25 | 25 | batch_size = 4 |
26 | 26 | log_file = os.path.join(output_dir, "loss.jsonl") |
27 | 27 |
|
28 | 28 | # odm related |
29 | 29 | step_idx = 0 |
30 | | -update_interval = 1 # every step |
| 30 | +update_interval = 10 # every 10 steps |
31 | 31 |
|
32 | 32 | # model |
33 | | -model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16) |
| 33 | +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) |
34 | 34 |
|
35 | 35 | # tokenizer |
36 | 36 | tokenizer = AutoTokenizer.from_pretrained(model_name) |
@@ -102,7 +102,7 @@ def collate_fn(batch, tokenizer): |
102 | 102 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=None) |
103 | 103 |
|
104 | 104 | # distributed setup |
105 | | -dataloader_config = DataLoaderConfiguration(split_batches=True, dispatch_batches=True) |
| 105 | +dataloader_config = DataLoaderConfiguration(dispatch_batches=False) |
106 | 106 | accelerator = Accelerator(dataloader_config=dataloader_config) |
107 | 107 | model, dataloader = accelerator.prepare(model, dataloader) |
108 | 108 |
|
@@ -141,7 +141,7 @@ class State: |
141 | 141 | if step_idx % update_interval == 0: |
142 | 142 | with torch.no_grad(): |
143 | 143 | model.eval() |
144 | | - dataloader.dataset.update_sampling_weights(model, accelerator, state) |
| 144 | + dataset.update_sampling_weights(model, accelerator, state) |
145 | 145 | model.train() |
146 | 146 | if step_idx > max_steps: |
147 | 147 | break |
|
0 commit comments