-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathmodel_factory.py
More file actions
73 lines (61 loc) · 2.47 KB
/
model_factory.py
File metadata and controls
73 lines (61 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from enum import Enum
from typing import Union, Dict, Type
from model_helper.llama_2_helper import Llama7BHelper
from model_helper.llama_3_1_helper import Llama3_1_8BHelper
from model_helper.qwen_helper import QwenHelper
class ModelType(Enum):
"""Enumeration of supported model types"""
LLAMA_7B = "llama_7b"
LLAMA_3_1_8B = "llama_3_1_8b"
QWEN_7B = "qwen_7b"
@classmethod
def from_string(cls, model_name: str) -> 'ModelType':
"""Create ModelType enum value from string"""
try:
return cls(model_name.lower())
except ValueError:
raise ValueError(f"Unsupported model type: {model_name}")
class ModelFactory:
"""Factory class for creating and managing different types of model instances"""
_model_registry: Dict[ModelType, Type] = {
ModelType.LLAMA_7B: Llama7BHelper,
ModelType.LLAMA_3_1_8B: Llama3_1_8BHelper,
ModelType.QWEN_7B: QwenHelper
}
@classmethod
def register_model(cls, model_type: ModelType, model_class: Type) -> None:
"""Register a new model type
Args:
model_type: Model type enum value
model_class: Model class
"""
cls._model_registry[model_type] = model_class
@classmethod
def create_model(
cls,
model_type: ModelType,
use_local: bool = True,
local_path: str = "./explanation/models_hf",
token: str = None,
**kwargs
) -> Union[Llama7BHelper, Llama3_1_8BHelper]:
"""Create a model instance
Args:
model_type: Type of model to create
use_local: Whether to use locally cached model
local_path: Local model path
token: HuggingFace token (only needed when use_local=False)
**kwargs: Additional parameters passed to model constructor
Returns:
Created model instance
Raises:
ValueError: If model type is not supported
"""
if model_type not in cls._model_registry:
raise ValueError(f"Unsupported model type: {model_type}")
model_class = cls._model_registry[model_type]
return model_class(use_local=use_local, local_path=local_path, token=token, **kwargs)
@classmethod
def get_supported_models(cls) -> list[str]:
"""Get names of all supported model types"""
return [model_type.value for model_type in cls._model_registry.keys()]