Skip to content

Commit 08affe2

Browse files
committed
feat: add ImageToVideo and TextToVideo model credential classes with validation and encryption
1 parent 81f89da commit 08affe2

4 files changed

Lines changed: 412 additions & 44 deletions

File tree

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# coding=utf-8
2+
3+
from typing import Dict, Any
4+
5+
from django.utils.translation import gettext_lazy as _, gettext
6+
7+
from common import forms
8+
from common.exception.app_exception import AppApiException
9+
from common.forms import BaseForm, PasswordInputField
10+
from models_provider.base_model_provider import BaseModelCredential, ValidCode
11+
from common.utils.logger import maxkb_logger
12+
13+
14+
15+
class ImageToVideoModelCredential(BaseForm, BaseModelCredential):
16+
"""
17+
Credential class for the Qwen Image-to-Video model.
18+
Provides validation and encryption for the model credentials.
19+
"""
20+
21+
base_url = forms.TextInputField(_("Base Url"), required=True, default_value="https://generativelanguage.googleapis.com")
22+
api_key = PasswordInputField("API Key", required=True)
23+
24+
def is_valid(
25+
self,
26+
model_type: str,
27+
model_name: str,
28+
model_credential: Dict[str, Any],
29+
model_params: Dict[str, Any],
30+
provider,
31+
raise_exception: bool = False,
32+
) -> bool:
33+
"""
34+
Validate the model credentials.
35+
36+
:param model_type: Type of the model (e.g., 'TEXT_TO_Video').
37+
:param model_name: Name of the model.
38+
:param model_credential: Dictionary containing the model credentials.
39+
:param model_params: Parameters for the model.
40+
:param provider: Model provider instance.
41+
:param raise_exception: Whether to raise an exception on validation failure.
42+
:return: Boolean indicating whether the credentials are valid.
43+
"""
44+
model_type_list = provider.get_model_type_list()
45+
if not any(mt.get("value") == model_type for mt in model_type_list):
46+
raise AppApiException(
47+
ValidCode.valid_error.value,
48+
gettext("{model_type} Model type is not supported").format(model_type=model_type),
49+
)
50+
51+
required_keys = ["api_key", "base_url"]
52+
for key in required_keys:
53+
if key not in model_credential:
54+
if raise_exception:
55+
raise AppApiException(ValidCode.valid_error.value, gettext("{key} is required").format(key=key))
56+
return False
57+
58+
try:
59+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
60+
res = model.check_auth()
61+
except Exception as e:
62+
maxkb_logger.error(f"Exception: {e}", exc_info=True)
63+
if isinstance(e, AppApiException):
64+
raise e
65+
if raise_exception:
66+
raise AppApiException(
67+
ValidCode.valid_error.value,
68+
gettext("Verification failed, please check whether the parameters are correct: {error}").format(
69+
error=str(e)
70+
),
71+
)
72+
return False
73+
74+
return True
75+
76+
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
77+
"""
78+
Encrypt sensitive fields in the model dictionary.
79+
80+
:param model: Dictionary containing model details.
81+
:return: Dictionary with encrypted sensitive fields.
82+
"""
83+
return {**model, "api_key": super().encryption(model.get("api_key", ""))}
84+
85+
def get_model_params_setting_form(self, model_name: str):
86+
"""
87+
Get the parameter setting form for the specified model.
88+
89+
:param model_name: Name of the model.
90+
:return: Parameter setting form.
91+
"""
92+
pass
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# coding=utf-8
2+
3+
from typing import Dict, Any
4+
5+
from django.utils.translation import gettext_lazy as _, gettext
6+
7+
from common import forms
8+
from common.exception.app_exception import AppApiException
9+
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
10+
from common.forms.switch_field import SwitchField
11+
from models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
from common.utils.logger import maxkb_logger
13+
14+
15+
16+
class TextToVideoModelCredential(BaseForm, BaseModelCredential):
17+
"""
18+
Credential class for the Qwen Text-to-Video model.
19+
Provides validation and encryption for the model credentials.
20+
"""
21+
base_url = forms.TextInputField(_("Base Url"), required=True, default_value="https://generativelanguage.googleapis.com")
22+
api_key = PasswordInputField('API Key', required=True)
23+
24+
def is_valid(
25+
self,
26+
model_type: str,
27+
model_name: str,
28+
model_credential: Dict[str, Any],
29+
model_params: Dict[str, Any],
30+
provider,
31+
raise_exception: bool = False
32+
) -> bool:
33+
"""
34+
Validate the model credentials.
35+
36+
:param model_type: Type of the model (e.g., 'TEXT_TO_Video').
37+
:param model_name: Name of the model.
38+
:param model_credential: Dictionary containing the model credentials.
39+
:param model_params: Parameters for the model.
40+
:param provider: Model provider instance.
41+
:param raise_exception: Whether to raise an exception on validation failure.
42+
:return: Boolean indicating whether the credentials are valid.
43+
"""
44+
model_type_list = provider.get_model_type_list()
45+
if not any(mt.get('value') == model_type for mt in model_type_list):
46+
raise AppApiException(
47+
ValidCode.valid_error.value,
48+
gettext('{model_type} Model type is not supported').format(model_type=model_type)
49+
)
50+
51+
required_keys = ['api_key', 'base_url']
52+
for key in required_keys:
53+
if key not in model_credential:
54+
if raise_exception:
55+
raise AppApiException(
56+
ValidCode.valid_error.value,
57+
gettext('{key} is required').format(key=key)
58+
)
59+
return False
60+
61+
try:
62+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
63+
res = model.check_auth()
64+
except Exception as e:
65+
maxkb_logger.error(f'Exception: {e}', exc_info=True)
66+
if isinstance(e, AppApiException):
67+
raise e
68+
if raise_exception:
69+
raise AppApiException(
70+
ValidCode.valid_error.value,
71+
gettext(
72+
'Verification failed, please check whether the parameters are correct: {error}'
73+
).format(error=str(e))
74+
)
75+
return False
76+
77+
return True
78+
79+
def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
80+
"""
81+
Encrypt sensitive fields in the model dictionary.
82+
83+
:param model: Dictionary containing model details.
84+
:return: Dictionary with encrypted sensitive fields.
85+
"""
86+
return {
87+
**model,
88+
'api_key': super().encryption(model.get('api_key', ''))
89+
}
90+
91+
def get_model_params_setting_form(self, model_name: str):
92+
"""
93+
Get the parameter setting form for the specified model.
94+
95+
:param model_name: Name of the model.
96+
:return: Parameter setting form.
97+
"""
98+
pass
Lines changed: 90 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
#!/usr/bin/env python
22
# -*- coding: UTF-8 -*-
33
"""
4-
@Project :MaxKB
4+
@Project :MaxKB
55
@File :gemini_model_provider.py
66
@Author :Brian Yang
7-
@Date :5/13/24 7:47 AM
7+
@Date :5/13/24 7:47 AM
88
"""
9+
910
import os
1011

