|
1 | 1 | #!/usr/bin/env python |
2 | 2 | # -*- coding: UTF-8 -*- |
3 | 3 | """ |
4 | | -@Project :MaxKB |
| 4 | +@Project :MaxKB |
5 | 5 | @File :gemini_model_provider.py |
6 | 6 | @Author :Brian Yang |
7 | | -@Date :5/13/24 7:47 AM |
| 7 | +@Date :5/13/24 7:47 AM |
8 | 8 | """ |
| 9 | + |
9 | 10 | import os |
10 | 11 |
|
11 | 12 | from common.utils.common import get_file_content |
12 | | -from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ |
13 | | - ModelInfoManage |
| 13 | +from models_provider.base_model_provider import ( |
| 14 | + IModelProvider, |
| 15 | + ModelProvideInfo, |
| 16 | + ModelInfo, |
| 17 | + ModelTypeConst, |
| 18 | + ModelInfoManage, |
| 19 | +) |
14 | 20 | from models_provider.impl.gemini_model_provider.credential.embedding import GeminiEmbeddingCredential |
15 | 21 | from models_provider.impl.gemini_model_provider.credential.image import GeminiImageModelCredential |
| 22 | +from models_provider.impl.gemini_model_provider.credential.itv import ImageToVideoModelCredential |
16 | 23 | from models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential |
17 | 24 | from models_provider.impl.gemini_model_provider.credential.stt import GeminiSTTModelCredential |
18 | 25 | from models_provider.impl.gemini_model_provider.credential.tti import GeminiTextToImageModelCredential |
| 26 | +from models_provider.impl.gemini_model_provider.credential.ttv import TextToVideoModelCredential |
19 | 27 | from models_provider.impl.gemini_model_provider.model.embedding import GeminiEmbeddingModel |
20 | 28 | from models_provider.impl.gemini_model_provider.model.image import GeminiImage |
21 | 29 | from models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel |
|
24 | 32 | from django.utils.translation import gettext as _ |
25 | 33 |
|
26 | 34 | from models_provider.impl.gemini_model_provider.model.tti import GeminiTextToImage |
| 35 | +from models_provider.impl.gemini_model_provider.model.ttv import GenerationVideoModel |
27 | 36 |
|
28 | 37 | gemini_llm_model_credential = GeminiLLMModelCredential() |
29 | 38 | gemini_image_model_credential = GeminiImageModelCredential() |
30 | 39 | gemini_stt_model_credential = GeminiSTTModelCredential() |
31 | 40 | gemini_embedding_model_credential = GeminiEmbeddingCredential() |
32 | 41 | gemini_tti_model_credential = GeminiTextToImageModelCredential() |
| 42 | +gemini_itv_model_credential = ImageToVideoModelCredential() |
| 43 | +gemini_ttv_model_credential = TextToVideoModelCredential() |
33 | 44 |
|
34 | 45 | model_info_list = [ |
35 | | - ModelInfo('gemini-1.0-pro', _('Latest Gemini 1.0 Pro model, updated with Google update'), |
36 | | - ModelTypeConst.LLM, |
37 | | - gemini_llm_model_credential, |
38 | | - GeminiChatModel), |
39 | | - ModelInfo('gemini-1.0-pro-vision', _('Latest Gemini 1.0 Pro Vision model, updated with Google update'), |
40 | | - ModelTypeConst.LLM, |
41 | | - gemini_llm_model_credential, |
42 | | - GeminiChatModel), |
| 46 | + ModelInfo( |
| 47 | + "gemini-1.0-pro", |
| 48 | + _("Latest Gemini 1.0 Pro model, updated with Google update"), |
| 49 | + ModelTypeConst.LLM, |
| 50 | + gemini_llm_model_credential, |
| 51 | + GeminiChatModel, |
| 52 | + ), |
| 53 | + ModelInfo( |
| 54 | + "gemini-1.0-pro-vision", |
| 55 | + _("Latest Gemini 1.0 Pro Vision model, updated with Google update"), |
| 56 | + ModelTypeConst.LLM, |
| 57 | + gemini_llm_model_credential, |
| 58 | + GeminiChatModel, |
| 59 | + ), |
43 | 60 | ] |
44 | 61 |
|
45 | 62 | model_image_info_list = [ |
46 | | - ModelInfo('gemini-1.5-flash', _('Latest Gemini 1.5 Flash model, updated with Google updates'), |
47 | | - ModelTypeConst.IMAGE, |
48 | | - gemini_image_model_credential, |
49 | | - GeminiImage), |
50 | | - ModelInfo('gemini-1.5-pro', _('Latest Gemini 1.5 Flash model, updated with Google updates'), |
51 | | - ModelTypeConst.IMAGE, |
52 | | - gemini_image_model_credential, |
53 | | - GeminiImage), |
| 63 | + ModelInfo( |
| 64 | + "gemini-1.5-flash", |
| 65 | + _("Latest Gemini 1.5 Flash model, updated with Google updates"), |
| 66 | + ModelTypeConst.IMAGE, |
| 67 | + gemini_image_model_credential, |
| 68 | + GeminiImage, |
| 69 | + ), |
| 70 | + ModelInfo( |
| 71 | + "gemini-1.5-pro", |
| 72 | + _("Latest Gemini 1.5 Flash model, updated with Google updates"), |
| 73 | + ModelTypeConst.IMAGE, |
| 74 | + gemini_image_model_credential, |
| 75 | + GeminiImage, |
| 76 | + ), |
54 | 77 | ] |
55 | 78 |
|
56 | 79 | model_stt_info_list = [ |
57 | | - ModelInfo('gemini-1.5-flash', _('Latest Gemini 1.5 Flash model, updated with Google updates'), |
58 | | - ModelTypeConst.STT, |
59 | | - gemini_stt_model_credential, |
60 | | - GeminiSpeechToText), |
61 | | - ModelInfo('gemini-1.5-pro', _('Latest Gemini 1.5 Flash model, updated with Google updates'), |
62 | | - ModelTypeConst.STT, |
63 | | - gemini_stt_model_credential, |
64 | | - GeminiSpeechToText), |
| 80 | + ModelInfo( |
| 81 | + "gemini-1.5-flash", |
| 82 | + _("Latest Gemini 1.5 Flash model, updated with Google updates"), |
| 83 | + ModelTypeConst.STT, |
| 84 | + gemini_stt_model_credential, |
| 85 | + GeminiSpeechToText, |
| 86 | + ), |
| 87 | + ModelInfo( |
| 88 | + "gemini-1.5-pro", |
| 89 | + _("Latest Gemini 1.5 Flash model, updated with Google updates"), |
| 90 | + ModelTypeConst.STT, |
| 91 | + gemini_stt_model_credential, |
| 92 | + GeminiSpeechToText, |
| 93 | + ), |
65 | 94 | ] |
66 | 95 |
|
67 | 96 | model_embedding_info_list = [ |
68 | | - ModelInfo('models/embedding-001', '', |
69 | | - ModelTypeConst.EMBEDDING, |
70 | | - gemini_embedding_model_credential, |
71 | | - GeminiEmbeddingModel), |
72 | | - ModelInfo('models/text-embedding-004', '', |
73 | | - ModelTypeConst.EMBEDDING, |
74 | | - gemini_embedding_model_credential, |
75 | | - GeminiEmbeddingModel), |
| 97 | + ModelInfo( |
| 98 | + "models/embedding-001", "", ModelTypeConst.EMBEDDING, gemini_embedding_model_credential, GeminiEmbeddingModel |
| 99 | + ), |
| 100 | + ModelInfo( |
| 101 | + "models/text-embedding-004", |
| 102 | + "", |
| 103 | + ModelTypeConst.EMBEDDING, |
| 104 | + gemini_embedding_model_credential, |
| 105 | + GeminiEmbeddingModel, |
| 106 | + ), |
76 | 107 | ] |
77 | 108 |
|
78 | 109 | model_tti_info_list = [ |
79 | | - ModelInfo('gemini-3.1-flash-image-preview', "", |
80 | | - ModelTypeConst.TTI, |
81 | | - gemini_tti_model_credential, |
82 | | - GeminiTextToImage) |
| 110 | + ModelInfo("gemini-3.1-flash-image-preview", "", ModelTypeConst.TTI, gemini_tti_model_credential, GeminiTextToImage) |
| 111 | +] |
| 112 | + |
| 113 | +ttv_model_info_list = [ |
| 114 | + ModelInfo("veo-3.1-generate-preview", "", ModelTypeConst.TTV, gemini_ttv_model_credential, GenerationVideoModel) |
83 | 115 | ] |
84 | 116 |
|
| 117 | +itv_model_info_list = [ |
| 118 | + ModelInfo("veo-3.1-generate-preview", "", ModelTypeConst.ITV, gemini_itv_model_credential, GenerationVideoModel) |
| 119 | +] |
| 120 | + |
| 121 | + |
85 | 122 | model_info_manage = ( |
86 | 123 | ModelInfoManage.builder() |
87 | 124 | .append_model_info_list(model_info_list) |
|
94 | 131 | .append_default_model_info(model_stt_info_list[0]) |
95 | 132 | .append_default_model_info(model_embedding_info_list[0]) |
96 | 133 | .append_default_model_info(model_tti_info_list[0]) |
| 134 | + .append_model_info_list(ttv_model_info_list) |
| 135 | + .append_default_model_info(ttv_model_info_list[0]) |
| 136 | + .append_model_info_list(itv_model_info_list) |
| 137 | + .append_default_model_info(itv_model_info_list[0]) |
97 | 138 | .build() |
98 | 139 | ) |
99 | 140 |
|
100 | 141 |
|
101 | 142 | class GeminiModelProvider(IModelProvider): |
102 | | - |
103 | 143 | def get_model_info_manage(self): |
104 | 144 | return model_info_manage |
105 | 145 |
|
106 | 146 | def get_model_provide_info(self): |
107 | | - return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content( |
108 | | - os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'gemini_model_provider', 'icon', |
109 | | - 'gemini_icon_svg'))) |
| 147 | + return ModelProvideInfo( |
| 148 | + provider="model_gemini_provider", |
| 149 | + name="Gemini", |
| 150 | + icon=get_file_content( |
| 151 | + os.path.join( |
| 152 | + PROJECT_DIR, "apps", "models_provider", "impl", "gemini_model_provider", "icon", "gemini_icon_svg" |
| 153 | + ) |
| 154 | + ), |
| 155 | + ) |
0 commit comments