diff --git a/agentscope-distribution/agentscope-all/pom.xml b/agentscope-distribution/agentscope-all/pom.xml
index e899a3f80..db8127d6e 100644
--- a/agentscope-distribution/agentscope-all/pom.xml
+++ b/agentscope-distribution/agentscope-all/pom.xml
@@ -88,6 +88,13 @@
true
+
+ io.agentscope
+ agentscope-extensions-training
+ compile
+ true
+
+
io.agentscopeagentscope-extensions-mem0
diff --git a/agentscope-distribution/agentscope-bom/pom.xml b/agentscope-distribution/agentscope-bom/pom.xml
index 5ba08db8a..428fa60ba 100644
--- a/agentscope-distribution/agentscope-bom/pom.xml
+++ b/agentscope-distribution/agentscope-bom/pom.xml
@@ -241,6 +241,13 @@
${project.version}
+
+
+ io.agentscope
+ agentscope-extensions-training
+ ${project.version}
+
+
io.agentscope
diff --git a/agentscope-extensions/agentscope-extensions-training/pom.xml b/agentscope-extensions/agentscope-extensions-training/pom.xml
new file mode 100644
index 000000000..34fd2463c
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/pom.xml
@@ -0,0 +1,110 @@
+
+
+
+
+ 4.0.0
+
+ io.agentscope
+ agentscope-extensions
+ ${revision}
+ ../pom.xml
+
+ agentscope-extensions-training
+
+ AgentScope Java - Extensions - Training
+ AgentScope Extensions - Training Data Collection and Export
+
+
+
+
+ io.agentscope
+ agentscope-core
+
+
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+
+
+ org.apache.parquet
+ parquet-avro
+ 1.13.1
+ true
+
+
+ org.apache.hadoop
+ hadoop-client
+ 3.3.6
+ true
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ log4j
+ log4j
+
+
+
+
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+ io.projectreactor
+ reactor-core
+
+
+
+
+ org.junit.jupiter
+ junit-jupiter
+ test
+
+
+ io.projectreactor
+ reactor-test
+ test
+
+
+ org.mockito
+ mockito-core
+ test
+
+
+ org.mockito
+ mockito-junit-jupiter
+ test
+
+
+ com.squareup.okhttp3
+ mockwebserver
+ test
+
+
+
+
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/TrinityClient.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/TrinityClient.java
new file mode 100644
index 000000000..a4f6011f2
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/TrinityClient.java
@@ -0,0 +1,195 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.backend;
+
+import io.agentscope.core.training.backend.dto.CommitRequest;
+import io.agentscope.core.training.backend.dto.FeedbackRequest;
+import io.agentscope.core.training.backend.dto.StatusResponse;
+import io.agentscope.core.util.JsonUtils;
+import java.time.Duration;
+import okhttp3.MediaType;
+import okhttp3.OkHttpClient;
+import okhttp3.Request;
+import okhttp3.RequestBody;
+import okhttp3.Response;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import reactor.core.publisher.Mono;
+
+/**
+ * Trinity Backend Client
+ *
+ *
Simplified Trinity client that only encapsulates Feedback and Commit APIs.
+ *
+ *
Chat API is no longer called through this client, but uses AgentScope's {@link io.agentscope.core.model.OpenAIChatModel} directly,
+ * because Trinity's Chat API is fully compatible with OpenAI format.
+ *
+ *
This client is only responsible for Trinity-specific training APIs:
+ *
+ *
Feedback API - Submit reward feedback
+ *
Commit API - Trigger training commit
+ *
+ */
+public class TrinityClient {
+ private static final Logger logger = LoggerFactory.getLogger(TrinityClient.class);
+ private static final MediaType JSON = MediaType.get("application/json; charset=utf-8");
+
+ private final String baseUrl;
+ private final OkHttpClient httpClient;
+
+ private TrinityClient(Builder builder) {
+ this.baseUrl = builder.endpoint;
+ this.httpClient =
+ new OkHttpClient.Builder()
+ .connectTimeout(builder.timeout)
+ .readTimeout(builder.timeout)
+ .writeTimeout(builder.timeout)
+ .build();
+ }
+
+ /**
+ * Submit Feedback (reward feedback)
+ *
+ * @param request Feedback request (containing msg_ids and reward)
+ * @return Mono<Void> that completes when feedback is submitted
+ */
+ public Mono feedback(FeedbackRequest request) {
+ return Mono.fromCallable(
+ () -> {
+ logger.debug(
+ "Submitting feedback: msgIds={}, reward={}, taskId={},"
+ + " runId={}",
+ request.getMsgIds(),
+ request.getReward(),
+ request.getTaskId(),
+ request.getRunId());
+
+ String jsonBody = JsonUtils.getJsonCodec().toJson(request);
+ String endpoint = baseUrl + "/feedback";
+
+ // Print actual JSON sent for debugging
+ logger.info("Sending feedback to {}: {}", endpoint, jsonBody);
+
+ Request httpRequest =
+ new Request.Builder()
+ .url(endpoint)
+ .post(RequestBody.create(jsonBody, JSON))
+ .build();
+
+ try (Response response = httpClient.newCall(httpRequest).execute()) {
+ if (!response.isSuccessful()) {
+ throw new RuntimeException(
+ "Feedback API failed: " + response.code());
+ }
+
+ String responseBody = response.body().string();
+ StatusResponse statusResponse =
+ JsonUtils.getJsonCodec()
+ .fromJson(responseBody, StatusResponse.class);
+
+ logger.info(
+ "Feedback submitted: msgIds={}, reward={}, taskId={},"
+ + " runId={}",
+ request.getMsgIds(),
+ request.getReward(),
+ request.getTaskId(),
+ request.getRunId());
+
+ return statusResponse;
+ }
+ })
+ .doOnError(e -> logger.error("Failed to submit feedback: {}", e.getMessage()))
+ .then();
+ }
+
+ /**
+ * Submit Commit to trigger training
+ *
+ * @param request Commit request (containing task_id and run_id)
+ * @return Mono<Void> that completes when commit is successful
+ */
+ public Mono commit(CommitRequest request) {
+ return Mono.fromCallable(
+ () -> {
+ logger.debug(
+ "Triggering commit: taskId={}, runId={}",
+ request.getTaskId(),
+ request.getRunId());
+
+ String jsonBody = JsonUtils.getJsonCodec().toJson(request);
+ String endpoint = baseUrl + "/commit";
+
+ Request httpRequest =
+ new Request.Builder()
+ .url(endpoint)
+ .post(RequestBody.create(jsonBody, JSON))
+ .build();
+
+ try (Response response = httpClient.newCall(httpRequest).execute()) {
+ if (!response.isSuccessful()) {
+ throw new RuntimeException(
+ "Commit API failed: " + response.code());
+ }
+
+ String responseBody = response.body().string();
+ StatusResponse statusResponse =
+ JsonUtils.getJsonCodec()
+ .fromJson(responseBody, StatusResponse.class);
+
+ logger.info(
+ "Training committed: taskId={}, runId={}, timeThreshold={}",
+ request.getTaskId(),
+ request.getRunId(),
+ request.getTimeThreshold());
+
+ return statusResponse;
+ }
+ })
+ .doOnError(e -> logger.error("Failed to commit: {}", e.getMessage()))
+ .then();
+ }
+
+ public String getEndpoint() {
+ return baseUrl;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+ private String endpoint;
+ private Duration timeout = Duration.ofSeconds(300);
+
+ public Builder endpoint(String endpoint) {
+ this.endpoint = endpoint;
+ return this;
+ }
+
+ public Builder timeout(Duration timeout) {
+ this.timeout = timeout;
+ return this;
+ }
+
+ public TrinityClient build() {
+ if (endpoint == null || endpoint.isEmpty()) {
+ throw new IllegalArgumentException("endpoint is required");
+ }
+ return new TrinityClient(this);
+ }
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/TrinityModelAdapter.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/TrinityModelAdapter.java
new file mode 100644
index 000000000..70ce34428
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/TrinityModelAdapter.java
@@ -0,0 +1,185 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.backend;
+
+import io.agentscope.core.message.Msg;
+import io.agentscope.core.model.ChatModelBase;
+import io.agentscope.core.model.ChatResponse;
+import io.agentscope.core.model.GenerateOptions;
+import io.agentscope.core.model.OpenAIChatModel;
+import io.agentscope.core.model.ToolSchema;
+import io.agentscope.core.training.runner.RunExecutionContext;
+import java.util.List;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import reactor.core.publisher.Flux;
+
+/**
+ * Trinity Model Adapter
+ *
+ *
Lightweight adapter using composition pattern to wrap OpenAIChatModel for:
+ *
+ *
Routing model calls to Trinity Chat API
+ *
Automatically collecting msg_ids and associating them with RunExecutionContext
+ *
+ *
+ *
Trinity's Chat API is fully compatible with OpenAI format, so OpenAIChatModel can be used directly.
+ * This adapter simply intercepts responses and collects msg_id (response.id) for subsequent feedback.
+ *
+ *
Automated design:
+ *
+ *
Internal component, created and managed by TrainingRouter
+ *
Automatically collects msg_ids to RunExecutionContext
Called by TrainingRouter for automatic msg_ids collection.
+ *
+ * @param context Task execution context
+ * @return this
+ */
+ public Builder executionContext(RunExecutionContext context) {
+ this.executionContext = context;
+ return this;
+ }
+
+ /**
+ * Build TrinityModelAdapter
+ *
+ * @return TrinityModelAdapter instance
+ */
+ public TrinityModelAdapter build() {
+ // Log: Print actual baseUrl used
+ logger.info(
+ "Creating TrinityModelAdapter with baseUrl={}, modelName={}",
+ baseUrl,
+ modelName);
+
+ // Create delegate object using OpenAIChatModel.builder()
+ OpenAIChatModel openAIModel =
+ OpenAIChatModel.builder()
+ .baseUrl(baseUrl)
+ .modelName(modelName)
+ .apiKey(apiKey)
+ .stream(false) // Trinity doesn't support streaming, force disable
+ .build();
+
+ logger.debug("TrinityModelAdapter created successfully");
+
+ return new TrinityModelAdapter(openAIModel, executionContext);
+ }
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/dto/CommitRequest.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/dto/CommitRequest.java
new file mode 100644
index 000000000..9e54702f7
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/dto/CommitRequest.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.backend.dto;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+/**
+ * Trinity Commit API Request
+ */
+public class CommitRequest {
+ @JsonProperty("task_id")
+ private String taskId;
+
+ @JsonProperty("run_id")
+ private String runId;
+
+ @JsonProperty("time_threshold")
+ private Long timeThreshold;
+
+ private CommitRequest(Builder builder) {
+ this.taskId = builder.taskId;
+ this.runId = builder.runId;
+ this.timeThreshold = builder.timeThreshold;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ // Getters
+ public String getTaskId() {
+ return taskId;
+ }
+
+ public String getRunId() {
+ return runId;
+ }
+
+ public Long getTimeThreshold() {
+ return timeThreshold;
+ }
+
+ public static class Builder {
+ private String taskId;
+ private String runId;
+ private Long timeThreshold;
+
+ public Builder taskId(String taskId) {
+ this.taskId = taskId;
+ return this;
+ }
+
+ public Builder runId(String runId) {
+ this.runId = runId;
+ return this;
+ }
+
+ public Builder timeThreshold(Long timeThreshold) {
+ this.timeThreshold = timeThreshold;
+ return this;
+ }
+
+ public CommitRequest build() {
+ return new CommitRequest(this);
+ }
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/dto/FeedbackRequest.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/dto/FeedbackRequest.java
new file mode 100644
index 000000000..4cfda71e8
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/dto/FeedbackRequest.java
@@ -0,0 +1,96 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.backend.dto;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import java.util.List;
+
+/**
+ * Trinity Feedback API Request
+ */
+public class FeedbackRequest {
+ @JsonProperty("msg_ids")
+ private List msgIds;
+
+ @JsonProperty("reward")
+ private Double reward;
+
+ @JsonProperty("task_id")
+ private String taskId;
+
+ @JsonProperty("run_id")
+ private String runId;
+
+ private FeedbackRequest(Builder builder) {
+ this.msgIds = builder.msgIds;
+ this.reward = builder.reward;
+ this.taskId = builder.taskId;
+ this.runId = builder.runId;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ // Getters
+ public List getMsgIds() {
+ return msgIds;
+ }
+
+ public Double getReward() {
+ return reward;
+ }
+
+ public String getTaskId() {
+ return taskId;
+ }
+
+ public String getRunId() {
+ return runId;
+ }
+
+ public static class Builder {
+ private List msgIds;
+ private Double reward;
+ private String taskId;
+ private String runId;
+
+ public Builder msgIds(List msgIds) {
+ this.msgIds = msgIds;
+ return this;
+ }
+
+ public Builder reward(Double reward) {
+ this.reward = reward;
+ return this;
+ }
+
+ public Builder taskId(String taskId) {
+ this.taskId = taskId;
+ return this;
+ }
+
+ public Builder runId(String runId) {
+ this.runId = runId;
+ return this;
+ }
+
+ public FeedbackRequest build() {
+ return new FeedbackRequest(this);
+ }
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/dto/StatusResponse.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/dto/StatusResponse.java
new file mode 100644
index 000000000..e210a4c0a
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/backend/dto/StatusResponse.java
@@ -0,0 +1,51 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.agentscope.core.training.backend.dto;
+
+/**
+ * Trinity Common Status Response
+ */
+public class StatusResponse {
+ private String status;
+ private String message;
+
+ public StatusResponse() {}
+
+ public StatusResponse(String status, String message) {
+ this.status = status;
+ this.message = message;
+ }
+
+ public String getStatus() {
+ return status;
+ }
+
+ public void setStatus(String status) {
+ this.status = status;
+ }
+
+ public String getMessage() {
+ return message;
+ }
+
+ public void setMessage(String message) {
+ this.message = message;
+ }
+
+ public boolean isSuccess() {
+ return "success".equalsIgnoreCase(status);
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/reward/RewardCalculator.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/reward/RewardCalculator.java
new file mode 100644
index 000000000..46b9a49d0
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/reward/RewardCalculator.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.reward;
+
+import io.agentscope.core.agent.Agent;
+
+/**
+ * Reward Calculator Interface
+ *
+ *
Calculate reward value based on shadow Agent's execution results
+ */
+public interface RewardCalculator {
+
+ /**
+ * Calculate reward value based on execution results
+ *
+ * @param agent Shadow Agent
+ * @return Reward value (typically between -1 and 1)
+ */
+ double calculate(Agent agent);
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/AgentCloner.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/AgentCloner.java
new file mode 100644
index 000000000..c3ec29454
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/AgentCloner.java
@@ -0,0 +1,146 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import io.agentscope.core.ReActAgent;
+import io.agentscope.core.agent.Agent;
+import io.agentscope.core.memory.InMemoryMemory;
+import io.agentscope.core.model.ExecutionConfig;
+import io.agentscope.core.model.Model;
+import io.agentscope.core.model.StructuredOutputReminder;
+import io.agentscope.core.plan.PlanNotebook;
+import io.agentscope.core.tool.ToolExecutionContext;
+import io.agentscope.core.tool.Toolkit;
+import java.lang.reflect.Field;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Agent Cloning Utility Class
+ *
+ *
Uses reflection to extract Agent configuration, then rebuilds new Agent instance with Builder
+ *
+ *
Design philosophy:
+ *
+ *
No shared state (create new memory)
+ *
Only replace model field
+ *
Keep all other configurations identical
+ *
+ */
+public class AgentCloner {
+ private static final Logger logger = LoggerFactory.getLogger(AgentCloner.class);
+
+ /**
+ * Clone Agent and replace model
+ *
+ * @param original Original Agent
+ * @param newModel New model (e.g. TrinityModelWrapper)
+ * @return Cloned Agent instance
+ */
+ public static Agent cloneWithModel(Agent original, Model newModel) {
+ if (original instanceof ReActAgent) {
+ return cloneReActAgent((ReActAgent) original, newModel);
+ }
+
+ // TODO: Support other Agent types
+ throw new UnsupportedOperationException(
+ "Agent cloning not implemented for: "
+ + original.getClass().getSimpleName()
+ + ". Please implement cloning logic for this agent type.");
+ }
+
+ /**
+ * Clone ReActAgent
+ */
+ private static ReActAgent cloneReActAgent(ReActAgent original, Model newModel) {
+ try {
+ logger.debug("Cloning ReActAgent: {}", original.getName());
+
+ // Extract configuration fields
+ String sysPrompt = extractField(original, "sysPrompt");
+ Toolkit toolkit = extractField(original, "toolkit");
+ Integer maxIters = extractField(original, "maxIters");
+ ExecutionConfig modelExecutionConfig = extractField(original, "modelExecutionConfig");
+ ExecutionConfig toolExecutionConfig = extractField(original, "toolExecutionConfig");
+ StructuredOutputReminder structuredOutputReminder =
+ extractField(original, "structuredOutputReminder");
+ PlanNotebook planNotebook = extractField(original, "planNotebook");
+ ToolExecutionContext toolExecutionContext =
+ extractField(original, "toolExecutionContext");
+
+ // Rebuild Agent using Builder
+ ReActAgent.Builder builder =
+ ReActAgent.builder()
+ .name(original.getName() + "-shadow")
+ .description(original.getDescription())
+ .sysPrompt(sysPrompt)
+ .model(newModel) // Replace model
+ .toolkit(toolkit)
+ .memory(new InMemoryMemory()); // New memory, no shared state
+
+ // Set optional fields
+ if (maxIters != null) {
+ builder.maxIters(maxIters);
+ }
+ if (modelExecutionConfig != null) {
+ builder.modelExecutionConfig(modelExecutionConfig);
+ }
+ if (toolExecutionConfig != null) {
+ builder.toolExecutionConfig(toolExecutionConfig);
+ }
+ if (structuredOutputReminder != null) {
+ builder.structuredOutputReminder(structuredOutputReminder);
+ }
+ if (planNotebook != null) {
+ builder.planNotebook(planNotebook);
+ }
+ if (toolExecutionContext != null) {
+ builder.toolExecutionContext(toolExecutionContext);
+ }
+
+ ReActAgent shadowAgent = builder.build();
+
+ logger.debug(
+ "Successfully cloned ReActAgent: {} -> {}",
+ original.getName(),
+ shadowAgent.getName());
+
+ return shadowAgent;
+
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to clone ReActAgent: " + original.getName(), e);
+ }
+ }
+
+ /**
+ * Extract private final fields using reflection
+ */
+ @SuppressWarnings("unchecked")
+ private static T extractField(Object obj, String fieldName) {
+ try {
+ Field field = obj.getClass().getDeclaredField(fieldName);
+ field.setAccessible(true);
+ return (T) field.get(obj);
+ } catch (NoSuchFieldException e) {
+ logger.warn("Field not found: {} in {}", fieldName, obj.getClass().getSimpleName());
+ return null;
+ } catch (Exception e) {
+ throw new RuntimeException(
+ "Failed to extract field: " + fieldName + " from " + obj.getClass(), e);
+ }
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/RunExecutionContext.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/RunExecutionContext.java
new file mode 100644
index 000000000..651dffa60
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/RunExecutionContext.java
@@ -0,0 +1,284 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import io.agentscope.core.message.Msg;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CopyOnWriteArrayList;
+
+/**
+ * Run Execution Context
+ *
+ *
Encapsulates the complete execution information for a single Run.
+ *
+ *
Core Concepts:
+ *
+ *
Task: Represents a user's original request (business-level unique identifier)
+ *
Run: A specific execution of the same Task (experiment/training level)
+ *
RunExecutionContext: Encapsulates the execution context for one Run
Run ID: The Nth execution of this Task (incrementing from "0")
+ *
msg_ids: All message IDs from LLM interactions in this Run (returned from Trinity API)
+ *
msgs: All messages from Agent memory in this Run (raw data for reward calculation)
+ *
Timing info: Start time and duration of this Run
+ *
+ *
+ *
Lifecycle:
+ *
+ * 1. TrainingRouter creates RunExecutionContext (corresponds to one Run)
+ * 2. Passed to TrinityModelAdapter (automatically collects msg_ids)
+ * 3. Shadow Agent execution completes, all data collected
+ * 4. TrainingRouter calculates reward
+ * 5. TrainingRouter submits Feedback using (taskId, runId, msg_ids, reward)
+ * 6. Registered to TaskExecutionRegistry (optional, for later queries)
+ *
+ *
+ *
Thread Safety:
+ *
+ *
Uses {@link CopyOnWriteArrayList} to store msg_ids
+ *
Supports concurrent additions (although currently executed synchronously)
+ *
+ *
+ * @see TaskIdGenerator Automatically generates Task ID
+ * @see RunRegistry Manages multiple Runs for the same Task
+ * @see TaskExecutionRegistry Stores execution contexts for all Runs
+ */
+public class RunExecutionContext {
+
+ /** Task ID - Unique identifier for user request (can be specified by user or auto-generated) */
+ private final String taskId;
+
+ /** Run ID - The Nth execution of the same Task (incrementing from "0") */
+ private final String runId;
+
+ /** Message IDs - All message IDs from LLM interactions in this Run */
+ private final List msgIds;
+
+ /** Messages - All messages from agent memory in this Run (for reward calculation) */
+ private final List msgs;
+
+ /** Start time of this Run (millisecond timestamp) */
+ private final long startTime;
+
+ /**
+ * Private constructor
+ *
+ * @param taskId Task ID
+ * @param runId Run ID
+ */
+ private RunExecutionContext(String taskId, String runId) {
+ if (taskId == null || taskId.isEmpty()) {
+ throw new IllegalArgumentException("Task ID cannot be null or empty");
+ }
+ if (runId == null || runId.isEmpty()) {
+ throw new IllegalArgumentException("Run ID cannot be null or empty");
+ }
+
+ this.taskId = taskId;
+ this.runId = runId;
+ this.msgIds = new CopyOnWriteArrayList<>();
+ this.msgs = new CopyOnWriteArrayList<>();
+ this.startTime = System.currentTimeMillis();
+ }
+
+ /**
+ * Create new Run execution context
+ *
+ *
Usage scenario: Created by TrainingRouter after intercepting Agent call.
+ *
+ * @param taskId Task ID
+ * @param runId Run ID
+ * @return Run execution context instance
+ * @throws IllegalArgumentException if taskId or runId is null/empty
+ */
+ public static RunExecutionContext create(String taskId, String runId) {
+ return new RunExecutionContext(taskId, runId);
+ }
+
+ /**
+ * Add Message ID
+ *
+ *
Message ID is extracted from Trinity Chat API response ({@code response.getId()}).
+ *
+ *
Thread-safe: This method can be called concurrently.
+ *
+ * @param msgId Message ID (ignored if null or empty)
+ */
+ public void addMsgId(String msgId) {
+ if (msgId != null && !msgId.isEmpty()) {
+ msgIds.add(msgId);
+ }
+ }
+
+ /**
+ * Get all Message IDs
+ *
+ *
Return value: Returns a new List copy to prevent external modification.
+ *
+ * @return Message IDs list (unmodifiable)
+ */
+ public List getMsgIds() {
+ return new ArrayList<>(msgIds);
+ }
+
+ /**
+ * Get Message IDs count
+ *
+ * @return Number of msg_ids
+ */
+ public int getMsgIdCount() {
+ return msgIds.size();
+ }
+
+ /**
+ * Get Task ID
+ *
+ * @return Task ID
+ */
+ public String getTaskId() {
+ return taskId;
+ }
+
+ /**
+ * Get Run ID
+ *
+ * @return Run ID
+ */
+ public String getRunId() {
+ return runId;
+ }
+
+ /**
+ * Get Run start time
+ *
+ * @return Start time (millisecond timestamp)
+ */
+ public long getStartTime() {
+ return startTime;
+ }
+
+ /**
+ * Get Run execution duration
+ *
+ * @return Execution duration (milliseconds)
+ */
+ public long getDuration() {
+ return System.currentTimeMillis() - startTime;
+ }
+
+ /**
+ * Check if msg_ids have been collected
+ *
+ * @return true if at least one msg_id exists
+ */
+ public boolean hasMsgIds() {
+ return !msgIds.isEmpty();
+ }
+
+ /**
+ * Add Message
+ *
+ *
Thread-safe: This method can be called concurrently.
+ *
+ * @param msg Message (ignored if null)
+ */
+ public void addMsg(Msg msg) {
+ if (msg != null) {
+ msgs.add(msg);
+ }
+ }
+
+ /**
+ * Batch set Messages (usually obtained from agent memory)
+ *
+ *
Clears existing messages and adds new message list.
+ *
+ *
Thread-safe: This method can be called concurrently.
+ *
+ * @param messages Message list (clears existing messages if null or empty)
+ */
+ public void setMessages(List messages) {
+ msgs.clear();
+ if (messages != null && !messages.isEmpty()) {
+ msgs.addAll(messages);
+ }
+ }
+
+ /**
+ * Get all Messages
+ *
+ *
Return value: Returns a new List copy to prevent external modification.
+ *
+ * @return Messages list (unmodifiable)
+ */
+ public List getMessages() {
+ return new ArrayList<>(msgs);
+ }
+
+ /**
+ * Get Messages count
+ *
+ * @return Number of msgs
+ */
+ public int getMessageCount() {
+ return msgs.size();
+ }
+
+ /**
+ * Check if messages have been collected
+ *
+ * @return true if at least one message exists
+ */
+ public boolean hasMessages() {
+ return !msgs.isEmpty();
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "RunExecution{task=%s, run=%s, msgIds=%d, msgs=%d, duration=%dms}",
+ taskId, runId, msgIds.size(), msgs.size(), getDuration());
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ RunExecutionContext that = (RunExecutionContext) o;
+ return taskId.equals(that.taskId) && runId.equals(that.runId);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = taskId.hashCode();
+ result = 31 * result + runId.hashCode();
+ return result;
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/RunRegistry.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/RunRegistry.java
new file mode 100644
index 000000000..8145b6d98
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/RunRegistry.java
@@ -0,0 +1,158 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import java.time.Duration;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Run ID Registry
+ *
+ *
Manages execution count (Run ID) for each Task ID.
+ *
+ *
Run ID is used to identify the Nth execution of the same Task, supports:
+ *
+ *
Comparison experiments (same question, different models/parameters)
+ *
Retry mechanism (re-execution after failure)
+ *
A/B testing (same request, different strategies)
+ *
+ *
+ *
Thread safety:
+ *
+ *
Uses {@link ConcurrentHashMap} to store counters
+ *
Uses {@link AtomicInteger} to ensure concurrent safety
+ *
+ *
+ *
Memory management:
+ *
+ *
Supports manual cleanup of specific Task counters
+ *
Supports periodic cleanup of all counters (avoid memory leaks)
+ *
+ *
+ * @see TaskIdGenerator
+ * @see RunExecutionContext
+ */
+class RunRegistry {
+
+ /**
+ * Task ID → Run counter
+ * Each Task ID corresponds to an incrementing counter
+ */
+ private static final ConcurrentHashMap taskRunCounters =
+ new ConcurrentHashMap<>();
+
+ // Private constructor to prevent instantiation
+ private RunRegistry() {}
+
+ /**
+ * Allocate a new Run ID for the specified Task
+ *
+ *
Run ID increments from 0 (0, 1, 2, ...)
+ *
+ *
Thread-safe: Safe to call this method concurrently from multiple threads.
+ *
+ * @param taskId Task ID, cannot be null
+ * @return Run ID (string format, e.g. "0", "1", "2")
+ * @throws IllegalArgumentException if taskId is null
+ */
+ public static String allocateRunId(String taskId) {
+ if (taskId == null) {
+ throw new IllegalArgumentException("Task ID cannot be null");
+ }
+
+ AtomicInteger counter = taskRunCounters.computeIfAbsent(taskId, k -> new AtomicInteger(0));
+
+ int runId = counter.getAndIncrement();
+ return String.valueOf(runId);
+ }
+
+ /**
+ * Get current Run count for the specified Task
+ *
+ * @param taskId Task ID
+ * @return Current Run count (returns 0 if Task doesn't exist)
+ */
+ public static int getCurrentRunCount(String taskId) {
+ if (taskId == null) {
+ return 0;
+ }
+ AtomicInteger counter = taskRunCounters.get(taskId);
+ return counter != null ? counter.get() : 0;
+ }
+
+ /**
+ * Clean up counter for specified Task
+ *
+ *
Used to release memory and avoid memory leaks from long-running processes.
+ *
+ *
Use cases:
+ *
+ *
Manual cleanup after Task execution completes
+ *
Cleanup counter after Feedback submission succeeds
+ *
+ *
+ * @param taskId Task ID
+ */
+ public static void cleanup(String taskId) {
+ if (taskId != null) {
+ taskRunCounters.remove(taskId);
+ }
+ }
+
+ /**
+ * Clean up all counters
+ *
+ *
Warning: This operation clears all Task counters, use with caution!
+ *
+ *
Use cases:
+ *
+ *
When TrainingRunner stops
+ *
During system reset
+ *
+ */
+ public static void clearAll() {
+ taskRunCounters.clear();
+ }
+
+ /**
+ * Get number of Tasks in current registry
+ *
+ * @return Number of Tasks
+ */
+ public static int size() {
+ return taskRunCounters.size();
+ }
+
+ /**
+ * Periodically clean up old Task counters (reserved interface)
+ *
+ *
Can be extended in the future to: clean up Tasks unused for specified time.
+ *
+ *
Implementation approach:
+ *
+ *
Record last access time for each Task
+ *
Periodically scan and clean up expired Tasks
+ *
+ *
+ * @param ttl Expiration time
+ */
+ public static void cleanupOldTasks(Duration ttl) {
+ // TODO: Can implement time-based automatic cleanup in the future
+ // Current version: Manually call cleanup() for cleaning
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TaskExecutionRegistry.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TaskExecutionRegistry.java
new file mode 100644
index 000000000..b9e71b6fe
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TaskExecutionRegistry.java
@@ -0,0 +1,279 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Collectors;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Task Execution Registry
+ *
+ *
Stores execution contexts for all Runs of all Tasks, supports querying and analysis.
+ *
+ *
Core features:
+ *
+ *
Register {@link RunExecutionContext} for each Run
+ *
Query all Runs for a Task
+ *
Query specific (Task, Run)
+ *
Statistical analysis (total Task count, total Run count, etc.)
Provides clear all data interface {@link #clearAll()}
+ *
Recommendation: Periodically clean up expired data to avoid memory leaks
+ *
+ *
+ * @see RunExecutionContext
+ * @see RunRegistry
+ */
+public class TaskExecutionRegistry {
+
+ private static final Logger logger = LoggerFactory.getLogger(TaskExecutionRegistry.class);
+
+ /**
+ * Task ID → List of RunExecutionContext (sorted by Run ID)
+ */
+ private static final ConcurrentHashMap> registry =
+ new ConcurrentHashMap<>();
+
+ // Private constructor to prevent instantiation
+ private TaskExecutionRegistry() {}
+
+ /**
+ * Register a RunExecutionContext
+ *
+ *
Use case: Called by TrainingRouter after Shadow Agent execution completes.
+ *
+ * @param context Run execution context
+ */
+ public static void register(RunExecutionContext context) {
+ if (context == null) {
+ logger.warn("Attempted to register null context, ignoring");
+ return;
+ }
+
+ String taskId = context.getTaskId();
+ registry.compute(
+ taskId,
+ (k, existingList) -> {
+ List list =
+ existingList != null
+ ? existingList
+ : Collections.synchronizedList(new ArrayList<>());
+ list.add(context);
+ return list;
+ });
+
+ logger.debug(
+ "Registered context: {}, total runs for this task: {}",
+ context,
+ registry.get(taskId).size());
+ }
+
+ /**
+ * Get contexts for all Runs of specified Task
+ *
+ *
Return value: Copy of list sorted by Run ID in ascending order.
+ *
+ * @param taskId Task ID
+ * @return All Run contexts for this Task (returns empty list if Task doesn't exist)
+ */
+ public static List getRunsByTask(String taskId) {
+ if (taskId == null) {
+ return Collections.emptyList();
+ }
+
+ List contexts = registry.get(taskId);
+ if (contexts == null || contexts.isEmpty()) {
+ return Collections.emptyList();
+ }
+
+ // Return copy, sorted by runId
+ synchronized (contexts) {
+ return contexts.stream()
+ .sorted(Comparator.comparing(ctx -> Integer.parseInt(ctx.getRunId())))
+ .collect(Collectors.toList());
+ }
+ }
+
+ /**
+ * Get context for specified (Task, Run)
+ *
+ * @param taskId Task ID
+ * @param runId Run ID
+ * @return Corresponding context, returns null if not found
+ */
+ public static RunExecutionContext getRun(String taskId, String runId) {
+ if (taskId == null || runId == null) {
+ return null;
+ }
+
+ List contexts = registry.get(taskId);
+ if (contexts == null) {
+ return null;
+ }
+
+ synchronized (contexts) {
+ return contexts.stream()
+ .filter(ctx -> ctx.getRunId().equals(runId))
+ .findFirst()
+ .orElse(null);
+ }
+ }
+
+ /**
+ * Get all Task IDs
+ *
+ * @return Set of all Task IDs
+ */
+ public static Set getAllTaskIds() {
+ return new HashSet<>(registry.keySet());
+ }
+
+ /**
+ * Get Run count for specified Task
+ *
+ * @param taskId Task ID
+ * @return Run count (returns 0 if Task doesn't exist)
+ */
+ public static int getRunCount(String taskId) {
+ if (taskId == null) {
+ return 0;
+ }
+ List contexts = registry.get(taskId);
+ return contexts != null ? contexts.size() : 0;
+ }
+
+ /**
+ * Get total Task count
+ *
+ * @return Task count
+ */
+ public static int getTaskCount() {
+ return registry.size();
+ }
+
+ /**
+ * Get total Run count (all Runs of all Tasks)
+ *
+ * @return Total Run count
+ */
+ public static int getTotalRunCount() {
+ return registry.values().stream().mapToInt(List::size).sum();
+ }
+
+ /**
+ * Clean up all data for specified Task
+ *
+ *
Used to release memory and avoid memory leaks from long-running processes.
+ *
+ * @param taskId Task ID
+ * @return Number of Runs cleaned up
+ */
+ public static int cleanup(String taskId) {
+ if (taskId == null) {
+ return 0;
+ }
+
+ List removed = registry.remove(taskId);
+ int count = removed != null ? removed.size() : 0;
+
+ if (count > 0) {
+ logger.info("Cleaned up {} runs for task: {}", count, taskId);
+ }
+
+ return count;
+ }
+
+ /**
+ * Clear all data
+ *
+ *
Warning: This operation deletes all Run data for all Tasks, use with caution!
+ */
+ public static void clearAll() {
+ int taskCount = registry.size();
+ int runCount = getTotalRunCount();
+ registry.clear();
+ logger.info("Cleared all data: {} tasks, {} runs", taskCount, runCount);
+ }
+
+ /**
+ * Get statistical summary
+ *
+ * @return Statistical information
+ */
+ public static RegistryStats getStats() {
+ return new RegistryStats(getTaskCount(), getTotalRunCount());
+ }
+
+ /**
+ * Statistical information
+ */
+ public static class RegistryStats {
+ private final int taskCount;
+ private final int totalRunCount;
+
+ private RegistryStats(int taskCount, int totalRunCount) {
+ this.taskCount = taskCount;
+ this.totalRunCount = totalRunCount;
+ }
+
+ public int getTaskCount() {
+ return taskCount;
+ }
+
+ public int getTotalRunCount() {
+ return totalRunCount;
+ }
+
+ public double getAverageRunsPerTask() {
+ return taskCount > 0 ? (double) totalRunCount / taskCount : 0.0;
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "RegistryStats{tasks=%d, runs=%d, avg=%.2f}",
+ taskCount, totalRunCount, getAverageRunsPerTask());
+ }
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TaskIdGenerator.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TaskIdGenerator.java
new file mode 100644
index 000000000..016729bab
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TaskIdGenerator.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import java.util.UUID;
+
+/**
+ * Task ID Generator
+ *
+ *
Automatically generates unique Task ID for each Agent call.
+ *
+ *
Task ID is used to identify a user's complete request and is the top-level identifier for training data tracking.
+ *
+ *
Design notes:
+ *
+ *
Internal component, not visible to users
+ *
Automatically generates new Task ID for each Agent call
Example: {@code task-1704067200000-3e4f5a6b}
+ *
+ * @return Task ID with timestamp
+ */
+ public static String generateWithTimestamp() {
+ long timestamp = System.currentTimeMillis();
+ String randomPart = UUID.randomUUID().toString().substring(0, 8);
+ return PREFIX + "-" + timestamp + "-" + randomPart;
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TrainingConfig.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TrainingConfig.java
new file mode 100644
index 000000000..9197f6f86
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TrainingConfig.java
@@ -0,0 +1,259 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import io.agentscope.core.training.reward.RewardCalculator;
+import io.agentscope.core.training.strategy.SamplingRateStrategy;
+import io.agentscope.core.training.strategy.TrainingSelectionStrategy;
+import java.time.Duration;
+
+/**
+ * Configuration for Training Runner.
+ *
+ *
Configures various parameters for the training process, including Trinity service address, training model, selection strategy, reward calculator, etc.
+ *
+ *
Automated design:
+ *
+ *
Task ID and Run ID are fully managed automatically by the system
+ *
Users don't need to worry about task tracking details
+ *
Supports flexible training request filtering strategies
+ *
+ */
+public class TrainingConfig {
+ private final String trinityEndpoint; // Trinity service address, e.g. http://47.252.36.19:8010
+ private final String trinityApiKey; // Trinity API Key (optional, some deployments require it)
+ private final String
+ modelName; // Training model path, e.g. /home/ecs-user/models/Qwen2.5-0.5B-Instruct
+ private final TrainingSelectionStrategy
+ selectionStrategy; // Training request filtering strategy
+ private final RewardCalculator rewardCalculator;
+ private final long commitIntervalSeconds;
+ private final Duration httpTimeout;
+ private final boolean enableAutoCommit;
+ private final int shadowPoolSize;
+ private final int shadowPoolCapacity;
+ private final int repeatTime; // Number of times each task runs repeatedly, defaults to 1
+
+ private TrainingConfig(Builder builder) {
+ this.trinityEndpoint = builder.trinityEndpoint;
+ this.trinityApiKey = builder.trinityApiKey;
+ this.modelName = builder.modelName;
+ this.selectionStrategy = builder.selectionStrategy;
+ this.rewardCalculator = builder.rewardCalculator;
+ this.commitIntervalSeconds = builder.commitIntervalSeconds;
+ this.httpTimeout = builder.httpTimeout;
+ this.enableAutoCommit = builder.enableAutoCommit;
+ this.shadowPoolSize = builder.shadowPoolSize;
+ this.shadowPoolCapacity = builder.shadowPoolCapacity;
+ this.repeatTime = builder.repeatTime;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public String getTrinityEndpoint() {
+ return trinityEndpoint;
+ }
+
+ public String getTrinityApiKey() {
+ return trinityApiKey;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public TrainingSelectionStrategy getSelectionStrategy() {
+ return selectionStrategy;
+ }
+
+ public RewardCalculator getRewardCalculator() {
+ return rewardCalculator;
+ }
+
+ public long getCommitIntervalSeconds() {
+ return commitIntervalSeconds;
+ }
+
+ public Duration getHttpTimeout() {
+ return httpTimeout;
+ }
+
+ public boolean isEnableAutoCommit() {
+ return enableAutoCommit;
+ }
+
+ public int getShadowPoolSize() {
+ return shadowPoolSize;
+ }
+
+ public int getShadowPoolCapacity() {
+ return shadowPoolCapacity;
+ }
+
+ public int getRepeatTime() {
+ return repeatTime;
+ }
+
+ /**
+ * Get current sampling rate
+ *
+ * @return Sampling rate (0.0 ~ 1.0), returns -1 if using non-sampling strategy
+ */
+ public double getSampleRate() {
+ if (selectionStrategy instanceof SamplingRateStrategy) {
+ return ((SamplingRateStrategy) selectionStrategy).getSampleRate();
+ }
+ return -1;
+ }
+
+ public static class Builder {
+ private String trinityEndpoint;
+ private String trinityApiKey =
+ "dummy"; // Default value, some Trinity deployments don't need real key
+ private String modelName = "training-model";
+ private TrainingSelectionStrategy selectionStrategy; // Optional, defaults to 10% sampling
+ private RewardCalculator rewardCalculator;
+ private long commitIntervalSeconds = 300; // Default 5 minutes
+ private Duration httpTimeout = Duration.ofSeconds(300); // Default 5 minutes timeout
+ private boolean enableAutoCommit = true;
+ private int shadowPoolSize = 10; // Default shadow Agent thread pool size
+ private int shadowPoolCapacity = 1000; // Default shadow Agent queue capacity
+ private int repeatTime = 1; // Default each task runs 1 time
+
+ public Builder trinityEndpoint(String endpoint) {
+ this.trinityEndpoint = endpoint;
+ return this;
+ }
+
+ public Builder trinityApiKey(String apiKey) {
+ this.trinityApiKey = apiKey;
+ return this;
+ }
+
+ public Builder modelName(String modelName) {
+ this.modelName = modelName;
+ return this;
+ }
+
+ /**
+ * Set training request filtering strategy
+ *
+ *
Supports multiple strategies:
+ *
+ *
{@link SamplingRateStrategy} - Based on sampling rate
+ *
{@link io.agentscope.core.training.strategy.ExplicitMarkingStrategy} - Based on user explicit marking
When a request is selected for training, it will run repeatTime times with the same taskId,
+ * each run will be assigned an incrementing runId (0, 1, 2, ...).
+ *
+ *
Use cases:
+ *
+ *
A/B/C/D testing: Compare different strategy effects
+ *
Diversity training: Generate multiple training samples for the same task
+ *
Stability assessment: Evaluate model stability with same input
+ *
+ *
+ * @param repeatTime Repeat execution count, must be >= 1, defaults to 1
+ * @return this
+ * @throws IllegalArgumentException if repeatTime < 1
+ */
+ public Builder repeatTime(int repeatTime) {
+ if (repeatTime < 1) {
+ throw new IllegalArgumentException("repeatTime must be >= 1");
+ }
+ this.repeatTime = repeatTime;
+ return this;
+ }
+
+ public TrainingConfig build() {
+ if (trinityEndpoint == null || trinityEndpoint.isEmpty()) {
+ throw new IllegalArgumentException("Trinity endpoint must be specified");
+ }
+ if (rewardCalculator == null) {
+ throw new IllegalArgumentException("RewardCalculator must be specified");
+ }
+
+ // If no strategy specified, use default 10% sampling rate
+ if (selectionStrategy == null) {
+ selectionStrategy = SamplingRateStrategy.of(0.1);
+ }
+
+ return new TrainingConfig(this);
+ }
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TrainingRouter.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TrainingRouter.java
new file mode 100644
index 000000000..83e6a3ecf
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TrainingRouter.java
@@ -0,0 +1,351 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import io.agentscope.core.agent.Agent;
+import io.agentscope.core.hook.Hook;
+import io.agentscope.core.hook.HookEvent;
+import io.agentscope.core.hook.PostCallEvent;
+import io.agentscope.core.hook.PreCallEvent;
+import io.agentscope.core.message.Msg;
+import io.agentscope.core.training.backend.TrinityClient;
+import io.agentscope.core.training.backend.TrinityModelAdapter;
+import io.agentscope.core.training.backend.dto.FeedbackRequest;
+import io.agentscope.core.training.reward.RewardCalculator;
+import io.agentscope.core.training.strategy.SelectionDecision;
+import io.agentscope.core.training.strategy.TrainingSelectionStrategy;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import reactor.core.publisher.Mono;
+import reactor.core.scheduler.Scheduler;
+import reactor.core.scheduler.Schedulers;
+
+/**
+ * Training Router Hook
+ *
+ *
Core training router Hook implementing shadow traffic training architecture with fully automated Task/Run ID management
+ *
+ *
Automation Flow:
+ *
+ *
PreCallEvent: No modifications, let request execute normally
+ *
PostCallEvent: Filter and judge, run shadow Agent asynchronously
+ *
Auto-generate Task ID: Independent task identifier for each request
+ *
Auto-allocate Run ID: Execution count for the same Task
+ *
Shadow Agent: Replace model with TrinityModelAdapter
+ *
Auto-collect msg_ids: Extract from Trinity API responses
+ *
Calculate Reward: Based on shadow Agent execution results
+ *
Submit Feedback: Batch submit all training data
+ *
+ *
+ *
Key Design:
+ *
+ *
100% requests use production model, ensuring service quality
+ *
Sample portion of requests run shadow Agent asynchronously, no user impact
+ *
Shadow Agent uses training model, collects training data
+ *
Completely transparent to users: All IDs managed automatically
+ *
+ *
+ * @see TaskIdGenerator
+ * @see RunRegistry
+ * @see RunExecutionContext
+ */
+public class TrainingRouter implements Hook {
+ private static final Logger logger = LoggerFactory.getLogger(TrainingRouter.class);
+
+ private final TrainingConfig config;
+ private final TrinityClient trinityClient;
+ private final RewardCalculator rewardCalculator;
+ private final TrainingSelectionStrategy selectionStrategy;
+ private final Scheduler asyncScheduler;
+
+ // Save PreCallEvent inputs for retrieval during PostCallEvent
+ private final Map> callInputs = new ConcurrentHashMap<>();
+
+ public TrainingRouter(
+ TrainingConfig config,
+ TrinityClient trinityClient,
+ RewardCalculator rewardCalculator,
+ TrainingSelectionStrategy selectionStrategy) {
+ this.config = config;
+ this.trinityClient = trinityClient;
+ this.rewardCalculator = rewardCalculator;
+ this.selectionStrategy = selectionStrategy;
+ this.asyncScheduler =
+ Schedulers.newBoundedElastic(
+ config.getShadowPoolSize(),
+ config.getShadowPoolCapacity(),
+ "training-shadow");
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public Mono onEvent(T event) {
+ if (event instanceof PreCallEvent) {
+ PreCallEvent e = (PreCallEvent) event;
+ return handlePreCall(e).thenReturn((T) e);
+ } else if (event instanceof PostCallEvent) {
+ PostCallEvent e = (PostCallEvent) event;
+ return handlePostCall(e).thenReturn((T) e);
+ } else {
+ return Mono.just(event);
+ }
+ }
+
+ @Override
+ public int priority() {
+ return 500; // Low priority, runs after business logic
+ }
+
+ /**
+ * Handle PreCallEvent
+ * Save input messages for use in PostCallEvent
+ */
+ private Mono handlePreCall(PreCallEvent event) {
+ return Mono.fromRunnable(
+ () -> {
+ String agentId = event.getAgent().getAgentId();
+ callInputs.put(agentId, event.getInputMessages());
+ logger.debug("Saved input messages for agent: {}", agentId);
+ });
+ }
+
+ /**
+ * Handle PostCallEvent
+ * Auto-generate Task ID, allocate Run ID, run shadow Agent asynchronously
+ */
+ private Mono handlePostCall(PostCallEvent event) {
+ return Mono.deferContextual(
+ ctx -> {
+ String agentId = event.getAgent().getAgentId();
+ String agentName = event.getAgent().getName();
+
+ // ✅ Prevent shadow agent from triggering training (avoid recursive loop)
+ if (agentName != null && agentName.contains("-shadow")) {
+ logger.trace(
+ "Skipping training for shadow agent: {} (prevents recursive"
+ + " training)",
+ agentName);
+ callInputs.remove(agentId); // Clean up input
+ return Mono.empty();
+ }
+
+ // Get input messages
+ List inputs = callInputs.remove(agentId);
+ if (inputs == null) {
+ logger.warn("No input messages found for agent: {}", agentId);
+ return Mono.empty();
+ }
+
+ // Check if training is needed (using unified selection strategy)
+ logger.info(
+ "Checking training selection: agent={}, strategy={}, threadId={}",
+ event.getAgent().getName(),
+ selectionStrategy.getClass().getSimpleName(),
+ Thread.currentThread().getId());
+
+ SelectionDecision decision =
+ selectionStrategy.shouldSelect(
+ event.getAgent(), inputs, event.getFinalMessage(), ctx);
+
+ if (!decision.shouldTrain()) {
+ logger.info(
+ "Skip training: agent={}, reason={}",
+ event.getAgent().getName(),
+ decision.getReason());
+ return Mono.empty();
+ }
+
+ // ✅ Get or generate Task ID
+ // Prioritize user-specified taskId (supports multiple runs of same task)
+ // Otherwise auto-generate new taskId
+ String taskId =
+ decision.getMetadata() != null
+ && decision.getMetadata().containsKey("taskId")
+ ? (String) decision.getMetadata().get("taskId")
+ : TaskIdGenerator.generate();
+
+ if (decision.getMetadata() != null
+ && decision.getMetadata().containsKey("taskId")) {
+ logger.debug(
+ "Using user-specified Task ID: {} for agent: {}",
+ taskId,
+ event.getAgent().getName());
+ } else {
+ logger.debug(
+ "Auto-generated Task ID: {} for agent: {}",
+ taskId,
+ event.getAgent().getName());
+ }
+
+ // ✅ Get repeatTime configuration (how many times each task runs)
+ int repeatTime = config.getRepeatTime();
+
+ logger.info(
+ "Training triggered for task {}: repeatTime={}, reason={}, labels={}",
+ taskId,
+ repeatTime,
+ decision.getReason(),
+ decision.getLabels());
+
+ // ✅ Loop multiple runs (using same taskId, runId auto-increments)
+ for (int i = 0; i < repeatTime; i++) {
+ // Allocate new runId for each iteration
+ String currentRunId = RunRegistry.allocateRunId(taskId);
+ RunExecutionContext currentContext =
+ RunExecutionContext.create(taskId, currentRunId);
+
+ logger.info(
+ "Starting run {}/{} for task {}: {}",
+ i + 1,
+ repeatTime,
+ taskId,
+ currentContext);
+
+ // Check if this is the last run
+ boolean isLastRun = (i == repeatTime - 1);
+
+ // Run shadow Agent asynchronously (non-blocking)
+ runShadowAgent(event.getAgent(), inputs, currentContext, decision)
+ .subscribeOn(asyncScheduler)
+ .doFinally(
+ signal -> {
+ // Clean up resources only after last run
+ if (isLastRun) {
+ RunRegistry.cleanup(taskId);
+ logger.debug(
+ "Cleaned up registry for Task: {} after {}"
+ + " runs",
+ taskId,
+ repeatTime);
+ }
+ })
+ .subscribe(
+ null,
+ error ->
+ logger.error(
+ "Shadow agent failed for {}",
+ currentContext,
+ error));
+ }
+
+ return Mono.empty();
+ });
+ }
+
+ /**
+ * Run shadow Agent
+ *
+ *
Fully automated:
+ *
+ *
Auto-create TrinityModelAdapter (associated with RunExecutionContext)
+ *
Auto-collect msg_ids into RunExecutionContext
+ *
Auto-calculate Reward (based on execution results)
+ *
Auto-submit Feedback (including Task ID, Run ID, msg_ids)
+ *
+ */
+ private Mono runShadowAgent(
+ Agent productionAgent,
+ List inputs,
+ RunExecutionContext executionContext,
+ SelectionDecision decision) {
+ return Mono.defer(
+ () -> {
+ logger.info(
+ "Starting shadow agent for {}: agent={}, labels={}",
+ executionContext,
+ productionAgent.getName(),
+ decision.getLabels());
+
+ try {
+ // ✅ 1. Create Trinity model (auto-associate with RunExecutionContext)
+ TrinityModelAdapter trinityModel =
+ TrinityModelAdapter.builder()
+ .baseUrl(trinityClient.getEndpoint() + "/v1")
+ .modelName(config.getModelName())
+ .apiKey(config.getTrinityApiKey())
+ .executionContext(
+ executionContext) // ← Pass execution context
+ .build();
+
+ // 2. Clone Agent and replace model
+ Agent shadowAgent =
+ AgentCloner.cloneWithModel(productionAgent, trinityModel);
+
+ logger.debug(
+ "Shadow agent created for {}: {}",
+ executionContext,
+ shadowAgent.getName());
+
+ // ✅ 3. Execute shadow Agent
+ shadowAgent.call(inputs).block();
+
+ logger.info(
+ "Shadow agent completed: {}, duration={}ms",
+ executionContext,
+ executionContext.getDuration());
+
+ // ✅ 4. Auto-get msg_ids from RunExecutionContext
+ List msgIds = executionContext.getMsgIds();
+
+ logger.info("Collected {} msg_ids for {}", msgIds.size(), executionContext);
+
+ if (msgIds.isEmpty()) {
+ logger.warn("No msg_ids collected for {}", executionContext);
+ return Mono.empty();
+ }
+
+ // ✅ 5. Calculate reward using RewardCalculator
+ double reward = rewardCalculator.calculate(shadowAgent);
+
+ logger.info("Calculated reward: {} for {}", reward, executionContext);
+
+ // ✅ 6. Auto-submit feedback (Task ID, Run ID, msg_ids auto-filled)
+ return trinityClient
+ .feedback(
+ FeedbackRequest.builder()
+ .taskId(executionContext.getTaskId())
+ .runId(executionContext.getRunId())
+ .msgIds(msgIds)
+ .reward(reward)
+ .build())
+ .doOnSuccess(
+ v -> {
+ logger.info(
+ "Feedback submitted successfully for {}",
+ executionContext);
+
+ // ✅ 8. Save execution context to registry (for later
+ // queries)
+ TaskExecutionRegistry.register(executionContext);
+ logger.debug(
+ "Registered context: {}, total runs: {}",
+ executionContext,
+ TaskExecutionRegistry.getRunCount(
+ executionContext.getTaskId()));
+ })
+ .then();
+
+ } catch (Exception e) {
+ logger.error("Failed to create shadow agent for {}", executionContext, e);
+ return Mono.empty();
+ }
+ });
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TrainingRunner.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TrainingRunner.java
new file mode 100644
index 000000000..909e1598c
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/runner/TrainingRunner.java
@@ -0,0 +1,292 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import io.agentscope.core.agent.AgentBase;
+import io.agentscope.core.training.backend.TrinityClient;
+import io.agentscope.core.training.backend.dto.CommitRequest;
+import io.agentscope.core.training.reward.RewardCalculator;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import reactor.core.publisher.Mono;
+
+/**
+ * Training Runner - Training Process Controller
+ *
+ *
Responsible for:
+ *
+ *
Intercepting Agent requests and determining if they should be used for training
+ *
Replacing ChatClient to Trinity backend
+ *
Collecting trace data
+ *
Calculating rewards and calling feedback API
+ *
Periodically calling commit API to trigger training
+ */
+public class SamplingRateStrategy implements TrainingSelectionStrategy {
+
+ private final double sampleRate;
+ private final Random random;
+
+ private SamplingRateStrategy(double sampleRate) {
+ if (sampleRate < 0 || sampleRate > 1) {
+ throw new IllegalArgumentException("Sample rate must be between 0 and 1");
+ }
+ this.sampleRate = sampleRate;
+ this.random = new Random();
+ }
+
+ /**
+ * Create sampling rate strategy
+ *
+ * @param sampleRate Sampling rate, range [0, 1]
+ * @return Strategy instance
+ */
+ public static SamplingRateStrategy of(double sampleRate) {
+ return new SamplingRateStrategy(sampleRate);
+ }
+
+ @Override
+ public SelectionDecision shouldSelect(
+ Agent agent, List inputMessages, Msg outputMessage, ContextView reactorContext) {
+
+ if (random.nextDouble() < sampleRate) {
+ return SelectionDecision.accept("sampling-rate", "sampled");
+ }
+
+ return SelectionDecision.reject("not-sampled");
+ }
+
+ @Override
+ public int priority() {
+ return 200; // Low priority, typically used as fallback strategy
+ }
+
+ @Override
+ public String name() {
+ return "SamplingRate(" + sampleRate + ")";
+ }
+
+ public double getSampleRate() {
+ return this.sampleRate;
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/SelectionDecision.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/SelectionDecision.java
new file mode 100644
index 000000000..49de5f8a9
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/SelectionDecision.java
@@ -0,0 +1,113 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.strategy;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Selection Decision - Selection decision result
+ *
+ *
Encapsulates the decision of whether to train, along with related metadata (labels, reason, etc.)
+ */
+public class SelectionDecision {
+
+ /** Whether training should occur */
+ private final boolean shouldTrain;
+
+ /** Decision reason (for logging and debugging) */
+ private final String reason;
+
+ /** Training labels (optional, for classifying training data) */
+ private final List labels;
+
+ /** Additional metadata (optional) */
+ private final Map metadata;
+
+ private SelectionDecision(
+ boolean shouldTrain, String reason, List labels, Map metadata) {
+ this.shouldTrain = shouldTrain;
+ this.reason = reason;
+ this.labels =
+ labels != null ? Collections.unmodifiableList(labels) : Collections.emptyList();
+ this.metadata =
+ metadata != null ? Collections.unmodifiableMap(metadata) : Collections.emptyMap();
+ }
+
+ /**
+ * Create a "should train" decision
+ */
+ public static SelectionDecision accept(String reason) {
+ return new SelectionDecision(true, reason, null, null);
+ }
+
+ /**
+ * Create a "should train" decision with labels
+ */
+ public static SelectionDecision accept(String reason, String... labels) {
+ return new SelectionDecision(true, reason, Arrays.asList(labels), null);
+ }
+
+ /**
+ * Create a "should train" decision with labels and metadata
+ */
+ public static SelectionDecision accept(
+ String reason, List labels, Map metadata) {
+ return new SelectionDecision(true, reason, labels, metadata);
+ }
+
+ /**
+ * Create a "should not train" decision
+ */
+ public static SelectionDecision reject(String reason) {
+ return new SelectionDecision(false, reason, null, null);
+ }
+
+ // Getters
+ public boolean shouldTrain() {
+ return shouldTrain;
+ }
+
+ public String getReason() {
+ return reason;
+ }
+
+ public List getLabels() {
+ return labels;
+ }
+
+ public Map getMetadata() {
+ return metadata;
+ }
+
+ @Override
+ public String toString() {
+ return "SelectionDecision{"
+ + "shouldTrain="
+ + shouldTrain
+ + ", reason='"
+ + reason
+ + '\''
+ + ", labels="
+ + labels
+ + ", metadata="
+ + metadata
+ + '}';
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/TrainingAnnotation.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/TrainingAnnotation.java
new file mode 100644
index 000000000..8bdfb3354
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/TrainingAnnotation.java
@@ -0,0 +1,207 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.strategy;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * TrainingAnnotation represents a training annotation for an agent call.
+ *
+ *
It contains:
+ *
+ *
enabled: whether this call should be used for training
+ *
taskId: (optional) user-specified task ID for this request
+ *
labels: user-defined labels for categorizing training data
+ *
metadata: additional context information
+ *
timestamp: when this annotation was created
+ *
+ */
+public class TrainingAnnotation {
+
+ /** Whether this call is marked for training */
+ private final boolean enabled;
+
+ /** Optional user-specified task ID (null means auto-generate) */
+ private final String taskId;
+
+ /** User-defined labels for this training sample */
+ private final List labels;
+
+ /** Additional metadata for this training sample */
+ private final Map metadata;
+
+ /** Timestamp when this annotation was created */
+ private final long timestamp;
+
+ /**
+ * Private constructor for builder pattern.
+ */
+ private TrainingAnnotation(
+ boolean enabled,
+ String taskId,
+ List labels,
+ Map metadata,
+ long timestamp) {
+ this.enabled = enabled;
+ this.taskId = taskId;
+ this.labels = labels != null ? labels : new ArrayList<>();
+ this.metadata = metadata != null ? metadata : new HashMap<>();
+ this.timestamp = timestamp;
+ }
+
+ // Getters
+ public boolean isEnabled() {
+ return enabled;
+ }
+
+ public String getTaskId() {
+ return taskId;
+ }
+
+ public List getLabels() {
+ return labels;
+ }
+
+ public Map getMetadata() {
+ return metadata;
+ }
+
+ public long getTimestamp() {
+ return timestamp;
+ }
+
+ /**
+ * Check if this annotation has expired based on the given TTL.
+ *
+ * @param ttlMillis Time-to-live in milliseconds
+ * @return true if expired, false otherwise
+ */
+ public boolean isExpired(long ttlMillis) {
+ return System.currentTimeMillis() - timestamp > ttlMillis;
+ }
+
+ /**
+ * Create a simple enabled annotation without labels or metadata.
+ */
+ public static TrainingAnnotation enabled() {
+ return TrainingAnnotation.builder().enabled(true).build();
+ }
+
+ /**
+ * Create an enabled annotation with labels.
+ */
+ public static TrainingAnnotation withLabels(String... labels) {
+ return TrainingAnnotation.builder().enabled(true).labels(Arrays.asList(labels)).build();
+ }
+
+ /**
+ * Create an enabled annotation with labels and metadata.
+ */
+ public static TrainingAnnotation withLabelsAndMetadata(
+ List labels, Map metadata) {
+ return TrainingAnnotation.builder()
+ .enabled(true)
+ .labels(labels != null ? new ArrayList<>(labels) : new ArrayList<>())
+ .metadata(metadata != null ? new HashMap<>(metadata) : new HashMap<>())
+ .build();
+ }
+
+ /**
+ * Create a builder for TrainingAnnotation.
+ */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ TrainingAnnotation that = (TrainingAnnotation) o;
+ return enabled == that.enabled
+ && timestamp == that.timestamp
+ && Objects.equals(taskId, that.taskId)
+ && Objects.equals(labels, that.labels)
+ && Objects.equals(metadata, that.metadata);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(enabled, taskId, labels, metadata, timestamp);
+ }
+
+ @Override
+ public String toString() {
+ return "TrainingAnnotation{"
+ + "enabled="
+ + enabled
+ + ", taskId="
+ + taskId
+ + ", labels="
+ + labels
+ + ", metadata="
+ + metadata
+ + ", timestamp="
+ + timestamp
+ + '}';
+ }
+
+ /**
+ * Builder class for TrainingAnnotation.
+ */
+ public static class Builder {
+ private boolean enabled;
+ private String taskId;
+ private List labels = new ArrayList<>();
+ private Map metadata = new HashMap<>();
+ private long timestamp = System.currentTimeMillis();
+
+ public Builder enabled(boolean enabled) {
+ this.enabled = enabled;
+ return this;
+ }
+
+ public Builder taskId(String taskId) {
+ this.taskId = taskId;
+ return this;
+ }
+
+ public Builder labels(List labels) {
+ this.labels = labels != null ? labels : new ArrayList<>();
+ return this;
+ }
+
+ public Builder metadata(Map metadata) {
+ this.metadata = metadata != null ? metadata : new HashMap<>();
+ return this;
+ }
+
+ public Builder timestamp(long timestamp) {
+ this.timestamp = timestamp;
+ return this;
+ }
+
+ public TrainingAnnotation build() {
+ return new TrainingAnnotation(enabled, taskId, labels, metadata, timestamp);
+ }
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/TrainingContext.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/TrainingContext.java
new file mode 100644
index 000000000..1f61437d2
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/TrainingContext.java
@@ -0,0 +1,173 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.strategy;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import reactor.util.context.Context;
+import reactor.util.context.ContextView;
+
+/**
+ * TrainingContext provides API for marking agent calls for training.
+ *
+ *
This class uses Reactor Context for propagating training markers across thread boundaries.
+ *
+ *
+ *
+ * @param labels Training labels
+ * @param metadata Metadata
+ * @return Reactor Context write function
+ */
+ public static java.util.function.Function mark(
+ List labels, Map metadata) {
+ TrainingAnnotation annotation = TrainingAnnotation.withLabelsAndMetadata(labels, metadata);
+ logger.info(
+ "Creating training marker with labels and metadata: labels={}, metadata={}",
+ labels,
+ metadata);
+ return ctx -> ctx.put(REACTOR_KEY, annotation);
+ }
+
+ /**
+ * Mark current Agent call for training (metadata only)
+ *
+ * @param metadata Metadata
+ * @return Reactor Context write function
+ */
+ public static java.util.function.Function mark(Map metadata) {
+ return mark(Collections.emptyList(), metadata);
+ }
+
+ /**
+ * Get the current training marker from Reactor Context.
+ *
+ * @param reactorContext Reactor context
+ * @return TrainingAnnotation if present, null otherwise
+ */
+ public static TrainingAnnotation getCurrent(ContextView reactorContext) {
+ if (reactorContext != null && reactorContext.hasKey(REACTOR_KEY)) {
+ TrainingAnnotation annotation = reactorContext.get(REACTOR_KEY);
+ logger.trace("Retrieved annotation from Reactor Context: {}", annotation);
+ return annotation;
+ }
+ return null;
+ }
+
+ /**
+ * Check if a marker has expired based on TTL.
+ *
+ * @param annotation The annotation to check
+ * @param ttlMillis Time-to-live in milliseconds
+ * @return true if expired or marker is null, false otherwise
+ */
+ static boolean isExpired(TrainingAnnotation annotation, long ttlMillis) {
+ return annotation == null || annotation.isExpired(ttlMillis);
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/TrainingSelectionStrategy.java b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/TrainingSelectionStrategy.java
new file mode 100644
index 000000000..c8730afcc
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/main/java/io/agentscope/core/training/strategy/TrainingSelectionStrategy.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.strategy;
+
+import io.agentscope.core.agent.Agent;
+import io.agentscope.core.message.Msg;
+import java.util.List;
+import reactor.util.context.ContextView;
+
+/**
+ * Training Selection Strategy - Training request filtering strategy
+ *
+ *
Unified interface for deciding which Agent requests should be used for training.
+ *
+ *
Supports multiple filtering methods:
+ *
+ *
Explicit marking - Users manually mark important requests
+ *
Random sampling - Randomly sample by probability
+ *
Combined strategy - Combination of multiple strategies (AND/OR)
+ *
+ *
+ *
+ * @see SamplingRateStrategy
+ * @see ExplicitMarkingStrategy
+ */
+@FunctionalInterface
+public interface TrainingSelectionStrategy {
+
+ /**
+ * Decide whether the current Agent call should be used for training
+ *
+ * @param agent The Agent being called
+ * @param inputMessages Input message list
+ * @param outputMessage Output message
+ * @param reactorContext Reactor context (for async scenarios)
+ * @return Selection decision result, including whether to train and related metadata
+ */
+ SelectionDecision shouldSelect(
+ Agent agent, List inputMessages, Msg outputMessage, ContextView reactorContext);
+
+ /**
+ * Priority of the strategy (lower number means higher priority)
+ *
+ *
When multiple strategies are combined, higher priority strategies execute first
+ *
+ * @return Priority value, defaults to 100
+ */
+ default int priority() {
+ return 100;
+ }
+
+ /**
+ * Name of the strategy (for logging and debugging)
+ *
+ * @return Strategy name
+ */
+ default String name() {
+ return this.getClass().getSimpleName();
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/TrinityClientTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/TrinityClientTest.java
new file mode 100644
index 000000000..a10ff6644
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/TrinityClientTest.java
@@ -0,0 +1,221 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.backend;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import io.agentscope.core.training.backend.dto.CommitRequest;
+import io.agentscope.core.training.backend.dto.FeedbackRequest;
+import io.agentscope.core.training.util.TrainingTestConstants;
+import java.io.IOException;
+import java.time.Duration;
+import java.util.Arrays;
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import okhttp3.mockwebserver.RecordedRequest;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+import reactor.test.StepVerifier;
+
+@DisplayName("TrinityClient Unit Tests")
+class TrinityClientTest {
+
+ private MockWebServer mockServer;
+ private TrinityClient client;
+
+ @BeforeEach
+ void setUp() throws IOException {
+ mockServer = new MockWebServer();
+ mockServer.start();
+
+ String baseUrl = mockServer.url("/").toString();
+ // Remove trailing slash
+ if (baseUrl.endsWith("/")) {
+ baseUrl = baseUrl.substring(0, baseUrl.length() - 1);
+ }
+
+ client = TrinityClient.builder().endpoint(baseUrl).timeout(Duration.ofSeconds(5)).build();
+ }
+
+ @AfterEach
+ void tearDown() throws IOException {
+ mockServer.shutdown();
+ }
+
+ @Test
+ @DisplayName("Should build client with valid endpoint")
+ void shouldBuildClientWithValidEndpoint() {
+ // Act
+ TrinityClient testClient =
+ TrinityClient.builder()
+ .endpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .build();
+
+ // Assert
+ assertNotNull(testClient);
+ assertEquals(TrainingTestConstants.TEST_TRINITY_ENDPOINT, testClient.getEndpoint());
+ }
+
+ @Test
+ @DisplayName("Should throw exception when endpoint is null")
+ void shouldThrowExceptionWhenEndpointIsNull() {
+ assertThrows(IllegalArgumentException.class, () -> TrinityClient.builder().build());
+ }
+
+ @Test
+ @DisplayName("Should throw exception when endpoint is empty")
+ void shouldThrowExceptionWhenEndpointIsEmpty() {
+ assertThrows(
+ IllegalArgumentException.class, () -> TrinityClient.builder().endpoint("").build());
+ }
+
+ @Test
+ @DisplayName("Should send feedback request successfully")
+ void shouldSendFeedbackRequestSuccessfully() throws InterruptedException {
+ // Arrange
+ mockServer.enqueue(
+ new MockResponse()
+ .setBody("{\"status\": \"success\", \"message\": \"OK\"}")
+ .setHeader("Content-Type", "application/json"));
+
+ FeedbackRequest request =
+ FeedbackRequest.builder()
+ .taskId(TrainingTestConstants.TEST_TASK_ID)
+ .runId(TrainingTestConstants.TEST_RUN_ID)
+ .msgIds(
+ Arrays.asList(
+ TrainingTestConstants.TEST_MSG_ID_1,
+ TrainingTestConstants.TEST_MSG_ID_2))
+ .reward(0.8)
+ .build();
+
+ // Act & Assert
+ StepVerifier.create(client.feedback(request)).verifyComplete();
+
+ // Verify request
+ RecordedRequest recordedRequest = mockServer.takeRequest();
+ assertEquals("/feedback", recordedRequest.getPath());
+ assertEquals("POST", recordedRequest.getMethod());
+ String body = recordedRequest.getBody().readUtf8();
+ assertTrue(body.contains("\"task_id\":\"" + TrainingTestConstants.TEST_TASK_ID + "\""));
+ assertTrue(body.contains("\"run_id\":\"" + TrainingTestConstants.TEST_RUN_ID + "\""));
+ assertTrue(body.contains("\"reward\":0.8"));
+ }
+
+ @Test
+ @DisplayName("Should handle feedback API error")
+ void shouldHandleFeedbackApiError() {
+ // Arrange
+ mockServer.enqueue(
+ new MockResponse().setResponseCode(500).setBody("Internal Server Error"));
+
+ FeedbackRequest request =
+ FeedbackRequest.builder()
+ .taskId(TrainingTestConstants.TEST_TASK_ID)
+ .runId(TrainingTestConstants.TEST_RUN_ID)
+ .msgIds(Arrays.asList(TrainingTestConstants.TEST_MSG_ID_1))
+ .reward(0.5)
+ .build();
+
+ // Act & Assert
+ StepVerifier.create(client.feedback(request)).expectError(RuntimeException.class).verify();
+ }
+
+ @Test
+ @DisplayName("Should send commit request successfully")
+ void shouldSendCommitRequestSuccessfully() throws InterruptedException {
+ // Arrange
+ mockServer.enqueue(
+ new MockResponse()
+ .setBody("{\"status\": \"success\", \"message\": \"OK\"}")
+ .setHeader("Content-Type", "application/json"));
+
+ CommitRequest request =
+ CommitRequest.builder()
+ .taskId(TrainingTestConstants.TEST_TASK_ID)
+ .runId(TrainingTestConstants.TEST_RUN_ID)
+ .timeThreshold(300000L)
+ .build();
+
+ // Act & Assert
+ StepVerifier.create(client.commit(request)).verifyComplete();
+
+ // Verify request
+ RecordedRequest recordedRequest = mockServer.takeRequest();
+ assertEquals("/commit", recordedRequest.getPath());
+ assertEquals("POST", recordedRequest.getMethod());
+ }
+
+ @Test
+ @DisplayName("Should handle commit API error")
+ void shouldHandleCommitApiError() {
+ // Arrange
+ mockServer.enqueue(new MockResponse().setResponseCode(400).setBody("Bad Request"));
+
+ CommitRequest request = CommitRequest.builder().build();
+
+ // Act & Assert
+ StepVerifier.create(client.commit(request)).expectError(RuntimeException.class).verify();
+ }
+
+ @Test
+ @DisplayName("Should use custom timeout")
+ void shouldUseCustomTimeout() {
+ // Act
+ TrinityClient testClient =
+ TrinityClient.builder()
+ .endpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .timeout(Duration.ofSeconds(60))
+ .build();
+
+ // Assert
+ assertNotNull(testClient);
+ }
+
+ @Test
+ @DisplayName("Should serialize feedback request with correct JSON field names")
+ void shouldSerializeFeedbackRequestWithCorrectJsonFieldNames() throws InterruptedException {
+ // Arrange
+ mockServer.enqueue(
+ new MockResponse()
+ .setBody("{\"status\": \"success\", \"message\": \"OK\"}")
+ .setHeader("Content-Type", "application/json"));
+
+ FeedbackRequest request =
+ FeedbackRequest.builder()
+ .taskId("task-123")
+ .runId("0")
+ .msgIds(Arrays.asList("msg-1", "msg-2"))
+ .reward(0.9)
+ .build();
+
+ // Act
+ client.feedback(request).block();
+
+ // Assert
+ RecordedRequest recordedRequest = mockServer.takeRequest();
+ String body = recordedRequest.getBody().readUtf8();
+ assertTrue(body.contains("\"msg_ids\""));
+ assertTrue(body.contains("\"task_id\""));
+ assertTrue(body.contains("\"run_id\""));
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/TrinityModelAdapterTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/TrinityModelAdapterTest.java
new file mode 100644
index 000000000..a05db3499
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/TrinityModelAdapterTest.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.backend;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+
+import io.agentscope.core.training.runner.RunExecutionContext;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+/**
+ * Tests for TrinityModelAdapter.
+ */
+@DisplayName("TrinityModelAdapter Tests")
+class TrinityModelAdapterTest {
+
+ @Test
+ @DisplayName("Should build TrinityModelAdapter with all parameters")
+ void shouldBuildWithAllParameters() {
+ RunExecutionContext context = RunExecutionContext.create("task-1", "0");
+
+ TrinityModelAdapter adapter =
+ TrinityModelAdapter.builder()
+ .baseUrl("http://localhost:8080/v1")
+ .modelName("test-model")
+ .apiKey("test-api-key")
+ .executionContext(context)
+ .build();
+
+ assertNotNull(adapter);
+ assertEquals("test-model", adapter.getModelName());
+ assertEquals(context, adapter.getExecutionContext());
+ }
+
+ @Test
+ @DisplayName("Should build TrinityModelAdapter with default apiKey")
+ void shouldBuildWithDefaultApiKey() {
+ TrinityModelAdapter adapter =
+ TrinityModelAdapter.builder()
+ .baseUrl("http://localhost:8080/v1")
+ .modelName("test-model")
+ .build();
+
+ assertNotNull(adapter);
+ assertEquals("test-model", adapter.getModelName());
+ }
+
+ @Test
+ @DisplayName("Should build TrinityModelAdapter without execution context")
+ void shouldBuildWithoutExecutionContext() {
+ TrinityModelAdapter adapter =
+ TrinityModelAdapter.builder()
+ .baseUrl("http://localhost:8080/v1")
+ .modelName("test-model")
+ .apiKey("api-key")
+ .build();
+
+ assertNotNull(adapter);
+ assertNull(adapter.getExecutionContext());
+ }
+
+ @Test
+ @DisplayName("Should return correct model name")
+ void shouldReturnCorrectModelName() {
+ TrinityModelAdapter adapter =
+ TrinityModelAdapter.builder()
+ .baseUrl("http://localhost:8080/v1")
+ .modelName("custom-model-name")
+ .build();
+
+ assertEquals("custom-model-name", adapter.getModelName());
+ }
+
+ @Test
+ @DisplayName("Should return execution context when set")
+ void shouldReturnExecutionContextWhenSet() {
+ RunExecutionContext context = RunExecutionContext.create("task-abc", "2");
+
+ TrinityModelAdapter adapter =
+ TrinityModelAdapter.builder()
+ .baseUrl("http://localhost:8080/v1")
+ .modelName("model")
+ .executionContext(context)
+ .build();
+
+ RunExecutionContext retrievedContext = adapter.getExecutionContext();
+
+ assertNotNull(retrievedContext);
+ assertEquals("task-abc", retrievedContext.getTaskId());
+ assertEquals("2", retrievedContext.getRunId());
+ }
+
+ @Test
+ @DisplayName("Should create builder successfully")
+ void shouldCreateBuilderSuccessfully() {
+ TrinityModelAdapter.Builder builder = TrinityModelAdapter.builder();
+ assertNotNull(builder);
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/dto/CommitRequestTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/dto/CommitRequestTest.java
new file mode 100644
index 000000000..61ece3e70
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/dto/CommitRequestTest.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.backend.dto;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.agentscope.core.training.util.TrainingTestConstants;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+@DisplayName("CommitRequest Unit Tests")
+class CommitRequestTest {
+
+ private final ObjectMapper objectMapper = new ObjectMapper();
+
+ @Test
+ @DisplayName("Should build commit request with all fields")
+ void shouldBuildCommitRequestWithAllFields() {
+ // Act
+ CommitRequest request =
+ CommitRequest.builder()
+ .taskId(TrainingTestConstants.TEST_TASK_ID)
+ .runId(TrainingTestConstants.TEST_RUN_ID)
+ .timeThreshold(300000L)
+ .build();
+
+ // Assert
+ assertEquals(TrainingTestConstants.TEST_TASK_ID, request.getTaskId());
+ assertEquals(TrainingTestConstants.TEST_RUN_ID, request.getRunId());
+ assertEquals(300000L, request.getTimeThreshold());
+ }
+
+ @Test
+ @DisplayName("Should serialize to JSON with snake_case field names")
+ void shouldSerializeToJsonWithSnakeCaseFieldNames() throws JsonProcessingException {
+ // Arrange
+ CommitRequest request =
+ CommitRequest.builder().taskId("task-123").runId("0").timeThreshold(60000L).build();
+
+ // Act
+ String json = objectMapper.writeValueAsString(request);
+
+ // Assert
+ assertTrue(json.contains("\"task_id\":\"task-123\""));
+ assertTrue(json.contains("\"run_id\":\"0\""));
+ assertTrue(json.contains("\"time_threshold\":60000"));
+ }
+
+ @Test
+ @DisplayName("Should handle null values")
+ void shouldHandleNullValues() {
+ // Act
+ CommitRequest request = CommitRequest.builder().build();
+
+ // Assert
+ assertNull(request.getTaskId());
+ assertNull(request.getRunId());
+ assertNull(request.getTimeThreshold());
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/dto/FeedbackRequestTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/dto/FeedbackRequestTest.java
new file mode 100644
index 000000000..8bf2f2e5b
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/dto/FeedbackRequestTest.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.backend.dto;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.agentscope.core.training.util.TrainingTestConstants;
+import java.util.Arrays;
+import java.util.List;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+@DisplayName("FeedbackRequest Unit Tests")
+class FeedbackRequestTest {
+
+ private final ObjectMapper objectMapper = new ObjectMapper();
+
+ @Test
+ @DisplayName("Should build feedback request with all fields")
+ void shouldBuildFeedbackRequestWithAllFields() {
+ // Arrange
+ List msgIds =
+ Arrays.asList(
+ TrainingTestConstants.TEST_MSG_ID_1, TrainingTestConstants.TEST_MSG_ID_2);
+
+ // Act
+ FeedbackRequest request =
+ FeedbackRequest.builder()
+ .taskId(TrainingTestConstants.TEST_TASK_ID)
+ .runId(TrainingTestConstants.TEST_RUN_ID)
+ .msgIds(msgIds)
+ .reward(0.85)
+ .build();
+
+ // Assert
+ assertEquals(TrainingTestConstants.TEST_TASK_ID, request.getTaskId());
+ assertEquals(TrainingTestConstants.TEST_RUN_ID, request.getRunId());
+ assertEquals(2, request.getMsgIds().size());
+ assertEquals(0.85, request.getReward(), 0.001);
+ }
+
+ @Test
+ @DisplayName("Should serialize to JSON with snake_case field names")
+ void shouldSerializeToJsonWithSnakeCaseFieldNames() throws JsonProcessingException {
+ // Arrange
+ FeedbackRequest request =
+ FeedbackRequest.builder()
+ .taskId("task-123")
+ .runId("0")
+ .msgIds(Arrays.asList("msg-1", "msg-2"))
+ .reward(0.9)
+ .build();
+
+ // Act
+ String json = objectMapper.writeValueAsString(request);
+
+ // Assert
+ assertTrue(json.contains("\"task_id\":\"task-123\""));
+ assertTrue(json.contains("\"run_id\":\"0\""));
+ assertTrue(json.contains("\"msg_ids\""));
+ assertTrue(json.contains("\"reward\":0.9"));
+ }
+
+ @Test
+ @DisplayName("Should handle null values")
+ void shouldHandleNullValues() {
+ // Act
+ FeedbackRequest request = FeedbackRequest.builder().build();
+
+ // Assert
+ assertNull(request.getTaskId());
+ assertNull(request.getRunId());
+ assertNull(request.getMsgIds());
+ assertNull(request.getReward());
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/dto/StatusResponseTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/dto/StatusResponseTest.java
new file mode 100644
index 000000000..fb7c6765f
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/backend/dto/StatusResponseTest.java
@@ -0,0 +1,104 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.backend.dto;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+@DisplayName("StatusResponse Unit Tests")
+class StatusResponseTest {
+
+ private final ObjectMapper objectMapper = new ObjectMapper();
+
+ @Test
+ @DisplayName("Should deserialize from JSON")
+ void shouldDeserializeFromJson() throws JsonProcessingException {
+ // Arrange
+ String json = "{\"status\": \"success\", \"message\": \"OK\"}";
+
+ // Act
+ StatusResponse response = objectMapper.readValue(json, StatusResponse.class);
+
+ // Assert
+ assertNotNull(response);
+ assertEquals("success", response.getStatus());
+ assertEquals("OK", response.getMessage());
+ assertTrue(response.isSuccess());
+ }
+
+ @Test
+ @DisplayName("Should deserialize error response")
+ void shouldDeserializeErrorResponse() throws JsonProcessingException {
+ // Arrange
+ String json = "{\"status\": \"error\", \"message\": \"Internal Server Error\"}";
+
+ // Act
+ StatusResponse response = objectMapper.readValue(json, StatusResponse.class);
+
+ // Assert
+ assertNotNull(response);
+ assertEquals("error", response.getStatus());
+ assertEquals("Internal Server Error", response.getMessage());
+ assertFalse(response.isSuccess());
+ }
+
+ @Test
+ @DisplayName("Should create with constructor")
+ void shouldCreateWithConstructor() {
+ // Act
+ StatusResponse response = new StatusResponse("success", "Operation completed");
+
+ // Assert
+ assertEquals("success", response.getStatus());
+ assertEquals("Operation completed", response.getMessage());
+ assertTrue(response.isSuccess());
+ }
+
+ @Test
+ @DisplayName("Should create with default constructor and setters")
+ void shouldCreateWithDefaultConstructorAndSetters() {
+ // Act
+ StatusResponse response = new StatusResponse();
+ response.setStatus("success");
+ response.setMessage("Done");
+
+ // Assert
+ assertEquals("success", response.getStatus());
+ assertEquals("Done", response.getMessage());
+ }
+
+ @Test
+ @DisplayName("Should check isSuccess case insensitively")
+ void shouldCheckIsSuccessCaseInsensitively() {
+ // Arrange
+ StatusResponse response1 = new StatusResponse("SUCCESS", "OK");
+ StatusResponse response2 = new StatusResponse("Success", "OK");
+ StatusResponse response3 = new StatusResponse("failed", "Error");
+
+ // Assert
+ assertTrue(response1.isSuccess());
+ assertTrue(response2.isSuccess());
+ assertFalse(response3.isSuccess());
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/RunExecutionContextTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/RunExecutionContextTest.java
new file mode 100644
index 000000000..79696581c
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/RunExecutionContextTest.java
@@ -0,0 +1,336 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import io.agentscope.core.message.Msg;
+import io.agentscope.core.message.MsgRole;
+import io.agentscope.core.training.util.TrainingTestConstants;
+import io.agentscope.core.training.util.TrainingTestUtils;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+@DisplayName("RunExecutionContext Unit Tests")
+class RunExecutionContextTest {
+
+ @Test
+ @DisplayName("Should create context with valid taskId and runId")
+ void shouldCreateContextWithValidIds() {
+ // Act
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+
+ // Assert
+ assertNotNull(context);
+ assertEquals(TrainingTestConstants.TEST_TASK_ID, context.getTaskId());
+ assertEquals(TrainingTestConstants.TEST_RUN_ID, context.getRunId());
+ assertTrue(context.getStartTime() > 0);
+ }
+
+ @Test
+ @DisplayName("Should throw exception when taskId is null")
+ void shouldThrowExceptionWhenTaskIdIsNull() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> RunExecutionContext.create(null, TrainingTestConstants.TEST_RUN_ID));
+ }
+
+ @Test
+ @DisplayName("Should throw exception when taskId is empty")
+ void shouldThrowExceptionWhenTaskIdIsEmpty() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> RunExecutionContext.create("", TrainingTestConstants.TEST_RUN_ID));
+ }
+
+ @Test
+ @DisplayName("Should throw exception when runId is null")
+ void shouldThrowExceptionWhenRunIdIsNull() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> RunExecutionContext.create(TrainingTestConstants.TEST_TASK_ID, null));
+ }
+
+ @Test
+ @DisplayName("Should throw exception when runId is empty")
+ void shouldThrowExceptionWhenRunIdIsEmpty() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> RunExecutionContext.create(TrainingTestConstants.TEST_TASK_ID, ""));
+ }
+
+ @Test
+ @DisplayName("Should add msgId successfully")
+ void shouldAddMsgIdSuccessfully() {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+
+ // Act
+ context.addMsgId(TrainingTestConstants.TEST_MSG_ID_1);
+ context.addMsgId(TrainingTestConstants.TEST_MSG_ID_2);
+
+ // Assert
+ assertEquals(2, context.getMsgIdCount());
+ assertTrue(context.getMsgIds().contains(TrainingTestConstants.TEST_MSG_ID_1));
+ assertTrue(context.getMsgIds().contains(TrainingTestConstants.TEST_MSG_ID_2));
+ }
+
+ @Test
+ @DisplayName("Should ignore null msgId")
+ void shouldIgnoreNullMsgId() {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+
+ // Act
+ context.addMsgId(null);
+ context.addMsgId("");
+
+ // Assert
+ assertEquals(0, context.getMsgIdCount());
+ assertFalse(context.hasMsgIds());
+ }
+
+ @Test
+ @DisplayName("Should return copy of msgIds list")
+ void shouldReturnCopyOfMsgIdsList() {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+ context.addMsgId(TrainingTestConstants.TEST_MSG_ID_1);
+
+ // Act
+ List msgIds = context.getMsgIds();
+ msgIds.add("should-not-affect-original");
+
+ // Assert
+ assertEquals(1, context.getMsgIdCount());
+ }
+
+ @Test
+ @DisplayName("Should add msgId thread-safely")
+ void shouldAddMsgIdThreadSafely() throws InterruptedException {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+ int threadCount = 10;
+ int messagesPerThread = 100;
+ CountDownLatch latch = new CountDownLatch(threadCount);
+ ExecutorService executor = Executors.newFixedThreadPool(threadCount);
+
+ // Act
+ for (int i = 0; i < threadCount; i++) {
+ final int threadId = i;
+ executor.submit(
+ () -> {
+ for (int j = 0; j < messagesPerThread; j++) {
+ context.addMsgId("msg-" + threadId + "-" + j);
+ }
+ latch.countDown();
+ });
+ }
+ latch.await();
+ executor.shutdown();
+
+ // Assert
+ assertEquals(threadCount * messagesPerThread, context.getMsgIdCount());
+ }
+
+ @Test
+ @DisplayName("Should set and get messages")
+ void shouldSetAndGetMessages() {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+ List messages = TrainingTestUtils.createTestMessages();
+
+ // Act
+ context.setMessages(messages);
+
+ // Assert
+ assertEquals(2, context.getMessageCount());
+ assertTrue(context.hasMessages());
+ }
+
+ @Test
+ @DisplayName("Should add single message")
+ void shouldAddSingleMessage() {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+ Msg msg = TrainingTestUtils.createTestMessage("user", MsgRole.USER, "Hello");
+
+ // Act
+ context.addMsg(msg);
+
+ // Assert
+ assertEquals(1, context.getMessageCount());
+ assertTrue(context.hasMessages());
+ }
+
+ @Test
+ @DisplayName("Should ignore null message")
+ void shouldIgnoreNullMessage() {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+
+ // Act
+ context.addMsg(null);
+
+ // Assert
+ assertEquals(0, context.getMessageCount());
+ assertFalse(context.hasMessages());
+ }
+
+ @Test
+ @DisplayName("Should return copy of messages list")
+ void shouldReturnCopyOfMessagesList() {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+ context.setMessages(TrainingTestUtils.createTestMessages());
+
+ // Act
+ List messages = context.getMessages();
+ messages.clear();
+
+ // Assert
+ assertEquals(2, context.getMessageCount());
+ }
+
+ @Test
+ @DisplayName("Should clear messages when setMessages with null")
+ void shouldClearMessagesWhenSetMessagesWithNull() {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+ context.setMessages(TrainingTestUtils.createTestMessages());
+
+ // Act
+ context.setMessages(null);
+
+ // Assert
+ assertEquals(0, context.getMessageCount());
+ assertFalse(context.hasMessages());
+ }
+
+ @Test
+ @DisplayName("Should calculate duration correctly")
+ void shouldCalculateDurationCorrectly() throws InterruptedException {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+
+ // Act
+ Thread.sleep(100);
+ long duration = context.getDuration();
+
+ // Assert
+ assertTrue(duration >= 100, "Duration should be at least 100ms");
+ }
+
+ @Test
+ @DisplayName("Should check hasMsgIds correctly")
+ void shouldCheckHasMsgIdsCorrectly() {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+
+ // Assert - initially false
+ assertFalse(context.hasMsgIds());
+
+ // Act
+ context.addMsgId(TrainingTestConstants.TEST_MSG_ID_1);
+
+ // Assert - now true
+ assertTrue(context.hasMsgIds());
+ }
+
+ @Test
+ @DisplayName("Should implement equals based on taskId and runId")
+ void shouldImplementEqualsBasedOnTaskIdAndRunId() {
+ // Arrange
+ RunExecutionContext context1 =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+ RunExecutionContext context2 =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+ RunExecutionContext context3 = RunExecutionContext.create("different-task", "1");
+
+ // Assert
+ assertEquals(context1, context2);
+ assertNotEquals(context1, context3);
+ }
+
+ @Test
+ @DisplayName("Should implement hashCode based on taskId and runId")
+ void shouldImplementHashCodeBasedOnTaskIdAndRunId() {
+ // Arrange
+ RunExecutionContext context1 =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+ RunExecutionContext context2 =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+
+ // Assert
+ assertEquals(context1.hashCode(), context2.hashCode());
+ }
+
+ @Test
+ @DisplayName("Should produce readable toString output")
+ void shouldProduceReadableToStringOutput() {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+ context.addMsgId(TrainingTestConstants.TEST_MSG_ID_1);
+
+ // Act
+ String str = context.toString();
+
+ // Assert
+ assertTrue(str.contains(TrainingTestConstants.TEST_TASK_ID));
+ assertTrue(str.contains(TrainingTestConstants.TEST_RUN_ID));
+ assertTrue(str.contains("msgIds=1"));
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/RunRegistryTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/RunRegistryTest.java
new file mode 100644
index 000000000..3d3edb1be
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/RunRegistryTest.java
@@ -0,0 +1,210 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import io.agentscope.core.training.util.TrainingTestConstants;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+@DisplayName("RunRegistry Unit Tests")
+class RunRegistryTest {
+
+ @BeforeEach
+ void setUp() {
+ RunRegistry.clearAll();
+ }
+
+ @AfterEach
+ void tearDown() {
+ RunRegistry.clearAll();
+ }
+
+ @Test
+ @DisplayName("Should allocate run ID starting from 0")
+ void shouldAllocateRunIdStartingFromZero() {
+ // Act
+ String runId = RunRegistry.allocateRunId(TrainingTestConstants.TEST_TASK_ID);
+
+ // Assert
+ assertEquals("0", runId);
+ }
+
+ @Test
+ @DisplayName("Should allocate incrementing run IDs for same task")
+ void shouldAllocateIncrementingRunIdsForSameTask() {
+ // Act
+ String runId1 = RunRegistry.allocateRunId(TrainingTestConstants.TEST_TASK_ID);
+ String runId2 = RunRegistry.allocateRunId(TrainingTestConstants.TEST_TASK_ID);
+ String runId3 = RunRegistry.allocateRunId(TrainingTestConstants.TEST_TASK_ID);
+
+ // Assert
+ assertEquals("0", runId1);
+ assertEquals("1", runId2);
+ assertEquals("2", runId3);
+ }
+
+ @Test
+ @DisplayName("Should allocate separate run IDs for different tasks")
+ void shouldAllocateSeparateRunIdsForDifferentTasks() {
+ // Act
+ String runId1 = RunRegistry.allocateRunId("task-1");
+ String runId2 = RunRegistry.allocateRunId("task-2");
+ String runId3 = RunRegistry.allocateRunId("task-1");
+
+ // Assert
+ assertEquals("0", runId1);
+ assertEquals("0", runId2);
+ assertEquals("1", runId3);
+ }
+
+ @Test
+ @DisplayName("Should throw exception when taskId is null")
+ void shouldThrowExceptionWhenTaskIdIsNull() {
+ assertThrows(IllegalArgumentException.class, () -> RunRegistry.allocateRunId(null));
+ }
+
+ @Test
+ @DisplayName("Should get current run count")
+ void shouldGetCurrentRunCount() {
+ // Arrange
+ RunRegistry.allocateRunId(TrainingTestConstants.TEST_TASK_ID);
+ RunRegistry.allocateRunId(TrainingTestConstants.TEST_TASK_ID);
+ RunRegistry.allocateRunId(TrainingTestConstants.TEST_TASK_ID);
+
+ // Act
+ int count = RunRegistry.getCurrentRunCount(TrainingTestConstants.TEST_TASK_ID);
+
+ // Assert
+ assertEquals(3, count);
+ }
+
+ @Test
+ @DisplayName("Should return 0 for non-existent task")
+ void shouldReturnZeroForNonExistentTask() {
+ // Act
+ int count = RunRegistry.getCurrentRunCount("non-existent-task");
+
+ // Assert
+ assertEquals(0, count);
+ }
+
+ @Test
+ @DisplayName("Should return 0 for null taskId in getCurrentRunCount")
+ void shouldReturnZeroForNullTaskIdInGetCurrentRunCount() {
+ // Act
+ int count = RunRegistry.getCurrentRunCount(null);
+
+ // Assert
+ assertEquals(0, count);
+ }
+
+ @Test
+ @DisplayName("Should cleanup task")
+ void shouldCleanupTask() {
+ // Arrange
+ RunRegistry.allocateRunId(TrainingTestConstants.TEST_TASK_ID);
+ RunRegistry.allocateRunId(TrainingTestConstants.TEST_TASK_ID);
+
+ // Act
+ RunRegistry.cleanup(TrainingTestConstants.TEST_TASK_ID);
+
+ // Assert
+ assertEquals(0, RunRegistry.getCurrentRunCount(TrainingTestConstants.TEST_TASK_ID));
+ }
+
+ @Test
+ @DisplayName("Should handle cleanup for null taskId")
+ void shouldHandleCleanupForNullTaskId() {
+ // Act - should not throw
+ RunRegistry.cleanup(null);
+
+ // Assert
+ assertEquals(0, RunRegistry.size());
+ }
+
+ @Test
+ @DisplayName("Should clear all tasks")
+ void shouldClearAllTasks() {
+ // Arrange
+ RunRegistry.allocateRunId("task-1");
+ RunRegistry.allocateRunId("task-2");
+ RunRegistry.allocateRunId("task-3");
+
+ // Act
+ RunRegistry.clearAll();
+
+ // Assert
+ assertEquals(0, RunRegistry.size());
+ }
+
+ @Test
+ @DisplayName("Should return correct size")
+ void shouldReturnCorrectSize() {
+ // Arrange
+ RunRegistry.allocateRunId("task-1");
+ RunRegistry.allocateRunId("task-2");
+ RunRegistry.allocateRunId("task-3");
+
+ // Act
+ int size = RunRegistry.size();
+
+ // Assert
+ assertEquals(3, size);
+ }
+
+ @Test
+ @DisplayName("Should allocate run IDs thread-safely")
+ void shouldAllocateRunIdsThreadSafely() throws InterruptedException {
+ // Arrange
+ int threadCount = 10;
+ int allocationsPerThread = 100;
+ CountDownLatch latch = new CountDownLatch(threadCount);
+ ExecutorService executor = Executors.newFixedThreadPool(threadCount);
+ Set allRunIds = java.util.Collections.synchronizedSet(new HashSet<>());
+
+ // Act
+ for (int i = 0; i < threadCount; i++) {
+ executor.submit(
+ () -> {
+ for (int j = 0; j < allocationsPerThread; j++) {
+ String runId =
+ RunRegistry.allocateRunId(TrainingTestConstants.TEST_TASK_ID);
+ allRunIds.add(runId);
+ }
+ latch.countDown();
+ });
+ }
+ latch.await();
+ executor.shutdown();
+
+ // Assert - all run IDs should be unique
+ assertEquals(threadCount * allocationsPerThread, allRunIds.size());
+ assertEquals(
+ threadCount * allocationsPerThread,
+ RunRegistry.getCurrentRunCount(TrainingTestConstants.TEST_TASK_ID));
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TaskExecutionRegistryTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TaskExecutionRegistryTest.java
new file mode 100644
index 000000000..03c96238e
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TaskExecutionRegistryTest.java
@@ -0,0 +1,317 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import io.agentscope.core.training.util.TrainingTestConstants;
+import java.util.List;
+import java.util.Set;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+@DisplayName("TaskExecutionRegistry Unit Tests")
+class TaskExecutionRegistryTest {
+
+ @BeforeEach
+ void setUp() {
+ TaskExecutionRegistry.clearAll();
+ }
+
+ @AfterEach
+ void tearDown() {
+ TaskExecutionRegistry.clearAll();
+ }
+
+ @Test
+ @DisplayName("Should register execution context")
+ void shouldRegisterExecutionContext() {
+ // Arrange
+ RunExecutionContext context =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+
+ // Act
+ TaskExecutionRegistry.register(context);
+
+ // Assert
+ assertEquals(1, TaskExecutionRegistry.getTaskCount());
+ assertEquals(1, TaskExecutionRegistry.getTotalRunCount());
+ }
+
+ @Test
+ @DisplayName("Should ignore null context")
+ void shouldIgnoreNullContext() {
+ // Act
+ TaskExecutionRegistry.register(null);
+
+ // Assert
+ assertEquals(0, TaskExecutionRegistry.getTaskCount());
+ }
+
+ @Test
+ @DisplayName("Should get runs by task")
+ void shouldGetRunsByTask() {
+ // Arrange
+ RunExecutionContext ctx0 =
+ RunExecutionContext.create(TrainingTestConstants.TEST_TASK_ID, "0");
+ RunExecutionContext ctx1 =
+ RunExecutionContext.create(TrainingTestConstants.TEST_TASK_ID, "1");
+ RunExecutionContext ctx2 =
+ RunExecutionContext.create(TrainingTestConstants.TEST_TASK_ID, "2");
+
+ TaskExecutionRegistry.register(ctx2);
+ TaskExecutionRegistry.register(ctx0);
+ TaskExecutionRegistry.register(ctx1);
+
+ // Act
+ List runs =
+ TaskExecutionRegistry.getRunsByTask(TrainingTestConstants.TEST_TASK_ID);
+
+ // Assert
+ assertEquals(3, runs.size());
+ // Should be sorted by runId
+ assertEquals("0", runs.get(0).getRunId());
+ assertEquals("1", runs.get(1).getRunId());
+ assertEquals("2", runs.get(2).getRunId());
+ }
+
+ @Test
+ @DisplayName("Should return empty list for non-existent task")
+ void shouldReturnEmptyListForNonExistentTask() {
+ // Act
+ List runs = TaskExecutionRegistry.getRunsByTask("non-existent");
+
+ // Assert
+ assertTrue(runs.isEmpty());
+ }
+
+ @Test
+ @DisplayName("Should return empty list for null taskId")
+ void shouldReturnEmptyListForNullTaskId() {
+ // Act
+ List runs = TaskExecutionRegistry.getRunsByTask(null);
+
+ // Assert
+ assertTrue(runs.isEmpty());
+ }
+
+ @Test
+ @DisplayName("Should get specific run")
+ void shouldGetSpecificRun() {
+ // Arrange
+ RunExecutionContext ctx =
+ RunExecutionContext.create(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+ ctx.addMsgId(TrainingTestConstants.TEST_MSG_ID_1);
+ TaskExecutionRegistry.register(ctx);
+
+ // Act
+ RunExecutionContext result =
+ TaskExecutionRegistry.getRun(
+ TrainingTestConstants.TEST_TASK_ID, TrainingTestConstants.TEST_RUN_ID);
+
+ // Assert
+ assertNotNull(result);
+ assertEquals(TrainingTestConstants.TEST_TASK_ID, result.getTaskId());
+ assertEquals(TrainingTestConstants.TEST_RUN_ID, result.getRunId());
+ }
+
+ @Test
+ @DisplayName("Should return null for non-existent run")
+ void shouldReturnNullForNonExistentRun() {
+ // Act
+ RunExecutionContext result = TaskExecutionRegistry.getRun("task", "999");
+
+ // Assert
+ assertNull(result);
+ }
+
+ @Test
+ @DisplayName("Should return null when taskId or runId is null")
+ void shouldReturnNullWhenTaskIdOrRunIdIsNull() {
+ // Assert
+ assertNull(TaskExecutionRegistry.getRun(null, "0"));
+ assertNull(TaskExecutionRegistry.getRun("task", null));
+ }
+
+ @Test
+ @DisplayName("Should get all task IDs")
+ void shouldGetAllTaskIds() {
+ // Arrange
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-1", "0"));
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-2", "0"));
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-3", "0"));
+
+ // Act
+ Set taskIds = TaskExecutionRegistry.getAllTaskIds();
+
+ // Assert
+ assertEquals(3, taskIds.size());
+ assertTrue(taskIds.contains("task-1"));
+ assertTrue(taskIds.contains("task-2"));
+ assertTrue(taskIds.contains("task-3"));
+ }
+
+ @Test
+ @DisplayName("Should get run count for task")
+ void shouldGetRunCountForTask() {
+ // Arrange
+ TaskExecutionRegistry.register(
+ RunExecutionContext.create(TrainingTestConstants.TEST_TASK_ID, "0"));
+ TaskExecutionRegistry.register(
+ RunExecutionContext.create(TrainingTestConstants.TEST_TASK_ID, "1"));
+ TaskExecutionRegistry.register(
+ RunExecutionContext.create(TrainingTestConstants.TEST_TASK_ID, "2"));
+
+ // Act
+ int count = TaskExecutionRegistry.getRunCount(TrainingTestConstants.TEST_TASK_ID);
+
+ // Assert
+ assertEquals(3, count);
+ }
+
+ @Test
+ @DisplayName("Should return 0 for non-existent task run count")
+ void shouldReturnZeroForNonExistentTaskRunCount() {
+ // Act
+ int count = TaskExecutionRegistry.getRunCount("non-existent");
+
+ // Assert
+ assertEquals(0, count);
+ }
+
+ @Test
+ @DisplayName("Should return 0 for null taskId run count")
+ void shouldReturnZeroForNullTaskIdRunCount() {
+ // Act
+ int count = TaskExecutionRegistry.getRunCount(null);
+
+ // Assert
+ assertEquals(0, count);
+ }
+
+ @Test
+ @DisplayName("Should get total run count")
+ void shouldGetTotalRunCount() {
+ // Arrange
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-1", "0"));
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-1", "1"));
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-2", "0"));
+
+ // Act
+ int total = TaskExecutionRegistry.getTotalRunCount();
+
+ // Assert
+ assertEquals(3, total);
+ }
+
+ @Test
+ @DisplayName("Should cleanup task")
+ void shouldCleanupTask() {
+ // Arrange
+ TaskExecutionRegistry.register(
+ RunExecutionContext.create(TrainingTestConstants.TEST_TASK_ID, "0"));
+ TaskExecutionRegistry.register(
+ RunExecutionContext.create(TrainingTestConstants.TEST_TASK_ID, "1"));
+ TaskExecutionRegistry.register(RunExecutionContext.create("other-task", "0"));
+
+ // Act
+ int cleaned = TaskExecutionRegistry.cleanup(TrainingTestConstants.TEST_TASK_ID);
+
+ // Assert
+ assertEquals(2, cleaned);
+ assertEquals(1, TaskExecutionRegistry.getTaskCount());
+ assertEquals(0, TaskExecutionRegistry.getRunCount(TrainingTestConstants.TEST_TASK_ID));
+ }
+
+ @Test
+ @DisplayName("Should return 0 when cleanup null taskId")
+ void shouldReturnZeroWhenCleanupNullTaskId() {
+ // Act
+ int cleaned = TaskExecutionRegistry.cleanup(null);
+
+ // Assert
+ assertEquals(0, cleaned);
+ }
+
+ @Test
+ @DisplayName("Should clear all data")
+ void shouldClearAllData() {
+ // Arrange
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-1", "0"));
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-2", "0"));
+
+ // Act
+ TaskExecutionRegistry.clearAll();
+
+ // Assert
+ assertEquals(0, TaskExecutionRegistry.getTaskCount());
+ assertEquals(0, TaskExecutionRegistry.getTotalRunCount());
+ }
+
+ @Test
+ @DisplayName("Should get stats")
+ void shouldGetStats() {
+ // Arrange
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-1", "0"));
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-1", "1"));
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-2", "0"));
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-2", "1"));
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-2", "2"));
+
+ // Act
+ TaskExecutionRegistry.RegistryStats stats = TaskExecutionRegistry.getStats();
+
+ // Assert
+ assertEquals(2, stats.getTaskCount());
+ assertEquals(5, stats.getTotalRunCount());
+ assertEquals(2.5, stats.getAverageRunsPerTask(), 0.001);
+ }
+
+ @Test
+ @DisplayName("Should return 0 average when no tasks")
+ void shouldReturnZeroAverageWhenNoTasks() {
+ // Act
+ TaskExecutionRegistry.RegistryStats stats = TaskExecutionRegistry.getStats();
+
+ // Assert
+ assertEquals(0, stats.getTaskCount());
+ assertEquals(0, stats.getTotalRunCount());
+ assertEquals(0.0, stats.getAverageRunsPerTask(), 0.001);
+ }
+
+ @Test
+ @DisplayName("Should produce readable stats toString")
+ void shouldProduceReadableStatsToString() {
+ // Arrange
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-1", "0"));
+ TaskExecutionRegistry.register(RunExecutionContext.create("task-1", "1"));
+
+ // Act
+ String str = TaskExecutionRegistry.getStats().toString();
+
+ // Assert
+ assertTrue(str.contains("tasks=1"));
+ assertTrue(str.contains("runs=2"));
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TaskIdGeneratorTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TaskIdGeneratorTest.java
new file mode 100644
index 000000000..558562ecc
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TaskIdGeneratorTest.java
@@ -0,0 +1,93 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+@DisplayName("TaskIdGenerator Unit Tests")
+class TaskIdGeneratorTest {
+
+ @BeforeEach
+ void setUp() {
+ // Clean up any state before each test
+ }
+
+ @AfterEach
+ void tearDown() {
+ // Clean up after each test
+ }
+
+ @Test
+ @DisplayName("Should generate task ID with correct prefix")
+ void shouldGenerateTaskIdWithCorrectPrefix() {
+ // Act
+ String taskId = TaskIdGenerator.generate();
+
+ // Assert
+ assertNotNull(taskId);
+ assertTrue(taskId.startsWith("task-"));
+ }
+
+ @Test
+ @DisplayName("Should generate unique task IDs")
+ void shouldGenerateUniqueTaskIds() {
+ // Act
+ String taskId1 = TaskIdGenerator.generate();
+ String taskId2 = TaskIdGenerator.generate();
+ String taskId3 = TaskIdGenerator.generate();
+
+ // Assert
+ assertNotEquals(taskId1, taskId2);
+ assertNotEquals(taskId2, taskId3);
+ assertNotEquals(taskId1, taskId3);
+ }
+
+ @Test
+ @DisplayName("Should generate task ID with timestamp")
+ void shouldGenerateTaskIdWithTimestamp() {
+ // Act
+ String taskId = TaskIdGenerator.generateWithTimestamp();
+
+ // Assert
+ assertNotNull(taskId);
+ assertTrue(taskId.startsWith("task-"));
+ // Should contain a timestamp (digits) after "task-"
+ String[] parts = taskId.split("-");
+ assertEquals(3, parts.length);
+ assertTrue(parts[1].matches("\\d+"), "Second part should be timestamp digits");
+ }
+
+ @Test
+ @DisplayName("Should generate unique task IDs with timestamp")
+ void shouldGenerateUniqueTaskIdsWithTimestamp() throws InterruptedException {
+ // Act
+ String taskId1 = TaskIdGenerator.generateWithTimestamp();
+ Thread.sleep(10); // Small delay to ensure different timestamps
+ String taskId2 = TaskIdGenerator.generateWithTimestamp();
+
+ // Assert
+ assertNotEquals(taskId1, taskId2);
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TrainingConfigTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TrainingConfigTest.java
new file mode 100644
index 000000000..80e20afa1
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TrainingConfigTest.java
@@ -0,0 +1,389 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+
+import io.agentscope.core.training.reward.RewardCalculator;
+import io.agentscope.core.training.strategy.ExplicitMarkingStrategy;
+import io.agentscope.core.training.strategy.SamplingRateStrategy;
+import io.agentscope.core.training.util.TrainingTestConstants;
+import java.time.Duration;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+@DisplayName("TrainingConfig Unit Tests")
+class TrainingConfigTest {
+
+ @Test
+ @DisplayName("Should build config with all required fields")
+ void shouldBuildConfigWithAllRequiredFields() {
+ // Arrange
+ RewardCalculator calculator = mock(RewardCalculator.class);
+
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(calculator)
+ .build();
+
+ // Assert
+ assertNotNull(config);
+ assertEquals(TrainingTestConstants.TEST_TRINITY_ENDPOINT, config.getTrinityEndpoint());
+ assertNotNull(config.getRewardCalculator());
+ }
+
+ @Test
+ @DisplayName("Should throw exception when trinityEndpoint is null")
+ void shouldThrowExceptionWhenTrinityEndpointIsNull() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ TrainingConfig.builder()
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build());
+ }
+
+ @Test
+ @DisplayName("Should throw exception when trinityEndpoint is empty")
+ void shouldThrowExceptionWhenTrinityEndpointIsEmpty() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ TrainingConfig.builder()
+ .trinityEndpoint("")
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build());
+ }
+
+ @Test
+ @DisplayName("Should throw exception when rewardCalculator is null")
+ void shouldThrowExceptionWhenRewardCalculatorIsNull() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .build());
+ }
+
+ @Test
+ @DisplayName("Should use default selectionStrategy when not specified")
+ void shouldUseDefaultSelectionStrategyWhenNotSpecified() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertNotNull(config.getSelectionStrategy());
+ assertTrue(config.getSelectionStrategy() instanceof SamplingRateStrategy);
+ assertEquals(0.1, config.getSampleRate(), 0.001);
+ }
+
+ @Test
+ @DisplayName("Should use custom selectionStrategy when specified")
+ void shouldUseCustomSelectionStrategyWhenSpecified() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .selectionStrategy(SamplingRateStrategy.of(0.5))
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertEquals(0.5, config.getSampleRate(), 0.001);
+ }
+
+ @Test
+ @DisplayName("Should use ExplicitMarkingStrategy")
+ void shouldUseExplicitMarkingStrategy() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .selectionStrategy(ExplicitMarkingStrategy.create())
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertTrue(config.getSelectionStrategy() instanceof ExplicitMarkingStrategy);
+ assertEquals(-1, config.getSampleRate(), 0.001);
+ }
+
+ @Test
+ @DisplayName("Should use default commitIntervalSeconds")
+ void shouldUseDefaultCommitIntervalSeconds() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertEquals(300, config.getCommitIntervalSeconds());
+ }
+
+ @Test
+ @DisplayName("Should use custom commitIntervalSeconds")
+ void shouldUseCustomCommitIntervalSeconds() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .commitIntervalSeconds(60)
+ .build();
+
+ // Assert
+ assertEquals(60, config.getCommitIntervalSeconds());
+ }
+
+ @Test
+ @DisplayName("Should use default httpTimeout")
+ void shouldUseDefaultHttpTimeout() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertEquals(Duration.ofSeconds(300), config.getHttpTimeout());
+ }
+
+ @Test
+ @DisplayName("Should use custom httpTimeout")
+ void shouldUseCustomHttpTimeout() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .httpTimeout(Duration.ofSeconds(60))
+ .build();
+
+ // Assert
+ assertEquals(Duration.ofSeconds(60), config.getHttpTimeout());
+ }
+
+ @Test
+ @DisplayName("Should use default repeatTime")
+ void shouldUseDefaultRepeatTime() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertEquals(1, config.getRepeatTime());
+ }
+
+ @Test
+ @DisplayName("Should use custom repeatTime")
+ void shouldUseCustomRepeatTime() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .repeatTime(3)
+ .build();
+
+ // Assert
+ assertEquals(3, config.getRepeatTime());
+ }
+
+ @Test
+ @DisplayName("Should throw exception when repeatTime is less than 1")
+ void shouldThrowExceptionWhenRepeatTimeLessThanOne() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .repeatTime(0)
+ .build());
+ }
+
+ @Test
+ @DisplayName("Should use default shadowPoolSize")
+ void shouldUseDefaultShadowPoolSize() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertEquals(10, config.getShadowPoolSize());
+ }
+
+ @Test
+ @DisplayName("Should use custom shadowPoolSize")
+ void shouldUseCustomShadowPoolSize() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .shadowPoolSize(20)
+ .build();
+
+ // Assert
+ assertEquals(20, config.getShadowPoolSize());
+ }
+
+ @Test
+ @DisplayName("Should use default shadowPoolCapacity")
+ void shouldUseDefaultShadowPoolCapacity() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertEquals(1000, config.getShadowPoolCapacity());
+ }
+
+ @Test
+ @DisplayName("Should use default trinityApiKey")
+ void shouldUseDefaultTrinityApiKey() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertEquals("dummy", config.getTrinityApiKey());
+ }
+
+ @Test
+ @DisplayName("Should use custom trinityApiKey")
+ void shouldUseCustomTrinityApiKey() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .trinityApiKey("custom-api-key")
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertEquals("custom-api-key", config.getTrinityApiKey());
+ }
+
+ @Test
+ @DisplayName("Should use default modelName")
+ void shouldUseDefaultModelName() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertEquals("training-model", config.getModelName());
+ }
+
+ @Test
+ @DisplayName("Should use custom modelName")
+ void shouldUseCustomModelName() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .modelName("custom-model")
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertEquals("custom-model", config.getModelName());
+ }
+
+ @Test
+ @DisplayName("Should use default enableAutoCommit")
+ void shouldUseDefaultEnableAutoCommit() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ // Assert
+ assertTrue(config.isEnableAutoCommit());
+ }
+
+ @Test
+ @DisplayName("Should use custom enableAutoCommit")
+ void shouldUseCustomEnableAutoCommit() {
+ // Act
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .enableAutoCommit(false)
+ .build();
+
+ // Assert
+ assertFalse(config.isEnableAutoCommit());
+ }
+
+ @Test
+ @DisplayName("Should throw exception when sampleRate is invalid using deprecated method")
+ @SuppressWarnings("deprecation")
+ void shouldThrowExceptionWhenSampleRateIsInvalid() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .sampleRate(-0.1)
+ .build());
+
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .rewardCalculator(mock(RewardCalculator.class))
+ .sampleRate(1.1)
+ .build());
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TrainingRouterTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TrainingRouterTest.java
new file mode 100644
index 000000000..c18dff4fe
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TrainingRouterTest.java
@@ -0,0 +1,173 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyList;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import io.agentscope.core.agent.Agent;
+import io.agentscope.core.hook.ErrorEvent;
+import io.agentscope.core.hook.PostCallEvent;
+import io.agentscope.core.hook.PreCallEvent;
+import io.agentscope.core.message.Msg;
+import io.agentscope.core.message.MsgRole;
+import io.agentscope.core.training.backend.TrinityClient;
+import io.agentscope.core.training.reward.RewardCalculator;
+import io.agentscope.core.training.strategy.SelectionDecision;
+import io.agentscope.core.training.strategy.TrainingSelectionStrategy;
+import java.util.Collections;
+import java.util.List;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import reactor.test.StepVerifier;
+import reactor.util.context.Context;
+
+/**
+ * Tests for TrainingRouter.
+ */
+@ExtendWith(MockitoExtension.class)
+@DisplayName("TrainingRouter Tests")
+class TrainingRouterTest {
+
+ @Mock private TrinityClient trinityClient;
+
+ @Mock private RewardCalculator rewardCalculator;
+
+ @Mock private TrainingSelectionStrategy selectionStrategy;
+
+ @Mock private Agent mockAgent;
+
+ private TrainingConfig config;
+ private TrainingRouter router;
+
+ @BeforeEach
+ void setUp() {
+ config =
+ TrainingConfig.builder()
+ .trinityEndpoint("http://localhost:8080")
+ .modelName("test-model")
+ .selectionStrategy(selectionStrategy)
+ .rewardCalculator(rewardCalculator)
+ .shadowPoolSize(2)
+ .shadowPoolCapacity(10)
+ .build();
+
+ router = new TrainingRouter(config, trinityClient, rewardCalculator, selectionStrategy);
+ }
+
+ @Test
+ @DisplayName("Should return correct priority")
+ void shouldReturnCorrectPriority() {
+ assertEquals(500, router.priority());
+ }
+
+ @Test
+ @DisplayName("Should handle PreCallEvent and save input messages")
+ void shouldHandlePreCallEventAndSaveInputMessages() {
+ when(mockAgent.getAgentId()).thenReturn("agent-123");
+
+ List inputMessages =
+ Collections.singletonList(
+ Msg.builder().role(MsgRole.USER).textContent("Hello").build());
+
+ PreCallEvent event = new PreCallEvent(mockAgent, inputMessages);
+
+ StepVerifier.create(router.onEvent(event)).expectNext(event).verifyComplete();
+ }
+
+ @Test
+ @DisplayName("Should skip training for shadow agent")
+ void shouldSkipTrainingForShadowAgent() {
+ when(mockAgent.getAgentId()).thenReturn("agent-shadow-123");
+ when(mockAgent.getName()).thenReturn("TestAgent-shadow");
+
+ Msg outputMsg = Msg.builder().role(MsgRole.ASSISTANT).textContent("Response").build();
+ PostCallEvent event = new PostCallEvent(mockAgent, outputMsg);
+
+ StepVerifier.create(router.onEvent(event)).expectNext(event).verifyComplete();
+
+ verify(selectionStrategy, never()).shouldSelect(any(), anyList(), any(), any());
+ }
+
+ @Test
+ @DisplayName("Should skip training when selection decision is false")
+ void shouldSkipTrainingWhenSelectionDecisionIsFalse() {
+ when(mockAgent.getAgentId()).thenReturn("agent-123");
+ when(mockAgent.getName()).thenReturn("TestAgent");
+
+ List inputMessages =
+ Collections.singletonList(
+ Msg.builder().role(MsgRole.USER).textContent("Hello").build());
+ PreCallEvent preEvent = new PreCallEvent(mockAgent, inputMessages);
+
+ router.onEvent(preEvent).block();
+
+ Msg outputMsg = Msg.builder().role(MsgRole.ASSISTANT).textContent("Response").build();
+ PostCallEvent postEvent = new PostCallEvent(mockAgent, outputMsg);
+
+ SelectionDecision rejectDecision = SelectionDecision.reject("Not selected");
+ when(selectionStrategy.shouldSelect(any(), anyList(), any(), any()))
+ .thenReturn(rejectDecision);
+
+ StepVerifier.create(router.onEvent(postEvent).contextWrite(Context.empty()))
+ .expectNext(postEvent)
+ .verifyComplete();
+
+ verify(trinityClient, never()).feedback(any());
+ }
+
+ @Test
+ @DisplayName("Should return empty when no input messages found")
+ void shouldReturnEmptyWhenNoInputMessagesFound() {
+ when(mockAgent.getAgentId()).thenReturn("agent-no-input");
+ when(mockAgent.getName()).thenReturn("TestAgent");
+
+ Msg outputMsg = Msg.builder().role(MsgRole.ASSISTANT).textContent("Response").build();
+ PostCallEvent event = new PostCallEvent(mockAgent, outputMsg);
+
+ StepVerifier.create(router.onEvent(event).contextWrite(Context.empty()))
+ .expectNext(event)
+ .verifyComplete();
+ }
+
+ @Test
+ @DisplayName("Should pass through other event types like ErrorEvent")
+ void shouldPassThroughOtherEventTypes() {
+ ErrorEvent errorEvent = new ErrorEvent(mockAgent, new RuntimeException("Test error"));
+
+ StepVerifier.create(router.onEvent(errorEvent)).expectNext(errorEvent).verifyComplete();
+ }
+
+ @Test
+ @DisplayName("Should create router with valid config")
+ void shouldCreateRouterWithValidConfig() {
+ TrainingRouter newRouter =
+ new TrainingRouter(config, trinityClient, rewardCalculator, selectionStrategy);
+
+ assertNotNull(newRouter);
+ assertEquals(500, newRouter.priority());
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TrainingRunnerTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TrainingRunnerTest.java
new file mode 100644
index 000000000..51a05baeb
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/runner/TrainingRunnerTest.java
@@ -0,0 +1,202 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.runner;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+
+import io.agentscope.core.training.reward.RewardCalculator;
+import io.agentscope.core.training.strategy.SamplingRateStrategy;
+import java.io.IOException;
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+/**
+ * Tests for TrainingRunner.
+ */
+@DisplayName("TrainingRunner Tests")
+class TrainingRunnerTest {
+
+ private MockWebServer mockServer;
+ private String mockEndpoint;
+
+ @BeforeEach
+ void setUp() throws IOException {
+ mockServer = new MockWebServer();
+ mockServer.start();
+ mockEndpoint = mockServer.url("/").toString().replaceAll("/$", "");
+ }
+
+ @AfterEach
+ void tearDown() throws IOException {
+ mockServer.shutdown();
+ }
+
+ @Test
+ @DisplayName("Should build TrainingRunner with config")
+ void shouldBuildTrainingRunnerWithConfig() {
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(mockEndpoint)
+ .modelName("test-model")
+ .selectionStrategy(SamplingRateStrategy.of(0.1))
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ TrainingRunner runner = TrainingRunner.builder().config(config).build();
+
+ assertNotNull(runner);
+ assertFalse(runner.isRunning());
+ assertEquals(config, runner.getConfig());
+ }
+
+ @Test
+ @DisplayName("Should build TrainingRunner with builder methods")
+ void shouldBuildTrainingRunnerWithBuilderMethods() {
+ TrainingRunner runner =
+ TrainingRunner.builder()
+ .trinityEndpoint(mockEndpoint)
+ .modelName("test-model")
+ .selectionStrategy(SamplingRateStrategy.of(0.1))
+ .rewardCalculator(mock(RewardCalculator.class))
+ .commitIntervalSeconds(300)
+ .repeatTime(3)
+ .build();
+
+ assertNotNull(runner);
+ assertFalse(runner.isRunning());
+ }
+
+ @Test
+ @DisplayName("Should throw when config is set and builder methods are used")
+ void shouldThrowWhenConfigSetAndBuilderMethodsUsed() {
+ TrainingConfig config =
+ TrainingConfig.builder()
+ .trinityEndpoint(mockEndpoint)
+ .modelName("test-model")
+ .selectionStrategy(SamplingRateStrategy.of(0.1))
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ assertThrows(
+ IllegalStateException.class,
+ () ->
+ TrainingRunner.builder()
+ .config(config)
+ .trinityEndpoint("http://other-endpoint")
+ .build());
+ }
+
+ @Test
+ @DisplayName("Should throw when no config provided")
+ void shouldThrowWhenNoConfigProvided() {
+ assertThrows(IllegalStateException.class, () -> TrainingRunner.builder().build());
+ }
+
+ @Test
+ @DisplayName("Should start and stop runner")
+ void shouldStartAndStopRunner() throws InterruptedException {
+ mockServer.enqueue(
+ new MockResponse()
+ .setBody("{\"status\": \"success\", \"message\": \"OK\"}")
+ .addHeader("Content-Type", "application/json"));
+
+ TrainingRunner runner =
+ TrainingRunner.builder()
+ .trinityEndpoint(mockEndpoint)
+ .modelName("test-model")
+ .selectionStrategy(SamplingRateStrategy.of(0.1))
+ .rewardCalculator(mock(RewardCalculator.class))
+ .commitIntervalSeconds(0)
+ .build();
+
+ assertFalse(runner.isRunning());
+
+ runner.start();
+ assertTrue(runner.isRunning());
+
+ runner.start();
+ assertTrue(runner.isRunning());
+
+ runner.stop();
+ assertFalse(runner.isRunning());
+
+ runner.stop();
+ assertFalse(runner.isRunning());
+ }
+
+ @Test
+ @DisplayName("Should execute commit")
+ void shouldExecuteCommit() throws Exception {
+ mockServer.enqueue(
+ new MockResponse()
+ .setBody("{\"status\": \"success\", \"message\": \"OK\"}")
+ .addHeader("Content-Type", "application/json"));
+
+ TrainingRunner runner =
+ TrainingRunner.builder()
+ .trinityEndpoint(mockEndpoint)
+ .modelName("test-model")
+ .selectionStrategy(SamplingRateStrategy.of(0.1))
+ .rewardCalculator(mock(RewardCalculator.class))
+ .build();
+
+ runner.commit().block();
+
+ assertEquals(1, mockServer.getRequestCount());
+ }
+
+ @Test
+ @DisplayName("Should schedule periodic commits when interval > 0")
+ void shouldSchedulePeriodicCommits() throws Exception {
+ mockServer.enqueue(
+ new MockResponse()
+ .setBody("{\"status\": \"success\", \"message\": \"OK\"}")
+ .addHeader("Content-Type", "application/json"));
+
+ mockServer.enqueue(
+ new MockResponse()
+ .setBody("{\"status\": \"success\", \"message\": \"OK\"}")
+ .addHeader("Content-Type", "application/json"));
+
+ TrainingRunner runner =
+ TrainingRunner.builder()
+ .trinityEndpoint(mockEndpoint)
+ .modelName("test-model")
+ .selectionStrategy(SamplingRateStrategy.of(0.1))
+ .rewardCalculator(mock(RewardCalculator.class))
+ .commitIntervalSeconds(1)
+ .build();
+
+ runner.start();
+ assertTrue(runner.isRunning());
+
+ Thread.sleep(1500);
+
+ runner.stop();
+
+ assertTrue(mockServer.getRequestCount() >= 1);
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/ExplicitMarkingStrategyTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/ExplicitMarkingStrategyTest.java
new file mode 100644
index 000000000..5921e5258
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/ExplicitMarkingStrategyTest.java
@@ -0,0 +1,218 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.strategy;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import io.agentscope.core.agent.Agent;
+import io.agentscope.core.message.Msg;
+import io.agentscope.core.message.MsgRole;
+import io.agentscope.core.training.util.TrainingTestConstants;
+import io.agentscope.core.training.util.TrainingTestUtils;
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+import reactor.util.context.Context;
+import reactor.util.context.ContextView;
+
+@DisplayName("ExplicitMarkingStrategy Unit Tests")
+class ExplicitMarkingStrategyTest {
+
+ private Agent mockAgent;
+ private List testInputs;
+ private Msg testOutput;
+
+ @BeforeEach
+ void setUp() {
+ mockAgent = TrainingTestUtils.createMockAgent(TrainingTestConstants.TEST_AGENT_NAME);
+ testInputs = TrainingTestUtils.createTestMessages();
+ testOutput =
+ TrainingTestUtils.createTestMessage("assistant", MsgRole.ASSISTANT, "Response");
+ }
+
+ @Test
+ @DisplayName("Should create strategy with default TTL")
+ void shouldCreateStrategyWithDefaultTTL() {
+ // Act
+ ExplicitMarkingStrategy strategy = ExplicitMarkingStrategy.create();
+
+ // Assert
+ assertNotNull(strategy);
+ assertEquals("ExplicitMarking", strategy.name());
+ }
+
+ @Test
+ @DisplayName("Should create strategy with custom TTL")
+ void shouldCreateStrategyWithCustomTTL() {
+ // Act
+ ExplicitMarkingStrategy strategy = ExplicitMarkingStrategy.withTTL(Duration.ofMinutes(5));
+
+ // Assert
+ assertNotNull(strategy);
+ }
+
+ @Test
+ @DisplayName("Should reject when no annotation in context")
+ void shouldRejectWhenNoAnnotationInContext() {
+ // Arrange
+ ExplicitMarkingStrategy strategy = ExplicitMarkingStrategy.create();
+
+ // Act
+ SelectionDecision decision =
+ strategy.shouldSelect(mockAgent, testInputs, testOutput, Context.empty());
+
+ // Assert
+ assertFalse(decision.shouldTrain());
+ assertEquals("no-explicit-marking", decision.getReason());
+ }
+
+ @Test
+ @DisplayName("Should accept when annotation is enabled")
+ void shouldAcceptWhenAnnotationIsEnabled() {
+ // Arrange
+ ExplicitMarkingStrategy strategy = ExplicitMarkingStrategy.create();
+ TrainingAnnotation annotation = TrainingAnnotation.enabled();
+ ContextView context = Context.of(TrainingContext.REACTOR_KEY, annotation);
+
+ // Act
+ SelectionDecision decision =
+ strategy.shouldSelect(mockAgent, testInputs, testOutput, context);
+
+ // Assert
+ assertTrue(decision.shouldTrain());
+ assertEquals("explicit-marking", decision.getReason());
+ }
+
+ @Test
+ @DisplayName("Should accept with labels from annotation")
+ void shouldAcceptWithLabelsFromAnnotation() {
+ // Arrange
+ ExplicitMarkingStrategy strategy = ExplicitMarkingStrategy.create();
+ TrainingAnnotation annotation = TrainingAnnotation.withLabels("high-quality", "production");
+ ContextView context = Context.of(TrainingContext.REACTOR_KEY, annotation);
+
+ // Act
+ SelectionDecision decision =
+ strategy.shouldSelect(mockAgent, testInputs, testOutput, context);
+
+ // Assert
+ assertTrue(decision.shouldTrain());
+ assertTrue(decision.getLabels().contains("high-quality"));
+ assertTrue(decision.getLabels().contains("production"));
+ }
+
+ @Test
+ @DisplayName("Should pass through taskId from annotation metadata")
+ void shouldPassThroughTaskIdFromAnnotationMetadata() {
+ // Arrange
+ ExplicitMarkingStrategy strategy = ExplicitMarkingStrategy.create();
+ TrainingAnnotation annotation =
+ TrainingAnnotation.builder()
+ .enabled(true)
+ .taskId("custom-task-id")
+ .labels(Arrays.asList("test"))
+ .build();
+ ContextView context = Context.of(TrainingContext.REACTOR_KEY, annotation);
+
+ // Act
+ SelectionDecision decision =
+ strategy.shouldSelect(mockAgent, testInputs, testOutput, context);
+
+ // Assert
+ assertTrue(decision.shouldTrain());
+ assertEquals("custom-task-id", decision.getMetadata().get("taskId"));
+ }
+
+ @Test
+ @DisplayName("Should pass through metadata from annotation")
+ void shouldPassThroughMetadataFromAnnotation() {
+ // Arrange
+ ExplicitMarkingStrategy strategy = ExplicitMarkingStrategy.create();
+ Map metadata = new HashMap<>();
+ metadata.put("userId", "user-123");
+ metadata.put("sessionId", "session-456");
+
+ TrainingAnnotation annotation =
+ TrainingAnnotation.builder()
+ .enabled(true)
+ .labels(Arrays.asList("test"))
+ .metadata(metadata)
+ .build();
+ ContextView context = Context.of(TrainingContext.REACTOR_KEY, annotation);
+
+ // Act
+ SelectionDecision decision =
+ strategy.shouldSelect(mockAgent, testInputs, testOutput, context);
+
+ // Assert
+ assertTrue(decision.shouldTrain());
+ assertEquals("user-123", decision.getMetadata().get("userId"));
+ assertEquals("session-456", decision.getMetadata().get("sessionId"));
+ }
+
+ @Test
+ @DisplayName("Should reject when annotation is expired")
+ void shouldRejectWhenAnnotationIsExpired() {
+ // Arrange
+ ExplicitMarkingStrategy strategy =
+ ExplicitMarkingStrategy.withTTL(Duration.ofMillis(1)); // Very short TTL
+
+ // Create an annotation with old timestamp
+ TrainingAnnotation annotation =
+ TrainingAnnotation.builder()
+ .enabled(true)
+ .timestamp(System.currentTimeMillis() - 1000) // 1 second ago
+ .build();
+ ContextView context = Context.of(TrainingContext.REACTOR_KEY, annotation);
+
+ // Act
+ SelectionDecision decision =
+ strategy.shouldSelect(mockAgent, testInputs, testOutput, context);
+
+ // Assert
+ assertFalse(decision.shouldTrain());
+ assertEquals("explicit-marking-expired", decision.getReason());
+ }
+
+ @Test
+ @DisplayName("Should return high priority")
+ void shouldReturnHighPriority() {
+ // Arrange
+ ExplicitMarkingStrategy strategy = ExplicitMarkingStrategy.create();
+
+ // Assert
+ assertEquals(10, strategy.priority());
+ }
+
+ @Test
+ @DisplayName("Should return correct name")
+ void shouldReturnCorrectName() {
+ // Arrange
+ ExplicitMarkingStrategy strategy = ExplicitMarkingStrategy.create();
+
+ // Assert
+ assertEquals("ExplicitMarking", strategy.name());
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/SamplingRateStrategyTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/SamplingRateStrategyTest.java
new file mode 100644
index 000000000..b72e9a4e3
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/SamplingRateStrategyTest.java
@@ -0,0 +1,170 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.strategy;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import io.agentscope.core.agent.Agent;
+import io.agentscope.core.message.Msg;
+import io.agentscope.core.message.MsgRole;
+import io.agentscope.core.training.util.TrainingTestConstants;
+import io.agentscope.core.training.util.TrainingTestUtils;
+import java.util.List;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+import reactor.util.context.Context;
+
+@DisplayName("SamplingRateStrategy Unit Tests")
+class SamplingRateStrategyTest {
+
+ private Agent mockAgent;
+ private List testInputs;
+ private Msg testOutput;
+
+ @BeforeEach
+ void setUp() {
+ mockAgent = TrainingTestUtils.createMockAgent(TrainingTestConstants.TEST_AGENT_NAME);
+ testInputs = TrainingTestUtils.createTestMessages();
+ testOutput =
+ TrainingTestUtils.createTestMessage("assistant", MsgRole.ASSISTANT, "Response");
+ }
+
+ @Test
+ @DisplayName("Should create strategy with valid sample rate")
+ void shouldCreateStrategyWithValidSampleRate() {
+ // Act
+ SamplingRateStrategy strategy = SamplingRateStrategy.of(0.5);
+
+ // Assert
+ assertNotNull(strategy);
+ assertEquals(0.5, strategy.getSampleRate(), 0.001);
+ }
+
+ @Test
+ @DisplayName("Should create strategy with zero sample rate")
+ void shouldCreateStrategyWithZeroSampleRate() {
+ // Act
+ SamplingRateStrategy strategy = SamplingRateStrategy.of(0.0);
+
+ // Assert
+ assertEquals(0.0, strategy.getSampleRate(), 0.001);
+ }
+
+ @Test
+ @DisplayName("Should create strategy with full sample rate")
+ void shouldCreateStrategyWithFullSampleRate() {
+ // Act
+ SamplingRateStrategy strategy = SamplingRateStrategy.of(1.0);
+
+ // Assert
+ assertEquals(1.0, strategy.getSampleRate(), 0.001);
+ }
+
+ @Test
+ @DisplayName("Should throw exception for negative sample rate")
+ void shouldThrowExceptionForNegativeSampleRate() {
+ assertThrows(IllegalArgumentException.class, () -> SamplingRateStrategy.of(-0.1));
+ }
+
+ @Test
+ @DisplayName("Should throw exception for sample rate greater than 1")
+ void shouldThrowExceptionForSampleRateGreaterThanOne() {
+ assertThrows(IllegalArgumentException.class, () -> SamplingRateStrategy.of(1.1));
+ }
+
+ @Test
+ @DisplayName("Should accept all with 100% sample rate")
+ void shouldAcceptAllWithFullSampleRate() {
+ // Arrange
+ SamplingRateStrategy strategy = SamplingRateStrategy.of(1.0);
+
+ // Act & Assert - run multiple times to verify consistency
+ for (int i = 0; i < 100; i++) {
+ SelectionDecision decision =
+ strategy.shouldSelect(mockAgent, testInputs, testOutput, Context.empty());
+ assertTrue(decision.shouldTrain(), "Should accept with 100% rate");
+ }
+ }
+
+ @Test
+ @DisplayName("Should reject all with 0% sample rate")
+ void shouldRejectAllWithZeroSampleRate() {
+ // Arrange
+ SamplingRateStrategy strategy = SamplingRateStrategy.of(0.0);
+
+ // Act & Assert - run multiple times to verify consistency
+ for (int i = 0; i < 100; i++) {
+ SelectionDecision decision =
+ strategy.shouldSelect(mockAgent, testInputs, testOutput, Context.empty());
+ assertFalse(decision.shouldTrain(), "Should reject with 0% rate");
+ }
+ }
+
+ @Test
+ @DisplayName("Should return correct priority")
+ void shouldReturnCorrectPriority() {
+ // Arrange
+ SamplingRateStrategy strategy = SamplingRateStrategy.of(0.5);
+
+ // Assert
+ assertEquals(200, strategy.priority());
+ }
+
+ @Test
+ @DisplayName("Should return correct name")
+ void shouldReturnCorrectName() {
+ // Arrange
+ SamplingRateStrategy strategy = SamplingRateStrategy.of(0.5);
+
+ // Assert
+ assertEquals("SamplingRate(0.5)", strategy.name());
+ }
+
+ @Test
+ @DisplayName("Should produce accept decision with correct reason")
+ void shouldProduceAcceptDecisionWithCorrectReason() {
+ // Arrange
+ SamplingRateStrategy strategy = SamplingRateStrategy.of(1.0);
+
+ // Act
+ SelectionDecision decision =
+ strategy.shouldSelect(mockAgent, testInputs, testOutput, Context.empty());
+
+ // Assert
+ assertEquals("sampling-rate", decision.getReason());
+ assertTrue(decision.getLabels().contains("sampled"));
+ }
+
+ @Test
+ @DisplayName("Should produce reject decision with correct reason")
+ void shouldProduceRejectDecisionWithCorrectReason() {
+ // Arrange
+ SamplingRateStrategy strategy = SamplingRateStrategy.of(0.0);
+
+ // Act
+ SelectionDecision decision =
+ strategy.shouldSelect(mockAgent, testInputs, testOutput, Context.empty());
+
+ // Assert
+ assertEquals("not-sampled", decision.getReason());
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/SelectionDecisionTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/SelectionDecisionTest.java
new file mode 100644
index 000000000..42d27bb27
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/SelectionDecisionTest.java
@@ -0,0 +1,133 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.strategy;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+@DisplayName("SelectionDecision Unit Tests")
+class SelectionDecisionTest {
+
+ @Test
+ @DisplayName("Should create accept decision with reason")
+ void shouldCreateAcceptDecisionWithReason() {
+ // Act
+ SelectionDecision decision = SelectionDecision.accept("test-reason");
+
+ // Assert
+ assertTrue(decision.shouldTrain());
+ assertEquals("test-reason", decision.getReason());
+ assertTrue(decision.getLabels().isEmpty());
+ assertTrue(decision.getMetadata().isEmpty());
+ }
+
+ @Test
+ @DisplayName("Should create accept decision with labels")
+ void shouldCreateAcceptDecisionWithLabels() {
+ // Act
+ SelectionDecision decision = SelectionDecision.accept("test-reason", "label1", "label2");
+
+ // Assert
+ assertTrue(decision.shouldTrain());
+ assertEquals("test-reason", decision.getReason());
+ assertEquals(2, decision.getLabels().size());
+ assertTrue(decision.getLabels().contains("label1"));
+ assertTrue(decision.getLabels().contains("label2"));
+ }
+
+ @Test
+ @DisplayName("Should create accept decision with labels and metadata")
+ void shouldCreateAcceptDecisionWithLabelsAndMetadata() {
+ // Arrange
+ List labels = Arrays.asList("label1", "label2");
+ Map metadata = new HashMap<>();
+ metadata.put("key1", "value1");
+ metadata.put("taskId", "custom-task-id");
+
+ // Act
+ SelectionDecision decision = SelectionDecision.accept("test-reason", labels, metadata);
+
+ // Assert
+ assertTrue(decision.shouldTrain());
+ assertEquals("test-reason", decision.getReason());
+ assertEquals(2, decision.getLabels().size());
+ assertEquals("value1", decision.getMetadata().get("key1"));
+ assertEquals("custom-task-id", decision.getMetadata().get("taskId"));
+ }
+
+ @Test
+ @DisplayName("Should create reject decision")
+ void shouldCreateRejectDecision() {
+ // Act
+ SelectionDecision decision = SelectionDecision.reject("not-sampled");
+
+ // Assert
+ assertFalse(decision.shouldTrain());
+ assertEquals("not-sampled", decision.getReason());
+ assertTrue(decision.getLabels().isEmpty());
+ assertTrue(decision.getMetadata().isEmpty());
+ }
+
+ @Test
+ @DisplayName("Should return unmodifiable labels list")
+ void shouldReturnUnmodifiableLabels() {
+ // Arrange
+ SelectionDecision decision = SelectionDecision.accept("test", "label1");
+
+ // Act & Assert
+ assertThrows(UnsupportedOperationException.class, () -> decision.getLabels().add("new"));
+ }
+
+ @Test
+ @DisplayName("Should return unmodifiable metadata map")
+ void shouldReturnUnmodifiableMetadata() {
+ // Arrange
+ Map metadata = new HashMap<>();
+ metadata.put("key", "value");
+ SelectionDecision decision =
+ SelectionDecision.accept("test", Arrays.asList("label"), metadata);
+
+ // Act & Assert
+ assertThrows(
+ UnsupportedOperationException.class,
+ () -> decision.getMetadata().put("new", "val"));
+ }
+
+ @Test
+ @DisplayName("Should produce readable toString")
+ void shouldProduceReadableToString() {
+ // Arrange
+ SelectionDecision decision = SelectionDecision.accept("sampling-rate", "sampled");
+
+ // Act
+ String str = decision.toString();
+
+ // Assert
+ assertTrue(str.contains("shouldTrain=true"));
+ assertTrue(str.contains("reason='sampling-rate'"));
+ assertTrue(str.contains("sampled"));
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/TrainingAnnotationTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/TrainingAnnotationTest.java
new file mode 100644
index 000000000..523857f4c
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/TrainingAnnotationTest.java
@@ -0,0 +1,235 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.strategy;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+@DisplayName("TrainingAnnotation Unit Tests")
+class TrainingAnnotationTest {
+
+ @Test
+ @DisplayName("Should create enabled annotation")
+ void shouldCreateEnabledAnnotation() {
+ // Act
+ TrainingAnnotation annotation = TrainingAnnotation.enabled();
+
+ // Assert
+ assertTrue(annotation.isEnabled());
+ assertTrue(annotation.getLabels().isEmpty());
+ assertTrue(annotation.getMetadata().isEmpty());
+ assertNull(annotation.getTaskId());
+ assertTrue(annotation.getTimestamp() > 0);
+ }
+
+ @Test
+ @DisplayName("Should create annotation with labels")
+ void shouldCreateAnnotationWithLabels() {
+ // Act
+ TrainingAnnotation annotation = TrainingAnnotation.withLabels("high-quality", "production");
+
+ // Assert
+ assertTrue(annotation.isEnabled());
+ assertEquals(2, annotation.getLabels().size());
+ assertTrue(annotation.getLabels().contains("high-quality"));
+ assertTrue(annotation.getLabels().contains("production"));
+ }
+
+ @Test
+ @DisplayName("Should create annotation with labels and metadata")
+ void shouldCreateAnnotationWithLabelsAndMetadata() {
+ // Arrange
+ Map metadata = new HashMap<>();
+ metadata.put("userId", "user-123");
+ metadata.put("sessionId", "session-456");
+
+ // Act
+ TrainingAnnotation annotation =
+ TrainingAnnotation.withLabelsAndMetadata(
+ Arrays.asList("important", "review"), metadata);
+
+ // Assert
+ assertTrue(annotation.isEnabled());
+ assertEquals(2, annotation.getLabels().size());
+ assertEquals("user-123", annotation.getMetadata().get("userId"));
+ assertEquals("session-456", annotation.getMetadata().get("sessionId"));
+ }
+
+ @Test
+ @DisplayName("Should build annotation with builder")
+ void shouldBuildAnnotationWithBuilder() {
+ // Act
+ TrainingAnnotation annotation =
+ TrainingAnnotation.builder()
+ .enabled(true)
+ .taskId("custom-task-id")
+ .labels(Arrays.asList("label1", "label2"))
+ .metadata(Map.of("key", "value"))
+ .build();
+
+ // Assert
+ assertTrue(annotation.isEnabled());
+ assertEquals("custom-task-id", annotation.getTaskId());
+ assertEquals(2, annotation.getLabels().size());
+ assertEquals("value", annotation.getMetadata().get("key"));
+ }
+
+ @Test
+ @DisplayName("Should build annotation with custom timestamp")
+ void shouldBuildAnnotationWithCustomTimestamp() {
+ // Arrange
+ long customTimestamp = 1000000L;
+
+ // Act
+ TrainingAnnotation annotation =
+ TrainingAnnotation.builder().enabled(true).timestamp(customTimestamp).build();
+
+ // Assert
+ assertEquals(customTimestamp, annotation.getTimestamp());
+ }
+
+ @Test
+ @DisplayName("Should check expiration correctly")
+ void shouldCheckExpirationCorrectly() {
+ // Arrange - create annotation with old timestamp
+ TrainingAnnotation annotation =
+ TrainingAnnotation.builder()
+ .enabled(true)
+ .timestamp(System.currentTimeMillis() - 2000) // 2 seconds ago
+ .build();
+
+ // Assert
+ assertTrue(annotation.isExpired(1000)); // 1 second TTL - should be expired
+ assertFalse(annotation.isExpired(5000)); // 5 second TTL - should not be expired
+ }
+
+ @Test
+ @DisplayName("Should not be expired with future timestamp")
+ void shouldNotBeExpiredWithFutureTimestamp() {
+ // Arrange
+ TrainingAnnotation annotation = TrainingAnnotation.enabled();
+
+ // Assert - just created, should not be expired
+ assertFalse(annotation.isExpired(60000)); // 1 minute TTL
+ }
+
+ @Test
+ @DisplayName("Should implement equals correctly")
+ void shouldImplementEqualsCorrectly() {
+ // Arrange
+ long timestamp = System.currentTimeMillis();
+ TrainingAnnotation annotation1 =
+ TrainingAnnotation.builder()
+ .enabled(true)
+ .taskId("task-1")
+ .labels(Arrays.asList("label"))
+ .timestamp(timestamp)
+ .build();
+
+ TrainingAnnotation annotation2 =
+ TrainingAnnotation.builder()
+ .enabled(true)
+ .taskId("task-1")
+ .labels(Arrays.asList("label"))
+ .timestamp(timestamp)
+ .build();
+
+ TrainingAnnotation annotation3 =
+ TrainingAnnotation.builder().enabled(false).timestamp(timestamp).build();
+
+ // Assert
+ assertEquals(annotation1, annotation2);
+ assertNotEquals(annotation1, annotation3);
+ }
+
+ @Test
+ @DisplayName("Should implement hashCode correctly")
+ void shouldImplementHashCodeCorrectly() {
+ // Arrange
+ long timestamp = System.currentTimeMillis();
+ TrainingAnnotation annotation1 =
+ TrainingAnnotation.builder()
+ .enabled(true)
+ .taskId("task-1")
+ .timestamp(timestamp)
+ .build();
+
+ TrainingAnnotation annotation2 =
+ TrainingAnnotation.builder()
+ .enabled(true)
+ .taskId("task-1")
+ .timestamp(timestamp)
+ .build();
+
+ // Assert
+ assertEquals(annotation1.hashCode(), annotation2.hashCode());
+ }
+
+ @Test
+ @DisplayName("Should produce readable toString")
+ void shouldProduceReadableToString() {
+ // Arrange
+ TrainingAnnotation annotation =
+ TrainingAnnotation.builder()
+ .enabled(true)
+ .taskId("task-123")
+ .labels(Arrays.asList("important"))
+ .build();
+
+ // Act
+ String str = annotation.toString();
+
+ // Assert
+ assertTrue(str.contains("enabled=true"));
+ assertTrue(str.contains("taskId=task-123"));
+ assertTrue(str.contains("important"));
+ }
+
+ @Test
+ @DisplayName("Should handle null labels in withLabelsAndMetadata")
+ void shouldHandleNullLabelsInWithLabelsAndMetadata() {
+ // Act
+ TrainingAnnotation annotation =
+ TrainingAnnotation.withLabelsAndMetadata(null, Map.of("key", "value"));
+
+ // Assert
+ assertNotNull(annotation.getLabels());
+ assertTrue(annotation.getLabels().isEmpty());
+ }
+
+ @Test
+ @DisplayName("Should handle null metadata in withLabelsAndMetadata")
+ void shouldHandleNullMetadataInWithLabelsAndMetadata() {
+ // Act
+ TrainingAnnotation annotation =
+ TrainingAnnotation.withLabelsAndMetadata(Arrays.asList("label"), null);
+
+ // Assert
+ assertNotNull(annotation.getMetadata());
+ assertTrue(annotation.getMetadata().isEmpty());
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/TrainingContextTest.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/TrainingContextTest.java
new file mode 100644
index 000000000..e78a3b6c9
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/strategy/TrainingContextTest.java
@@ -0,0 +1,160 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.strategy;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Function;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+import reactor.util.context.Context;
+import reactor.util.context.ContextView;
+
+@DisplayName("TrainingContext Unit Tests")
+class TrainingContextTest {
+
+ @Test
+ @DisplayName("Should create mark function")
+ void shouldCreateMarkFunction() {
+ // Act
+ Function markFn = TrainingContext.mark();
+
+ // Assert
+ assertNotNull(markFn);
+
+ // Apply the function
+ Context context = markFn.apply(Context.empty());
+ assertTrue(context.hasKey(TrainingContext.REACTOR_KEY));
+ }
+
+ @Test
+ @DisplayName("Should create mark function with labels")
+ void shouldCreateMarkFunctionWithLabels() {
+ // Act
+ Function markFn = TrainingContext.mark("high-quality", "production");
+
+ // Assert
+ Context context = markFn.apply(Context.empty());
+ TrainingAnnotation annotation = context.get(TrainingContext.REACTOR_KEY);
+
+ assertTrue(annotation.isEnabled());
+ assertEquals(2, annotation.getLabels().size());
+ assertTrue(annotation.getLabels().contains("high-quality"));
+ assertTrue(annotation.getLabels().contains("production"));
+ }
+
+ @Test
+ @DisplayName("Should create mark function with labels and metadata")
+ void shouldCreateMarkFunctionWithLabelsAndMetadata() {
+ // Arrange
+ Map metadata = new HashMap<>();
+ metadata.put("userId", "user-123");
+ metadata.put("taskId", "custom-task-id");
+
+ // Act
+ Function markFn =
+ TrainingContext.mark(Arrays.asList("important"), metadata);
+
+ // Assert
+ Context context = markFn.apply(Context.empty());
+ TrainingAnnotation annotation = context.get(TrainingContext.REACTOR_KEY);
+
+ assertTrue(annotation.isEnabled());
+ assertEquals("user-123", annotation.getMetadata().get("userId"));
+ assertEquals("custom-task-id", annotation.getMetadata().get("taskId"));
+ }
+
+ @Test
+ @DisplayName("Should create mark function with only metadata")
+ void shouldCreateMarkFunctionWithOnlyMetadata() {
+ // Arrange
+ Map metadata = new HashMap<>();
+ metadata.put("key", "value");
+
+ // Act
+ Function markFn = TrainingContext.mark(metadata);
+
+ // Assert
+ Context context = markFn.apply(Context.empty());
+ TrainingAnnotation annotation = context.get(TrainingContext.REACTOR_KEY);
+
+ assertTrue(annotation.isEnabled());
+ assertTrue(annotation.getLabels().isEmpty());
+ assertEquals("value", annotation.getMetadata().get("key"));
+ }
+
+ @Test
+ @DisplayName("Should get current annotation from context")
+ void shouldGetCurrentAnnotationFromContext() {
+ // Arrange
+ TrainingAnnotation annotation = TrainingAnnotation.withLabels("test-label");
+ ContextView context = Context.of(TrainingContext.REACTOR_KEY, annotation);
+
+ // Act
+ TrainingAnnotation result = TrainingContext.getCurrent(context);
+
+ // Assert
+ assertNotNull(result);
+ assertTrue(result.isEnabled());
+ assertTrue(result.getLabels().contains("test-label"));
+ }
+
+ @Test
+ @DisplayName("Should return null when no annotation in context")
+ void shouldReturnNullWhenNoAnnotationInContext() {
+ // Act
+ TrainingAnnotation result = TrainingContext.getCurrent(Context.empty());
+
+ // Assert
+ assertNull(result);
+ }
+
+ @Test
+ @DisplayName("Should return null for null context")
+ void shouldReturnNullForNullContext() {
+ // Act
+ TrainingAnnotation result = TrainingContext.getCurrent(null);
+
+ // Assert
+ assertNull(result);
+ }
+
+ @Test
+ @DisplayName("Should check expiration correctly")
+ void shouldCheckExpirationCorrectly() {
+ // Arrange - create expired annotation
+ TrainingAnnotation expiredAnnotation =
+ TrainingAnnotation.builder()
+ .enabled(true)
+ .timestamp(System.currentTimeMillis() - 2000) // 2 seconds ago
+ .build();
+
+ TrainingAnnotation validAnnotation = TrainingAnnotation.enabled();
+
+ // Assert
+ assertTrue(TrainingContext.isExpired(expiredAnnotation, 1000)); // 1 second TTL
+ assertFalse(TrainingContext.isExpired(validAnnotation, 60000)); // 1 minute TTL
+ assertTrue(TrainingContext.isExpired(null, 60000)); // null is always expired
+ }
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/util/TrainingTestConstants.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/util/TrainingTestConstants.java
new file mode 100644
index 000000000..d16bea007
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/util/TrainingTestConstants.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.util;
+
+import java.time.Duration;
+
+/**
+ * Test constants for Training module unit tests.
+ */
+public final class TrainingTestConstants {
+
+ private TrainingTestConstants() {}
+
+ // Trinity Service
+ public static final String TEST_TRINITY_ENDPOINT = "http://mock-trinity:8080";
+ public static final String TEST_MODEL_NAME = "test-model";
+ public static final String TEST_API_KEY = "test-api-key";
+
+ // Task/Run IDs
+ public static final String TEST_TASK_ID = "test-task-001";
+ public static final String TEST_RUN_ID = "0";
+
+ // Timeouts
+ public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5);
+ public static final Duration SHORT_TIMEOUT = Duration.ofSeconds(1);
+ public static final Duration LONG_TIMEOUT = Duration.ofSeconds(10);
+
+ // Sampling
+ public static final double DEFAULT_SAMPLE_RATE = 0.1;
+ public static final double HIGH_SAMPLE_RATE = 0.9;
+ public static final double LOW_SAMPLE_RATE = 0.01;
+
+ // Commit intervals
+ public static final long DEFAULT_COMMIT_INTERVAL = 300;
+ public static final long SHORT_COMMIT_INTERVAL = 60;
+
+ // Pool sizes
+ public static final int DEFAULT_SHADOW_POOL_SIZE = 10;
+ public static final int DEFAULT_SHADOW_POOL_CAPACITY = 1000;
+
+ // Repeat times
+ public static final int DEFAULT_REPEAT_TIME = 1;
+ public static final int MULTI_REPEAT_TIME = 3;
+
+ // Test messages
+ public static final String TEST_MSG_ID_1 = "msg-001";
+ public static final String TEST_MSG_ID_2 = "msg-002";
+ public static final String TEST_MSG_ID_3 = "msg-003";
+
+ // Agent names
+ public static final String TEST_AGENT_NAME = "TestAgent";
+ public static final String TEST_SHADOW_AGENT_NAME = "TestAgent-shadow";
+}
diff --git a/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/util/TrainingTestUtils.java b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/util/TrainingTestUtils.java
new file mode 100644
index 000000000..06b1b06a0
--- /dev/null
+++ b/agentscope-extensions/agentscope-extensions-training/src/test/java/io/agentscope/core/training/util/TrainingTestUtils.java
@@ -0,0 +1,134 @@
+/*
+ * Copyright 2024-2026 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.agentscope.core.training.util;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import io.agentscope.core.agent.Agent;
+import io.agentscope.core.message.Msg;
+import io.agentscope.core.message.MsgRole;
+import io.agentscope.core.training.backend.TrinityClient;
+import io.agentscope.core.training.reward.RewardCalculator;
+import io.agentscope.core.training.runner.RunExecutionContext;
+import io.agentscope.core.training.runner.TrainingConfig;
+import io.agentscope.core.training.strategy.SamplingRateStrategy;
+import java.util.Arrays;
+import java.util.List;
+import reactor.core.publisher.Mono;
+
+/**
+ * Test utilities for Training module unit tests.
+ */
+public final class TrainingTestUtils {
+
+ private TrainingTestUtils() {}
+
+ /**
+ * Creates a mock Agent with the given name.
+ */
+ public static Agent createMockAgent(String name) {
+ Agent agent = mock(Agent.class);
+ when(agent.getName()).thenReturn(name);
+ when(agent.getAgentId()).thenReturn("id-" + name);
+ when(agent.getDescription()).thenReturn("Mock agent: " + name);
+ return agent;
+ }
+
+ /**
+ * Creates a mock RewardCalculator that returns the specified reward.
+ */
+ public static RewardCalculator createMockRewardCalculator(double reward) {
+ RewardCalculator calculator = mock(RewardCalculator.class);
+ when(calculator.calculate(org.mockito.ArgumentMatchers.any())).thenReturn(reward);
+ return calculator;
+ }
+
+ /**
+ * Creates a mock TrinityClient.
+ */
+ public static TrinityClient createMockTrinityClient() {
+ TrinityClient client = mock(TrinityClient.class);
+ when(client.getEndpoint()).thenReturn(TrainingTestConstants.TEST_TRINITY_ENDPOINT);
+ when(client.feedback(org.mockito.ArgumentMatchers.any())).thenReturn(Mono.empty());
+ when(client.commit(org.mockito.ArgumentMatchers.any())).thenReturn(Mono.empty());
+ return client;
+ }
+
+ /**
+ * Creates a default TrainingConfig for testing.
+ */
+ public static TrainingConfig createTestConfig() {
+ return TrainingConfig.builder()
+ .trinityEndpoint(TrainingTestConstants.TEST_TRINITY_ENDPOINT)
+ .modelName(TrainingTestConstants.TEST_MODEL_NAME)
+ .selectionStrategy(
+ SamplingRateStrategy.of(TrainingTestConstants.DEFAULT_SAMPLE_RATE))
+ .rewardCalculator(createMockRewardCalculator(0.5))
+ .commitIntervalSeconds(TrainingTestConstants.DEFAULT_COMMIT_INTERVAL)
+ .build();
+ }
+
+ /**
+ * Creates a RunExecutionContext for testing.
+ */
+ public static RunExecutionContext createTestContext(String taskId, String runId) {
+ return RunExecutionContext.create(taskId, runId);
+ }
+
+ /**
+ * Creates a test Msg list.
+ */
+ public static List createTestMessages() {
+ Msg msg1 = Msg.builder().name("user").role(MsgRole.USER).textContent("Hello").build();
+ Msg msg2 =
+ Msg.builder()
+ .name("assistant")
+ .role(MsgRole.ASSISTANT)
+ .textContent("Hi there!")
+ .build();
+ return Arrays.asList(msg1, msg2);
+ }
+
+ /**
+ * Creates a single test Msg.
+ */
+ public static Msg createTestMessage(String name, MsgRole role, String content) {
+ return Msg.builder().name(name).role(role).textContent(content).build();
+ }
+
+ /**
+ * Creates a list of test message IDs.
+ */
+ public static List createTestMsgIds() {
+ return Arrays.asList(
+ TrainingTestConstants.TEST_MSG_ID_1,
+ TrainingTestConstants.TEST_MSG_ID_2,
+ TrainingTestConstants.TEST_MSG_ID_3);
+ }
+
+ /**
+ * Pauses execution for the specified milliseconds.
+ */
+ public static void sleep(long millis) {
+ try {
+ Thread.sleep(millis);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ }
+}
diff --git a/agentscope-extensions/pom.xml b/agentscope-extensions/pom.xml
index 0b66ae58c..ddc7f33be 100644
--- a/agentscope-extensions/pom.xml
+++ b/agentscope-extensions/pom.xml
@@ -45,6 +45,7 @@
agentscope-extensions-scheduleragentscope-extensions-session-redisagentscope-extensions-session-mysql
+ agentscope-extensions-trainingagentscope-micronaut-extensionsagentscope-quarkus-extensionsagentscope-spring-boot-starters
diff --git a/docs/en/task/online-training.md b/docs/en/task/online-training.md
new file mode 100644
index 000000000..1294f7700
--- /dev/null
+++ b/docs/en/task/online-training.md
@@ -0,0 +1,267 @@
+# AgentScope Training Extension
+
+## Overview
+
+### Background
+
+Agent developers typically work with open-source models, using fine-tuning methods such as **SFT (Supervised Fine-Tuning)** and **RFT (Reinforcement Fine-Tuning)** to balance Agent cost, performance, and effectiveness in specific scenarios. This extension helps Agent developers **conveniently and continuously leverage real online interaction data to optimize models and Agents**, establishing a **complete data loop from production to training systems**, enabling **"Agents that get smarter with use"** through online training.
+
+### Online Training
+
+Online Training is a training paradigm that directly utilizes real user interaction data in production or near-production environments to continuously optimize Agent behavior. Unlike traditional Offline Training—which involves collecting historical logs, building static datasets, and training models in isolated environments—online training emphasizes deep coupling with real toolchains and user behavior, achieving a "run, learn, and optimize" closed loop.
+
+#### Key Characteristics
+
+**1. Reuse Production Toolchains**
+
+Agents can directly invoke real tools deployed in production (such as APIs, databases, business systems, etc.) during training, without the need to build simulation environments or write mock tools specifically for training.
+
+**Advantage**: Avoids "training-deployment deviation" (Reality Gap) caused by inconsistencies between mock tools and actual production behavior; significantly reduces integration costs and improves the authenticity and effectiveness of training data.
+
+**2. Support Incremental Learning with Fast Cold Start**
+
+Does not depend on complete historical datasets; Agents can start learning from a small number or even single real interactions, suitable for newly launched Agents or long-tail scenarios, significantly lowering the startup threshold.
+
+
+#### Constraints
+
+**Safe Support for Read-Only Tools by Default**
+
+Since the training process may involve multiple attempts or replays, directly invoking write-operation tools (such as "place order", "deduct payment", "send message") may cause repeated execution, leading to business risks. Therefore, write operations require additional safeguards through sandbox mechanisms, idempotent design, or manual review. Users need to ensure the safety of the tools used by their Agents.
+
+**Multi-Turn Interaction Scenarios Require Explicit Modeling**
+
+Current mainstream training frameworks natively support single-turn interactions between users and Agents (user asks → Agent responds). In this interaction, Agents can have multiple interactions with LLMs.
+
+For multi-turn dialogues or complex task flows (such as book flight → select seat → payment), developers need to design additional state management, user behavior simulation, or trajectory sampling strategies.
+
+
+## Architecture
+
+
+This solution uses Trinity-RFT as the training backend. Trinity-RFT is a general-purpose, flexible, and user-friendly Large Language Model (LLM) Reinforcement Fine-Tuning (RFT) framework.
+
+Github: https://github.com/agentscope-ai/Trinity-RFT
+
+Version requirement: v0.4.0 and above
+
+The online training mode decouples three components: Agent Runner, inference service (Explorer), and training service (Trainer):
+
+- Agent Runner is responsible for running user Agent applications, processing user requests, and interacting with Explorer through RESTful APIs. This component is implemented, deployed, and managed by users themselves, with no constraints from Trinity-RFT.
+- Explorer serves as the inference service, processing requests from Agent Runner, recording trainable data (Experience), and storing data in the database. Explorer provides the following RESTful interfaces for Agent Runner to call:
+ - chat: Compatible with standard OpenAI chat completions interface, handling user dialogue requests.
+ - feedback: Receives user feedback on Agent responses.
+ - commit: Notifies Explorer to submit data to Trainer.
+- Trainer serves as the training service, retrieving new training data from the database, training the model, and storing updated model checkpoints in a shared file system for Explorer to use.
+
+Agent Runner, Explorer, and Trainer can be deployed on different servers. Among them, Agent Runner is managed by users themselves, only requiring network connectivity with Explorer and no GPU resources. Explorer and Trainer need to be deployed on GPU servers through Trinity-RFT and must ensure both can access the same shared file system so that model checkpoints saved by Trainer can be read by Explorer.
+
+### Core Features
+
+This solution provides **end-to-end online training** support natively in AgentScope Java, aiming to establish a complete loop from production to model optimization with the following goals:
+
+- **Leverage Real Online Interaction Data**: Agent developers can directly train using real request invocations and tool states from production Agent environments
+- **Minimal Setup Experience**: Agent developers only need to provide key training configurations (such as reward functions in RL) to automatically complete execution, data collection, and the entire training process
+- **Unified Training Interface Covering Mainstream Optimization Methods**: Native support for supervised fine-tuning (**SFT**), knowledge distillation, and task-specific reinforcement learning algorithms (such as **PPO**), without needing to switch frameworks or depend on other ecosystems
+
+---
+
+## Quick Start
+
+### Maven Dependency
+
+```xml
+
+ io.agentscope
+ agentscope-extensions-training
+ ${agentscope.version}
+
+```
+### Define Request Selection Logic
+
+Request selection logic is used to filter out requests that need to be used for training.
+
+#### Built-in Strategies:
+
+**SamplingRateStrategy** - Random sampling. All online requests are filtered by percentage.
+```java
+TrainingSelectionStrategy strategy = SamplingRateStrategy.of(0.1); // 10%
+```
+
+**ExplicitMarkingStrategy** - Users explicitly mark important requests
+```java
+TrainingSelectionStrategy strategy = ExplicitMarkingStrategy.create();
+
+// In your application code, explicitly mark requests for training
+TrainingContext.mark("high-quality", "user-feedback");
+agent.call(msg).block(); // This request will be used for training
+```
+#### Custom Strategy
+You can implement the TrainingSelectionStrategy interface by referring to SamplingRateStrategy or ExplicitMarkingStrategy, and customize your request filtering logic in the shouldSelect method according to your business needs.
+
+
+### Define Reward Function
+You can implement the RewardCalculator interface and customize your reward calculation logic in the calculate method according to your business needs. Generally, rewards are decimals between 0 and 1.
+### Start Training Backend
+
+#### Install Trinity
+
+Before installation, ensure your system meets the following requirements. Source installation is recommended:
+
+- **Python**: Version 3.10 to 3.12 (inclusive)
+- **CUDA**: Version >= 12.8
+- **GPU**: At least 2 GPUs
+
+```bash
+git clone https://github.com/agentscope-ai/Trinity-RFT
+cd Trinity-RFT
+pip install -e ".[dev]"
+pip install flash-attn==2.8.1
+```
+#### Configure Training Settings
+##### Write Explorer Service Configuration
+```yaml
+mode: serve # set to 'serve' for online inference service
+project: test # set your project name
+name: test # set your experiment name
+checkpoint_root_dir: CHECKPOINT_ROOT_DIR # set the root directory for checkpoints, must be an absolute path, and should be on a shared filesystem
+model:
+ model_path: /path/to/your/model # set the path to your base model
+ max_model_len: 8192
+ max_response_tokens: 2048
+ temperature: 0.7
+algorithm:
+ algorithm_type: "ppo" # current version only supports ppo for online training (group is not supported yet)
+cluster:
+ node_num: 1
+ gpu_per_node: 4 # suppose you have 4 GPUs on the node
+explorer:
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 2 # make sure tensor_parallel_size * engine_num <= node_num * gpu_per_node
+ enable_openai_api: true
+ enable_history: true
+ enable_auto_tool_choice: true
+ tool_call_parser: hermes
+ # reasoning_parser: deepseek_r1 # if using Qwen3 series models, uncomment this line
+ dtype: bfloat16
+ seed: 42
+ service_status_check_interval: 10 # check new checkpoints and update data every 10 seconds
+ proxy_port: 8010 # set the port for Explorer service
+# trainer:
+# save_interval: 1 # save checkpoint every step
+# ulysses_sequence_parallel_size: 2 # set according to your model and hardware
+buffer:
+ train_batch_size: 16
+ trainer_input:
+ experience_buffer:
+ name: exp_buffer # table name in the database
+ storage_type: sql
+ # path: your_db_url # if not provided, use a sqlite database in checkpoint_root_dir/project/name/buffer
+synchronizer:
+ sync_method: checkpoint
+ sync_interval: 1
+monitor:
+ monitor_type: tensorboard
+```
+##### Write Trainer Service Configuration
+```yaml
+mode: train # set to 'train' for training service
+project: test # set your project name, must be the same as in Explorer
+name: test # set your experiment name, must be the same as in Explorer
+checkpoint_root_dir: CHECKPOINT_ROOT_DIR # set the root directory for checkpoints, must be the same as in Explorer
+model:
+ model_path: /path/to/your/model # set the path to your base model, must be the same as in Explorer
+ max_model_len: 8192 # must be the same as in Explorer
+ max_response_tokens: 2048 # must be the same as in Explorer
+ temperature: 0.7 # must be the same as in Explorer
+algorithm:
+ algorithm_type: "ppo" # current version only supports ppo for online training (group is not supported yet)
+cluster:
+ node_num: 1
+ gpu_per_node: 4 # suppose you have 4 GPUs on the node
+buffer:
+ train_batch_size: 32 # trainer consumes 16 samples per step
+ trainer_input:
+ experience_buffer:
+ name: exp_buffer # table name in the database, must be the same as in Explorer
+ storage_type: sql
+ # path: your_db_url # if not provided, use a sqlite database in checkpoint_root_dir/project/name/buffer
+trainer:
+ save_interval: 16 # save checkpoint every step
+ ulysses_sequence_parallel_size: 1 # set according to your model and hardware
+ save_hf_checkpoint: always
+ max_checkpoints_to_keep: 5
+ trainer_config:
+ trainer:
+ balance_batch: false
+ max_actor_ckpt_to_keep: 5
+ max_critic_ckpt_to_keep: 5
+synchronizer:
+ sync_method: checkpoint
+ sync_interval: 1
+
+monitor:
+ monitor_type: tensorboard
+```
+
+#### Start Training Backend Environment
+Before starting Explorer and Trainer services, you need to start the Ray cluster
+```bash
+ray start --head
+```
+Start Explorer and Trainer services separately.
+
+```bash
+trinity run --config explorer.yaml
+trinity run --config trainer.yaml
+```
+After starting the Explorer service, the service address will be printed in the log, typically on port 8010
+### Configure Online Training and Start Agent
+
+#### Configuration Options
+
+```java
+TrainingRunner trainingRunner = TrainingRunner.builder()
+ .trinityEndpoint(TRINITY_ENDPOINT) // Trinity Explorer service address
+ .modelName(TRAINING_MODEL_NAME) // Corresponds to model_path in Trinity configuration
+ .selectionStrategy(new CustomStrategy())
+ .rewardCalculator(new CustomReward())
+ .commitIntervalSeconds(60*5)
+ .repeatTime(1)
+ .build();
+trainingRunner.start();
+```
+
+#### Complete Example
+
+```java
+import io.agentscope.core.training.runner.TrainingRunner;
+import io.agentscope.core.training.strategy.SamplingRateStrategy;
+
+// 1. Start training runner (no Task ID/Run ID needed!)
+TrainingRunner runner = TrainingRunner.builder()
+ .trinityEndpoint("http://trinity-backend:8010")
+ .modelName("/path/to/qwen-model")
+ .selectionStrategy(SamplingRateStrategy.of(0.1)) // 10% sampling
+ .rewardCalculator(agent -> 0.0) // Custom reward calculation logic
+ .commitIntervalSeconds(300) // Commit every 5 minutes
+ .build();
+
+runner.start();
+
+// 2. Use your Agent normally - training happens transparently!
+ReActAgent agent = ReActAgent.builder()
+ .name("ProductionAgent")
+ .model(gpt4Model) // Production model (GPT-4)
+ .tools(tools)
+ .build();
+
+// User requests are processed normally (using GPT-4), 10% automatically sampled for training
+Msg response = agent.call(Msg.userMsg("Search for Python tutorials")).block();
+
+// 3. Stop when training is complete
+runner.stop();
+```
+
+---
diff --git a/docs/imgs/training.svg b/docs/imgs/training.svg
new file mode 100644
index 000000000..99c204518
--- /dev/null
+++ b/docs/imgs/training.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/docs/zh/task/online-training.md b/docs/zh/task/online-training.md
new file mode 100644
index 000000000..4eac61e39
--- /dev/null
+++ b/docs/zh/task/online-training.md
@@ -0,0 +1,268 @@
+# AgentScope Training 训练扩展
+
+## 概述
+
+### 背景
+
+Agent 开发者通常基于开源模型,通过 **SFT(监督微调)**、**RFT(强化微调)** 等微调手段,在特定场景下平衡 Agent 成本、性能与效果。该插件帮助 Agent 开发者**便捷、持续地利用在线真实交互数据优化模型与 Agent**,打通从生产环境到训练系统的**全链路数据闭环**,通过在线训练,实现**"Agent 越用越聪明"**。
+
+### 在线训练
+
+在线训练(Online Training)是一种直接在生产环境或接近生产环境的实时系统中,利用真实用户交互数据持续优化智能体(Agent)行为的训练范式。与传统的离线训练(Offline Training)——即先收集历史日志、构建静态数据集、再在隔离环境中训练模型——不同,在线训练强调与真实工具链和用户行为的深度耦合,实现“边运行、边学习、边优化”的闭环。
+
+#### 核心特点
+
+**1. 复用线上真实工具链**
+
+Agent 在训练中可直接调用线上已部署的真实工具(如 API、数据库、业务系统等),无需为训练专门搭建模拟环境或编写 mock 工具。
+
+**优势**:避免因 mock 工具与线上实际行为不一致导致的"训练-部署偏差"(Reality Gap);大幅降低集成成本,提升训练数据的真实性与有效性。
+
+**2. 支持增量学习,快速冷启动**
+
+不依赖完整的历史数据集,Agent 可从少量甚至单次真实交互开始学习,适合新上线 Agent 或长尾场景,显著降低启动门槛。
+
+
+#### 约束
+
+**默认仅安全支持只读工具**
+
+因训练过程可能涉及多次尝试或重放(replay),若直接调用写操作工具(如"下单""扣款""发消息"),可能导致重复执行,引发业务风险。因此,写操作需通过沙箱机制、幂等设计或人工审核等方式额外保障安全性。用户需要自行保证 Agent 使用的工具的安全性。
+
+**多轮交互场景需显式建模**
+
+当前主流训练框架原生支持用户与 Agent 的单轮交互(用户提问 → Agent 响应)。该轮交互中,Agent 可以和 LLM 有多次交互。
+
+对于多轮对话或复杂任务流(如订机票 → 选座位 → 支付),需开发者额外设计状态管理、用户行为模拟或轨迹采样策略。
+
+
+## 架构
+
+
+该方案使用Trinity-RFT作为训练后端进行训练。Trinity-RFT 是一个通用、灵活、用户友好的大语言模型(LLM)强化微调(RFT)框架。
+
+Github地址:https://github.com/agentscope-ai/Trinity-RFT
+
+版本要求:v0.4.0及以上
+
+在线训练模式将 Agent 运行 (Agent Runner)、推理服务 (Explorer) 、训练服务 (Trainer) 三个部分解耦开来:
+
+- Agent Runner 负责运行用户的 Agent 应用,处理用户请求,并通过 restful API 与 Explorer 进行交互。该部分由用户自行实现、部署和管理,Trinity-RFT 不对该部分做任何约束。
+- Explorer 作为推理服务,处理来自 Agent Runner 的请求,记录可训练数据(Experience),并将数据存储在数据库中。 Explorer 提供以下 Restful 接口供 Agent Runner 调用:
+ - chat: 兼容标准的 openai chat completions 接口,处理用户的对话请求。
+ - feedback: 接收用户对 Agent 回答的反馈信息。
+ - commit: 告知 Explorer 向 Trainer 提交数据。
+- Trainer 作为训练服务,从数据库中获取新的训练数据,对模型进行训练,并将更新后的模型检查点存储在共享文件系统中,供 Explorer 使用。
+
+Agent Runner、Explorer 和 Trainer 可以部署在不同的服务器上。其中 Agent Runner 由用户自行管理,只需要保证网络与 Explorer 互通,无需 GPU 资源。
+而 Explorer 和 Trainer 需要通过 Trinity-RFT 部署在 GPU 服务器上,且需要保证两者可以访问同一个共享文件系统,以便 Trainer 保存的模型检查点可以被 Explorer 读取。
+
+### 核心功能
+
+本方案通过 AgentScope Java 原生支持**端到端在线训练**,旨在打通从生产环境到模型优化的全链路闭环,实现以下目标:
+
+- **利用线上真实交互数据**:Agent 开发者可直接基于生产环境中 Agent 的真实请求调用与工具状态,使用线上的数据进行训练
+- **极简使用体验**:Agent 开发者仅需提供关键训练配置(如:RL 中的奖励函数),即可自动完成执行、数据收集、训练全流程
+- **统一训练接口,覆盖主流优化方法**:原生支持监督微调(**SFT**)、知识蒸馏,以及适用于特定任务的强化学习算法(如**PPO**),无需切换框架或依赖其他生态
+
+---
+
+## 快速开始
+
+### Maven 依赖
+
+```xml
+
+ io.agentscope
+ agentscope-extensions-training
+ ${agentscope.version}
+
+```
+### 定义请求筛选逻辑
+
+请求筛选逻辑用于筛选出需要用于训练的请求。
+
+#### 内置策略:
+
+**SamplingRateStrategy** - 随机采样。所有线上请求按照百分比进行筛选。
+```java
+TrainingSelectionStrategy strategy = SamplingRateStrategy.of(0.1); // 10%
+```
+
+**ExplicitMarkingStrategy** - 用户显式标记重要请求
+```java
+TrainingSelectionStrategy strategy = ExplicitMarkingStrategy.create();
+
+// 在你的应用代码中显示标记请求用于训练
+TrainingContext.mark("high-quality", "user-feedback");
+agent.call(msg).block(); // 这个请求会被用于训练
+```
+#### 自定义策略
+您可以参考SamplingRateStrategy或者ExplicitMarkingStrategy实现TrainingSelectionStrategy接口,并在shouldSelect方法中根据您的业务需求自定义您的请求筛选逻辑。
+
+
+### 定义奖励函数
+您可以实现RewardCalculator接口,并在calculate方法中根据您的业务需求自定义您的奖励计算逻辑。一般而言,奖励为0-1之间的小数。
+### 启动训练后端
+
+#### 安装 Trinity
+
+在安装之前,请确保您的系统满足以下要求,推荐使用源码安装:
+
+- **Python**:版本 3.10 至 3.12(含)
+- **CUDA**:版本 >= 12.8
+- **GPU**:至少 2 块 GPU
+
+```bash
+git clone https://github.com/agentscope-ai/Trinity-RFT
+cd Trinity-RFT
+pip install -e ".[dev]"
+pip install flash-attn==2.8.1
+```
+#### 配置训练配置
+##### 编写explorer服务配置
+```yaml
+mode: serve # set to 'serve' for online inference service
+project: test # set your project name
+name: test # set your experiment name
+checkpoint_root_dir: CHECKPOINT_ROOT_DIR # set the root directory for checkpoints, must be an absolute path, and should be on a shared filesystem
+model:
+ model_path: /path/to/your/model # set the path to your base model
+ max_model_len: 8192
+ max_response_tokens: 2048
+ temperature: 0.7
+algorithm:
+ algorithm_type: "ppo" # current version only supports ppo for online training (group is not supported yet)
+cluster:
+ node_num: 1
+ gpu_per_node: 4 # suppose you have 4 GPUs on the node
+explorer:
+ rollout_model:
+ engine_num: 2
+ tensor_parallel_size: 2 # make sure tensor_parallel_size * engine_num <= node_num * gpu_per_node
+ enable_openai_api: true
+ enable_history: true
+ enable_auto_tool_choice: true
+ tool_call_parser: hermes
+ # reasoning_parser: deepseek_r1 # if using Qwen3 series models, uncomment this line
+ dtype: bfloat16
+ seed: 42
+ service_status_check_interval: 10 # check new checkpoints and update data every 10 seconds
+ proxy_port: 8010 # set the port for Explorer service
+# trainer:
+# save_interval: 1 # save checkpoint every step
+# ulysses_sequence_parallel_size: 2 # set according to your model and hardware
+buffer:
+ train_batch_size: 16
+ trainer_input:
+ experience_buffer:
+ name: exp_buffer # table name in the database
+ storage_type: sql
+ # path: your_db_url # if not provided, use a sqlite database in checkpoint_root_dir/project/name/buffer
+synchronizer:
+ sync_method: checkpoint
+ sync_interval: 1
+monitor:
+ monitor_type: tensorboard
+```
+##### 编写Trainner服务配置
+```yaml
+mode: train # set to 'train' for training service
+project: test # set your project name, must be the same as in Explorer
+name: test # set your experiment name, must be the same as in Explorer
+checkpoint_root_dir: CHECKPOINT_ROOT_DIR # set the root directory for checkpoints, must be the same as in Explorer
+model:
+ model_path: /path/to/your/model # set the path to your base model, must be the same as in Explorer
+ max_model_len: 8192 # must be the same as in Explorer
+ max_response_tokens: 2048 # must be the same as in Explorer
+ temperature: 0.7 # must be the same as in Explorer
+algorithm:
+ algorithm_type: "ppo" # current version only supports ppo for online training (group is not supported yet)
+cluster:
+ node_num: 1
+ gpu_per_node: 4 # suppose you have 4 GPUs on the node
+buffer:
+ train_batch_size: 32 # trainer consumes 16 samples per step
+ trainer_input:
+ experience_buffer:
+ name: exp_buffer # table name in the database, must be the same as in Explorer
+ storage_type: sql
+ # path: your_db_url # if not provided, use a sqlite database in checkpoint_root_dir/project/name/buffer
+trainer:
+ save_interval: 16 # save checkpoint every step
+ ulysses_sequence_parallel_size: 1 # set according to your model and hardware
+ save_hf_checkpoint: always
+ max_checkpoints_to_keep: 5
+ trainer_config:
+ trainer:
+ balance_batch: false
+ max_actor_ckpt_to_keep: 5
+ max_critic_ckpt_to_keep: 5
+synchronizer:
+ sync_method: checkpoint
+ sync_interval: 1
+
+monitor:
+ monitor_type: tensorboard
+```
+
+#### 启动训练后端环境
+启动 Explorer 和 Trainer 服务前需要启动 ray 集群
+```bash
+ray start --head
+```
+分别启动Explorer与Trainner服务。
+
+```bash
+trinity run --config explorer.yaml
+trinity run --config trainer.yaml
+```
+启动Explorer 服务后,会将服务地址打印在日志中,一般端口为8010
+### 配置在线训练与启动Agent
+
+#### 配置选项
+
+```java
+TrainingRunner trainingRunner = TrainingRunner.builder()
+ .trinityEndpoint(TRINITY_ENDPOINT) //Trinity Explorer服务地址
+ .modelName(TRAINING_MODEL_NAME)//对应Trinity配置中model_path
+ .selectionStrategy(new CustomStrategy())
+ .rewardCalculator(new CustomReward())
+ .commitIntervalSeconds(60*5)
+ .repeatTime(1)
+ .build();
+trainingRunner.start();
+```
+
+#### 完整示例
+
+```java
+import io.agentscope.core.training.runner.TrainingRunner;
+import io.agentscope.core.training.strategy.SamplingRateStrategy;
+
+// 1. 启动训练 runner(无需 Task ID/Run ID!)
+TrainingRunner runner = TrainingRunner.builder()
+ .trinityEndpoint("http://trinity-backend:8010")
+ .modelName("/path/to/qwen-model")
+ .selectionStrategy(SamplingRateStrategy.of(0.1)) // 10% 采样
+ .rewardCalculator(agent -> 0.0) // 自定义奖励计算逻辑
+ .commitIntervalSeconds(300) // 每 5 分钟 commit 一次
+ .build();
+
+runner.start();
+
+// 2. 正常使用你的 Agent - 完全无感知训练!
+ReActAgent agent = ReActAgent.builder()
+ .name("ProductionAgent")
+ .model(gpt4Model) // 生产模型 (GPT-4)
+ .tools(tools)
+ .build();
+
+// 用户请求正常处理(使用 GPT-4),自动采样10%请求用于训练
+Msg response = agent.call(Msg.userMsg("搜索 Python 教程")).block();
+
+// 3. 训练完成后停止
+runner.stop();
+```
+
+---