1112
from common.utils.common import get_file_content
12-
from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
13-
ModelInfoManage
13+
from models_provider.base_model_provider import (
14+
IModelProvider,
15+
ModelProvideInfo,
16+
ModelInfo,
17+
ModelTypeConst,
18+
ModelInfoManage,
19+
)
1420
from models_provider.impl.gemini_model_provider.credential.embedding import GeminiEmbeddingCredential
1521
from models_provider.impl.gemini_model_provider.credential.image import GeminiImageModelCredential
22+
from models_provider.impl.gemini_model_provider.credential.itv import ImageToVideoModelCredential
1623
from models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential
1724
from models_provider.impl.gemini_model_provider.credential.stt import GeminiSTTModelCredential
1825
from models_provider.impl.gemini_model_provider.credential.tti import GeminiTextToImageModelCredential
26+
from models_provider.impl.gemini_model_provider.credential.ttv import TextToVideoModelCredential
1927
from models_provider.impl.gemini_model_provider.model.embedding import GeminiEmbeddingModel
2028
from models_provider.impl.gemini_model_provider.model.image import GeminiImage
2129
from models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
@@ -24,64 +32,93 @@
2432
from django.utils.translation import gettext as _
2533

2634
from models_provider.impl.gemini_model_provider.model.tti import GeminiTextToImage
35+
from models_provider.impl.gemini_model_provider.model.ttv import GenerationVideoModel
2736

2837
gemini_llm_model_credential = GeminiLLMModelCredential()
2938
gemini_image_model_credential = GeminiImageModelCredential()
3039
gemini_stt_model_credential = GeminiSTTModelCredential()
3140
gemini_embedding_model_credential = GeminiEmbeddingCredential()
3241
gemini_tti_model_credential = GeminiTextToImageModelCredential()
42+
gemini_itv_model_credential = ImageToVideoModelCredential()
43+
gemini_ttv_model_credential = TextToVideoModelCredential()
3344

3445
model_info_list = [
35-
ModelInfo('gemini-1.0-pro', _('Latest Gemini 1.0 Pro model, updated with Google update'),
36-
ModelTypeConst.LLM,
37-
gemini_llm_model_credential,
38-
GeminiChatModel),
39-
ModelInfo('gemini-1.0-pro-vision', _('Latest Gemini 1.0 Pro Vision model, updated with Google update'),
40-
ModelTypeConst.LLM,
41-
gemini_llm_model_credential,
42-
GeminiChatModel),
46+
ModelInfo(
47+
"gemini-1.0-pro",
48+
_("Latest Gemini 1.0 Pro model, updated with Google update"),
49+
ModelTypeConst.LLM,
50+
gemini_llm_model_credential,
51+
GeminiChatModel,
52+
),
53+
ModelInfo(
54+
"gemini-1.0-pro-vision",
55+
_("Latest Gemini 1.0 Pro Vision model, updated with Google update"),
56+
ModelTypeConst.LLM,
57+
gemini_llm_model_credential,
58+
GeminiChatModel,
59+
),
4360
]
4461

