Skip to content

feat(embeddings): KG-104 Control Embedding Dimensions & Batch Support#1296

Closed
mltheuser wants to merge 10 commits intoJetBrains:developfrom
mltheuser:mltheuser/kg-104
Closed

feat(embeddings): KG-104 Control Embedding Dimensions & Batch Support#1296
mltheuser wants to merge 10 commits intoJetBrains:developfrom
mltheuser:mltheuser/kg-104

Conversation

@mltheuser
Copy link
Copy Markdown
Contributor

@mltheuser mltheuser commented Dec 22, 2025

feat: Control Embedding Dimensions & Batch Support

Context

Tickets: KG-104, KG-538

Hey team, this PR aims to tackle two requested improvements to koog's embedding capabilities: Dimension Control and Batch Processing.

Review wanted

I will stop working on this PR for a bit to give the team time to review the approach and provide feedback. The current state has the generic interface changes I intend to make and an example implementation of these patterns for the Google provider. Looking forward to your thoughts!

Design Philosophy

The implementation heavily follows Koog's established patterns—specifically mimicking the LLMParams architecture used for completions. This ensures the API feels familiar and keeps our provider-specific configurations consistent.

Implementation Details

Capability-based Model Declaration

I’ve added a specific capability so we can validate which models support controlling the dimension size at runtime and let callers know what's supported.

public val GeminiEmbedding001: LLModel = LLModel(
    provider = LLMProvider.Google,
    id = "gemini-embedding-001",
    // Explicitly declaring dimension support
    capabilities = listOf(LLMCapability.Embed, LLMCapability.Embedding.Dimensions), 
    // ...
)

Extensible EmbeddingParams

Just like LLMParams, we now have EmbeddingParams. It starts with a base class for common needs (dimensions) and allows providers to extend it for niche features (like Google’s taskType).

// Base class
public open class EmbeddingParams(
    public val dimensions: Int? = null,
)

// Provider-specific extension
public class GoogleEmbeddingParams(
    dimensions: Int? = null,
    public val taskType: GoogleEmbeddingTaskType? = null,
    public val title: String? = null,
) : EmbeddingParams(dimensions)

Smart, Backward-Compatible Interface

I’ve updated LLMEmbeddingProvider to support batching without breaking existing implementations.

I included a default "polyfill" for embedBatch. If a provider doesn't have a native batch endpoint yet, the interface automatically handles parallelizing the requests using coroutines (possible to run into rate limit exceeded though). This lets us roll out the feature interface immediately while updating individual providers incrementally.

interface LLMEmbeddingProvider {
    suspend fun embed(
        text: String,
        model: LLModel,
        params: EmbeddingParams = EmbeddingParams() 
    ): List<Double>
    
    // Default implementation handles concurrency for us
    suspend fun embedBatch(
        texts: List<String>,
        model: LLModel,
        params: EmbeddingParams = EmbeddingParams()
    ): List<List<Double>> = coroutineScope {
        texts.map { async { embed(it, model, params) } }.awaitAll()
    }
}

Polymorphic Parameter Conversion

To keep things type-safe inside the providers, I used the established pattern for converting base params to provider-specific params:

internal fun EmbeddingParams.toGoogleEmbeddingParams(): GoogleEmbeddingParams =
    (this as? GoogleEmbeddingParams) ?: GoogleEmbeddingParams(dimensions = dimensions)

Testing (Draft)

  • Added unit tests for the new parameter mapping.
  • Verified that embedBatch defaults correctly to parallel execution for providers without overrides.
  • Tested manually against Google Gemini to ensure dimensions are actually being truncated/resized by the API.

…LMEmbeddingProvider

- Add LLMCapability.Embedding.Dimensions for models supporting variable output dimensions
- Add EmbeddingParams base class in prompt-model with dimensions parameter
- Update LLMEmbeddingProvider interface:
  - Add EmbeddingParams parameter to embed() method
  - Add embedBatch() with default parallel polyfill implementation
- Apply minimal signature updates to all provider clients (OpenAI, Mistral, Ollama, Bedrock)
  to enable compilation (full implementation pending)

Part of KG-104 (dimension control) and KG-538 (batch embedding)
Google Provider:
- Add GoogleEmbeddingParams with taskType and title support
- Add toGoogleEmbeddingParams() extension for polymorphic conversion
- Update GoogleEmbeddingRequest DTO with outputDimensionality, taskType, title
- Update GoogleLLMClient.embed() to use EmbeddingParams with capability validation
- Add Embedding.Dimensions capability to GeminiEmbedding001 model

Unit Tests:
- Add GoogleEmbeddingParamsTest for validation and conversion
- Update LLMEmbedderTest with embedBatch tests (provider-agnostic)

Integration Tests:
- Add integration_testEmbedWithDimensions (tests 256-dim output)
- Add integration_testEmbedBatch (tests 3-text batch)
- Add dimensionCapableEmbeddingModels() stream
- Temporarily limit embeddingModels() to Google until other providers migrate

