Skip to content

Commit 3436457

Browse files
committed
fix: vision embedding
1 parent c5d5f13 commit 3436457

3 files changed

Lines changed: 119 additions & 1 deletion

File tree

veadk/configs/model_configs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
DEFAULT_MODEL_AGENT_API_BASE,
2424
DEFAULT_MODEL_AGENT_NAME,
2525
DEFAULT_MODEL_AGENT_PROVIDER,
26+
DEFAULT_MODEL_EMBEDDING_NAME,
2627
)
2728

2829

@@ -46,7 +47,7 @@ def api_key(self) -> str:
4647
class EmbeddingModelConfig(BaseSettings):
4748
model_config = SettingsConfigDict(env_prefix="MODEL_EMBEDDING_")
4849

49-
name: str = "doubao-embedding-text-240715"
50+
name: str = DEFAULT_MODEL_EMBEDDING_NAME
5051
"""Model name for embedding."""
5152

5253
dim: int = 2560

veadk/consts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,5 @@
7575

7676
DEFAULT_NACOS_GROUP = "VEADK_GROUP"
7777
DEFAULT_NACOS_INSTANCE_NAME = "veadk"
78+
79+
DEFAULT_MODEL_EMBEDDING_NAME = "doubao-embedding-vision-250615"

veadk/models/ark_embedding.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from typing import Any, Dict, Optional, List
2+
3+
import httpx
4+
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
5+
from llama_index.core.callbacks.base import CallbackManager
6+
# from llama_index.embeddings.openai import OpenAIEmbedding
7+
8+
9+
class ArkEmbedding(BaseEmbedding):
10+
"""
11+
OpenAI-Like class for embeddings.
12+
13+
Args:
14+
model_name (str):
15+
Model for embedding.
16+
api_key (str):
17+
The API key (if any) to use for the embedding API.
18+
api_base (str):
19+
The base URL for the embedding API.
20+
api_version (str):
21+
The version for the embedding API.
22+
max_retries (int):
23+
The maximum number of retries for the embedding API.
24+
timeout (float):
25+
The timeout for the embedding API.
26+
reuse_client (bool):
27+
Whether to reuse the client for the embedding API.
28+
callback_manager (CallbackManager):
29+
The callback manager for the embedding API.
30+
default_headers (Dict[str, str]):
31+
The default headers for the embedding API.
32+
additional_kwargs (Dict[str, Any]):
33+
Additional kwargs for the embedding API.
34+
dimensions (int):
35+
The number of dimensions for the embedding API.
36+
37+
Example:
38+
```bash
39+
pip install llama-index-embeddings-openai-like
40+
```
41+
42+
```python
43+
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
44+
45+
embedding = ArkEmbedding(
46+
model_name="my-model-name",
47+
api_base="http://localhost:1234/v1",
48+
api_key="fake",
49+
embed_batch_size=10,
50+
)
51+
```
52+
53+
"""
54+
55+
def _get_query_embedding(self, query: str) -> Embedding:
56+
# client = self._get_client()
57+
# retry_decorator = self._create_retry_decorator()
58+
59+
pass
60+
61+
async def _aget_query_embedding(self, query: str) -> Embedding:
62+
pass
63+
64+
def _get_text_embedding(self, text: str) -> Embedding:
65+
pass
66+
67+
def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: ...
68+
69+
def __init__(
70+
self,
71+
model_name: str,
72+
embed_batch_size: int = 10,
73+
dimensions: Optional[int] = None,
74+
additional_kwargs: Optional[Dict[str, Any]] = None,
75+
api_key: str = "fake",
76+
api_base: Optional[str] = None,
77+
api_version: Optional[str] = None,
78+
max_retries: int = 10,
79+
timeout: float = 60.0,
80+
reuse_client: bool = True,
81+
callback_manager: Optional[CallbackManager] = None,
82+
default_headers: Optional[Dict[str, str]] = None,
83+
http_client: Optional[httpx.Client] = None,
84+
async_http_client: Optional[httpx.AsyncClient] = None,
85+
num_workers: Optional[int] = None,
86+
**kwargs: Any,
87+
) -> None:
88+
# ensure model is not passed in kwargs, will cause error in parent class
89+
if "model" in kwargs:
90+
raise ValueError(
91+
"Use `model_name` instead of `model` to initialize OpenAILikeEmbedding"
92+
)
93+
94+
super().__init__(
95+
model_name=model_name,
96+
embed_batch_size=embed_batch_size,
97+
dimensions=dimensions,
98+
callback_manager=callback_manager,
99+
additional_kwargs=additional_kwargs,
100+
api_key=api_key,
101+
api_base=api_base,
102+
api_version=api_version,
103+
max_retries=max_retries,
104+
reuse_client=reuse_client,
105+
timeout=timeout,
106+
default_headers=default_headers,
107+
http_client=http_client,
108+
async_http_client=async_http_client,
109+
num_workers=num_workers,
110+
**kwargs,
111+
)
112+
113+
@classmethod
114+
def class_name(cls) -> str:
115+
return "ArkEmbedding"

0 commit comments

Comments
 (0)