Skip to content

Commit d58133d

Browse files
authored
Merge pull request #53 from RETR0-OS/fix-training-class
Fix training class
2 parents 5fa8eb9 + 7dbbe74 commit d58133d

10 files changed

Lines changed: 205 additions & 101 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,5 @@ ModelForge.egg-info/
2323
unsloth_compiled_cache/*
2424
# MkDocs build directory
2525
site/
26+
.claude/
27+
*/.claude/*

ModelForge/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ async def lifespan(app: FastAPI):
5454
app = FastAPI(
5555
title="ModelForge",
5656
description="Modular fine-tuning platform with support for multiple providers and strategies",
57-
version="2.0.0",
57+
version="v2",
5858
lifespan=lifespan,
5959
)
6060

ModelForge/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def main():
6060
print(" |_| |_|\\___/ \\__,_|\\___|_|_| \\___/|_| \\__, |\\___| ")
6161
print(" __/ | ")
6262
print(" |___/ ")
63-
print("\n ModelForge v2.0 - Modular Fine-Tuning Platform")
63+
print("\n ModelForge v2.0 - No-code Fine-Tuning Platform")
6464
print("=" * 80 + "\n")
6565

6666
# Check HuggingFace login

ModelForge/services/training_service.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,6 @@ def train_model(
218218
max_seq_length=config.get("max_seq_length", 2048),
219219
)
220220
tokenizer.eos_token = tokenizer.eos_token or tokenizer.sep_token
221-
# Store eos_token in config for use by training strategies
222-
config["eos_token"] = tokenizer.eos_token
223221
else:
224222
model = provider.load_model(
225223
model_id=config["model_name"],
@@ -228,8 +226,6 @@ def train_model(
228226
)
229227
tokenizer = provider.load_tokenizer(config["model_name"])
230228
tokenizer.eos_token = tokenizer.eos_token or tokenizer.sep_token
231-
# Store eos_token in config for use by training strategies
232-
config["eos_token"] = tokenizer.eos_token
233229

234230
# Auto-detect and correct precision settings to prevent Unsloth errors
235231
config = self._auto_detect_precision_settings(model, config)
@@ -306,8 +302,6 @@ def train_model(
306302
logger.info(f"Calculated max_steps: {total_steps} (epochs={num_epochs}, examples={num_examples}, effective_batch={effective_batch_size})")
307303

308304
tokenizer.eos_token = tokenizer.eos_token or tokenizer.sep_token
309-
# Store eos_token in config for use by training strategies
310-
config["eos_token"] = tokenizer.eos_token
311305

312306
# Get metrics function
313307
metrics_fn = MetricsCalculator.get_metrics_fn_for_task(

ModelForge/strategies/qlora_strategy.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
77

88
# Import unsloth first to prevent EOS token corruption
9-
# This must come before TRL imports to ensure proper tokenizer initialization
9+
# This must come before transformers imports to ensure proper tokenizer initialization
1010
try:
1111
import unsloth
1212
except ImportError:
1313
pass
1414

15-
from trl import SFTTrainer, SFTConfig
15+
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
1616

1717
from ..logging_config import logger
1818

@@ -87,21 +87,26 @@ def prepare_model(self, model: Any, config: Dict) -> Any:
8787

8888
def prepare_dataset(self, dataset: Any, tokenizer: Any, config: Dict) -> Any:
8989
"""
90-
Prepare dataset for QLoRA by consolidating all fields into a single 'text' field.
90+
Prepare dataset for QLoRA by tokenizing text and creating labels.
9191
9292
Args:
9393
dataset: Pre-formatted dataset with task-specific fields
94-
tokenizer: Tokenizer instance (for EOS token)
95-
config: Configuration dictionary (contains task type)
94+
tokenizer: Tokenizer instance
95+
config: Configuration dictionary (contains task type, max_seq_length)
9696
9797
Returns:
98-
Dataset with consolidated 'text' field
98+
Dataset with tokenized fields: input_ids, attention_mask, labels
9999
"""
100100
logger.info(f"Preparing dataset for QLoRA: {len(dataset)} examples")
101101

102102
# Get EOS token with SEP fallback
103103
eos_token = tokenizer.eos_token or tokenizer.sep_token or ""
104104
task = config.get("task", "text-generation")
105+
max_seq_length = config.get("max_seq_length", 2048)
106+
107+
# Handle max_seq_length = -1 (use model's maximum)
108+
if max_seq_length == -1:
109+
max_seq_length = 2048 # Fallback default
105110

106111
def create_text_field(example):
107112
"""Consolidate all fields into a single 'text' field with EOS token."""
@@ -139,10 +144,40 @@ def create_text_field(example):
139144

140145
return {"text": text}
141146

142-
# Apply transformation and remove original columns
143-
dataset = dataset.map(create_text_field, remove_columns=dataset.column_names, num_proc=1)
147+
# Step 1: Create text field
148+
dataset = dataset.map(
149+
create_text_field,
150+
remove_columns=dataset.column_names,
151+
num_proc=1
152+
)
153+
154+
# Step 2: Tokenize text
155+
def tokenize_function(examples):
156+
"""Tokenize text and create labels for causal LM."""
157+
# Tokenize with truncation and padding
158+
tokenized = tokenizer(
159+
examples["text"],
160+
truncation=True,
161+
max_length=max_seq_length,
162+
padding="max_length", # Pad to max_length for consistency
163+
return_tensors=None, # Return lists, not tensors (datasets handles this)
164+
)
165+
166+
# For causal LM: labels = input_ids
167+
# The model will shift internally for next-token prediction
168+
tokenized["labels"] = tokenized["input_ids"].copy()
169+
170+
return tokenized
171+
172+
# Apply tokenization
173+
dataset = dataset.map(
174+
tokenize_function,
175+
batched=True,
176+
remove_columns=["text"], # Remove text field, keep only tokenized
177+
num_proc=1,
178+
)
144179

145-
logger.info(f"Dataset prepared with consolidated 'text' field: {len(dataset)} examples")
180+
logger.info(f"Dataset tokenized: {len(dataset)} examples with max_length={max_seq_length}")
146181
return dataset
147182

148183
def create_trainer(
@@ -155,23 +190,23 @@ def create_trainer(
155190
callbacks: list = None,
156191
) -> Any:
157192
"""
158-
Create SFTTrainer with QLoRA-specific optimizations.
193+
Create Trainer with QLoRA-specific optimizations.
159194
160195
Args:
161196
model: Prepared model with QLoRA
162-
train_dataset: Training dataset
163-
eval_dataset: Evaluation dataset
197+
train_dataset: Tokenized training dataset
198+
eval_dataset: Tokenized evaluation dataset
164199
tokenizer: Tokenizer instance
165200
config: Training configuration
166201
callbacks: Training callbacks
167202
168203
Returns:
169-
SFTTrainer instance
204+
Trainer instance
170205
"""
171-
logger.info("Creating SFTTrainer with QLoRA optimizations")
206+
logger.info("Creating Trainer with QLoRA optimizations")
172207

173208
# QLoRA-optimized training arguments
174-
training_args = SFTConfig(
209+
training_args = TrainingArguments(
175210
output_dir=config.get("output_dir", "./checkpoints"),
176211
num_train_epochs=config.get("num_train_epochs", 1),
177212
# QLoRA can use larger batch sizes due to memory efficiency
@@ -194,32 +229,33 @@ def create_trainer(
194229
lr_scheduler_type=config.get("lr_scheduler_type", "cosine"),
195230
report_to="tensorboard",
196231
logging_dir=config.get("logging_dir", "./training_logs"),
197-
max_seq_length=config.get("max_seq_length", None),
198-
packing=config.get("packing", False),
199232
# Gradient checkpointing for memory efficiency
200233
gradient_checkpointing=config.get("gradient_checkpointing", True),
201234
gradient_checkpointing_kwargs={"use_reentrant": False},
202235
# Evaluation settings
203-
evaluation_strategy="steps" if eval_dataset else "no",
236+
eval_strategy="steps" if eval_dataset else "no",
204237
eval_steps=config.get("eval_steps", 100),
205238
save_strategy="steps",
206239
load_best_model_at_end=True if eval_dataset else False,
207240
metric_for_best_model="eval_loss" if eval_dataset else None,
208-
# Use tokenizer's EOS token instead of corrupted placeholder
209-
eos_token=config.get("eos_token"),
210-
# Disable completion_only_loss to avoid conflicts
211-
completion_only_loss=False,
212241
# Disable distributed training for Unsloth (required when using device_map='auto')
213242
ddp_find_unused_parameters=False,
243+
use_cache=False,
214244
)
215245

216-
# Create trainer (dataset has been formatted to 'text' field in prepare_dataset)
217-
trainer = SFTTrainer(
246+
# Create data collator for causal language modeling
247+
data_collator = DataCollatorForLanguageModeling(
248+
tokenizer=tokenizer,
249+
mlm=False, # Causal LM
250+
)
251+
252+
# Create standard Trainer
253+
trainer = Trainer(
218254
model=model,
255+
args=training_args,
219256
train_dataset=train_dataset,
220257
eval_dataset=eval_dataset,
221-
args=training_args,
222-
processing_class=tokenizer,
258+
data_collator=data_collator,
223259
callbacks=callbacks or [],
224260
)
225261

0 commit comments

Comments
 (0)