diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..26d3352
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..d843f34
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,4 @@
+
+
diff --git a/F2LLM/arguments.py b/F2LLM/arguments.py
index b967c8f..8bad43c 100644
--- a/F2LLM/arguments.py
+++ b/F2LLM/arguments.py
@@ -27,6 +27,12 @@ class Args:
log_interval: int = 20
checkpointing_steps: int = 100
validation_steps: int = 100
+ # LoRA-specific arguments
+ use_lora: bool = False
+ lora_r: int = 8
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_target_modules: str = "all-linear" # Comma-separated list or "all-linear"
# just placeholder, for logging purpose
num_processes: int=0
diff --git a/F2LLM/config_lora_example.json b/F2LLM/config_lora_example.json
new file mode 100644
index 0000000..afef649
--- /dev/null
+++ b/F2LLM/config_lora_example.json
@@ -0,0 +1,25 @@
+{
+ "model_path": "models/qwen3-0.6b",
+ "experiment_id": "f2llm_lora_example",
+ "output_dir": "output",
+ "tb_dir": "tb_logs",
+ "cache_dir": "cache",
+ "train_data_path": "data_tokenized_qwen",
+ "train_batch_size": 4,
+ "max_seq_length": 1024,
+ "learning_rate": 1e-4,
+ "min_lr": 1e-6,
+ "weight_decay": 1e-2,
+ "warmup_steps": 100,
+ "num_hard_neg": 7,
+ "train_steps": 1000,
+ "train_epochs": 3,
+ "log_interval": 20,
+ "checkpointing_steps": 100,
+ "validation_steps": 100,
+ "use_lora": true,
+ "lora_r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_modules": "all-linear"
+}
\ No newline at end of file
diff --git a/F2LLM/docs/lora_support.md b/F2LLM/docs/lora_support.md
new file mode 100644
index 0000000..bc39dbb
--- /dev/null
+++ b/F2LLM/docs/lora_support.md
@@ -0,0 +1,157 @@
+# LoRA Support in F2LLM
+
+## Overview
+
+Low-Rank Adaptation (LoRA) is a parameter-efficient fine-tuning technique that significantly reduces the number of trainable parameters while maintaining model performance. F2LLM provides built-in support for LoRA, allowing users to fine-tune large language models efficiently without requiring full model updates.
+
+## Key Benefits
+
+- **Memory Efficiency**: Dramatically reduces memory requirements during training
+- **Computational Efficiency**: Faster training with fewer parameters to update
+- **Storage Efficiency**: Smaller adapter files compared to full model checkpoints
+- **Modularity**: Easy to switch between different LoRA adapters for various tasks
+
+## Configuration
+
+LoRA can be enabled by setting the appropriate parameters in your configuration file or through command line arguments.
+
+### Configuration Parameters
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `use_lora` | bool | `false` | Enable or disable LoRA |
+| `lora_r` | int | `8` | The rank of the LoRA decomposition |
+| `lora_alpha` | int | `16` | Scaling factor for LoRA |
+| `lora_dropout` | float | `0.05` | Dropout rate applied to LoRA layers |
+| `lora_target_modules` | str | `"all-linear"` | Target modules to apply LoRA to |
+
+### Target Modules
+
+The `lora_target_modules` parameter specifies which layers to apply LoRA to:
+
+- **"all-linear"** (default): Applies LoRA to all linear projection layers including:
+ - `q_proj`: Query projections
+ - `v_proj`: Value projections
+ - `k_proj`: Key projections
+ - `o_proj`: Output projections
+ - `gate_proj`: Gate projections (in feed-forward networks)
+ - `up_proj`: Up projections (in feed-forward networks)
+ - `down_proj`: Down projections (in feed-forward networks)
+ - `lm_head`: Language model head
+
+- **Custom list**: Comma-separated module names (e.g., `"q_proj,v_proj"`)
+
+## Example Configuration
+
+```json
+{
+ "model_path": "models/qwen3-0.6b",
+ "experiment_id": "f2llm_lora_example",
+ "output_dir": "output",
+ "tb_dir": "tb_logs",
+ "cache_dir": "cache",
+ "train_data_path": "data_tokenized_qwen",
+ "train_batch_size": 4,
+ "max_seq_length": 1024,
+ "learning_rate": 1e-4,
+ "min_lr": 1e-6,
+ "weight_decay": 1e-2,
+ "warmup_steps": 100,
+ "num_hard_neg": 7,
+ "train_steps": 1000,
+ "train_epochs": 3,
+ "log_interval": 20,
+ "checkpointing_steps": 100,
+ "validation_steps": 100,
+ "use_lora": true,
+ "lora_r": 8,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_modules": "all-linear"
+}
+```
+
+## Implementation Details
+
+### Model Initialization
+
+When `use_lora` is set to `true`, the model automatically applies LoRA during initialization in the `F2LLM.__init__()` method:
+
+1. The base model is loaded from the specified `model_path`
+2. LoRA configuration is created with the provided parameters
+3. The PEFT (Parameter-Efficient FineTuning) library applies the LoRA adapters
+
+### Parameter Efficiency
+
+With LoRA enabled, only a fraction of the model's parameters are trainable:
+
+- **Full model parameters**: All model weights
+- **Trainable parameters**: Only LoRA adapter weights and biases
+- **Memory savings**: Often 90%+ reduction in trainable parameters
+
+## Usage Examples
+
+### Training with LoRA
+
+1. Create a configuration file with LoRA enabled
+2. Run the training script:
+
+```bash
+python run.py --config config_lora_example.json
+```
+
+### Loading Models with LoRA Adapters
+
+Use the `lora_utils.py` module to load models with previously trained adapters:
+
+```python
+from lora_utils import load_model_with_lora
+
+model, tokenizer = load_model_with_lora(
+ base_model_path="path/to/base/model",
+ lora_adapter_path="path/to/lora/adapter"
+)
+```
+
+### Merging LoRA Weights
+
+To permanently merge LoRA weights with the base model:
+
+```python
+from lora_utils import merge_lora_weights
+
+merged_model = merge_lora_weights(model, save_path="path/to/merged/model")
+```
+
+## Utilities
+
+### lora_utils.py
+
+This module provides several utility functions for LoRA operations:
+
+- `load_model_with_lora()`: Load a base model with optional LoRA adapter
+- `merge_lora_weights()`: Merge LoRA weights with the base model
+- `get_lora_model_info()`: Get information about a LoRA model configuration
+- `count_parameters()`: Count model parameters (trainable vs total)
+
+## Best Practices
+
+1. **Start with default parameters**: Use r=8, alpha=16, dropout=0.05 as a starting point
+2. **Adjust r value**: Higher r values (16, 32) may improve performance but increase memory
+3. **Tune alpha**: Alpha/r ratio often around 2 is effective (e.g., r=8, alpha=16)
+4. **Monitor parameter count**: Check the trainable vs total parameter ratio during initialization
+5. **Use appropriate target modules**: "all-linear" covers most important layers, but task-specific modules might be more efficient
+
+## Troubleshooting
+
+### Common Issues
+
+- **PEFT library not found**: Install with `pip install peft`
+- **Memory issues**: Reduce LoRA rank (`lora_r`) to further decrease memory usage
+- **Performance degradation**: Try increasing `lora_r` or `lora_alpha` values
+
+### Performance Considerations
+
+- Lower ranks (r=4, 8) use less memory but may underperform
+- Higher ranks (r=32, 64) approach full fine-tuning performance but use more memory
+- The alpha/ratio is often kept around 2 for optimal performance
\ No newline at end of file
diff --git a/F2LLM/lora_utils.py b/F2LLM/lora_utils.py
new file mode 100644
index 0000000..ecbd8af
--- /dev/null
+++ b/F2LLM/lora_utils.py
@@ -0,0 +1,123 @@
+"""
+Utilities for LoRA (Low-Rank Adaptation) support in F2LLM.
+This module provides functions for loading LoRA models and converting between full and LoRA models.
+"""
+
+from transformers import AutoModel, AutoTokenizer
+from peft import PeftModel, LoraConfig, get_peft_model, TaskType
+import torch
+
+
+def load_model_with_lora(base_model_path, lora_adapter_path=None, **lora_kwargs):
+ """
+ Load a base model with optional LoRA adapter.
+
+ Args:
+ base_model_path (str): Path to the base model
+ lora_adapter_path (str, optional): Path to the LoRA adapter
+ **lora_kwargs: Additional LoRA configuration arguments
+
+ Returns:
+ tuple: (model, tokenizer)
+ """
+ # Load the base model
+ model = AutoModel.from_pretrained(
+ base_model_path,
+ trust_remote_code=True,
+ torch_dtype=torch.bfloat16,
+ attn_implementation='flash_attention_2'
+ )
+ model.config.use_cache = False
+
+ tokenizer = AutoTokenizer.from_pretrained(base_model_path)
+
+ # Apply LoRA if adapter path is provided
+ if lora_adapter_path:
+ model = PeftModel.from_pretrained(model, lora_adapter_path)
+ print(f"Loaded LoRA adapter from {lora_adapter_path}")
+ elif lora_kwargs: # Apply new LoRA if configuration is provided
+ target_modules = lora_kwargs.get("target_modules", "all-linear")
+ if target_modules == "all-linear":
+ target_modules = [
+ "q_proj", "v_proj", "k_proj", "o_proj",
+ "gate_proj", "up_proj", "down_proj",
+ "lm_head"
+ ]
+ elif isinstance(target_modules, str):
+ target_modules = [module.strip() for module in target_modules.split(",")]
+
+ lora_config = LoraConfig(
+ task_type=TaskType.FEATURE_EXTRACTION,
+ r=lora_kwargs.get("lora_r", 8),
+ lora_alpha=lora_kwargs.get("lora_alpha", 16),
+ target_modules=target_modules,
+ lora_dropout=lora_kwargs.get("lora_dropout", 0.05),
+ bias="none",
+ )
+
+ model = get_peft_model(model, lora_config)
+ print(f"Applied LoRA with config: {lora_config}")
+
+ return model, tokenizer
+
+
+def merge_lora_weights(model, save_path=None):
+ """
+ Merge LoRA weights with the base model.
+
+ Args:
+ model: PEFT model with LoRA
+ save_path (str, optional): Path to save the merged model
+
+ Returns:
+ Merged model
+ """
+ if hasattr(model, 'merge_and_unload'):
+ merged_model = model.merge_and_unload()
+ if save_path:
+ merged_model.save_pretrained(save_path)
+ return merged_model
+ else:
+ raise ValueError("Model does not support merging. Make sure it's a PEFT model.")
+
+
+def get_lora_model_info(model):
+ """
+ Get information about a LoRA model.
+
+ Args:
+ model: PEFT model with LoRA
+
+ Returns:
+ dict: Information about the model's LoRA configuration
+ """
+ if hasattr(model, 'peft_config'):
+ info = {}
+ for adapter_name, config in model.peft_config.items():
+ info[adapter_name] = {
+ 'r': config.r,
+ 'alpha': config.lora_alpha,
+ 'dropout': config.lora_dropout,
+ 'target_modules': config.target_modules,
+ 'bias': config.bias,
+ }
+ return info
+ else:
+ return {"message": "Model does not have LoRA configuration"}
+
+
+def count_parameters(model, only_trainable=False):
+ """
+ Count the number of parameters in the model.
+
+ Args:
+ model: PyTorch model
+ only_trainable (bool): Whether to count only trainable parameters
+
+ Returns:
+ int: Number of parameters
+ """
+ if only_trainable:
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+ else:
+ return sum(p.numel() for p in model.parameters())
\ No newline at end of file
diff --git a/F2LLM/model.py b/F2LLM/model.py
index d33ade7..ed283dd 100644
--- a/F2LLM/model.py
+++ b/F2LLM/model.py
@@ -12,11 +12,53 @@ def __init__(self,
self.args = args
self.dtype = torch.bfloat16
self.device = None # set after accelerator.prepare
+
+ # Load base model
self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2')
self.lm.config.use_cache = False
+
+ # Apply LoRA if enabled
+ if args and args.use_lora:
+ self._apply_lora()
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.max_seq_length = max_seq_length
+ def _apply_lora(self):
+ """Apply LoRA to the model if enabled."""
+ try:
+ from peft import LoraConfig, get_peft_model, TaskType
+ except ImportError:
+ raise ImportError(
+ "To use LoRA, please install the `peft` library: `pip install peft`"
+ )
+
+ # Process target modules
+ if self.args.lora_target_modules == "all-linear":
+ # For decoder-only models, common target modules are linear layers
+ target_modules = [
+ "q_proj", "v_proj", "k_proj", "o_proj",
+ "gate_proj", "up_proj", "down_proj",
+ "lm_head"
+ ]
+ else:
+ target_modules = [module.strip() for module in self.args.lora_target_modules.split(",")]
+
+ lora_config = LoraConfig(
+ task_type=TaskType.FEATURE_EXTRACTION, # Feature extraction for embedding models
+ r=self.args.lora_r,
+ lora_alpha=self.args.lora_alpha,
+ target_modules=target_modules,
+ lora_dropout=self.args.lora_dropout,
+ bias="none",
+ modules_to_save=[], # We don't need to save any additional modules
+ )
+
+ self.lm = get_peft_model(self.lm, lora_config)
+ print(f"LoRA applied with config: r={self.args.lora_r}, alpha={self.args.lora_alpha}, dropout={self.args.lora_dropout}")
+ print(f"Trainable parameters after LoRA: {self.lm.num_parameters(only_trainable=True)}")
+ print(f"Total parameters: {self.lm.num_parameters()}")
+
def set_device(self):
self.device = self.lm.device
diff --git a/F2LLM/requirements.txt b/F2LLM/requirements.txt
index 82fb447..d5deb83 100644
--- a/F2LLM/requirements.txt
+++ b/F2LLM/requirements.txt
@@ -5,3 +5,4 @@ flash-attn
torch
transformers
tensorboard
+peft
diff --git a/F2LLM/run.py b/F2LLM/run.py
index e40b707..aea8b59 100644
--- a/F2LLM/run.py
+++ b/F2LLM/run.py
@@ -124,10 +124,20 @@ def __iter__(self):
# set seed again to make sure that different models share the same seed
set_seed(0)
-optimizer = AdamW(model.lm.parameters(),
- weight_decay=args.weight_decay,
- lr=args.learning_rate,
- betas=(0.9, 0.98))
+# Determine parameters for optimizer based on LoRA usage
+if args.use_lora:
+ # Only optimize LoRA parameters if LoRA is enabled
+ optimizer = AdamW(model.lm.parameters(),
+ weight_decay=args.weight_decay,
+ lr=args.learning_rate,
+ betas=(0.9, 0.98))
+ print(f"Using LoRA - optimizing {model.lm.num_parameters(only_trainable=True)} trainable parameters out of {model.lm.num_parameters()}")
+else:
+ # Optimize all model parameters
+ optimizer = AdamW(model.lm.parameters(),
+ weight_decay=args.weight_decay,
+ lr=args.learning_rate,
+ betas=(0.9, 0.98))
lr_scheduler = get_scheduler("cosine",
optimizer=optimizer,
diff --git a/F2LLM/test_lora.py b/F2LLM/test_lora.py
new file mode 100644
index 0000000..6155063
--- /dev/null
+++ b/F2LLM/test_lora.py
@@ -0,0 +1,118 @@
+"""
+test to verify LoRA functionality in F2LLM
+"""
+import torch
+from arguments import Args
+from model import F2LLM
+import tempfile
+import os
+
+def test_lora_functionality():
+ """Test that LoRA can be applied to the model correctly."""
+
+ # Create a mock args object with LoRA enabled
+ args = Args(
+ model_path="microsoft/Phi-3-mini-4k-instruct", # Using a smaller model for testing
+ experiment_id="test_lora",
+ output_dir="test_output",
+ tb_dir="test_tb",
+ cache_dir="test_cache",
+ train_data_path="dummy_path",
+ use_lora=True,
+ lora_r=8,
+ lora_alpha=16,
+ lora_dropout=0.05,
+ lora_target_modules="all-linear"
+ )
+
+ try:
+ print("Testing LoRA functionality...")
+
+ # Create model with LoRA
+ model = F2LLM(
+ model_path=args.model_path,
+ max_seq_length=512,
+ args=args
+ )
+
+ # Check that model has LoRA applied
+ total_params = model.lm.num_parameters()
+ trainable_params = model.lm.num_parameters(only_trainable=True)
+
+ print(f"Total parameters: {total_params}")
+ print(f"Trainable parameters: {trainable_params}")
+ print(f"Percentage of trainable parameters: {trainable_params/total_params*100:.2f}%")
+
+ # With LoRA, we expect significantly fewer trainable parameters
+ assert trainable_params < total_params * 0.1, \
+ f"Expected fewer trainable parameters with LoRA. Total: {total_params}, Trainable: {trainable_params}"
+
+ print("LoRA functionality test passed!")
+ return True
+
+ except ImportError as e:
+ print(f"PEFT library not available: {e}")
+ print("Please install PEFT: pip install peft")
+ return False
+ except Exception as e:
+ print(f"Error during LoRA test: {e}")
+ return False
+
+
+def test_non_lora_functionality():
+ """Test that the model still works without LoRA."""
+
+ # Create a mock args object with LoRA disabled
+ args = Args(
+ model_path="microsoft/Phi-3-mini-4k-instruct", # Using a smaller model for testing
+ experiment_id="test_no_lora",
+ output_dir="test_output",
+ tb_dir="test_tb",
+ cache_dir="test_cache",
+ train_data_path="dummy_path",
+ use_lora=False
+ )
+
+ try:
+ print("Testing non-LoRA functionality...")
+
+ # Create model without LoRA
+ model = F2LLM(
+ model_path=args.model_path,
+ max_seq_length=512,
+ args=args
+ )
+
+ # Check that model parameters are as expected (all trainable)
+ total_params = model.lm.num_parameters()
+ trainable_params = model.lm.num_parameters(only_trainable=True)
+
+ print(f"Total parameters: {total_params}")
+ print(f"Trainable parameters: {trainable_params}")
+
+ # Without LoRA, most parameters should be trainable
+ assert abs(trainable_params - total_params) < 10, \
+ f"Expected most parameters to be trainable without LoRA. Total: {total_params}, Trainable: {trainable_params}"
+
+ print("Non-LoRA functionality test passed!")
+ return True
+
+ except Exception as e:
+ print(f"Error during non-LoRA test: {e}")
+ return False
+
+
+if __name__ == "__main__":
+ print("Running LoRA functionality tests...")
+
+ # Test LoRA functionality
+ lora_test_passed = test_lora_functionality()
+
+ # Test non-LoRA functionality
+ no_lora_test_passed = test_non_lora_functionality()
+
+ if lora_test_passed and no_lora_test_passed:
+ print("\nAll tests passed!")
+ else:
+ print("\nSome tests failed!")
+ exit(1)
diff --git a/F2LLM/utils.py b/F2LLM/utils.py
index b167d3c..a839626 100644
--- a/F2LLM/utils.py
+++ b/F2LLM/utils.py
@@ -21,13 +21,31 @@ def save_checkpoint(args, accelerator, model, output_dir, lr_scheduler):
if accelerator.is_main_process:
model.tokenizer.save_pretrained(output_dir)
+
unwrapped_model = accelerator.unwrap_model(model.lm)
- unwrapped_model.save_pretrained(
- output_dir,
- is_main_process=accelerator.is_main_process,
- save_function=accelerator.save,
- state_dict=accelerator.get_state_dict(model.lm), # this is required for zero 3
- )
+
+ # Handle LoRA-specific saving
+ if args.use_lora:
+ # For LoRA models, save both the base model and adapters
+ unwrapped_model.save_pretrained(
+ output_dir,
+ is_main_process=accelerator.is_main_process,
+ save_function=accelerator.save,
+ state_dict=accelerator.get_state_dict(model.lm), # this is required for zero 3
+ )
+ # Also save the base model config and tokenizer if not saved already
+ if accelerator.is_main_process:
+ from transformers import AutoConfig
+ config = AutoConfig.from_pretrained(args.model_path)
+ config.save_pretrained(output_dir)
+ else:
+ unwrapped_model.save_pretrained(
+ output_dir,
+ is_main_process=accelerator.is_main_process,
+ save_function=accelerator.save,
+ state_dict=accelerator.get_state_dict(model.lm), # this is required for zero 3
+ )
+
accelerator.wait_for_everyone()