4562
model_image_info_list = [
46-
ModelInfo('gemini-1.5-flash', _('Latest Gemini 1.5 Flash model, updated with Google updates'),
47-
ModelTypeConst.IMAGE,
48-
gemini_image_model_credential,
49-
GeminiImage),
50-
ModelInfo('gemini-1.5-pro', _('Latest Gemini 1.5 Flash model, updated with Google updates'),
51-
ModelTypeConst.IMAGE,
52-
gemini_image_model_credential,
53-
GeminiImage),
63+
ModelInfo(
64+
"gemini-1.5-flash",
65+
_("Latest Gemini 1.5 Flash model, updated with Google updates"),
66+
ModelTypeConst.IMAGE,
67+
gemini_image_model_credential,
68+
GeminiImage,
69+
),
70+
ModelInfo(
71+
"gemini-1.5-pro",
72+
_("Latest Gemini 1.5 Flash model, updated with Google updates"),
73+
ModelTypeConst.IMAGE,
74+
gemini_image_model_credential,
75+
GeminiImage,
76+
),
5477
]
5578

5679
model_stt_info_list = [
57-
ModelInfo('gemini-1.5-flash', _('Latest Gemini 1.5 Flash model, updated with Google updates'),
58-
ModelTypeConst.STT,
59-
gemini_stt_model_credential,
60-
GeminiSpeechToText),
61-
ModelInfo('gemini-1.5-pro', _('Latest Gemini 1.5 Flash model, updated with Google updates'),
62-
ModelTypeConst.STT,
63-
gemini_stt_model_credential,
64-
GeminiSpeechToText),
80+
ModelInfo(
81+
"gemini-1.5-flash",
82+
_("Latest Gemini 1.5 Flash model, updated with Google updates"),
83+
ModelTypeConst.STT,
84+
gemini_stt_model_credential,
85+
GeminiSpeechToText,
86+
),
87+
ModelInfo(
88+
"gemini-1.5-pro",
89+
_("Latest Gemini 1.5 Flash model, updated with Google updates"),
90+
ModelTypeConst.STT,
91+
gemini_stt_model_credential,
92+
GeminiSpeechToText,
93+
),
6594
]
6695

6796
model_embedding_info_list = [
68-
ModelInfo('models/embedding-001', '',
69-
ModelTypeConst.EMBEDDING,
70-
gemini_embedding_model_credential,
71-
GeminiEmbeddingModel),
72-
ModelInfo('models/text-embedding-004', '',
73-
ModelTypeConst.EMBEDDING,
74-
gemini_embedding_model_credential,
75-
GeminiEmbeddingModel),
97+
ModelInfo(
98+
"models/embedding-001", "", ModelTypeConst.EMBEDDING, gemini_embedding_model_credential, GeminiEmbeddingModel
99+
),
100+
ModelInfo(
101+
"models/text-embedding-004",
102+
"",
103+
ModelTypeConst.EMBEDDING,
104+
gemini_embedding_model_credential,
105+
GeminiEmbeddingModel,
106+
),
76107
]
77108

78109
model_tti_info_list = [
79-
ModelInfo('gemini-3.1-flash-image-preview', "",
80-
ModelTypeConst.TTI,
81-
gemini_tti_model_credential,
82-
GeminiTextToImage)
110+
ModelInfo("gemini-3.1-flash-image-preview", "", ModelTypeConst.TTI, gemini_tti_model_credential, GeminiTextToImage)
111+
]
112+
113+
ttv_model_info_list = [
114+
ModelInfo("veo-3.1-generate-preview", "", ModelTypeConst.TTV, gemini_ttv_model_credential, GenerationVideoModel)
83115
]
84116

117+
itv_model_info_list = [
118+
ModelInfo("veo-3.1-generate-preview", "", ModelTypeConst.ITV, gemini_itv_model_credential, GenerationVideoModel)
119+
]
120+
121+
85122
model_info_manage = (
86123
ModelInfoManage.builder()
87124
.append_model_info_list(model_info_list)
@@ -94,16 +131,25 @@
94131
.append_default_model_info(model_stt_info_list[0])
95132
.append_default_model_info(model_embedding_info_list[0])
96133
.append_default_model_info(model_tti_info_list[0])
134+
.append_model_info_list(ttv_model_info_list)
135+
.append_default_model_info(ttv_model_info_list[0])
136+
.append_model_info_list(itv_model_info_list)
137+
.append_default_model_info(itv_model_info_list[0])
97138
.build()
98139
)
99140

100141

101142
class GeminiModelProvider(IModelProvider):
102-
103143
def get_model_info_manage(self):
104144
return model_info_manage
105145

106146
def get_model_provide_info(self):
107-
return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content(
108-
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'gemini_model_provider', 'icon',
109-
'gemini_icon_svg')))
147+
return ModelProvideInfo(
148+
provider="model_gemini_provider",
149+
name="Gemini",
150+
icon=get_file_content(
151+
os.path.join(
152+
PROJECT_DIR, "apps", "models_provider", "impl", "gemini_model_provider", "icon", "gemini_icon_svg"
153+
)
154+
),
155+
)

0 commit comments

Comments
 (0)