66from peft import LoraConfig , get_peft_model , TaskType , prepare_model_for_kbit_training
77
88# Import unsloth first to prevent EOS token corruption
9- # This must come before TRL imports to ensure proper tokenizer initialization
9+ # This must come before transformers imports to ensure proper tokenizer initialization
1010try :
1111 import unsloth
1212except ImportError :
1313 pass
1414
15- from trl import SFTTrainer , SFTConfig
15+ from transformers import Trainer , TrainingArguments , DataCollatorForLanguageModeling
1616
1717from ..logging_config import logger
1818
@@ -87,21 +87,26 @@ def prepare_model(self, model: Any, config: Dict) -> Any:
8787
8888 def prepare_dataset (self , dataset : Any , tokenizer : Any , config : Dict ) -> Any :
8989 """
90- Prepare dataset for QLoRA by consolidating all fields into a single 'text' field .
90+ Prepare dataset for QLoRA by tokenizing text and creating labels .
9191
9292 Args:
9393 dataset: Pre-formatted dataset with task-specific fields
94- tokenizer: Tokenizer instance (for EOS token)
95- config: Configuration dictionary (contains task type)
94+ tokenizer: Tokenizer instance
95+ config: Configuration dictionary (contains task type, max_seq_length )
9696
9797 Returns:
98- Dataset with consolidated 'text' field
98+ Dataset with tokenized fields: input_ids, attention_mask, labels
9999 """
100100 logger .info (f"Preparing dataset for QLoRA: { len (dataset )} examples" )
101101
102102 # Get EOS token with SEP fallback
103103 eos_token = tokenizer .eos_token or tokenizer .sep_token or ""
104104 task = config .get ("task" , "text-generation" )
105+ max_seq_length = config .get ("max_seq_length" , 2048 )
106+
107+ # Handle max_seq_length = -1 (use model's maximum)
108+ if max_seq_length == - 1 :
109+ max_seq_length = 2048 # Fallback default
105110
106111 def create_text_field (example ):
107112 """Consolidate all fields into a single 'text' field with EOS token."""
@@ -139,10 +144,40 @@ def create_text_field(example):
139144
140145 return {"text" : text }
141146
142- # Apply transformation and remove original columns
143- dataset = dataset .map (create_text_field , remove_columns = dataset .column_names , num_proc = 1 )
147+ # Step 1: Create text field
148+ dataset = dataset .map (
149+ create_text_field ,
150+ remove_columns = dataset .column_names ,
151+ num_proc = 1
152+ )
153+
154+ # Step 2: Tokenize text
155+ def tokenize_function (examples ):
156+ """Tokenize text and create labels for causal LM."""
157+ # Tokenize with truncation and padding
158+ tokenized = tokenizer (
159+ examples ["text" ],
160+ truncation = True ,
161+ max_length = max_seq_length ,
162+ padding = "max_length" , # Pad to max_length for consistency
163+ return_tensors = None , # Return lists, not tensors (datasets handles this)
164+ )
165+
166+ # For causal LM: labels = input_ids
167+ # The model will shift internally for next-token prediction
168+ tokenized ["labels" ] = tokenized ["input_ids" ].copy ()
169+
170+ return tokenized
171+
172+ # Apply tokenization
173+ dataset = dataset .map (
174+ tokenize_function ,
175+ batched = True ,
176+ remove_columns = ["text" ], # Remove text field, keep only tokenized
177+ num_proc = 1 ,
178+ )
144179
145- logger .info (f"Dataset prepared with consolidated 'text' field : { len (dataset )} examples" )
180+ logger .info (f"Dataset tokenized : { len (dataset )} examples with max_length= { max_seq_length } " )
146181 return dataset
147182
148183 def create_trainer (
@@ -155,23 +190,23 @@ def create_trainer(
155190 callbacks : list = None ,
156191 ) -> Any :
157192 """
158- Create SFTTrainer with QLoRA-specific optimizations.
193+ Create Trainer with QLoRA-specific optimizations.
159194
160195 Args:
161196 model: Prepared model with QLoRA
162- train_dataset: Training dataset
163- eval_dataset: Evaluation dataset
197+ train_dataset: Tokenized training dataset
198+ eval_dataset: Tokenized evaluation dataset
164199 tokenizer: Tokenizer instance
165200 config: Training configuration
166201 callbacks: Training callbacks
167202
168203 Returns:
169- SFTTrainer instance
204+ Trainer instance
170205 """
171- logger .info ("Creating SFTTrainer with QLoRA optimizations" )
206+ logger .info ("Creating Trainer with QLoRA optimizations" )
172207
173208 # QLoRA-optimized training arguments
174- training_args = SFTConfig (
209+ training_args = TrainingArguments (
175210 output_dir = config .get ("output_dir" , "./checkpoints" ),
176211 num_train_epochs = config .get ("num_train_epochs" , 1 ),
177212 # QLoRA can use larger batch sizes due to memory efficiency
@@ -194,32 +229,33 @@ def create_trainer(
194229 lr_scheduler_type = config .get ("lr_scheduler_type" , "cosine" ),
195230 report_to = "tensorboard" ,
196231 logging_dir = config .get ("logging_dir" , "./training_logs" ),
197- max_seq_length = config .get ("max_seq_length" , None ),
198- packing = config .get ("packing" , False ),
199232 # Gradient checkpointing for memory efficiency
200233 gradient_checkpointing = config .get ("gradient_checkpointing" , True ),
201234 gradient_checkpointing_kwargs = {"use_reentrant" : False },
202235 # Evaluation settings
203- evaluation_strategy = "steps" if eval_dataset else "no" ,
236+ eval_strategy = "steps" if eval_dataset else "no" ,
204237 eval_steps = config .get ("eval_steps" , 100 ),
205238 save_strategy = "steps" ,
206239 load_best_model_at_end = True if eval_dataset else False ,
207240 metric_for_best_model = "eval_loss" if eval_dataset else None ,
208- # Use tokenizer's EOS token instead of corrupted placeholder
209- eos_token = config .get ("eos_token" ),
210- # Disable completion_only_loss to avoid conflicts
211- completion_only_loss = False ,
212241 # Disable distributed training for Unsloth (required when using device_map='auto')
213242 ddp_find_unused_parameters = False ,
243+ use_cache = False ,
214244 )
215245
216- # Create trainer (dataset has been formatted to 'text' field in prepare_dataset)
217- trainer = SFTTrainer (
246+ # Create data collator for causal language modeling
247+ data_collator = DataCollatorForLanguageModeling (
248+ tokenizer = tokenizer ,
249+ mlm = False , # Causal LM
250+ )
251+
252+ # Create standard Trainer
253+ trainer = Trainer (
218254 model = model ,
255+ args = training_args ,
219256 train_dataset = train_dataset ,
220257 eval_dataset = eval_dataset ,
221- args = training_args ,
222- processing_class = tokenizer ,
258+ data_collator = data_collator ,
223259 callbacks = callbacks or [],
224260 )
225261
0 commit comments