Part of KG-104 (dimension control) and KG-538 (batch embedding)
@mltheuser mltheuser marked this pull request as draft December 22, 2025 13:33
@mltheuser mltheuser changed the title Mltheuser/kg 104 KG-104 Control Embedding Dimensions & Batch Support Dec 31, 2025
@mltheuser
Copy link
Copy Markdown
Contributor Author

@eugenekarpenko , @kpavlov , @sdubov would be nice if one of you could do an initial review for this.

Copy link
Copy Markdown
Collaborator

@sdubov sdubov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mltheuser thank you for the contribution! It looks very cool extension for the embedding models capability. I looked through the code and added some comments about different parts of a code. I would also like to provide general proposals with improvements if you do not mind:

  1. Would it be possible to split the updates related to "dimensions" and "batch" to two different PRs? Currently, we use the "squash and merge" policy for GitHub when merging PRs. Since there are different changes, it would be better to have them as separate commits.
  2. My general feeling is that we need to support embedding capability for models independently from "dimensions" parameter. So, if the model has embedding capability, it still should be available. I would propose to have EmbeddedParameters common for all modes with embeddings and add dimensions for Google Parameters only for now. This will allow you to keep a list of models that supports embeddings the same and extend them when a models is updated with dimensions.

implementation(kotlin("test"))
implementation(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client"))
// TODO: Re-enable after OpenAI migration
// implementation(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client"))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add a placeholder for OpenAI for now and skip dimensions (or handle it somehow when explicitly specify dimension capability with OpenAI models) instead of disabling the embedding for OpenAI clients?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By making dimensions a provider specific parameter there is no need to migrate all in one go anymore. Re-enabled.

BedrockModels.Embeddings.AmazonTitanEmbedText,
OpenAIModels.Embeddings.TextEmbedding3Large,
MistralAIModels.Embeddings.MistralEmbed,
// BedrockModels.Embeddings.AmazonTitanEmbedText,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here. It seems we might loose some base embedding tests if we just exclude these models for now. WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-enabled. By the way this was never meant to be merged like this I was intending to migrate all providers in one go and just commented them out to get early feedback on a working version.

* The conversion function [toGoogleEmbeddingParams] handles both cases transparently.
*/
@Serializable
public enum class GoogleEmbeddingTaskType(public val apiValue: String) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please extract this class into a separate file? It seems a separate entity for me that would be easier to teach when it is located separately from parameters. WDYT?

Also, could you please add documentation for apiValue parameter? It would be clear to understand and read the code.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Also cleaned up the enum and it's description.

parts = listOf(GooglePart.Text(text))
)
),
outputDimensionality = googleParams.dimensions,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about adding parameters as a single parameter to the request instance instead of a separate parameters. Request can process parameters internally in that case. Current request instance does not implement any API that define parameters.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that idea. I have create a factory method in GoogleEmbeddingRequest for this purpose.

override suspend fun embed(
text: String,
model: LLModel,
params: ai.koog.prompt.params.EmbeddingParams
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Here and in other places above) Could you please use the import for importing EmbeddingParams class symbols instead of a FQN here. It help to make code shorter and easy to ready:

