feat(embeddings): KG-104 Control Embedding Dimensions & Batch Support#1296
feat(embeddings): KG-104 Control Embedding Dimensions & Batch Support#1296mltheuser wants to merge 10 commits intoJetBrains:developfrom
Conversation
…LMEmbeddingProvider - Add LLMCapability.Embedding.Dimensions for models supporting variable output dimensions - Add EmbeddingParams base class in prompt-model with dimensions parameter - Update LLMEmbeddingProvider interface: - Add EmbeddingParams parameter to embed() method - Add embedBatch() with default parallel polyfill implementation - Apply minimal signature updates to all provider clients (OpenAI, Mistral, Ollama, Bedrock) to enable compilation (full implementation pending) Part of KG-104 (dimension control) and KG-538 (batch embedding)
Google Provider: - Add GoogleEmbeddingParams with taskType and title support - Add toGoogleEmbeddingParams() extension for polymorphic conversion - Update GoogleEmbeddingRequest DTO with outputDimensionality, taskType, title - Update GoogleLLMClient.embed() to use EmbeddingParams with capability validation - Add Embedding.Dimensions capability to GeminiEmbedding001 model Unit Tests: - Add GoogleEmbeddingParamsTest for validation and conversion - Update LLMEmbedderTest with embedBatch tests (provider-agnostic) Integration Tests: - Add integration_testEmbedWithDimensions (tests 256-dim output) - Add integration_testEmbedBatch (tests 3-text batch) - Add dimensionCapableEmbeddingModels() stream - Temporarily limit embeddingModels() to Google until other providers migrate Part of KG-104 (dimension control) and KG-538 (batch embedding)
|
@eugenekarpenko , @kpavlov , @sdubov would be nice if one of you could do an initial review for this. |
sdubov
left a comment
There was a problem hiding this comment.
@mltheuser thank you for the contribution! It looks very cool extension for the embedding models capability. I looked through the code and added some comments about different parts of a code. I would also like to provide general proposals with improvements if you do not mind:
- Would it be possible to split the updates related to "dimensions" and "batch" to two different PRs? Currently, we use the "squash and merge" policy for GitHub when merging PRs. Since there are different changes, it would be better to have them as separate commits.
- My general feeling is that we need to support embedding capability for models independently from "dimensions" parameter. So, if the model has embedding capability, it still should be available. I would propose to have
EmbeddedParameterscommon for all modes with embeddings and add dimensions for Google Parameters only for now. This will allow you to keep a list of models that supports embeddings the same and extend them when a models is updated with dimensions.
| implementation(kotlin("test")) | ||
| implementation(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client")) | ||
| // TODO: Re-enable after OpenAI migration | ||
| // implementation(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client")) |
There was a problem hiding this comment.
Would it be possible to add a placeholder for OpenAI for now and skip dimensions (or handle it somehow when explicitly specify dimension capability with OpenAI models) instead of disabling the embedding for OpenAI clients?
There was a problem hiding this comment.
By making dimensions a provider specific parameter there is no need to migrate all in one go anymore. Re-enabled.
| BedrockModels.Embeddings.AmazonTitanEmbedText, | ||
| OpenAIModels.Embeddings.TextEmbedding3Large, | ||
| MistralAIModels.Embeddings.MistralEmbed, | ||
| // BedrockModels.Embeddings.AmazonTitanEmbedText, |
There was a problem hiding this comment.
Same comment here. It seems we might loose some base embedding tests if we just exclude these models for now. WDYT?
There was a problem hiding this comment.
Re-enabled. By the way this was never meant to be merged like this I was intending to migrate all providers in one go and just commented them out to get early feedback on a working version.
| * The conversion function [toGoogleEmbeddingParams] handles both cases transparently. | ||
| */ | ||
| @Serializable | ||
| public enum class GoogleEmbeddingTaskType(public val apiValue: String) { |
There was a problem hiding this comment.
Could you please extract this class into a separate file? It seems a separate entity for me that would be easier to teach when it is located separately from parameters. WDYT?
Also, could you please add documentation for apiValue parameter? It would be clear to understand and read the code.
There was a problem hiding this comment.
Done. Also cleaned up the enum and it's description.
| parts = listOf(GooglePart.Text(text)) | ||
| ) | ||
| ), | ||
| outputDimensionality = googleParams.dimensions, |
There was a problem hiding this comment.
What do you think about adding parameters as a single parameter to the request instance instead of a separate parameters. Request can process parameters internally in that case. Current request instance does not implement any API that define parameters.
There was a problem hiding this comment.
I like that idea. I have create a factory method in GoogleEmbeddingRequest for this purpose.
| override suspend fun embed( | ||
| text: String, | ||
| model: LLModel, | ||
| params: ai.koog.prompt.params.EmbeddingParams |
There was a problem hiding this comment.
(Here and in other places above) Could you please use the import for importing EmbeddingParams class symbols instead of a FQN here. It help to make code shorter and easy to ready:
import ai.koog.prompt.params.EmbeddingParams
...
override suspend fun embed(
text: String,
model: LLModel,
params: EmbeddingParams
...
There was a problem hiding this comment.
Ah thanks, I missed that one! Fixed.
| public val dimensions: Int? = null, | ||
| ) { | ||
| init { | ||
| dimensions?.let { |
There was a problem hiding this comment.
This is a strange requirement. The class parameters allows to pass null for dimensions, but you prevent null by validation in constructor. I think it is a bit redundant. Please let me know if I miss something.
There was a problem hiding this comment.
I think you misread it. This validation will only be triggered when the dimensions param is not null because of the dimensions?.let {}
|
Thanks a lot @sdubov! I will work through your comments soon. |
EugeneTheDev
left a comment
There was a problem hiding this comment.
Generally looks good, but I have a few questions. And why did you disable embedding tests for some providers (unit and integration)? We should not break existing functionality, so let's keep these tests to check that
|
|
||
| val results = mockClient.embedBatch(texts, model, EmbeddingParams()) | ||
|
|
||
| assertEquals(3, results.size) |
There was a problem hiding this comment.
Looks like it can be simplified to a single assert:
assertEquals(
listOf(
listOf(...),
...,
),
results,
)| override suspend fun embed( | ||
| text: String, | ||
| model: LLModel, | ||
| params: ai.koog.prompt.params.EmbeddingParams |
There was a problem hiding this comment.
Here and in other clients that don't support EmbeddingParams yet (I reckon it's all clients except Google), let's also add an additional check and throw if param.dimensions is not null. This will provide more clear feedback to the users that this feature is not supported yet
There was a problem hiding this comment.
Given the feedback I decided to make dimensions a provider specific field of EmbeddingParams. You are right that this reduces coupling and avoids confusion at the cost of a little more code (given that all our providers support custom dimensions for at least on of their models).
8e32fec to
04d4348
Compare
|
Thanks for the feedback @sdubov and @EugeneTheDev! I've refactored the PR based on your suggestions: I've removed the batch embedding feature entirely from this PR. It now focuses solely on dimension control, which should make review much easier. I'll follow up with batch embedding in a separate PR. Exactly as @sdubov suggested! The dimensions parameter is now provider specific, while the base EmbeddingParams stays simple and common to all providers. This way we can add new features incrementally without any breaking changes. About the disabled tests: Sorry for the confusion! Those were temporarily commented out just to get the code compiling while I was iterating - this was meant for early feedback, not a merge-ready state. I would have migrated all providers before requesting an actual merge. I've now restored them properly. My plan is: get this foundational dimension control merged (currently just Google), then incrementally add the other providers (OpenAI, Bedrock, etc.) in follow-up PRs, and finally tackle batch embedding as the last piece. Let me know if this approach works better! |
|
Is one of you able to edit the PR title in github? It seems I can't since the PR is owned by the koog repo. Would be great if you could prepend "feat" or similar. |
|
@sdubov @EugeneTheDev please check back when you find the time. |
There was a problem hiding this comment.
Looks good for me in general. The major comment is about the EmbeddingParams type. @mltheuser, could you please address to these question? As well as some other failed jobs including conventional PRs job (validates the commit messages structure).
| */ | ||
| @JvmStatic | ||
| fun dimensionCapableEmbeddingModels(): Stream<LLModel> { | ||
| return embeddingModels().filter { it.capabilities.contains(LLMCapability.Embedding.Dimensions) } |
There was a problem hiding this comment.
There is a helper method LLModel#supports() that make a safe check for model capabilities. I think you can re-use it here.
| } | ||
|
|
||
| private fun validateEmbeddingRequest(model: LLModel, params: GoogleEmbeddingParams) { | ||
| require(model.capabilities.contains(LLMCapability.Embed)) { |
There was a problem hiding this comment.
You should be able to replace model.capabilities.contains(LLMCapability.Embed) with the code model.supports(LLMCapability.Embed) for consistency.
There was a problem hiding this comment.
Good catch! That works, thanks.
| * @property taskType Specifies the intended use case for the embeddings. | ||
| * @property title Document title (only valid with taskType=RETRIEVAL_DOCUMENT). | ||
| */ | ||
| public class GoogleEmbeddingParams( |
There was a problem hiding this comment.
It seem that that this class need to be serialzable as well.
There was a problem hiding this comment.
GoogleEmbeddingParams itself is never serialized or deserialized. Its fields are manually extracted and placed into GoogleEmbeddingRequest (which is @serializable). See GoogleEmbeddingRequest.from().
GoogleEmbeddingRequest is a wire format model (what gets sent to the Google API), while GoogleEmbeddingParams is a user-facing model (what developers pass into the SDK) that might contain field names that need to be supported across multiple providers for a more unified api. These two concerns should be decoupled — the naming, types, and structure serve different audiences.
Marking it @serializable would be misleading since direct serialization would produce incorrect output (wrong field names, wrong structure). The class intentionally isn't serializable.
| dimensions?.let { | ||
| require(it > 0) { "dimensions must be > 0, but was $it" } | ||
| } | ||
| title?.let { |
There was a problem hiding this comment.
This check still looks strange for me. You check the title to require other parameter to be a defined value. Why would anyone call a constructor with taskType any, but GoogleEmbeddingTaskType.RETRIEVAL_DOCUMENT with this check? Do we expect to receive and deserialize these parameters from LLM?
There was a problem hiding this comment.
The parameter are given by the user not an LLM. The init block says: "if title is provided, then taskType must be RETRIEVAL_DOCUMENT". This is a constraint from the Google Embedding API. We could also remove the check and let the http request fail but I consider an early sanity check the better approach. I have also played around with some alternative ways to write it but have always come back to the field?.let { doSomeCheck } pattern. We could also replace it with if (field != null) { doSomeCheck } if you prefer it? The checks on init {} should stay though.
| * @see LLMParams for the equivalent pattern used in completion/chat models. | ||
| */ | ||
| @Serializable | ||
| public open class EmbeddingParams { |
There was a problem hiding this comment.
Sorry, @mltheuser, could you please clarify, why is it implemented as an open class instead of an interface? I see several issue with current approach:
- Current implementation of the
hashCodeand theequalsmethods breaks the comparison logic as every instance have the same hash code. - Why would I need to make an instance of a base
EmbeddingParamstype?
There was a problem hiding this comment.
You are completely right. The main motivation behind this was to keep in line with with the LLMParams class I used for reference. Then again, now after removing the dimensions field from this class it becomes quite apparent that this is not the right patter for the job. I will refactor it to be an interface like you suggested.
5ece339 to
c6babfe
Compare
|
@sdubov , @EugeneTheDev addressed your comments. |
|
Hi @mltheuser, sorry for the delay. I will now close this one as outdated. |
feat: Control Embedding Dimensions & Batch Support
Context
Tickets: KG-104, KG-538
Hey team, this PR aims to tackle two requested improvements to koog's embedding capabilities: Dimension Control and Batch Processing.
Review wanted
I will stop working on this PR for a bit to give the team time to review the approach and provide feedback. The current state has the generic interface changes I intend to make and an example implementation of these patterns for the
Googleprovider. Looking forward to your thoughts!Design Philosophy
The implementation heavily follows Koog's established patterns—specifically mimicking the
LLMParamsarchitecture used for completions. This ensures the API feels familiar and keeps our provider-specific configurations consistent.Implementation Details
Capability-based Model Declaration
I’ve added a specific capability so we can validate which models support controlling the dimension size at runtime and let callers know what's supported.
Extensible
EmbeddingParamsJust like
LLMParams, we now haveEmbeddingParams. It starts with a base class for common needs (dimensions) and allows providers to extend it for niche features (like Google’staskType).Smart, Backward-Compatible Interface
I’ve updated
LLMEmbeddingProviderto support batching without breaking existing implementations.I included a default "polyfill" for
embedBatch. If a provider doesn't have a native batch endpoint yet, the interface automatically handles parallelizing the requests using coroutines (possible to run into rate limit exceeded though). This lets us roll out the feature interface immediately while updating individual providers incrementally.Polymorphic Parameter Conversion
To keep things type-safe inside the providers, I used the established pattern for converting base params to provider-specific params:
Testing (Draft)
embedBatchdefaults correctly to parallel execution for providers without overrides.