import ai.koog.prompt.params.EmbeddingParams
...
    override suspend fun embed(
        text: String,
        model: LLModel,
        params: EmbeddingParams
...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks, I missed that one! Fixed.

public val dimensions: Int? = null,
) {
init {
dimensions?.let {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a strange requirement. The class parameters allows to pass null for dimensions, but you prevent null by validation in constructor. I think it is a bit redundant. Please let me know if I miss something.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you misread it. This validation will only be triggered when the dimensions param is not null because of the dimensions?.let {}

@mltheuser
Copy link
Copy Markdown
Contributor Author

Thanks a lot @sdubov! I will work through your comments soon.

Copy link
Copy Markdown
Collaborator

@EugeneTheDev EugeneTheDev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good, but I have a few questions. And why did you disable embedding tests for some providers (unit and integration)? We should not break existing functionality, so let's keep these tests to check that


val results = mockClient.embedBatch(texts, model, EmbeddingParams())

assertEquals(3, results.size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it can be simplified to a single assert:

assertEquals(
    listOf(
        listOf(...),
        ...,
    ),
    results,
)

override suspend fun embed(
text: String,
model: LLModel,
params: ai.koog.prompt.params.EmbeddingParams
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and in other clients that don't support EmbeddingParams yet (I reckon it's all clients except Google), let's also add an additional check and throw if param.dimensions is not null. This will provide more clear feedback to the users that this feature is not supported yet

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the feedback I decided to make dimensions a provider specific field of EmbeddingParams. You are right that this reduces coupling and avoids confusion at the cost of a little more code (given that all our providers support custom dimensions for at least on of their models).

@mltheuser
Copy link
Copy Markdown
Contributor Author

Thanks for the feedback @sdubov and @EugeneTheDev!

I've refactored the PR based on your suggestions:

I've removed the batch embedding feature entirely from this PR. It now focuses solely on dimension control, which should make review much easier. I'll follow up with batch embedding in a separate PR.

Exactly as @sdubov suggested! The dimensions parameter is now provider specific, while the base EmbeddingParams stays simple and common to all providers. This way we can add new features incrementally without any breaking changes.

About the disabled tests: Sorry for the confusion! Those were temporarily commented out just to get the code compiling while I was iterating - this was meant for early feedback, not a merge-ready state. I would have migrated all providers before requesting an actual merge. I've now restored them properly.

My plan is: get this foundational dimension control merged (currently just Google), then incrementally add the other providers (OpenAI, Bedrock, etc.) in follow-up PRs, and finally tackle batch embedding as the last piece.

Let me know if this approach works better!

@mltheuser mltheuser marked this pull request as ready for review February 9, 2026 16:50
@mltheuser
Copy link
Copy Markdown
Contributor Author

Is one of you able to edit the PR title in github? It seems I can't since the PR is owned by the koog repo. Would be great if you could prepend "feat" or similar.

@aozherelyeva aozherelyeva linked an issue Feb 10, 2026 that may be closed by this pull request
@mltheuser
Copy link
Copy Markdown
Contributor Author

@sdubov @EugeneTheDev please check back when you find the time.

Copy link
Copy Markdown
Collaborator

@sdubov sdubov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good for me in general. The major comment is about the EmbeddingParams type. @mltheuser, could you please address to these question? As well as some other failed jobs including conventional PRs job (validates the commit messages structure).

*/
@JvmStatic
fun dimensionCapableEmbeddingModels(): Stream<LLModel> {
return embeddingModels().filter { it.capabilities.contains(LLMCapability.Embedding.Dimensions) }
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a helper method LLModel#supports() that make a safe check for model capabilities. I think you can re-use it here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

private fun validateEmbeddingRequest(model: LLModel, params: GoogleEmbeddingParams) {
require(model.capabilities.contains(LLMCapability.Embed)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to replace model.capabilities.contains(LLMCapability.Embed) with the code model.supports(LLMCapability.Embed) for consistency.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! That works, thanks.

* @property taskType Specifies the intended use case for the embeddings.
* @property title Document title (only valid with taskType=RETRIEVAL_DOCUMENT).
*/
public class GoogleEmbeddingParams(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seem that that this class need to be serialzable as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GoogleEmbeddingParams itself is never serialized or deserialized. Its fields are manually extracted and placed into GoogleEmbeddingRequest (which is @serializable). See GoogleEmbeddingRequest.from().

GoogleEmbeddingRequest is a wire format model (what gets sent to the Google API), while GoogleEmbeddingParams is a user-facing model (what developers pass into the SDK) that might contain field names that need to be supported across multiple providers for a more unified api. These two concerns should be decoupled — the naming, types, and structure serve different audiences.

Marking it @serializable would be misleading since direct serialization would produce incorrect output (wrong field names, wrong structure). The class intentionally isn't serializable.

dimensions?.let {
require(it > 0) { "dimensions must be > 0, but was $it" }
}
title?.let {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check still looks strange for me. You check the title to require other parameter to be a defined value. Why would anyone call a constructor with taskType any, but GoogleEmbeddingTaskType.RETRIEVAL_DOCUMENT with this check? Do we expect to receive and deserialize these parameters from LLM?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter are given by the user not an LLM. The init block says: "if title is provided, then taskType must be RETRIEVAL_DOCUMENT". This is a constraint from the Google Embedding API. We could also remove the check and let the http request fail but I consider an early sanity check the better approach. I have also played around with some alternative ways to write it but have always come back to the field?.let { doSomeCheck } pattern. We could also replace it with if (field != null) { doSomeCheck } if you prefer it? The checks on init {} should stay though.

* @see LLMParams for the equivalent pattern used in completion/chat models.
*/
@Serializable
public open class EmbeddingParams {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, @mltheuser, could you please clarify, why is it implemented as an open class instead of an interface? I see several issue with current approach:

  1. Current implementation of the hashCode and the equals methods breaks the comparison logic as every instance have the same hash code.
  2. Why would I need to make an instance of a base EmbeddingParams type?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are completely right. The main motivation behind this was to keep in line with with the LLMParams class I used for reference. Then again, now after removing the dimensions field from this class it becomes quite apparent that this is not the right patter for the job. I will refactor it to be an interface like you suggested.

@mltheuser mltheuser changed the title KG-104 Control Embedding Dimensions & Batch Support feat(embedding): KG-104 Control Embedding Dimensions & Batch Support Feb 22, 2026
@mltheuser mltheuser changed the title feat(embedding): KG-104 Control Embedding Dimensions & Batch Support feat(embeddings): KG-104 Control Embedding Dimensions & Batch Support Mar 1, 2026
@mltheuser
Copy link
Copy Markdown
Contributor Author

@sdubov , @EugeneTheDev addressed your comments.

@aozherelyeva
Copy link
Copy Markdown
Contributor

Hi @mltheuser, sorry for the delay. I will now close this one as outdated.
Please rebase to the develop and open a new PR in case the change is still actual. Sorry for the long processing 😞

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support dimensions for embedding models

4 participants