From 359dd191e3d58e5eb30e83cc0f69290dfa330a98 Mon Sep 17 00:00:00 2001 From: Arpit Bandejiya Date: Thu, 14 May 2026 20:57:58 +0530 Subject: [PATCH 1/3] Dump a initial change for coordinator Signed-off-by: Arpit Bandejiya --- .../schema/OpenSearchSchemaBuilder.java | 6 + .../spi/AnalyticsSearchBackendPlugin.java | 23 ++ .../exec/AnalyticsSearchService.java | 158 ++++++-- .../exec/AnalyticsSearchTransportService.java | 343 +++++++++++++++-- .../analytics/exec/QueryContext.java | 5 + .../analytics/exec/QueryScheduler.java | 1 + .../exec/action/FetchByRowIdsAction.java | 24 ++ .../action/FetchByRowIdsArrowResponse.java | 33 ++ .../exec/action/FetchByRowIdsRequest.java | 88 +++++ .../exec/action/FetchByRowIdsResponse.java | 50 +++ .../action/FragmentExecutionResponse.java | 73 ++++ .../LateMaterializationStageExecution.java | 337 +++++++++++++++++ .../LateMaterializationStageScheduler.java | 31 ++ .../analytics/exec/stage/ResponseCodec.java | 40 ++ .../exec/stage/RowResponseCodec.java | 112 ++++++ .../stage/ShardFragmentStageExecution.java | 116 ++++-- .../stage/ShardFragmentStageScheduler.java | 18 +- .../exec/stage/StageExecutionBuilder.java | 3 +- .../planner/FieldStorageResolver.java | 13 + .../analytics/planner/PlannerImpl.java | 19 +- .../analytics/planner/dag/DAGBuilder.java | 61 ++++ .../analytics/planner/dag/QueryDAG.java | 4 + .../analytics/planner/dag/Stage.java | 30 +- .../planner/dag/StageExecutionType.java | 8 +- .../rel/OpenSearchDistributionTraitDef.java | 6 +- .../rules/LateMaterializationRule.java | 226 ++++++++++++ .../ShardFragmentStageExecutionTests.java | 3 +- .../dsl/converter/ProjectConverter.java | 5 + .../analytics/qa/LateMaterializationIT.java | 344 ++++++++++++++++++ 29 files changed, 2083 insertions(+), 97 deletions(-) create mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsAction.java create mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsArrowResponse.java create mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsRequest.java create mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsResponse.java create mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FragmentExecutionResponse.java create mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/LateMaterializationStageExecution.java create mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/LateMaterializationStageScheduler.java create mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ResponseCodec.java create mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/RowResponseCodec.java create mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/rules/LateMaterializationRule.java create mode 100644 sandbox/qa/analytics-engine-rest/src/test/java/org/opensearch/analytics/qa/LateMaterializationIT.java diff --git a/sandbox/libs/analytics-api/src/main/java/org/opensearch/analytics/schema/OpenSearchSchemaBuilder.java b/sandbox/libs/analytics-api/src/main/java/org/opensearch/analytics/schema/OpenSearchSchemaBuilder.java index ff5dcff67b604..d2781af2a437a 100644 --- a/sandbox/libs/analytics-api/src/main/java/org/opensearch/analytics/schema/OpenSearchSchemaBuilder.java +++ b/sandbox/libs/analytics-api/src/main/java/org/opensearch/analytics/schema/OpenSearchSchemaBuilder.java @@ -117,6 +117,12 @@ private static AbstractTable buildTable(Map properties) { public RelDataType getRowType(RelDataTypeFactory typeFactory) { RelDataTypeFactory.Builder builder = typeFactory.builder(); addLeafFields(builder, typeFactory, properties, ""); + // Virtual row ID column — always present in parquet files, computed by analytics backend. + // Only add if not already in the mapping. + if (!properties.containsKey("__row_id__")) { + builder.add("__row_id__", typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.BIGINT), true)); + } return builder.build(); } }; diff --git a/sandbox/libs/analytics-framework/src/main/java/org/opensearch/analytics/spi/AnalyticsSearchBackendPlugin.java b/sandbox/libs/analytics-framework/src/main/java/org/opensearch/analytics/spi/AnalyticsSearchBackendPlugin.java index 59f9b1f899f92..51e6c3a16044e 100644 --- a/sandbox/libs/analytics-framework/src/main/java/org/opensearch/analytics/spi/AnalyticsSearchBackendPlugin.java +++ b/sandbox/libs/analytics-framework/src/main/java/org/opensearch/analytics/spi/AnalyticsSearchBackendPlugin.java @@ -8,6 +8,10 @@ package org.opensearch.analytics.spi; +import org.apache.arrow.memory.BufferAllocator; +import org.opensearch.analytics.backend.EngineResultStream; +import org.opensearch.index.engine.exec.IndexReaderProvider; + import java.util.List; /** @@ -119,4 +123,23 @@ default void configureFilterDelegation(FilterDelegationHandle handle, BackendExe * Called after {@link #configureFilterDelegation}. Pass {@code null} to clear. */ default void setDelegationThreadTracker(DelegationThreadTracker tracker) {} + + /** + * QTF fetch phase: reads specific rows by global row ID. + * Row IDs are passed as a BigIntVector for zero-copy transfer to native. + * + * @param reader the index reader for the target shard + * @param rowIdVector Arrow BigIntVector containing global row IDs + * @param columns column names to read + * @param allocator Arrow buffer allocator for result import + * @return a result stream containing the requested rows + */ + default EngineResultStream fetchByRowIds( + IndexReaderProvider.Reader reader, + org.apache.arrow.vector.BigIntVector rowIdVector, + String[] columns, + BufferAllocator allocator + ) { + throw new UnsupportedOperationException("fetchByRowIds not implemented for [" + name() + "]"); + } } diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchService.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchService.java index 3c389add9d00c..0847b7716e26a 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchService.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchService.java @@ -9,21 +9,31 @@ package org.opensearch.analytics.exec; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; import org.opensearch.analytics.backend.AnalyticsOperationListener; +import org.opensearch.analytics.backend.EngineResultBatch; import org.opensearch.analytics.backend.EngineResultStream; import org.opensearch.analytics.backend.SearchExecEngine; import org.opensearch.analytics.backend.ShardScanExecutionContext; import org.opensearch.analytics.exec.action.FragmentExecutionRequest; +import org.opensearch.analytics.exec.action.FragmentExecutionResponse; import org.opensearch.analytics.exec.task.AnalyticsShardTask; import org.opensearch.analytics.spi.AnalyticsSearchBackendPlugin; import org.opensearch.analytics.spi.BackendExecutionContext; import org.opensearch.analytics.spi.DelegationDescriptor; -import org.opensearch.analytics.spi.DelegationThreadTracker; import org.opensearch.analytics.spi.FilterDelegationHandle; import org.opensearch.analytics.spi.FragmentInstructionHandler; import org.opensearch.analytics.spi.FragmentInstructionHandlerFactory; import org.opensearch.analytics.spi.InstructionNode; import org.opensearch.arrow.flight.transport.ArrowAllocatorProvider; +import org.opensearch.common.Nullable; import org.opensearch.common.concurrent.GatedCloseable; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.tasks.TaskCancelledException; @@ -31,9 +41,11 @@ import org.opensearch.index.engine.exec.IndexReaderProvider.Reader; import org.opensearch.index.shard.IndexShard; import org.opensearch.tasks.Task; -import org.opensearch.tasks.TaskResourceTrackingService; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.nio.channels.Channels; +import java.util.Iterator; import java.util.List; import java.util.Map; @@ -60,7 +72,6 @@ public class AnalyticsSearchService implements AutoCloseable { private final AnalyticsOperationListener listener; private final BufferAllocator allocator; private final NamedWriteableRegistry namedWriteableRegistry; - private TaskResourceTrackingService taskResourceTrackingService; public AnalyticsSearchService(Map backends) { this(backends, List.of(), null); @@ -86,8 +97,78 @@ public void close() { allocator.close(); } - public void setTaskResourceTrackingService(TaskResourceTrackingService service) { - this.taskResourceTrackingService = service; + public FragmentExecutionResponse executeFragment(FragmentExecutionRequest request, IndexShard shard) { + return executeFragment(request, shard, null); + } + + public FragmentExecutionResponse executeFragment(FragmentExecutionRequest request, IndexShard shard, AnalyticsShardTask task) { + ResolvedFragment resolved = resolveFragment(request, shard); + long startNanos = System.nanoTime(); + try (FragmentResources ctx = startFragment(request, resolved, shard, task)) { + FragmentExecutionResponse response = collectResponse(ctx.stream(), task); + long tookNanos = System.nanoTime() - startNanos; + listener.onFragmentSuccess(resolved.queryId, resolved.stageId, resolved.shardIdStr, tookNanos, response.getRowCount()); + return response; + } catch (TaskCancelledException | IllegalStateException | IllegalArgumentException e) { + listener.onFragmentFailure(resolved.queryId, resolved.stageId, resolved.shardIdStr, e); + throw e; + } catch (Exception e) { + listener.onFragmentFailure(resolved.queryId, resolved.stageId, resolved.shardIdStr, e); + throw new RuntimeException("Failed to execute fragment on " + shard.shardId(), e); + } + } + + /** + * QTF fetch phase: read specific rows by global row ID. + * Bypasses Substrait plan resolution — calls directly into backend's FFM. + */ + public org.opensearch.analytics.exec.action.FetchByRowIdsResponse executeFetchByRowIds( + org.opensearch.analytics.exec.action.FetchByRowIdsRequest request, + IndexShard shard, + AnalyticsShardTask task + ) { + long startNanos = System.nanoTime(); + String shardIdStr = shard.shardId().toString(); + try { + EngineResultStream stream = executeFetchStreaming(request, shard, task); + FragmentExecutionResponse fragmentResp = collectResponse(stream, task); + long tookNanos = System.nanoTime() - startNanos; + listener.onFragmentSuccess(request.getQueryId(), 0, shardIdStr, tookNanos, fragmentResp.getRowCount()); + return new org.opensearch.analytics.exec.action.FetchByRowIdsResponse(fragmentResp.getIpcPayload(), fragmentResp.getRowCount()); + } catch (Exception e) { + listener.onFragmentFailure(request.getQueryId(), 0, shardIdStr, e); + throw new RuntimeException("Failed to execute fetch-by-row-ids on " + shard.shardId(), e); + } + } + + /** + * Streaming variant: returns the raw EngineResultStream for the fetch phase. + * Used by the streaming transport handler to send Arrow batches directly. + */ + public EngineResultStream executeFetchStreaming( + org.opensearch.analytics.exec.action.FetchByRowIdsRequest request, + IndexShard shard, + AnalyticsShardTask task + ) { + IndexReaderProvider readerProvider = shard.getReaderProvider(); + if (readerProvider == null) { + throw new IllegalStateException("No ReaderProvider on " + shard.shardId()); + } + try { + GatedCloseable gatedReader = readerProvider.acquireReader(); + long[] rowIds = request.getRowIds(); + org.apache.arrow.vector.BigIntVector rowIdVector = new org.apache.arrow.vector.BigIntVector("__row_id__", allocator); + rowIdVector.allocateNew(rowIds.length); + for (int i = 0; i < rowIds.length; i++) { + rowIdVector.set(i, rowIds[i]); + } + rowIdVector.setValueCount(rowIds.length); + + AnalyticsSearchBackendPlugin backend = backends.values().iterator().next(); + return backend.fetchByRowIds(gatedReader.get(), rowIdVector, request.getColumns(), allocator); + } catch (Exception e) { + throw new RuntimeException("Failed to start fetch-by-row-ids on " + shard.shardId(), e); + } } public FragmentResources executeFragmentStreaming(FragmentExecutionRequest request, IndexShard shard, AnalyticsShardTask task) { @@ -109,7 +190,6 @@ private FragmentResources startFragment(FragmentExecutionRequest request, Resolv SearchExecEngine engine = null; EngineResultStream stream = null; BackendExecutionContext backendContext = null; - Runnable trackerCleanup = null; try { ShardScanExecutionContext ctx = buildContext(request, gatedReader.get(), resolved.plan, shard, task); AnalyticsSearchBackendPlugin backend = backends.get(resolved.plan.getBackendId()); @@ -133,33 +213,14 @@ private FragmentResources startFragment(FragmentExecutionRequest request, Resolv AnalyticsSearchBackendPlugin acceptingBackend = backends.get(acceptingBackendId); FilterDelegationHandle handle = acceptingBackend.getFilterDelegationHandle(delegation.delegatedExpressions(), ctx); backend.configureFilterDelegation(handle, backendContext); - - if (task != null && taskResourceTrackingService != null) { - long taskId = task.getId(); - TaskResourceTrackingService service = taskResourceTrackingService; - backend.setDelegationThreadTracker(new DelegationThreadTracker() { - @Override - public long trackStart() { - long threadId = Thread.currentThread().threadId(); - service.taskExecutionStartedOnThread(taskId, threadId); - return threadId; - } - - @Override - public void trackEnd(long threadId) { - service.taskExecutionFinishedOnThread(taskId, threadId); - } - }); - trackerCleanup = () -> backend.setDelegationThreadTracker(null); - } } engine = backend.getSearchExecEngineProvider().createSearchExecEngine(ctx, backendContext); stream = engine.execute(ctx); - return new FragmentResources(gatedReader, engine, stream, trackerCleanup); + return new FragmentResources(gatedReader, engine, stream); } catch (Exception e) { try { - new FragmentResources(gatedReader, engine, stream, trackerCleanup).close(); + new FragmentResources(gatedReader, engine, stream).close(); } catch (Exception suppressed) { e.addSuppressed(suppressed); } @@ -226,4 +287,47 @@ private ShardScanExecutionContext buildContext( return ctx; } + FragmentExecutionResponse collectResponse(EngineResultStream stream) { + return collectResponse(stream, null); + } + + FragmentExecutionResponse collectResponse(EngineResultStream stream, @Nullable AnalyticsShardTask task) { + // Serialize incoming Arrow batches as an Arrow IPC stream: one schema header + // followed by one record-batch message per incoming batch. Arrow's own + // serializer handles every Arrow type — no per-type Java code path. + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + WriteChannel channel = new WriteChannel(Channels.newChannel(baos)); + Schema schema = null; + int totalRows = 0; + Iterator it = stream.iterator(); + try { + while (it.hasNext()) { + if (task != null && task.isCancelled()) { + throw new TaskCancelledException("task cancelled: " + task.getReasonCancelled()); + } + EngineResultBatch batch = it.next(); + VectorSchemaRoot root = batch.getArrowRoot(); + try { + if (schema == null) { + schema = root.getSchema(); + MessageSerializer.serialize(channel, schema); + } + try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch()) { + MessageSerializer.serialize(channel, recordBatch); + } + totalRows += root.getRowCount(); + } finally { + root.close(); + } + } + if (schema != null) { + // Write the end-of-stream marker so the reader sees a clean EOS + // instead of hitting end-of-input mid-message. + ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT); + } + } catch (IOException e) { + throw new IllegalStateException("Failed to serialize fragment output as Arrow IPC stream", e); + } + return new FragmentExecutionResponse(baos.toByteArray(), totalRows); + } } diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchTransportService.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchTransportService.java index ba22057f7915c..e6204789593ed 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchTransportService.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchTransportService.java @@ -9,26 +9,33 @@ package org.opensearch.analytics.exec; import org.opensearch.analytics.backend.EngineResultBatch; +import org.opensearch.analytics.exec.action.FetchByRowIdsAction; +import org.opensearch.analytics.exec.action.FetchByRowIdsRequest; +import org.opensearch.analytics.exec.action.FetchByRowIdsResponse; import org.opensearch.analytics.exec.action.FragmentExecutionAction; import org.opensearch.analytics.exec.action.FragmentExecutionArrowResponse; import org.opensearch.analytics.exec.action.FragmentExecutionRequest; +import org.opensearch.analytics.exec.action.FragmentExecutionResponse; import org.opensearch.analytics.exec.task.AnalyticsShardTask; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Nullable; import org.opensearch.common.inject.Inject; import org.opensearch.common.inject.Singleton; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.index.shard.IndexShard; import org.opensearch.indices.IndicesService; import org.opensearch.ratelimitting.admissioncontrol.enums.AdmissionControlActionType; import org.opensearch.tasks.Task; -import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.StreamTransportService; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportResponseHandler; +import org.opensearch.transport.TransportService; import org.opensearch.transport.stream.StreamErrorCode; import org.opensearch.transport.stream.StreamException; import org.opensearch.transport.stream.StreamTransportResponse; @@ -37,42 +44,67 @@ import java.util.Iterator; /** - * Stateless transport dispatch component for fragment requests. Owns the - * {@link StreamTransportService} (analytics-engine is streaming-only) and + * Stateless transport dispatch component for fragment requests. Owns + * {@link TransportService} (or {@link StreamTransportService}) and * connection lookup. * - *

Does NOT track per-query or per-node concurrency state — callers provide - * their own {@link PendingExecutions} instance to gate dispatch concurrency. + *

Does NOT track per-query or per-node concurrency + * state — callers provide their own {@link PendingExecutions} instance + * to gate dispatch concurrency. * * @opensearch.internal */ @Singleton public class AnalyticsSearchTransportService { - private final StreamTransportService transportService; + private final TransportService transportService; private final ClusterService clusterService; + private final boolean streamingEnabled; @Inject public AnalyticsSearchTransportService( - StreamTransportService streamTransportService, + TransportService transportService, + @Nullable StreamTransportService streamTransportService, ClusterService clusterService, AnalyticsSearchService searchService, - IndicesService indicesService, - TaskResourceTrackingService taskResourceTrackingService + IndicesService indicesService ) { - if (streamTransportService == null) { - throw new IllegalStateException( - "analytics-engine requires the STREAM_TRANSPORT feature flag to be enabled " - + "(opensearch.experimental.feature.stream_transport.enabled=true)" - ); - } - searchService.setTaskResourceTrackingService(taskResourceTrackingService); - this.transportService = streamTransportService; + this.streamingEnabled = streamTransportService != null; + this.transportService = this.streamingEnabled ? streamTransportService : transportService; this.clusterService = clusterService; - registerStreamingFragmentHandler(this.transportService, searchService, indicesService); + if (this.streamingEnabled) { + registerStreamingFragmentHandler(this.transportService, searchService, indicesService); + } else { + registerFragmentHandler(this.transportService, searchService, indicesService); + } + registerFetchHandler(this.transportService, searchService, indicesService); + } + + public boolean isStreamingEnabled() { + return streamingEnabled; + } + + private static void registerFragmentHandler( + TransportService transportService, + AnalyticsSearchService searchService, + IndicesService indicesService + ) { + transportService.registerRequestHandler( + FragmentExecutionAction.NAME, + ThreadPool.Names.SAME, + false, + true, + AdmissionControlActionType.SEARCH, + FragmentExecutionRequest::new, + (request, channel, task) -> { + IndexShard shard = indicesService.indexServiceSafe(request.getShardId().getIndex()).getShard(request.getShardId().id()); + FragmentExecutionResponse response = searchService.executeFragment(request, shard, (AnalyticsShardTask) task); + channel.sendResponse(response); + } + ); } private static void registerStreamingFragmentHandler( - StreamTransportService transportService, + TransportService transportService, AnalyticsSearchService searchService, IndicesService indicesService ) { @@ -104,6 +136,214 @@ private static void registerStreamingFragmentHandler( ); } + private static void registerFetchHandler( + TransportService transportService, + AnalyticsSearchService searchService, + IndicesService indicesService + ) { + if (transportService instanceof StreamTransportService) { + // Streaming path: send Arrow batches directly + transportService.registerRequestHandler( + FetchByRowIdsAction.NAME, + ThreadPool.Names.SAME, + false, + true, + AdmissionControlActionType.SEARCH, + FetchByRowIdsRequest::new, + (request, channel, task) -> { + IndexShard shard = indicesService.indexServiceSafe(request.getShardId().getIndex()).getShard(request.getShardId().id()); + try { + org.opensearch.analytics.backend.EngineResultStream stream = searchService.executeFetchStreaming( + request, + shard, + (AnalyticsShardTask) task + ); + Iterator it = stream.iterator(); + while (it.hasNext()) { + org.opensearch.analytics.backend.EngineResultBatch batch = it.next(); + channel.sendResponseBatch( + new org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse(batch.getArrowRoot()) + ); + } + channel.completeStream(); + } catch (StreamException e) { + if (e.getErrorCode() != StreamErrorCode.CANCELLED) { + channel.sendResponse(e); + } + } catch (Exception e) { + channel.sendResponse(e); + } + } + ); + } else { + // Non-streaming path: serialize to IPC bytes + transportService.registerRequestHandler( + FetchByRowIdsAction.NAME, + ThreadPool.Names.SAME, + false, + true, + AdmissionControlActionType.SEARCH, + FetchByRowIdsRequest::new, + (request, channel, task) -> { + IndexShard shard = indicesService.indexServiceSafe(request.getShardId().getIndex()).getShard(request.getShardId().id()); + FetchByRowIdsResponse response = searchService.executeFetchByRowIds(request, shard, (AnalyticsShardTask) task); + channel.sendResponse(response); + } + ); + } + } + + public void dispatchFetch( + FetchByRowIdsRequest request, + DiscoveryNode targetNode, + StreamingResponseListener listener, + Task parentTask + ) { + if (streamingEnabled) { + dispatchFetchStreaming(request, targetNode, listener, parentTask); + } else { + dispatchFetchNonStreaming(request, targetNode, listener, parentTask); + } + } + + private void dispatchFetchStreaming( + FetchByRowIdsRequest request, + DiscoveryNode targetNode, + StreamingResponseListener listener, + Task parentTask + ) { + try { + Transport.Connection connection = getConnection(null, targetNode.getId()); + transportService.sendChildRequest( + connection, + FetchByRowIdsAction.NAME, + request, + parentTask, + TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(), + new TransportResponseHandler() { + @Override + public org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse read(StreamInput in) throws IOException { + return new org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse(in); + } + + @Override + public boolean skipsDeserialization() { + return true; + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public void handleStreamResponse( + StreamTransportResponse stream + ) { + try { + org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse current; + org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse last = null; + while ((current = stream.nextResponse()) != null) { + if (last != null) { + listener.onStreamResponse(wrapArrowAsResponse(last), false); + } + last = current; + } + if (last != null) { + listener.onStreamResponse(wrapArrowAsResponse(last), true); + } + } catch (Exception e) { + listener.onFailure(e); + } finally { + try { + stream.close(); + } catch (Exception ignore) {} + } + } + + @Override + public void handleResponse(org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse response) { + listener.onStreamResponse(wrapArrowAsResponse(response), true); + } + + @Override + public void handleException(TransportException e) { + listener.onFailure(e); + } + } + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private void dispatchFetchNonStreaming( + FetchByRowIdsRequest request, + DiscoveryNode targetNode, + StreamingResponseListener listener, + Task parentTask + ) { + try { + Transport.Connection connection = getConnection(null, targetNode.getId()); + transportService.sendChildRequest( + connection, + FetchByRowIdsAction.NAME, + request, + parentTask, + TransportRequestOptions.EMPTY, + new TransportResponseHandler() { + @Override + public FetchByRowIdsResponse read(StreamInput in) throws IOException { + return new FetchByRowIdsResponse(in); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public void handleResponse(FetchByRowIdsResponse response) { + listener.onStreamResponse(response, true); + } + + @Override + public void handleException(TransportException e) { + listener.onFailure(e); + } + } + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private static FetchByRowIdsResponse wrapArrowAsResponse(org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse arrowResp) { + // For the streaming path, wrap the Arrow batch as IPC bytes for uniform handling + // in LateMaterializationStageExecution's assembler. TODO: pass VectorSchemaRoot directly. + org.apache.arrow.vector.VectorSchemaRoot root = arrowResp.getRoot(); + if (root == null) return new FetchByRowIdsResponse(new byte[0], 0); + try { + java.io.ByteArrayOutputStream baos = new java.io.ByteArrayOutputStream(); + org.apache.arrow.vector.ipc.WriteChannel channel = new org.apache.arrow.vector.ipc.WriteChannel( + java.nio.channels.Channels.newChannel(baos) + ); + org.apache.arrow.vector.ipc.message.MessageSerializer.serialize(channel, root.getSchema()); + try ( + org.apache.arrow.vector.ipc.message.ArrowRecordBatch batch = new org.apache.arrow.vector.VectorUnloader(root) + .getRecordBatch() + ) { + org.apache.arrow.vector.ipc.message.MessageSerializer.serialize(channel, batch); + } + org.apache.arrow.vector.ipc.ArrowStreamWriter.writeEndOfStream(channel, org.apache.arrow.vector.ipc.message.IpcOption.DEFAULT); + return new FetchByRowIdsResponse(baos.toByteArray(), root.getRowCount()); + } catch (Exception e) { + throw new RuntimeException("Failed to serialize Arrow batch to IPC", e); + } finally { + root.close(); + } + } + Transport.Connection getConnection(String clusterAlias, String nodeId) { DiscoveryNode node = clusterService.state().nodes().get(nodeId); return transportService.getConnection(node); @@ -116,15 +356,56 @@ public void dispatchFragmentStreaming( Task parentTask, PendingExecutions pending ) { - TransportResponseHandler handler = new TransportResponseHandler<>() { + dispatchFragment( + request, + targetNode, + listener, + parentTask, + pending, + in -> new FragmentExecutionArrowResponse(in), + TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(), + true + ); + } + + public void dispatchFragment( + FragmentExecutionRequest request, + DiscoveryNode targetNode, + StreamingResponseListener listener, + Task parentTask, + PendingExecutions pending + ) { + dispatchFragment( + request, + targetNode, + listener, + parentTask, + pending, + in -> new FragmentExecutionResponse(in), + TransportRequestOptions.EMPTY, + false + ); + } + + private void dispatchFragment( + FragmentExecutionRequest request, + DiscoveryNode targetNode, + StreamingResponseListener listener, + Task parentTask, + PendingExecutions pending, + Writeable.Reader reader, + TransportRequestOptions options, + boolean skipsDeserialization + ) { + TransportResponseHandler handler = new TransportResponseHandler<>() { @Override - public FragmentExecutionArrowResponse read(StreamInput in) throws IOException { - return new FragmentExecutionArrowResponse(in); + public T read(StreamInput in) throws IOException { + return reader.read(in); } @Override public boolean skipsDeserialization() { - return true; + return skipsDeserialization; } @Override @@ -133,10 +414,10 @@ public String executor() { } @Override - public void handleStreamResponse(StreamTransportResponse stream) { + public void handleStreamResponse(StreamTransportResponse stream) { try { - FragmentExecutionArrowResponse current; - FragmentExecutionArrowResponse last = null; + T current; + T last = null; while ((current = stream.nextResponse()) != null) { if (last != null) { listener.onStreamResponse(last, false); @@ -157,12 +438,9 @@ public void handleStreamResponse(StreamTransportResponse { try { Transport.Connection connection = getConnection(null, targetNode.getId()); diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/QueryContext.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/QueryContext.java index cd279a6ba4301..47355a5891254 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/QueryContext.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/QueryContext.java @@ -44,6 +44,7 @@ public class QueryContext { private final List operationListeners; private volatile BufferAllocator bufferAllocator; private boolean closed; // guarded by `this` + private final List resolvedShardTargets = new java.util.ArrayList<>(); public QueryContext(QueryDAG dag, Executor searchExecutor, AnalyticsQueryTask parentTask) { this(dag, searchExecutor, parentTask, DEFAULT_MAX_CONCURRENT_SHARD_REQUESTS, DEFAULT_PER_QUERY_MEMORY_LIMIT, List.of()); @@ -144,6 +145,10 @@ public void closeBufferAllocator() { } } + public List getResolvedShardTargets() { + return resolvedShardTargets; + } + // ─── Test factories ──────────────────────────────────────────────── /** Creates a test context with a synchronous executor. */ diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/QueryScheduler.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/QueryScheduler.java index a32b98c452b1b..e58b3375be6a5 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/QueryScheduler.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/QueryScheduler.java @@ -103,6 +103,7 @@ private PlanWalker createWalker( opListener.onQueryFailure(queryId, e); listener.onFailure(e); }); + return new PlanWalker(config, stageExecutionBuilder, wrapped); } diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsAction.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsAction.java new file mode 100644 index 0000000000000..0ca39a5e383c8 --- /dev/null +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsAction.java @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analytics.exec.action; + +import org.opensearch.action.ActionType; + +/** + * Transport action for QTF fetch phase: fetches specific rows by global row ID. + */ +public class FetchByRowIdsAction extends ActionType { + + public static final String NAME = "indices:data/read/analytics/fetch_by_row_ids"; + public static final FetchByRowIdsAction INSTANCE = new FetchByRowIdsAction(); + + private FetchByRowIdsAction() { + super(NAME, FetchByRowIdsResponse::new); + } +} diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsArrowResponse.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsArrowResponse.java new file mode 100644 index 0000000000000..eeef07f19d962 --- /dev/null +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsArrowResponse.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analytics.exec.action; + +import org.apache.arrow.vector.VectorSchemaRoot; +import org.opensearch.arrow.flight.transport.ArrowBatchResponse; +import org.opensearch.core.common.io.stream.StreamInput; + +import java.io.IOException; + +/** + * Streaming Arrow response for the QTF fetch phase. + * Carries a single Arrow batch from the data node back to the coordinator + * via the streaming transport — zero-copy, no IPC serialization. + * + * @opensearch.internal + */ +public class FetchByRowIdsArrowResponse extends ArrowBatchResponse { + + public FetchByRowIdsArrowResponse(VectorSchemaRoot root) { + super(root); + } + + public FetchByRowIdsArrowResponse(StreamInput in) throws IOException { + super(in); + } +} diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsRequest.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsRequest.java new file mode 100644 index 0000000000000..feb84b8e73038 --- /dev/null +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsRequest.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analytics.exec.action; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.analytics.exec.task.AnalyticsShardTask; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.tasks.Task; + +import java.io.IOException; +import java.util.Map; + +/** + * Transport request for QTF fetch phase. + * Carries global row IDs and column names to the data node for targeted row retrieval. + */ +public class FetchByRowIdsRequest extends ActionRequest { + + private final String queryId; + private final ShardId shardId; + private final long[] rowIds; + private final String[] columns; + + public FetchByRowIdsRequest(String queryId, ShardId shardId, long[] rowIds, String[] columns) { + this.queryId = queryId; + this.shardId = shardId; + this.rowIds = rowIds; + this.columns = columns; + } + + public FetchByRowIdsRequest(StreamInput in) throws IOException { + super(in); + this.queryId = in.readString(); + this.shardId = new ShardId(in); + this.rowIds = in.readLongArray(); + this.columns = in.readStringArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(queryId); + shardId.writeTo(out); + out.writeLongArray(rowIds); + out.writeStringArray(columns); + } + + public String getQueryId() { + return queryId; + } + + public ShardId getShardId() { + return shardId; + } + + public long[] getRowIds() { + return rowIds; + } + + public String[] getColumns() { + return columns; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new AnalyticsShardTask(id, type, action, getDescription(), parentTaskId, headers); + } + + @Override + public String getDescription() { + return "fetch_by_row_ids{query=" + queryId + ", shard=" + shardId + ", rows=" + rowIds.length + "}"; + } +} diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsResponse.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsResponse.java new file mode 100644 index 0000000000000..6fc1f460315e6 --- /dev/null +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FetchByRowIdsResponse.java @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analytics.exec.action; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * Transport response for QTF fetch phase. + * Carries Arrow IPC stream bytes containing the fetched rows with __row_id__. + */ +public class FetchByRowIdsResponse extends ActionResponse { + + private final byte[] ipcPayload; + private final int rowCount; + + public FetchByRowIdsResponse(byte[] ipcPayload, int rowCount) { + this.ipcPayload = ipcPayload; + this.rowCount = rowCount; + } + + public FetchByRowIdsResponse(StreamInput in) throws IOException { + super(in); + this.ipcPayload = in.readByteArray(); + this.rowCount = in.readVInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeByteArray(ipcPayload); + out.writeVInt(rowCount); + } + + public byte[] getIpcPayload() { + return ipcPayload; + } + + public int getRowCount() { + return rowCount; + } +} diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FragmentExecutionResponse.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FragmentExecutionResponse.java new file mode 100644 index 0000000000000..d54c9c6b04469 --- /dev/null +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/action/FragmentExecutionResponse.java @@ -0,0 +1,73 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analytics.exec.action; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import java.io.IOException; + +/** + * Transport response carrying the output of a shard fragment execution as an + * Arrow IPC stream payload (schema header + zero or more record-batch messages, + * produced by {@link org.apache.arrow.vector.ipc.ArrowStreamWriter}). + * + *

Arrow IPC handles every Arrow type natively (temporal, string-view, + * dictionary, nested) without hand-rolled per-type serialization. Previously, + * this response carried {@code List} rows and relied on + * {@code StreamOutput.writeGenericValue} — which does not support Java 8+ + * temporal types like {@link java.time.LocalDateTime} and so failed the moment + * a shard emitted a batch with a Timestamp column. + * + *

Wire format: {@code ipcPayload (byte[]) + rowCount (vint)}. The row count + * is the total across all batches in the payload, cached for metrics / logging + * so consumers don't have to decode the payload just to report "N rows handled". + * + * @opensearch.internal + */ +public class FragmentExecutionResponse extends ActionResponse { + + private final byte[] ipcPayload; + private final int rowCount; + + public FragmentExecutionResponse(byte[] ipcPayload, int rowCount) { + this.ipcPayload = ipcPayload; + this.rowCount = rowCount; + } + + public FragmentExecutionResponse(StreamInput in) throws IOException { + super(in); + this.ipcPayload = in.readByteArray(); + this.rowCount = in.readVInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeByteArray(ipcPayload); + out.writeVInt(rowCount); + } + + /** + * Arrow IPC stream bytes — a schema message followed by zero or more record + * batch messages, as written by {@link org.apache.arrow.vector.ipc.ArrowStreamWriter}. + * An empty array means the fragment produced no output at all (no schema, + * no rows). + */ + public byte[] getIpcPayload() { + return ipcPayload; + } + + /** + * Total number of rows across all batches in {@link #getIpcPayload()}. + */ + public int getRowCount() { + return rowCount; + } +} diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/LateMaterializationStageExecution.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/LateMaterializationStageExecution.java new file mode 100644 index 0000000000000..148d071e4353b --- /dev/null +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/LateMaterializationStageExecution.java @@ -0,0 +1,337 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analytics.exec.stage; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.analytics.backend.ExchangeSource; +import org.opensearch.analytics.exec.AnalyticsSearchTransportService; +import org.opensearch.analytics.exec.QueryContext; +import org.opensearch.analytics.exec.StreamingResponseListener; +import org.opensearch.analytics.exec.action.FetchByRowIdsRequest; +import org.opensearch.analytics.exec.action.FetchByRowIdsResponse; +import org.opensearch.analytics.planner.dag.ShardExecutionTarget; +import org.opensearch.analytics.planner.dag.Stage; +import org.opensearch.analytics.spi.ExchangeSink; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +/** + * Late materialization (QTF) root stage execution. + * + *

Lifecycle: started by PlanWalker after the child COORDINATOR_REDUCE stage completes. + * Reads the reduced output (sorted + limited rows with __row_id__ + shard_id), + * builds a position map, dispatches fetch requests per shard, assembles the final + * globally-sorted result, and feeds it into the output sink. + * + * @opensearch.internal + */ +public final class LateMaterializationStageExecution extends AbstractStageExecution + implements + DataProducer, + org.opensearch.analytics.spi.DataConsumer { + + private static final Logger logger = LogManager.getLogger(LateMaterializationStageExecution.class); + + private final QueryContext config; + private final ExchangeSink outputSink; + private final AnalyticsSearchTransportService dispatcher; + private final org.opensearch.analytics.exec.RowProducingSink inputSink; + + LateMaterializationStageExecution( + Stage stage, + QueryContext config, + ExchangeSink outputSink, + AnalyticsSearchTransportService dispatcher + ) { + super(stage); + this.config = config; + this.outputSink = outputSink; + this.dispatcher = dispatcher; + this.inputSink = new org.opensearch.analytics.exec.RowProducingSink(); + } + + @Override + public ExchangeSink inputSink(int childStageId) { + return inputSink; + } + + @Override + public void start() { + if (transitionTo(StageExecution.State.RUNNING) == false) return; + + Iterable reducedResult = inputSink.readResult(); + BufferAllocator allocator = config.bufferAllocator(); + List shardTargets = config.getResolvedShardTargets(); + + VectorSchemaRoot firstBatch = reducedResult.iterator().hasNext() ? reducedResult.iterator().next() : null; + if (firstBatch == null || firstBatch.getVector("__row_id__") == null || firstBatch.getVector("shard_id") == null) { + // Not a QTF result — pass through unchanged + for (VectorSchemaRoot batch : reducedResult) { + outputSink.feed(batch); + } + transitionTo(StageExecution.State.SUCCEEDED); + return; + } + + // Derive fetch columns from schema (exclude __row_id__ and shard_id) + String[] fetchColumns = firstBatch.getSchema() + .getFields() + .stream() + .map(Field::getName) + .filter(name -> !"__row_id__".equals(name) && !"shard_id".equals(name)) + .toArray(String[]::new); + + // Build position map + PositionMap positionMap = buildPositionMap(reducedResult); + logger.info("[LateMat] Position map built: totalRows={}, shards={}", positionMap.totalRows(), positionMap.shardCount()); + + // Close reduced batches — we've extracted what we need + for (VectorSchemaRoot batch : reducedResult) { + batch.close(); + } + + if (positionMap.totalRows() == 0) { + transitionTo(StageExecution.State.SUCCEEDED); + return; + } + + // Dispatch fetches + dispatchFetches(positionMap, fetchColumns, shardTargets, allocator); + } + + @Override + public void cancel(String reason) { + transitionTo(StageExecution.State.CANCELLED); + } + + @Override + public ExchangeSource outputSource() { + if (outputSink instanceof ExchangeSource source) { + return source; + } + throw new UnsupportedOperationException("outputSink does not implement ExchangeSource"); + } + + // ── Position Map ───────────────────────────────────────────────────────────── + + private PositionMap buildPositionMap(Iterable reducedResult) { + PositionMap map = new PositionMap(); + int pos = 0; + for (VectorSchemaRoot batch : reducedResult) { + FieldVector rowIdRaw = batch.getVector("__row_id__"); + IntVector shardIdCol = (IntVector) batch.getVector("shard_id"); + for (int i = 0; i < batch.getRowCount(); i++) { + int shard = shardIdCol.get(i); + long rowId; + if (rowIdRaw instanceof BigIntVector bigInt) { + rowId = bigInt.get(i); + } else { + rowId = ((Number) rowIdRaw.getObject(i)).longValue(); + } + map.put(shard, rowId, pos); + pos++; + } + } + return map; + } + + // ── Fetch Dispatch ─────────────────────────────────────────────────────────── + + private void dispatchFetches( + PositionMap positionMap, + String[] fetchColumns, + List shardTargets, + BufferAllocator allocator + ) { + Map fetchPlan = positionMap.getPerShardFetchPlan(); + AtomicInteger remaining = new AtomicInteger(fetchPlan.size()); + List fetchResults = java.util.Collections.synchronizedList(new ArrayList<>()); + + for (Map.Entry entry : fetchPlan.entrySet()) { + int shardOrdinal = entry.getKey(); + long[] rowIds = entry.getValue(); + + if (shardOrdinal >= shardTargets.size()) { + captureFailure( + new IllegalStateException("[LateMat] Shard ordinal " + shardOrdinal + " exceeds target count " + shardTargets.size()) + ); + transitionTo(StageExecution.State.FAILED); + return; + } + + ShardExecutionTarget target = shardTargets.get(shardOrdinal); + FetchByRowIdsRequest fetchReq = new FetchByRowIdsRequest(config.queryId(), target.shardId(), rowIds, fetchColumns); + + dispatcher.dispatchFetch( + fetchReq, + target.node(), + new FetchResponseListener(shardOrdinal, positionMap, fetchResults, remaining, allocator), + config.parentTask() + ); + } + } + + // ── Assembly ───────────────────────────────────────────────────────────────── + + private class FetchResponseListener implements StreamingResponseListener { + private final int shardOrdinal; + private final PositionMap positionMap; + private final List fetchResults; + private final AtomicInteger remaining; + private final BufferAllocator allocator; + + FetchResponseListener( + int shardOrdinal, + PositionMap positionMap, + List fetchResults, + AtomicInteger remaining, + BufferAllocator allocator + ) { + this.shardOrdinal = shardOrdinal; + this.positionMap = positionMap; + this.fetchResults = fetchResults; + this.remaining = remaining; + this.allocator = allocator; + } + + @Override + public void onStreamResponse(FetchByRowIdsResponse response, boolean isLast) { + fetchResults.add(new FetchResult(shardOrdinal, response)); + if (isLast && remaining.decrementAndGet() == 0) { + assembleAndDeliver(); + } + } + + @Override + public void onFailure(Exception e) { + captureFailure(e); + transitionTo(StageExecution.State.FAILED); + } + + private void assembleAndDeliver() { + try { + VectorSchemaRoot assembled = assembleResult(fetchResults, positionMap, allocator); + outputSink.feed(assembled); + transitionTo(StageExecution.State.SUCCEEDED); + } catch (Exception e) { + captureFailure(e); + transitionTo(StageExecution.State.FAILED); + } + } + } + + private VectorSchemaRoot assembleResult(List fetchResults, PositionMap positionMap, BufferAllocator allocator) { + int totalRows = positionMap.totalRows(); + VectorSchemaRoot output = null; + List outputFields = null; + + for (FetchResult fr : fetchResults) { + byte[] ipc = fr.response().getIpcPayload(); + if (ipc == null || ipc.length == 0) continue; + int shardOrdinal = fr.shardOrdinal(); + + try (var reader = new org.apache.arrow.vector.ipc.ArrowStreamReader(new java.io.ByteArrayInputStream(ipc), allocator)) { + while (reader.loadNextBatch()) { + VectorSchemaRoot batch = reader.getVectorSchemaRoot(); + int batchRows = batch.getRowCount(); + + if (output == null) { + outputFields = batch.getSchema().getFields().stream().filter(f -> !"__row_id__".equals(f.getName())).toList(); + output = VectorSchemaRoot.create(new Schema(outputFields), allocator); + output.allocateNew(); + output.setRowCount(totalRows); + for (FieldVector v : output.getFieldVectors()) { + v.setValueCount(totalRows); + } + } + + FieldVector rowIdRaw = batch.getVector("__row_id__"); + for (int i = 0; i < batchRows; i++) { + long rowId; + if (rowIdRaw instanceof BigIntVector bigInt) { + rowId = bigInt.get(i); + } else { + rowId = ((Number) rowIdRaw.getObject(i)).longValue(); + } + int destPos = positionMap.getPosition(shardOrdinal, rowId); + for (Field f : outputFields) { + FieldVector src = batch.getVector(f.getName()); + FieldVector dst = output.getVector(f.getName()); + dst.copyFrom(i, destPos, src); + } + } + } + } catch (Exception e) { + if (output != null) output.close(); + throw new RuntimeException("[LateMat] Failed to decode fetch response", e); + } + } + + if (output == null) { + return VectorSchemaRoot.create(new Schema(List.of()), allocator); + } + return output; + } + + // ── Supporting types ───────────────────────────────────────────────────────── + + private record FetchResult(int shardOrdinal, FetchByRowIdsResponse response) { + } + + static class PositionMap { + private final Map positionLookup = new HashMap<>(); + private final Map> perShardRowIds = new HashMap<>(); + private int totalRows = 0; + + void put(int shard, long rowId, int position) { + positionLookup.put(encode(shard, rowId), position); + perShardRowIds.computeIfAbsent(shard, k -> new ArrayList<>()).add(rowId); + totalRows++; + } + + int getPosition(int shard, long rowId) { + Integer pos = positionLookup.get(encode(shard, rowId)); + if (pos == null) { + throw new IllegalStateException("[LateMat] No position for shard=" + shard + " rowId=" + rowId); + } + return pos; + } + + Map getPerShardFetchPlan() { + return perShardRowIds.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().stream().mapToLong(Long::longValue).toArray())); + } + + int totalRows() { + return totalRows; + } + + int shardCount() { + return perShardRowIds.size(); + } + + private static long encode(int shard, long rowId) { + return ((long) shard << 40) | rowId; + } + } +} diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/LateMaterializationStageScheduler.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/LateMaterializationStageScheduler.java new file mode 100644 index 0000000000000..b05dbddd30b91 --- /dev/null +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/LateMaterializationStageScheduler.java @@ -0,0 +1,31 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analytics.exec.stage; + +import org.opensearch.analytics.exec.AnalyticsSearchTransportService; +import org.opensearch.analytics.exec.QueryContext; +import org.opensearch.analytics.planner.dag.Stage; +import org.opensearch.analytics.spi.ExchangeSink; + +/** + * Creates a {@link LateMaterializationStageExecution} for the QTF fetch + assembly phase. + */ +final class LateMaterializationStageScheduler implements StageScheduler { + + private final AnalyticsSearchTransportService transport; + + LateMaterializationStageScheduler(AnalyticsSearchTransportService transport) { + this.transport = transport; + } + + @Override + public StageExecution createExecution(Stage stage, ExchangeSink sink, QueryContext config) { + return new LateMaterializationStageExecution(stage, config, sink, transport); + } +} diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ResponseCodec.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ResponseCodec.java new file mode 100644 index 0000000000000..528b3a93e2b1f --- /dev/null +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ResponseCodec.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analytics.exec.stage; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.opensearch.core.action.ActionResponse; + +/** + * Decodes a transport response into an Arrow {@link VectorSchemaRoot} for + * the coordinator-side sink. Implementations handle the specific wire + * format — {@code Object[]} rows (current), Arrow IPC (Flight), or any + * future format. + * + *

The codec is injected into {@link ShardFragmentStageExecution} at + * construction time by the scheduler. Swapping the codec swaps the + * serialization format without touching stage execution logic. + * + * @param the transport response type + * @opensearch.internal + */ +@FunctionalInterface +public interface ResponseCodec { + + /** + * Decodes a transport response into an Arrow {@link VectorSchemaRoot}. + * The returned VSR is owned by the caller (the sink). + * + * @param response the transport response + * @param allocator the buffer allocator for Arrow vectors + * @return a new VectorSchemaRoot; caller owns and must close it + */ + VectorSchemaRoot decode(R response, BufferAllocator allocator); +} diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/RowResponseCodec.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/RowResponseCodec.java new file mode 100644 index 0000000000000..38aef31c1337b --- /dev/null +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/RowResponseCodec.java @@ -0,0 +1,112 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analytics.exec.stage; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.types.pojo.Schema; +import org.opensearch.analytics.exec.action.FragmentExecutionResponse; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * {@link ResponseCodec} that deserializes the Arrow IPC stream payload carried + * by {@link FragmentExecutionResponse} into a single consolidated + * {@link VectorSchemaRoot}. Uses Arrow's {@link ArrowStreamReader} for message + * sequencing (schema header, record batches, end-of-stream) and + * {@link FieldVector#makeTransferPair} / {@link FieldVector#copyFromSafe} to + * move data into a caller-owned root — both are supported by every vector + * kind, including the view vectors ({@code Utf8View}, {@code BinaryView}) that + * DataFusion emits for aggregate group keys. + * + *

Deliberately avoids {@code VectorSchemaRootAppender} — its underlying + * {@code VectorAppender} rejects view vectors with + * {@code UnsupportedOperationException}. + * + * @opensearch.internal + */ +public final class RowResponseCodec implements ResponseCodec { + + /** Singleton instance — stateless, thread-safe. */ + public static final RowResponseCodec INSTANCE = new RowResponseCodec(); + + private RowResponseCodec() {} + + @Override + public VectorSchemaRoot decode(FragmentExecutionResponse response, BufferAllocator allocator) { + if (allocator == null) { + throw new IllegalArgumentException("BufferAllocator must not be null"); + } + byte[] payload = response.getIpcPayload(); + if (payload == null || payload.length == 0) { + return VectorSchemaRoot.create(new Schema(List.of()), allocator); + } + + List batches = new ArrayList<>(); + Schema schema; + try (ArrowStreamReader reader = new ArrowStreamReader(new ByteArrayInputStream(payload), allocator)) { + VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); + schema = readerRoot.getSchema(); + while (reader.loadNextBatch()) { + // Transfer each batch's buffers out of the reader's reused root into an + // independent root owned by this codec. Transfer works for all vector + // kinds (including views) and is zero-copy when allocators match. + VectorSchemaRoot batchRoot = VectorSchemaRoot.create(schema, allocator); + int rowCount = readerRoot.getRowCount(); + for (int i = 0; i < readerRoot.getFieldVectors().size(); i++) { + readerRoot.getVector(i).makeTransferPair(batchRoot.getVector(i)).transfer(); + } + batchRoot.setRowCount(rowCount); + batches.add(batchRoot); + } + } catch (IOException e) { + for (VectorSchemaRoot b : batches) + b.close(); + throw new IllegalStateException("Failed to decode Arrow IPC payload from fragment response", e); + } + + if (batches.isEmpty()) { + return VectorSchemaRoot.create(schema, allocator); + } + if (batches.size() == 1) { + return batches.get(0); + } + // Multiple batches — concatenate via per-cell copyFromSafe. Slower than columnar + // append but is the only operation Arrow Java implements for every vector type + // (VectorAppender throws on view vectors; see class javadoc). + int totalRows = batches.stream().mapToInt(VectorSchemaRoot::getRowCount).sum(); + VectorSchemaRoot combined = VectorSchemaRoot.create(schema, allocator); + try { + combined.allocateNew(); + for (int f = 0; f < combined.getFieldVectors().size(); f++) { + FieldVector dst = combined.getVector(f); + int offset = 0; + for (VectorSchemaRoot batch : batches) { + FieldVector src = batch.getVector(f); + int rows = batch.getRowCount(); + for (int r = 0; r < rows; r++) { + dst.copyFromSafe(r, offset + r, src); + } + offset += rows; + } + dst.setValueCount(totalRows); + } + combined.setRowCount(totalRows); + return combined; + } finally { + for (VectorSchemaRoot b : batches) + b.close(); + } + } +} diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecution.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecution.java index 83ae3dd0b4577..e2bc303a59b3e 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecution.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecution.java @@ -16,11 +16,14 @@ import org.opensearch.analytics.exec.StreamingResponseListener; import org.opensearch.analytics.exec.action.FragmentExecutionArrowResponse; import org.opensearch.analytics.exec.action.FragmentExecutionRequest; +import org.opensearch.analytics.exec.action.FragmentExecutionResponse; import org.opensearch.analytics.planner.dag.ExecutionTarget; import org.opensearch.analytics.planner.dag.ShardExecutionTarget; import org.opensearch.analytics.planner.dag.Stage; import org.opensearch.analytics.spi.ExchangeSink; +import org.opensearch.arrow.flight.transport.ArrowBatchResponse; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.action.ActionResponse; import java.util.List; import java.util.Map; @@ -29,16 +32,17 @@ import java.util.function.Function; /** - * Leaf stage execution that dispatches fragment work to data-node shards via - * Arrow streaming, feeding resulting batches into the parent stage's - * {@link ExchangeSink}. + * Leaf stage execution that dispatches fragment work to data-node shards. + * + *

Handles both Arrow streaming and row (codec-decoded) responses, feeding + * resulting batches into the parent stage's {@link ExchangeSink}. * *

One-shot: constructed, {@link #start()} called once, listener * signaled on completion, then discarded. * * @opensearch.internal */ -final class ShardFragmentStageExecution extends AbstractStageExecution implements DataProducer { +public final class ShardFragmentStageExecution extends AbstractStageExecution implements DataProducer { private final AtomicInteger inFlight = new AtomicInteger(0); @@ -47,15 +51,20 @@ final class ShardFragmentStageExecution extends AbstractStageExecution implement private final ClusterService clusterService; private final Function requestBuilder; private final AnalyticsSearchTransportService dispatcher; + private final ResponseCodec responseCodec; private final Map pendingPerNode = new ConcurrentHashMap<>(); + // QTF: ordinal → target mapping so coordinator can dispatch fetches to the right shard/node. + private final List resolvedTargets = new java.util.ArrayList<>(); + ShardFragmentStageExecution( Stage stage, QueryContext config, ExchangeSink outputSink, ClusterService clusterService, Function requestBuilder, - AnalyticsSearchTransportService dispatcher + AnalyticsSearchTransportService dispatcher, + ResponseCodec responseCodec ) { super(stage); this.config = config; @@ -63,6 +72,11 @@ final class ShardFragmentStageExecution extends AbstractStageExecution implement this.clusterService = clusterService; this.requestBuilder = requestBuilder; this.dispatcher = dispatcher; + this.responseCodec = responseCodec; + } + + private boolean useArrowStreaming() { + return dispatcher.isStreamingEnabled(); } @Override @@ -75,43 +89,58 @@ public void start() { if (transitionTo(StageExecution.State.RUNNING) == false) return; inFlight.set(resolved.size()); for (ExecutionTarget target : resolved) { - dispatchShardTask((ShardExecutionTarget) target); + resolvedTargets.add((ShardExecutionTarget) target); + } + // Populate context targets BEFORE dispatch (local dispatch is synchronous) + if (stage.isInjectShardOrdinal()) { + config.getResolvedShardTargets().addAll(resolvedTargets); + } + for (int i = 0; i < resolvedTargets.size(); i++) { + dispatchShardTask(resolvedTargets.get(i), i); } } - private void dispatchShardTask(ShardExecutionTarget target) { + private void dispatchShardTask(ShardExecutionTarget target, int shardOrdinal) { FragmentExecutionRequest request = requestBuilder.apply(target); PendingExecutions pending = pendingFor(target); - dispatcher.dispatchFragmentStreaming(request, target.node(), responseListener(), config.parentTask(), pending); + if (useArrowStreaming()) { + dispatcher.dispatchFragmentStreaming( + request, + target.node(), + responseListener(FragmentExecutionArrowResponse::getRoot, shardOrdinal), + config.parentTask(), + pending + ); + } else { + dispatcher.dispatchFragment( + request, + target.node(), + responseListener(r -> responseCodec.decode(r, config.bufferAllocator()), shardOrdinal), + config.parentTask(), + pending + ); + } } - private StreamingResponseListener responseListener() { + private StreamingResponseListener responseListener( + Function toVsr, + int shardOrdinal + ) { return new StreamingResponseListener<>() { - // Runs inline on the per-stream virtual thread driving handleStreamResponse. - // Must NOT offload to a thread pool: reordering across batches would let the - // isLast=true task race ahead, flip state to SUCCEEDED, and drop queued - // earlier batches via the isDone() short-circuit. @Override - public void onStreamResponse(FragmentExecutionArrowResponse response, boolean isLast) { + public void onStreamResponse(T response, boolean isLast) { if (isDone()) { - VectorSchemaRoot root = response.getRoot(); - if (root != null) { - root.close(); - } + releaseResponseResources(response); return; } - VectorSchemaRoot vsr = response.getRoot(); - try { - outputSink.feed(vsr); - } catch (Exception e) { - // Without this guard the exception only surfaces on the stream's virtual - // thread; inFlight never decrements and the stage hangs to QUERY_TIMEOUT. - captureFailure(new RuntimeException("Stage " + stage.getStageId() + " sink feed failed", e)); - metrics.incrementTasksFailed(); - onShardTerminated(); - return; + VectorSchemaRoot vsr = toVsr.apply(response); + + if (stage.isInjectShardOrdinal()) { + vsr = injectShardId(vsr, shardOrdinal); } + + outputSink.feed(vsr); metrics.addRowsProcessed(vsr.getRowCount()); if (isLast) { @@ -129,6 +158,12 @@ public void onFailure(Exception e) { }; } + private static void releaseResponseResources(T response) { + if (response instanceof ArrowBatchResponse arrowResp && arrowResp.getRoot() != null) { + arrowResp.getRoot().close(); + } + } + private void onShardTerminated() { if (inFlight.decrementAndGet() == 0) { Exception captured = getFailure(); @@ -154,6 +189,11 @@ public ExchangeSource outputSource() { throw new UnsupportedOperationException("outputSink does not implement ExchangeSource"); } + /** QTF: returns the ordered list of shard targets for fetch dispatch. */ + public List getResolvedTargets() { + return java.util.Collections.unmodifiableList(resolvedTargets); + } + private boolean isDone() { StageExecution.State s = getState(); return s == StageExecution.State.SUCCEEDED || s == StageExecution.State.FAILED || s == StageExecution.State.CANCELLED; @@ -162,4 +202,24 @@ private boolean isDone() { private PendingExecutions pendingFor(ShardExecutionTarget target) { return pendingPerNode.computeIfAbsent(target.node().getId(), n -> new PendingExecutions(config.maxConcurrentShardRequests())); } + + /** + * QTF: Inject a shard_id column into the Arrow batch so the coordinator + * can track which shard each row came from after the reduce merge. + */ + private static VectorSchemaRoot injectShardId(VectorSchemaRoot batch, int shardId) { + org.apache.arrow.vector.IntVector shardIdVector = new org.apache.arrow.vector.IntVector( + "shard_id", + batch.getFieldVectors().get(0).getAllocator() + ); + shardIdVector.allocateNew(batch.getRowCount()); + for (int i = 0; i < batch.getRowCount(); i++) { + shardIdVector.set(i, shardId); + } + shardIdVector.setValueCount(batch.getRowCount()); + + java.util.List vectors = new java.util.ArrayList<>(batch.getFieldVectors()); + vectors.add(shardIdVector); + return new VectorSchemaRoot(vectors); + } } diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageScheduler.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageScheduler.java index dd120de7b4c6d..da616cfdce341 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageScheduler.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageScheduler.java @@ -11,6 +11,7 @@ import org.opensearch.analytics.exec.AnalyticsSearchTransportService; import org.opensearch.analytics.exec.QueryContext; import org.opensearch.analytics.exec.action.FragmentExecutionRequest; +import org.opensearch.analytics.exec.action.FragmentExecutionResponse; import org.opensearch.analytics.planner.dag.ShardExecutionTarget; import org.opensearch.analytics.planner.dag.Stage; import org.opensearch.analytics.planner.dag.StagePlan; @@ -30,16 +31,31 @@ * and doesn't care whether it is a root sink or a parent-provided child sink * — {@link StageExecutionBuilder} resolves that distinction before calling. * + *

Injects a {@link ResponseCodec} into the execution to decouple the wire + * format from stage logic. The default codec ({@link RowResponseCodec}) handles + * the current {@code Object[]} row format; a future Arrow IPC codec would be + * swapped in here. + * * @opensearch.internal */ final class ShardFragmentStageScheduler implements StageScheduler { private final ClusterService clusterService; private final AnalyticsSearchTransportService transport; + private final ResponseCodec responseCodec; ShardFragmentStageScheduler(ClusterService clusterService, AnalyticsSearchTransportService transport) { + this(clusterService, transport, RowResponseCodec.INSTANCE); + } + + ShardFragmentStageScheduler( + ClusterService clusterService, + AnalyticsSearchTransportService transport, + ResponseCodec responseCodec + ) { this.clusterService = clusterService; this.transport = transport; + this.responseCodec = responseCodec; } @Override @@ -57,7 +73,7 @@ public StageExecution createExecution(Stage stage, ExchangeSink sink, QueryConte // This keeps target resolution out of the build phase so cancellation before // dispatch doesn't pay for cluster-state routing, and leaves room for shuffle // reads whose targets depend on child manifests only available at dispatch time. - return new ShardFragmentStageExecution(stage, config, sink, clusterService, requestBuilder, transport); + return new ShardFragmentStageExecution(stage, config, sink, clusterService, requestBuilder, transport, responseCodec); } private static List buildPlanAlternatives(Stage stage) { diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/StageExecutionBuilder.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/StageExecutionBuilder.java index cf5d907717161..355902ad3efb1 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/StageExecutionBuilder.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/StageExecutionBuilder.java @@ -57,6 +57,7 @@ public StageExecutionBuilder(ClusterService clusterService, AnalyticsSearchTrans registerScheduler(StageExecutionType.SHARD_FRAGMENT, new ShardFragmentStageScheduler(clusterService, dispatcher)); registerScheduler(StageExecutionType.COORDINATOR_REDUCE, new LocalStageScheduler()); registerScheduler(StageExecutionType.LOCAL_PASSTHROUGH, (stage, sink, config) -> new PassThroughStageExecution(stage, sink)); + registerScheduler(StageExecutionType.LATE_MATERIALIZATION, new LateMaterializationStageScheduler(dispatcher)); } /** @@ -88,7 +89,7 @@ public StageExecution buildRootExecution(Stage rootStage, QueryContext config) { */ public StageExecution buildExecution(Stage stage, StageExecution parentExec, QueryContext config) { ExchangeSink sink = switch (stage.getExecutionType()) { - case SHARD_FRAGMENT, COORDINATOR_REDUCE, LOCAL_PASSTHROUGH -> resolveRowSink(stage, parentExec); + case SHARD_FRAGMENT, COORDINATOR_REDUCE, LOCAL_PASSTHROUGH, LATE_MATERIALIZATION -> resolveRowSink(stage, parentExec); }; return buildStageExecution(stage, sink, config); } diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/FieldStorageResolver.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/FieldStorageResolver.java index 2c4bad3a9b866..c09f976069105 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/FieldStorageResolver.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/FieldStorageResolver.java @@ -69,6 +69,19 @@ public FieldStorageResolver(IndexMetadata indexMetadata) { this.fieldStorage = new HashMap<>(); populateFromProperties(properties, "", primaryFormat); + // Virtual row ID column — always in parquet, computed by analytics backend. + this.fieldStorage.put( + "__row_id__", + new FieldStorageInfo( + "__row_id__", + "long", + FieldType.fromMappingType("long"), + List.of(primaryFormat), + List.of(), + List.of(), + false + ) + ); } @SuppressWarnings("unchecked") diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/PlannerImpl.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/PlannerImpl.java index 26794af1b2093..f86b9600678e9 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/PlannerImpl.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/PlannerImpl.java @@ -27,6 +27,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.analytics.planner.rel.OpenSearchDistributionTraitDef; +import org.opensearch.analytics.planner.rules.LateMaterializationRule; import org.opensearch.analytics.planner.rules.OpenSearchAggregateReduceRule; import org.opensearch.analytics.planner.rules.OpenSearchAggregateRule; import org.opensearch.analytics.planner.rules.OpenSearchAggregateSplitRule; @@ -146,6 +147,22 @@ public static RelNode markAndOptimize(RelNode rawRelNode, PlannerContext context RelNode result = volcanoPlanner.findBestExp(); LOGGER.info("After CBO:\n{}", RelOptUtil.toString(result)); - return result; + + // Phase 3: Post-CBO rewrites — late materialization (QTF) detection. + // Runs on the marked OpenSearch nodes after CBO has inserted exchanges. + // The rule narrows projections to sort/filter columns + __row_id__ when + // beneficial, enabling the DAGBuilder to detect QTF eligibility downstream. + HepProgramBuilder postCboBuilder = new HepProgramBuilder(); + postCboBuilder.addMatchOrder(HepMatchOrder.TOP_DOWN); + postCboBuilder.addRuleInstance(new LateMaterializationRule()); + HepPlanner postCboPlanner = new HepPlanner(postCboBuilder.build()); + postCboPlanner.setRoot(result); + RelNode afterPostCbo = postCboPlanner.findBestExp(); + + if (afterPostCbo != result) { + LOGGER.info("After post-CBO (QTF rewrite):\n{}", RelOptUtil.toString(afterPostCbo)); + } + + return afterPostCbo; } } diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/DAGBuilder.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/DAGBuilder.java index ebf4b1d84a1ce..bfc06067f841f 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/DAGBuilder.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/DAGBuilder.java @@ -12,6 +12,7 @@ import org.opensearch.analytics.planner.CapabilityRegistry; import org.opensearch.analytics.planner.CapabilityResolutionUtils; import org.opensearch.analytics.planner.rel.OpenSearchExchangeReducer; +import org.opensearch.analytics.planner.rel.OpenSearchProject; import org.opensearch.analytics.planner.rel.OpenSearchRelNode; import org.opensearch.analytics.planner.rel.OpenSearchStageInputScan; import org.opensearch.analytics.spi.ExchangeSinkProvider; @@ -69,9 +70,69 @@ public static QueryDAG build(RelNode cboOutput, CapabilityRegistry registry, Clu TargetResolver rootTargetResolver = childStages.isEmpty() ? new ShardTargetResolver(rootFragment, clusterService) : null; Stage rootStage = new Stage(counter[0]++, rootFragment, childStages, null, sinkProvider, rootTargetResolver); + + // QTF: if the root fragment's output schema contains __row_id__, the + // LateMaterializationRule has narrowed the projection to sort/filter + // columns + __row_id__. Wrap the reduce stage with a late-materialization + // root that handles the fetch + assembly phase. + if (isLateMaterializationEligible(rootFragment)) { + return wrapWithLateMaterialization(counter, rootStage); + } + return new QueryDAG(UUID.randomUUID().toString(), rootStage); } + /** + * Checks whether the {@code LateMaterializationRule} has rewritten the plan for QTF. + * + *

Detection: the root fragment's top node must be an {@code OpenSearchProject} + * whose output schema includes {@code __row_id__}. This is a reliable signal because: + *

    + *
  • The scan always has {@code __row_id__} in its row type (added by schema builder), + * but a normal Project does not reference it — only the QTF-rewritten Project does.
  • + *
  • If there is no Project (e.g. {@code SELECT *}), the rule did not fire, so we + * should not trigger late materialization.
  • + *
+ */ + private static boolean isLateMaterializationEligible(RelNode rootFragment) { + if (rootFragment instanceof OpenSearchProject project) { + return project.getRowType().getFieldNames().contains("__row_id__"); + } + return false; + } + + /** + * Wraps a 2-stage DAG (shard + reduce) into a 3-stage QTF DAG: + *
+     *   LateMaterializationStageExecution (root)
+     *       └─ LocalStageExecution (COORDINATOR_REDUCE)  — sort + limit
+     *             └─ ShardFragmentStageExecution (SHARD_FRAGMENT, injectShardOrdinal=true)
+     * 
+ */ + private static QueryDAG wrapWithLateMaterialization(int[] counter, Stage reduceStage) { + // Mark the shard fragment child to inject shard_id into every batch + for (Stage child : reduceStage.getChildStages()) { + if (child.getTargetResolver() != null) { + child.setInjectShardOrdinal(true); + } + } + + // The reduce stage becomes a child of the new late-mat root. + // Late-mat root has no sink provider, no target resolver, and no plan fragment + // (it is a pure execution stage — position map + fetch + assembly). + Stage lateMatRoot = new Stage( + counter[0]++, + null, // no plan fragment (pure execution stage) + List.of(reduceStage), + null, // no exchange info (root) + null, // no sink provider + null, // no target resolver + true // lateMaterialization = true + ); + + return new QueryDAG(UUID.randomUUID().toString(), lateMatRoot); + } + private static RelNode sever( RelNode node, int[] counter, diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/QueryDAG.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/QueryDAG.java index b1020f48b0ae3..53c4b82e41db5 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/QueryDAG.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/QueryDAG.java @@ -30,9 +30,13 @@ public String toString() { private static void appendStage(StringBuilder sb, Stage stage, int depth) { String indent = " ".repeat(depth); sb.append(indent).append("Stage ").append(stage.getStageId()); + sb.append(" [").append(stage.getExecutionType()).append("]"); if (stage.getExchangeInfo() != null) { sb.append(" exchange=").append(stage.getExchangeInfo().distributionType()); } + if (stage.isInjectShardOrdinal()) { + sb.append(" injectShardOrdinal=true"); + } sb.append("\n"); if (stage.getFragment() != null) { for (String line : RelOptUtil.toString(stage.getFragment()).split("\n")) { diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/Stage.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/Stage.java index 61e5668b5dda9..a422f6a92c5a0 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/Stage.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/Stage.java @@ -47,6 +47,7 @@ public class Stage { private final StageExecutionType executionType; private List planAlternatives; private FragmentInstructionHandlerFactory instructionHandlerFactory; + private boolean injectShardOrdinal; public Stage( int stageId, @@ -55,6 +56,23 @@ public Stage( ExchangeInfo exchangeInfo, ExchangeSinkProvider exchangeSinkProvider, TargetResolver targetResolver + ) { + this(stageId, fragment, childStages, exchangeInfo, exchangeSinkProvider, targetResolver, false); + } + + /** + * Constructs a stage with an optional late-materialization override. + * When {@code lateMaterialization} is true, the execution type is forced to + * {@link StageExecutionType#LATE_MATERIALIZATION} regardless of the derived type. + */ + public Stage( + int stageId, + RelNode fragment, + List childStages, + ExchangeInfo exchangeInfo, + ExchangeSinkProvider exchangeSinkProvider, + TargetResolver targetResolver, + boolean lateMaterialization ) { this.stageId = stageId; this.fragment = fragment; @@ -62,7 +80,9 @@ public Stage( this.exchangeInfo = exchangeInfo; this.exchangeSinkProvider = exchangeSinkProvider; this.targetResolver = targetResolver; - this.executionType = setStageExecutionType(exchangeSinkProvider, targetResolver); + this.executionType = lateMaterialization + ? StageExecutionType.LATE_MATERIALIZATION + : setStageExecutionType(exchangeSinkProvider, targetResolver); this.planAlternatives = List.of(); } @@ -128,6 +148,14 @@ public void setInstructionHandlerFactory(FragmentInstructionHandlerFactory instr this.instructionHandlerFactory = instructionHandlerFactory; } + public boolean isInjectShardOrdinal() { + return injectShardOrdinal; + } + + public void setInjectShardOrdinal(boolean injectShardOrdinal) { + this.injectShardOrdinal = injectShardOrdinal; + } + private StageExecutionType setStageExecutionType(ExchangeSinkProvider exchangeSinkProvider, TargetResolver targetResolver) { if (targetResolver != null) { return StageExecutionType.SHARD_FRAGMENT; diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/StageExecutionType.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/StageExecutionType.java index 0393479ca7c5a..cfdcdb6cdcf5a 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/StageExecutionType.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/dag/StageExecutionType.java @@ -35,5 +35,11 @@ public enum StageExecutionType { * stages sitting above children that already produced the final rows. * A single-stage query that scans shards is {@link #SHARD_FRAGMENT}, not this. */ - LOCAL_PASSTHROUGH + LOCAL_PASSTHROUGH, + /** + * Late materialization (QTF) root stage. Consumes the reduced output from + * a child COORDINATOR_REDUCE stage, builds a position map from (shard_id, __row_id__), + * dispatches fetch-by-row-id requests per shard, and assembles the final result. + */ + LATE_MATERIALIZATION } diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/rel/OpenSearchDistributionTraitDef.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/rel/OpenSearchDistributionTraitDef.java index 771688ad8cfba..4d84aa80dd749 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/rel/OpenSearchDistributionTraitDef.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/rel/OpenSearchDistributionTraitDef.java @@ -110,9 +110,9 @@ public RelNode convert(RelOptPlanner planner, RelNode rel, OpenSearchDistributio List reduceViable = CapabilityResolutionUtils.filterByReduceCapability(registry, viableBackends); result = new OpenSearchExchangeReducer(rel.getCluster(), rel.getTraitSet().replace(toTrait), rel, reduceViable); } else { - // TODO: implement HASH/RANGE shuffle exchange when joins and shuffle aggregates are added. - // Requires DataTransferCapability producer/consumer intersection for shuffle impl selection. - throw new UnsupportedOperationException("HASH/RANGE exchange not yet implemented [toTrait=" + toTrait + "]"); + // Unsupported conversion (e.g. SINGLETON→RANDOM) — return null to let + // Volcano discard this option and find an alternative plan. + return null; } return planner.register(result, rel); diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/rules/LateMaterializationRule.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/rules/LateMaterializationRule.java new file mode 100644 index 0000000000000..8af18a748aba4 --- /dev/null +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/planner/rules/LateMaterializationRule.java @@ -0,0 +1,226 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analytics.planner.rules; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.analytics.planner.RelNodeUtils; +import org.opensearch.analytics.planner.rel.OpenSearchExchangeReducer; +import org.opensearch.analytics.planner.rel.OpenSearchFilter; +import org.opensearch.analytics.planner.rel.OpenSearchProject; +import org.opensearch.analytics.planner.rel.OpenSearchSort; +import org.opensearch.analytics.planner.rel.OpenSearchTableScan; + +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +/** + * Post-CBO rewrite rule for Query-Then-Fetch (QTF) late materialization. + * + *

Detects the pattern {@code Project -> Sort(with LIMIT) -> [Filter ->] Scan} + * and narrows the projection to only the columns needed for sort and filter, + * plus the {@code __row_id__} column. Columns that are only needed for the + * final output (pureProjectColumns) are deferred to a later fetch phase. + * + *

The DAGBuilder detects {@code __row_id__} in the shard fragment's output + * schema and wraps the DAG with a {@code LateMaterializationStageExecution} + * that handles the fetch phase. + * + *

This rule only fires when: + *

    + *
  • Sort has a fetch (LIMIT) — without limit, all rows are needed anyway
  • + *
  • There are columns in the project that are NOT used by sort/filter + * (pureProjectColumns is non-empty)
  • + *
  • The query does not already project {@code __row_id__} (no user-driven rewrite)
  • + *
+ * + * @opensearch.internal + */ +public class LateMaterializationRule extends RelOptRule { + + private static final Logger LOGGER = LogManager.getLogger(LateMaterializationRule.class); + static final String ROW_ID_COLUMN = "__row_id__"; + + /** + * Pattern: Project -> Sort -> Filter -> Scan (with optional filter). + * We use a broad match on Project -> Sort -> any, and inspect the subtree + * in onMatch to handle the optional filter. + */ + public LateMaterializationRule() { + super(operand(OpenSearchProject.class, operand(OpenSearchSort.class, operand(RelNode.class, any()))), "LateMaterializationRule"); + } + + @Override + public void onMatch(RelOptRuleCall call) { + OpenSearchProject project = call.rel(0); + OpenSearchSort sort = call.rel(1); + RelNode sortChild = RelNodeUtils.unwrapHep(call.rel(2)); + + // Only fire when Sort has a limit (fetch != null) + if (sort.fetch == null) { + return; + } + + // Walk below the sort to find the filter (optional) and scan. + // Multi-shard plans have an ExchangeReducer between sort and filter/scan: + // Project -> Sort -> ExchangeReducer -> [Filter ->] Scan + // Single-shard plans have filter/scan directly: + // Project -> Sort -> [Filter ->] Scan + RelNode belowExchange = sortChild; + if (sortChild instanceof OpenSearchExchangeReducer reducer) { + belowExchange = RelNodeUtils.unwrapHep(reducer.getInput()); + } + + OpenSearchFilter filter = null; + OpenSearchTableScan scan; + + if (belowExchange instanceof OpenSearchFilter osFilter) { + filter = osFilter; + RelNode filterChild = RelNodeUtils.unwrapHep(osFilter.getInput()); + if (filterChild instanceof OpenSearchTableScan osScan) { + scan = osScan; + } else { + return; // Not the expected shape + } + } else if (belowExchange instanceof OpenSearchTableScan osScan) { + scan = osScan; + } else { + return; // Not the expected shape (e.g. aggregate below sort) + } + + // Find the __row_id__ column index in the sort's input row type. + // Filter, ExchangeReducer, and Scan all pass through the same row type, + // so we use the sort's row type (which equals its input's row type). + RelDataType sortInputRowType = sort.getRowType(); + int rowIdIndex = findFieldIndex(sortInputRowType, ROW_ID_COLUMN); + if (rowIdIndex < 0) { + // __row_id__ not in schema — cannot do QTF + return; + } + + // Check if the project already outputs __row_id__ + if (projectAlreadyOutputsRowId(project)) { + return; + } + + // Collect column indices used by sort collation. + // Collation field indices reference the sort's input row type. + Set computationIndices = new LinkedHashSet<>(); + + for (RelFieldCollation fieldCollation : sort.getCollation().getFieldCollations()) { + computationIndices.add(fieldCollation.getFieldIndex()); + } + + // Collect column indices used by the filter condition (if present). + // Filter condition indices also reference the scan's row type, which is + // identical to the sort's input row type (Filter/ExchangeReducer pass through). + if (filter != null) { + collectInputRefs(filter.getCondition(), computationIndices); + } + + // Collect column indices referenced by the project's expressions. + // Project expressions reference the sort's output row type (= sort's input row type). + Set projectIndices = new LinkedHashSet<>(); + for (RexNode expr : project.getProjects()) { + collectInputRefs(expr, projectIndices); + } + + // pureProjectColumns = columns referenced only by project, not by sort/filter + Set pureProjectIndices = new LinkedHashSet<>(projectIndices); + pureProjectIndices.removeAll(computationIndices); + + if (pureProjectIndices.isEmpty()) { + // All project columns are needed for sort/filter — no benefit from QTF + return; + } + + LOGGER.info( + "QTF rewrite: {} pure project columns can be deferred to fetch phase (sort/filter need {} columns)", + pureProjectIndices.size(), + computationIndices.size() + ); + + // Build the narrowed column set: sort/filter columns + __row_id__ + Set narrowedIndices = new LinkedHashSet<>(computationIndices); + narrowedIndices.add(rowIdIndex); + + // Build the new projection expressions: one RexInputRef per narrowed column + // referencing the sort's row type. + List newProjectExprs = new ArrayList<>(); + List newFieldNames = new ArrayList<>(); + for (int idx : narrowedIndices) { + RelDataTypeField field = sortInputRowType.getFieldList().get(idx); + newProjectExprs.add(new RexInputRef(idx, field.getType())); + newFieldNames.add(field.getName()); + } + + // Build the new row type for the narrowed project + RelDataType newRowType = project.getCluster() + .getTypeFactory() + .createStructType(newProjectExprs.stream().map(RexNode::getType).toList(), newFieldNames); + + // Create the new narrowed project, reusing sort as input + OpenSearchProject narrowedProject = new OpenSearchProject( + project.getCluster(), + project.getTraitSet(), + RelNodeUtils.unwrapHep(project.getInput()), + newProjectExprs, + newRowType, + project.getViableBackends() + ); + + LOGGER.info("QTF rewrite applied: narrowed projection from {} to {} columns", project.getRowType().getFieldCount(), newFieldNames); + + call.transformTo(narrowedProject); + } + + /** + * Checks whether the project already outputs __row_id__. + */ + private boolean projectAlreadyOutputsRowId(OpenSearchProject project) { + return project.getRowType().getFieldNames().contains(ROW_ID_COLUMN); + } + + /** + * Finds the index of a field by name in the row type, or -1 if not found. + */ + static int findFieldIndex(RelDataType rowType, String fieldName) { + List fields = rowType.getFieldList(); + for (int i = 0; i < fields.size(); i++) { + if (fields.get(i).getName().equals(fieldName)) { + return i; + } + } + return -1; + } + + /** + * Collects all RexInputRef indices from the given expression tree. + */ + private void collectInputRefs(RexNode node, Set indices) { + node.accept(new RexVisitorImpl(true) { + @Override + public Void visitInputRef(RexInputRef inputRef) { + indices.add(inputRef.getIndex()); + return null; + } + }); + } +} diff --git a/sandbox/plugins/analytics-engine/src/test/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecutionTests.java b/sandbox/plugins/analytics-engine/src/test/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecutionTests.java index 34e4c3a156f84..3707d5e626d54 100644 --- a/sandbox/plugins/analytics-engine/src/test/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecutionTests.java +++ b/sandbox/plugins/analytics-engine/src/test/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecutionTests.java @@ -20,6 +20,7 @@ import org.opensearch.analytics.exec.QueryContext; import org.opensearch.analytics.exec.StreamingResponseListener; import org.opensearch.analytics.exec.action.FragmentExecutionArrowResponse; +import org.opensearch.analytics.exec.stage.RowResponseCodec; import org.opensearch.analytics.exec.action.FragmentExecutionRequest; import org.opensearch.analytics.exec.task.AnalyticsQueryTask; import org.opensearch.analytics.planner.dag.ShardExecutionTarget; @@ -134,7 +135,7 @@ private ShardFragmentStageExecution buildExecution( List.of(new FragmentExecutionRequest.PlanAlternative("test-backend", new byte[0], List.of())) ); - return new ShardFragmentStageExecution(stage, config, sink, clusterService, requestBuilder, dispatcher); + return new ShardFragmentStageExecution(stage, config, sink, clusterService, requestBuilder, dispatcher, RowResponseCodec.INSTANCE); } private VectorSchemaRoot createTestBatch(int rows) { diff --git a/sandbox/plugins/dsl-query-executor/src/main/java/org/opensearch/dsl/converter/ProjectConverter.java b/sandbox/plugins/dsl-query-executor/src/main/java/org/opensearch/dsl/converter/ProjectConverter.java index 5bdf52613e370..a4e512e0f1820 100644 --- a/sandbox/plugins/dsl-query-executor/src/main/java/org/opensearch/dsl/converter/ProjectConverter.java +++ b/sandbox/plugins/dsl-query-executor/src/main/java/org/opensearch/dsl/converter/ProjectConverter.java @@ -149,6 +149,11 @@ private void resolveField( ) throws ConversionException { RelDataTypeField field = rowType.getField(fieldName, false, false); if (field == null) { + // __row_id__ is a virtual column computed by the analytics backend. + // The DSL schema doesn't know about it — skip silently. + if ("__row_id__".equals(fieldName)) { + return; + } throw new ConversionException("Field '" + fieldName + "' not found in schema"); } if (seen.add(field.getName())) { diff --git a/sandbox/qa/analytics-engine-rest/src/test/java/org/opensearch/analytics/qa/LateMaterializationIT.java b/sandbox/qa/analytics-engine-rest/src/test/java/org/opensearch/analytics/qa/LateMaterializationIT.java new file mode 100644 index 0000000000000..cb19bb681647f --- /dev/null +++ b/sandbox/qa/analytics-engine-rest/src/test/java/org/opensearch/analytics/qa/LateMaterializationIT.java @@ -0,0 +1,344 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.analytics.qa; + +import org.apache.lucene.tests.util.LuceneTestCase.AwaitsFix; +import org.opensearch.client.Request; +import org.opensearch.client.Response; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +/** + * End-to-end debug IT for the Late Materialization (QTF) flow. + * + *

2 shards, 2 segments per shard, 5 docs total. Designed to trace the full flow: + * query phase → shard_id injection → reduce → position map → fetch → assembly. + * + *

+ * Segment 1: alice(score=100, city=NYC), bob(score=200, city=SF), dave(score=50, city=LA)
+ * Segment 2: carol(score=150, city=NYC), eve(score=300, city=SF)
+ * 
+ * + *

Query: SELECT __row_id__, name, score FROM t ORDER BY score LIMIT 3 + *

Expected: dave(50), alice(100), carol(150) + */ +public class LateMaterializationIT extends AnalyticsRestTestCase { + + private static final String INDEX = "late_mat_debug"; + private static boolean ready = false; + + private void setup() throws IOException { + if (ready) return; + + try { client().performRequest(new Request("DELETE", "/" + INDEX)); } catch (Exception ignored) {} + + // 1 shard (analytics engine doesn't support multi-shard distribution yet) + // The QTF flow is debugged with multi-segment within one shard. + Request create = new Request("PUT", "/" + INDEX); + create.setJsonEntity("{" + + "\"settings\":{" + + " \"number_of_shards\":1,\"number_of_replicas\":0," + + " \"index.pluggable.dataformat.enabled\":true," + + " \"index.pluggable.dataformat\":\"composite\"," + + " \"index.composite.primary_data_format\":\"parquet\"," + + " \"index.composite.secondary_data_formats\":\"lucene\"" + + "}," + + "\"mappings\":{\"properties\":{" + + " \"name\":{\"type\":\"keyword\"}," + + " \"score\":{\"type\":\"integer\"}," + + " \"city\":{\"type\":\"keyword\"}" + + "}}}"); + client().performRequest(create); + + Request health = new Request("GET", "/_cluster/health/" + INDEX); + health.addParameter("wait_for_status", "green"); + health.addParameter("timeout", "30s"); + client().performRequest(health); + + // Segment 1: 3 docs (distributed across 2 shards by hash) + bulk("{\"index\":{}}\n{\"name\":\"alice\",\"score\":100,\"city\":\"NYC\"}\n" + + "{\"index\":{}}\n{\"name\":\"bob\",\"score\":200,\"city\":\"SF\"}\n" + + "{\"index\":{}}\n{\"name\":\"dave\",\"score\":50,\"city\":\"LA\"}\n"); + client().performRequest(new Request("POST", "/" + INDEX + "/_flush?force=true")); + + // Segment 2: 2 more docs + bulk("{\"index\":{}}\n{\"name\":\"carol\",\"score\":150,\"city\":\"NYC\"}\n" + + "{\"index\":{}}\n{\"name\":\"eve\",\"score\":300,\"city\":\"SF\"}\n"); + client().performRequest(new Request("POST", "/" + INDEX + "/_flush?force=true")); + + ready = true; + } + + /** + * Basic QTF query: projects __row_id__ + sort key + data columns. + * This triggers the full QTF flow. + * + * Watch the logs for: + * - [ShardFragmentStageExecution] shard_id injection + * - [QTF] Position map built + * - [QTF] Dispatching fetch + */ + public void testQtfSortByScore() throws IOException { + setup(); + + String ppl = "source = " + INDEX + " | sort score | fields __row_id__, name, score | head 3"; + List> rows = executePplRows(ppl); + + logger.info("[LateMat-IT] Results for sort-by-score:"); + for (int i = 0; i < rows.size(); i++) { + logger.info(" row {}: {}", i, rows.get(i)); + } + + // Verify we got results (exact values depend on QTF wiring status) + assertNotNull("Should have results", rows); + assertTrue("Should have at least 1 row", rows.size() >= 1); + } + + /** + * QTF with filter: only city='NYC' rows, sorted by score. + * Expected: alice(100), carol(150) — both from shard 0. + */ + public void testQtfFilteredSort() throws IOException { + setup(); + + String ppl = "source = " + INDEX + " | where city = 'NYC' | sort score | fields __row_id__, name, score"; + List> rows = executePplRows(ppl); + + logger.info("[LateMat-IT] Results for filtered sort (city=NYC):"); + for (int i = 0; i < rows.size(); i++) { + logger.info(" row {}: {}", i, rows.get(i)); + } + + assertNotNull(rows); + assertEquals("NYC has 2 docs", 2, rows.size()); + } + + /** + * Full scan no filter — all 5 docs sorted by score. + * Expected order: dave(50), alice(100), carol(150), bob(200), eve(300) + */ + public void testQtfFullScan() throws IOException { + setup(); + + String ppl = "source = " + INDEX + " | sort score | fields __row_id__, name, score"; + List> rows = executePplRows(ppl); + + logger.info("[LateMat-IT] Results for full scan:"); + for (int i = 0; i < rows.size(); i++) { + logger.info(" row {}: {}", i, rows.get(i)); + } + + assertNotNull(rows); + assertEquals("Should have all 5 docs", 5, rows.size()); + } + + // ── Multi-shard test ── + + private static final String INDEX_MULTI = "late_mat_multi_shard"; + private static boolean multiReady = false; + + private void setupMultiShard() throws IOException { + if (multiReady) return; + + try { client().performRequest(new Request("DELETE", "/" + INDEX_MULTI)); } catch (Exception ignored) {} + + Request create = new Request("PUT", "/" + INDEX_MULTI); + create.setJsonEntity("{" + + "\"settings\":{" + + " \"number_of_shards\":2,\"number_of_replicas\":0," + + " \"index.pluggable.dataformat.enabled\":true," + + " \"index.pluggable.dataformat\":\"composite\"," + + " \"index.composite.primary_data_format\":\"parquet\"," + + " \"index.composite.secondary_data_formats\":\"lucene\"" + + "}," + + "\"mappings\":{\"properties\":{" + + " \"name\":{\"type\":\"keyword\"}," + + " \"score\":{\"type\":\"integer\"}," + + " \"city\":{\"type\":\"keyword\"}" + + "}}}"); + client().performRequest(create); + + Request health = new Request("GET", "/_cluster/health/" + INDEX_MULTI); + health.addParameter("wait_for_status", "green"); + health.addParameter("timeout", "30s"); + client().performRequest(health); + + String[] names = {"alice", "bob", "carol", "dave", "eve", "frank", "grace", "heidi", + "ivan", "judy", "karl", "laura", "mike", "nina", "oscar", "pat", "quinn", "rosa", + "steve", "tina", "uma", "vic", "wendy", "xena", "yuri", "zara", "adam", "beth", "chad", "diana"}; + String[] cities = {"NYC", "SF", "LA", "NYC", "SF"}; + + // Segment 1: first 15 docs + StringBuilder bulk1 = new StringBuilder(); + for (int i = 0; i < 15; i++) { + bulk1.append("{\"index\":{}}\n"); + bulk1.append("{\"name\":\"").append(names[i]).append("\",\"score\":").append((i + 1) * 10) + .append(",\"city\":\"").append(cities[i % cities.length]).append("\"}\n"); + } + bulkTo(INDEX_MULTI, bulk1.toString()); + client().performRequest(new Request("POST", "/" + INDEX_MULTI + "/_flush?force=true")); + + // Segment 2: next 15 docs + StringBuilder bulk2 = new StringBuilder(); + for (int i = 15; i < 30; i++) { + bulk2.append("{\"index\":{}}\n"); + bulk2.append("{\"name\":\"").append(names[i]).append("\",\"score\":").append((i + 1) * 10) + .append(",\"city\":\"").append(cities[i % cities.length]).append("\"}\n"); + } + bulkTo(INDEX_MULTI, bulk2.toString()); + client().performRequest(new Request("POST", "/" + INDEX_MULTI + "/_flush?force=true")); + + multiReady = true; + } + + /** + * Multi-shard QTF: 2 shards, 30 docs. + * Tests whether the position map + fetch correctly handles multiple shards. + */ + public void testQtfMultiShard() throws IOException { + setupMultiShard(); + + String ppl = "source = " + INDEX_MULTI + " | where score > 100 | sort score | fields __row_id__, name, score"; + List> rows = executePplRows(ppl); + + logger.info("[LateMat-IT] Results for multi-shard filtered sort (score > 100):"); + for (int i = 0; i < rows.size(); i++) { + logger.info(" row {}: {}", i, rows.get(i)); + } + + assertNotNull(rows); + // 30 docs with scores 10,20,...,300. score > 100 means scores 110-300 = 20 rows + assertEquals("Should have 20 rows with score > 100", 20, rows.size()); + } + + // ── Multi-index test ── + + private static final String INDEX_MI_A = "late_mat_mi_a"; + private static final String INDEX_MI_B = "late_mat_mi_b"; + private static boolean multiIndexReady = false; + + private void setupMultiIndex() throws IOException { + if (multiIndexReady) return; + + String[] indices = {INDEX_MI_A, INDEX_MI_B}; + for (String idx : indices) { + try { client().performRequest(new Request("DELETE", "/" + idx)); } catch (Exception ignored) {} + + Request create = new Request("PUT", "/" + idx); + create.setJsonEntity("{" + + "\"settings\":{" + + " \"number_of_shards\":2,\"number_of_replicas\":0," + + " \"index.pluggable.dataformat.enabled\":true," + + " \"index.pluggable.dataformat\":\"composite\"," + + " \"index.composite.primary_data_format\":\"parquet\"," + + " \"index.composite.secondary_data_formats\":\"lucene\"" + + "}," + + "\"mappings\":{\"properties\":{" + + " \"name\":{\"type\":\"keyword\"}," + + " \"score\":{\"type\":\"integer\"}," + + " \"city\":{\"type\":\"keyword\"}" + + "}}}"); + client().performRequest(create); + + Request health = new Request("GET", "/_cluster/health/" + idx); + health.addParameter("wait_for_status", "green"); + health.addParameter("timeout", "30s"); + client().performRequest(health); + } + + // Index A: 20 docs across 2 segments + StringBuilder bulkA1 = new StringBuilder(); + for (int i = 0; i < 10; i++) { + bulkA1.append("{\"index\":{}}\n"); + bulkA1.append("{\"name\":\"a_").append(i).append("\",\"score\":").append((i + 1) * 5) + .append(",\"city\":\"NYC\"}\n"); + } + bulkTo(INDEX_MI_A, bulkA1.toString()); + client().performRequest(new Request("POST", "/" + INDEX_MI_A + "/_flush?force=true")); + + StringBuilder bulkA2 = new StringBuilder(); + for (int i = 10; i < 20; i++) { + bulkA2.append("{\"index\":{}}\n"); + bulkA2.append("{\"name\":\"a_").append(i).append("\",\"score\":").append((i + 1) * 5) + .append(",\"city\":\"SF\"}\n"); + } + bulkTo(INDEX_MI_A, bulkA2.toString()); + client().performRequest(new Request("POST", "/" + INDEX_MI_A + "/_flush?force=true")); + + // Index B: 15 docs across 2 segments + StringBuilder bulkB1 = new StringBuilder(); + for (int i = 0; i < 8; i++) { + bulkB1.append("{\"index\":{}}\n"); + bulkB1.append("{\"name\":\"b_").append(i).append("\",\"score\":").append((i + 1) * 7) + .append(",\"city\":\"LA\"}\n"); + } + bulkTo(INDEX_MI_B, bulkB1.toString()); + client().performRequest(new Request("POST", "/" + INDEX_MI_B + "/_flush?force=true")); + + StringBuilder bulkB2 = new StringBuilder(); + for (int i = 8; i < 15; i++) { + bulkB2.append("{\"index\":{}}\n"); + bulkB2.append("{\"name\":\"b_").append(i).append("\",\"score\":").append((i + 1) * 7) + .append(",\"city\":\"NYC\"}\n"); + } + bulkTo(INDEX_MI_B, bulkB2.toString()); + client().performRequest(new Request("POST", "/" + INDEX_MI_B + "/_flush?force=true")); + + multiIndexReady = true; + } + + /** + * Multi-index + multi-shard + multi-segment QTF. + * 2 indices × 2 shards × 2 segments. Tests ordinal space spanning indices. + */ + @AwaitsFix(bugUrl = "https://github.com/opensearch-project/OpenSearch/issues/0000") + public void testQtfMultiIndex() throws IOException { + setupMultiIndex(); + + String ppl = "source = " + INDEX_MI_A + "," + INDEX_MI_B + + " | where score > 50 | sort score | fields __row_id__, name, score"; + List> rows = executePplRows(ppl); + + logger.info("[LateMat-IT] Results for multi-index sort (score > 50):"); + for (int i = 0; i < rows.size(); i++) { + logger.info(" row {}: {}", i, rows.get(i)); + } + + assertNotNull(rows); + assertTrue("Should have rows from both indices", rows.size() >= 1); + } + + // ── Helpers ── + + private void bulk(String body) throws IOException { + bulkTo(INDEX, body); + } + + private void bulkTo(String index, String body) throws IOException { + Request req = new Request("POST", "/" + index + "/_bulk"); + req.setJsonEntity(body); + req.addParameter("refresh", "true"); + client().performRequest(req); + } + + private List> executePplRows(String ppl) throws IOException { + logger.info("[LateMat-IT] Executing: {}", ppl); + Request req = new Request("POST", "/_analytics/ppl"); + req.setJsonEntity("{\"query\": \"" + escapeJson(ppl) + "\"}"); + Response resp = client().performRequest(req); + Map parsed = assertOkAndParse(resp, "PPL"); + + @SuppressWarnings("unchecked") + List> rows = (List>) parsed.get("rows"); + assertNotNull("No rows in response", rows); + return rows; + } +} From 5c42fb86f00fec0a8b45d2026dc5d115a92047b2 Mon Sep 17 00:00:00 2001 From: Arpit Bandejiya Date: Thu, 14 May 2026 21:19:01 +0530 Subject: [PATCH 2/3] Clean the codec Signed-off-by: Arpit Bandejiya --- .../analytics/exec/stage/ResponseCodec.java | 40 ------- .../exec/stage/RowResponseCodec.java | 112 ------------------ .../stage/ShardFragmentStageExecution.java | 93 +++++---------- .../stage/ShardFragmentStageScheduler.java | 18 +-- 4 files changed, 32 insertions(+), 231 deletions(-) delete mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ResponseCodec.java delete mode 100644 sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/RowResponseCodec.java diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ResponseCodec.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ResponseCodec.java deleted file mode 100644 index 528b3a93e2b1f..0000000000000 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ResponseCodec.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.analytics.exec.stage; - -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.opensearch.core.action.ActionResponse; - -/** - * Decodes a transport response into an Arrow {@link VectorSchemaRoot} for - * the coordinator-side sink. Implementations handle the specific wire - * format — {@code Object[]} rows (current), Arrow IPC (Flight), or any - * future format. - * - *

The codec is injected into {@link ShardFragmentStageExecution} at - * construction time by the scheduler. Swapping the codec swaps the - * serialization format without touching stage execution logic. - * - * @param the transport response type - * @opensearch.internal - */ -@FunctionalInterface -public interface ResponseCodec { - - /** - * Decodes a transport response into an Arrow {@link VectorSchemaRoot}. - * The returned VSR is owned by the caller (the sink). - * - * @param response the transport response - * @param allocator the buffer allocator for Arrow vectors - * @return a new VectorSchemaRoot; caller owns and must close it - */ - VectorSchemaRoot decode(R response, BufferAllocator allocator); -} diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/RowResponseCodec.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/RowResponseCodec.java deleted file mode 100644 index 38aef31c1337b..0000000000000 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/RowResponseCodec.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.analytics.exec.stage; - -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.ipc.ArrowStreamReader; -import org.apache.arrow.vector.types.pojo.Schema; -import org.opensearch.analytics.exec.action.FragmentExecutionResponse; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -/** - * {@link ResponseCodec} that deserializes the Arrow IPC stream payload carried - * by {@link FragmentExecutionResponse} into a single consolidated - * {@link VectorSchemaRoot}. Uses Arrow's {@link ArrowStreamReader} for message - * sequencing (schema header, record batches, end-of-stream) and - * {@link FieldVector#makeTransferPair} / {@link FieldVector#copyFromSafe} to - * move data into a caller-owned root — both are supported by every vector - * kind, including the view vectors ({@code Utf8View}, {@code BinaryView}) that - * DataFusion emits for aggregate group keys. - * - *

Deliberately avoids {@code VectorSchemaRootAppender} — its underlying - * {@code VectorAppender} rejects view vectors with - * {@code UnsupportedOperationException}. - * - * @opensearch.internal - */ -public final class RowResponseCodec implements ResponseCodec { - - /** Singleton instance — stateless, thread-safe. */ - public static final RowResponseCodec INSTANCE = new RowResponseCodec(); - - private RowResponseCodec() {} - - @Override - public VectorSchemaRoot decode(FragmentExecutionResponse response, BufferAllocator allocator) { - if (allocator == null) { - throw new IllegalArgumentException("BufferAllocator must not be null"); - } - byte[] payload = response.getIpcPayload(); - if (payload == null || payload.length == 0) { - return VectorSchemaRoot.create(new Schema(List.of()), allocator); - } - - List batches = new ArrayList<>(); - Schema schema; - try (ArrowStreamReader reader = new ArrowStreamReader(new ByteArrayInputStream(payload), allocator)) { - VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); - schema = readerRoot.getSchema(); - while (reader.loadNextBatch()) { - // Transfer each batch's buffers out of the reader's reused root into an - // independent root owned by this codec. Transfer works for all vector - // kinds (including views) and is zero-copy when allocators match. - VectorSchemaRoot batchRoot = VectorSchemaRoot.create(schema, allocator); - int rowCount = readerRoot.getRowCount(); - for (int i = 0; i < readerRoot.getFieldVectors().size(); i++) { - readerRoot.getVector(i).makeTransferPair(batchRoot.getVector(i)).transfer(); - } - batchRoot.setRowCount(rowCount); - batches.add(batchRoot); - } - } catch (IOException e) { - for (VectorSchemaRoot b : batches) - b.close(); - throw new IllegalStateException("Failed to decode Arrow IPC payload from fragment response", e); - } - - if (batches.isEmpty()) { - return VectorSchemaRoot.create(schema, allocator); - } - if (batches.size() == 1) { - return batches.get(0); - } - // Multiple batches — concatenate via per-cell copyFromSafe. Slower than columnar - // append but is the only operation Arrow Java implements for every vector type - // (VectorAppender throws on view vectors; see class javadoc). - int totalRows = batches.stream().mapToInt(VectorSchemaRoot::getRowCount).sum(); - VectorSchemaRoot combined = VectorSchemaRoot.create(schema, allocator); - try { - combined.allocateNew(); - for (int f = 0; f < combined.getFieldVectors().size(); f++) { - FieldVector dst = combined.getVector(f); - int offset = 0; - for (VectorSchemaRoot batch : batches) { - FieldVector src = batch.getVector(f); - int rows = batch.getRowCount(); - for (int r = 0; r < rows; r++) { - dst.copyFromSafe(r, offset + r, src); - } - offset += rows; - } - dst.setValueCount(totalRows); - } - combined.setRowCount(totalRows); - return combined; - } finally { - for (VectorSchemaRoot b : batches) - b.close(); - } - } -} diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecution.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecution.java index e2bc303a59b3e..d4f4b135e7043 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecution.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageExecution.java @@ -16,15 +16,14 @@ import org.opensearch.analytics.exec.StreamingResponseListener; import org.opensearch.analytics.exec.action.FragmentExecutionArrowResponse; import org.opensearch.analytics.exec.action.FragmentExecutionRequest; -import org.opensearch.analytics.exec.action.FragmentExecutionResponse; import org.opensearch.analytics.planner.dag.ExecutionTarget; import org.opensearch.analytics.planner.dag.ShardExecutionTarget; import org.opensearch.analytics.planner.dag.Stage; import org.opensearch.analytics.spi.ExchangeSink; -import org.opensearch.arrow.flight.transport.ArrowBatchResponse; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.core.action.ActionResponse; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -32,17 +31,16 @@ import java.util.function.Function; /** - * Leaf stage execution that dispatches fragment work to data-node shards. - * - *

Handles both Arrow streaming and row (codec-decoded) responses, feeding - * resulting batches into the parent stage's {@link ExchangeSink}. + * Leaf stage execution that dispatches fragment work to data-node shards via + * Arrow streaming, feeding resulting batches into the parent stage's + * {@link ExchangeSink}. * *

One-shot: constructed, {@link #start()} called once, listener * signaled on completion, then discarded. * * @opensearch.internal */ -public final class ShardFragmentStageExecution extends AbstractStageExecution implements DataProducer { +final class ShardFragmentStageExecution extends AbstractStageExecution implements DataProducer { private final AtomicInteger inFlight = new AtomicInteger(0); @@ -51,11 +49,8 @@ public final class ShardFragmentStageExecution extends AbstractStageExecution im private final ClusterService clusterService; private final Function requestBuilder; private final AnalyticsSearchTransportService dispatcher; - private final ResponseCodec responseCodec; private final Map pendingPerNode = new ConcurrentHashMap<>(); - - // QTF: ordinal → target mapping so coordinator can dispatch fetches to the right shard/node. - private final List resolvedTargets = new java.util.ArrayList<>(); + private final List resolvedTargets = new ArrayList<>(); ShardFragmentStageExecution( Stage stage, @@ -63,8 +58,7 @@ public final class ShardFragmentStageExecution extends AbstractStageExecution im ExchangeSink outputSink, ClusterService clusterService, Function requestBuilder, - AnalyticsSearchTransportService dispatcher, - ResponseCodec responseCodec + AnalyticsSearchTransportService dispatcher ) { super(stage); this.config = config; @@ -72,11 +66,6 @@ public final class ShardFragmentStageExecution extends AbstractStageExecution im this.clusterService = clusterService; this.requestBuilder = requestBuilder; this.dispatcher = dispatcher; - this.responseCodec = responseCodec; - } - - private boolean useArrowStreaming() { - return dispatcher.isStreamingEnabled(); } @Override @@ -91,7 +80,7 @@ public void start() { for (ExecutionTarget target : resolved) { resolvedTargets.add((ShardExecutionTarget) target); } - // Populate context targets BEFORE dispatch (local dispatch is synchronous) + // QTF: populate context targets BEFORE dispatch (local dispatch is synchronous) if (stage.isInjectShardOrdinal()) { config.getResolvedShardTargets().addAll(resolvedTargets); } @@ -103,44 +92,33 @@ public void start() { private void dispatchShardTask(ShardExecutionTarget target, int shardOrdinal) { FragmentExecutionRequest request = requestBuilder.apply(target); PendingExecutions pending = pendingFor(target); - if (useArrowStreaming()) { - dispatcher.dispatchFragmentStreaming( - request, - target.node(), - responseListener(FragmentExecutionArrowResponse::getRoot, shardOrdinal), - config.parentTask(), - pending - ); - } else { - dispatcher.dispatchFragment( - request, - target.node(), - responseListener(r -> responseCodec.decode(r, config.bufferAllocator()), shardOrdinal), - config.parentTask(), - pending - ); - } + dispatcher.dispatchFragmentStreaming(request, target.node(), responseListener(shardOrdinal), config.parentTask(), pending); } - private StreamingResponseListener responseListener( - Function toVsr, - int shardOrdinal - ) { + private StreamingResponseListener responseListener(int shardOrdinal) { return new StreamingResponseListener<>() { @Override - public void onStreamResponse(T response, boolean isLast) { + public void onStreamResponse(FragmentExecutionArrowResponse response, boolean isLast) { if (isDone()) { - releaseResponseResources(response); + VectorSchemaRoot root = response.getRoot(); + if (root != null) { + root.close(); + } return; } - VectorSchemaRoot vsr = toVsr.apply(response); - - if (stage.isInjectShardOrdinal()) { - vsr = injectShardId(vsr, shardOrdinal); + VectorSchemaRoot vsr = response.getRoot(); + try { + if (stage.isInjectShardOrdinal()) { + vsr = injectShardId(vsr, shardOrdinal); + } + outputSink.feed(vsr); + } catch (Exception e) { + captureFailure(new RuntimeException("Stage " + stage.getStageId() + " sink feed failed", e)); + metrics.incrementTasksFailed(); + onShardTerminated(); + return; } - - outputSink.feed(vsr); metrics.addRowsProcessed(vsr.getRowCount()); if (isLast) { @@ -158,12 +136,6 @@ public void onFailure(Exception e) { }; } - private static void releaseResponseResources(T response) { - if (response instanceof ArrowBatchResponse arrowResp && arrowResp.getRoot() != null) { - arrowResp.getRoot().close(); - } - } - private void onShardTerminated() { if (inFlight.decrementAndGet() == 0) { Exception captured = getFailure(); @@ -174,7 +146,6 @@ private void onShardTerminated() { @Override public void cancel(String reason) { if (transitionTo(StageExecution.State.CANCELLED) == false) return; - // Cancelling the parent task propagates to data-node shard tasks via TaskCancellationService. org.opensearch.tasks.Task parentTask = config.parentTask(); if (parentTask instanceof org.opensearch.tasks.CancellableTask ct && ct.isCancelled() == false) { ct.cancel(reason); @@ -189,9 +160,9 @@ public ExchangeSource outputSource() { throw new UnsupportedOperationException("outputSink does not implement ExchangeSource"); } - /** QTF: returns the ordered list of shard targets for fetch dispatch. */ + /** Returns the ordered list of resolved shard targets. */ public List getResolvedTargets() { - return java.util.Collections.unmodifiableList(resolvedTargets); + return Collections.unmodifiableList(resolvedTargets); } private boolean isDone() { @@ -203,10 +174,6 @@ private PendingExecutions pendingFor(ShardExecutionTarget target) { return pendingPerNode.computeIfAbsent(target.node().getId(), n -> new PendingExecutions(config.maxConcurrentShardRequests())); } - /** - * QTF: Inject a shard_id column into the Arrow batch so the coordinator - * can track which shard each row came from after the reduce merge. - */ private static VectorSchemaRoot injectShardId(VectorSchemaRoot batch, int shardId) { org.apache.arrow.vector.IntVector shardIdVector = new org.apache.arrow.vector.IntVector( "shard_id", @@ -218,7 +185,7 @@ private static VectorSchemaRoot injectShardId(VectorSchemaRoot batch, int shardI } shardIdVector.setValueCount(batch.getRowCount()); - java.util.List vectors = new java.util.ArrayList<>(batch.getFieldVectors()); + List vectors = new ArrayList<>(batch.getFieldVectors()); vectors.add(shardIdVector); return new VectorSchemaRoot(vectors); } diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageScheduler.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageScheduler.java index da616cfdce341..2572d195c8df2 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageScheduler.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/stage/ShardFragmentStageScheduler.java @@ -11,7 +11,6 @@ import org.opensearch.analytics.exec.AnalyticsSearchTransportService; import org.opensearch.analytics.exec.QueryContext; import org.opensearch.analytics.exec.action.FragmentExecutionRequest; -import org.opensearch.analytics.exec.action.FragmentExecutionResponse; import org.opensearch.analytics.planner.dag.ShardExecutionTarget; import org.opensearch.analytics.planner.dag.Stage; import org.opensearch.analytics.planner.dag.StagePlan; @@ -31,10 +30,7 @@ * and doesn't care whether it is a root sink or a parent-provided child sink * — {@link StageExecutionBuilder} resolves that distinction before calling. * - *

Injects a {@link ResponseCodec} into the execution to decouple the wire - * format from stage logic. The default codec ({@link RowResponseCodec}) handles - * the current {@code Object[]} row format; a future Arrow IPC codec would be - * swapped in here. + *

Uses Arrow streaming transport for shard communication. * * @opensearch.internal */ @@ -42,20 +38,10 @@ final class ShardFragmentStageScheduler implements StageScheduler { private final ClusterService clusterService; private final AnalyticsSearchTransportService transport; - private final ResponseCodec responseCodec; ShardFragmentStageScheduler(ClusterService clusterService, AnalyticsSearchTransportService transport) { - this(clusterService, transport, RowResponseCodec.INSTANCE); - } - - ShardFragmentStageScheduler( - ClusterService clusterService, - AnalyticsSearchTransportService transport, - ResponseCodec responseCodec - ) { this.clusterService = clusterService; this.transport = transport; - this.responseCodec = responseCodec; } @Override @@ -73,7 +59,7 @@ public StageExecution createExecution(Stage stage, ExchangeSink sink, QueryConte // This keeps target resolution out of the build phase so cancellation before // dispatch doesn't pay for cluster-state routing, and leaves room for shuffle // reads whose targets depend on child manifests only available at dispatch time. - return new ShardFragmentStageExecution(stage, config, sink, clusterService, requestBuilder, transport, responseCodec); + return new ShardFragmentStageExecution(stage, config, sink, clusterService, requestBuilder, transport); } private static List buildPlanAlternatives(Stage stage) { From 402cb3851317ff4eb346ce4af793a942f11a5db1 Mon Sep 17 00:00:00 2001 From: Arpit Bandejiya Date: Thu, 14 May 2026 21:34:43 +0530 Subject: [PATCH 3/3] refactor a bit --- .../exec/AnalyticsSearchService.java | 204 ++++----- .../exec/AnalyticsSearchTransportService.java | 402 +++++------------- 2 files changed, 176 insertions(+), 430 deletions(-) diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchService.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchService.java index 0847b7716e26a..e7a2ca8140bb6 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchService.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchService.java @@ -9,31 +9,21 @@ package org.opensearch.analytics.exec; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.ipc.ArrowStreamWriter; -import org.apache.arrow.vector.ipc.WriteChannel; -import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; -import org.apache.arrow.vector.ipc.message.IpcOption; -import org.apache.arrow.vector.ipc.message.MessageSerializer; -import org.apache.arrow.vector.types.pojo.Schema; import org.opensearch.analytics.backend.AnalyticsOperationListener; -import org.opensearch.analytics.backend.EngineResultBatch; import org.opensearch.analytics.backend.EngineResultStream; import org.opensearch.analytics.backend.SearchExecEngine; import org.opensearch.analytics.backend.ShardScanExecutionContext; import org.opensearch.analytics.exec.action.FragmentExecutionRequest; -import org.opensearch.analytics.exec.action.FragmentExecutionResponse; import org.opensearch.analytics.exec.task.AnalyticsShardTask; import org.opensearch.analytics.spi.AnalyticsSearchBackendPlugin; import org.opensearch.analytics.spi.BackendExecutionContext; import org.opensearch.analytics.spi.DelegationDescriptor; +import org.opensearch.analytics.spi.DelegationThreadTracker; import org.opensearch.analytics.spi.FilterDelegationHandle; import org.opensearch.analytics.spi.FragmentInstructionHandler; import org.opensearch.analytics.spi.FragmentInstructionHandlerFactory; import org.opensearch.analytics.spi.InstructionNode; import org.opensearch.arrow.flight.transport.ArrowAllocatorProvider; -import org.opensearch.common.Nullable; import org.opensearch.common.concurrent.GatedCloseable; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.tasks.TaskCancelledException; @@ -41,11 +31,9 @@ import org.opensearch.index.engine.exec.IndexReaderProvider.Reader; import org.opensearch.index.shard.IndexShard; import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskResourceTrackingService; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.nio.channels.Channels; -import java.util.Iterator; import java.util.List; import java.util.Map; @@ -72,6 +60,7 @@ public class AnalyticsSearchService implements AutoCloseable { private final AnalyticsOperationListener listener; private final BufferAllocator allocator; private final NamedWriteableRegistry namedWriteableRegistry; + private TaskResourceTrackingService taskResourceTrackingService; public AnalyticsSearchService(Map backends) { this(backends, List.of(), null); @@ -97,78 +86,8 @@ public void close() { allocator.close(); } - public FragmentExecutionResponse executeFragment(FragmentExecutionRequest request, IndexShard shard) { - return executeFragment(request, shard, null); - } - - public FragmentExecutionResponse executeFragment(FragmentExecutionRequest request, IndexShard shard, AnalyticsShardTask task) { - ResolvedFragment resolved = resolveFragment(request, shard); - long startNanos = System.nanoTime(); - try (FragmentResources ctx = startFragment(request, resolved, shard, task)) { - FragmentExecutionResponse response = collectResponse(ctx.stream(), task); - long tookNanos = System.nanoTime() - startNanos; - listener.onFragmentSuccess(resolved.queryId, resolved.stageId, resolved.shardIdStr, tookNanos, response.getRowCount()); - return response; - } catch (TaskCancelledException | IllegalStateException | IllegalArgumentException e) { - listener.onFragmentFailure(resolved.queryId, resolved.stageId, resolved.shardIdStr, e); - throw e; - } catch (Exception e) { - listener.onFragmentFailure(resolved.queryId, resolved.stageId, resolved.shardIdStr, e); - throw new RuntimeException("Failed to execute fragment on " + shard.shardId(), e); - } - } - - /** - * QTF fetch phase: read specific rows by global row ID. - * Bypasses Substrait plan resolution — calls directly into backend's FFM. - */ - public org.opensearch.analytics.exec.action.FetchByRowIdsResponse executeFetchByRowIds( - org.opensearch.analytics.exec.action.FetchByRowIdsRequest request, - IndexShard shard, - AnalyticsShardTask task - ) { - long startNanos = System.nanoTime(); - String shardIdStr = shard.shardId().toString(); - try { - EngineResultStream stream = executeFetchStreaming(request, shard, task); - FragmentExecutionResponse fragmentResp = collectResponse(stream, task); - long tookNanos = System.nanoTime() - startNanos; - listener.onFragmentSuccess(request.getQueryId(), 0, shardIdStr, tookNanos, fragmentResp.getRowCount()); - return new org.opensearch.analytics.exec.action.FetchByRowIdsResponse(fragmentResp.getIpcPayload(), fragmentResp.getRowCount()); - } catch (Exception e) { - listener.onFragmentFailure(request.getQueryId(), 0, shardIdStr, e); - throw new RuntimeException("Failed to execute fetch-by-row-ids on " + shard.shardId(), e); - } - } - - /** - * Streaming variant: returns the raw EngineResultStream for the fetch phase. - * Used by the streaming transport handler to send Arrow batches directly. - */ - public EngineResultStream executeFetchStreaming( - org.opensearch.analytics.exec.action.FetchByRowIdsRequest request, - IndexShard shard, - AnalyticsShardTask task - ) { - IndexReaderProvider readerProvider = shard.getReaderProvider(); - if (readerProvider == null) { - throw new IllegalStateException("No ReaderProvider on " + shard.shardId()); - } - try { - GatedCloseable gatedReader = readerProvider.acquireReader(); - long[] rowIds = request.getRowIds(); - org.apache.arrow.vector.BigIntVector rowIdVector = new org.apache.arrow.vector.BigIntVector("__row_id__", allocator); - rowIdVector.allocateNew(rowIds.length); - for (int i = 0; i < rowIds.length; i++) { - rowIdVector.set(i, rowIds[i]); - } - rowIdVector.setValueCount(rowIds.length); - - AnalyticsSearchBackendPlugin backend = backends.values().iterator().next(); - return backend.fetchByRowIds(gatedReader.get(), rowIdVector, request.getColumns(), allocator); - } catch (Exception e) { - throw new RuntimeException("Failed to start fetch-by-row-ids on " + shard.shardId(), e); - } + public void setTaskResourceTrackingService(TaskResourceTrackingService service) { + this.taskResourceTrackingService = service; } public FragmentResources executeFragmentStreaming(FragmentExecutionRequest request, IndexShard shard, AnalyticsShardTask task) { @@ -190,6 +109,7 @@ private FragmentResources startFragment(FragmentExecutionRequest request, Resolv SearchExecEngine engine = null; EngineResultStream stream = null; BackendExecutionContext backendContext = null; + Runnable trackerCleanup = null; try { ShardScanExecutionContext ctx = buildContext(request, gatedReader.get(), resolved.plan, shard, task); AnalyticsSearchBackendPlugin backend = backends.get(resolved.plan.getBackendId()); @@ -213,14 +133,33 @@ private FragmentResources startFragment(FragmentExecutionRequest request, Resolv AnalyticsSearchBackendPlugin acceptingBackend = backends.get(acceptingBackendId); FilterDelegationHandle handle = acceptingBackend.getFilterDelegationHandle(delegation.delegatedExpressions(), ctx); backend.configureFilterDelegation(handle, backendContext); + + if (task != null && taskResourceTrackingService != null) { + long taskId = task.getId(); + TaskResourceTrackingService service = taskResourceTrackingService; + backend.setDelegationThreadTracker(new DelegationThreadTracker() { + @Override + public long trackStart() { + long threadId = Thread.currentThread().threadId(); + service.taskExecutionStartedOnThread(taskId, threadId); + return threadId; + } + + @Override + public void trackEnd(long threadId) { + service.taskExecutionFinishedOnThread(taskId, threadId); + } + }); + trackerCleanup = () -> backend.setDelegationThreadTracker(null); + } } engine = backend.getSearchExecEngineProvider().createSearchExecEngine(ctx, backendContext); stream = engine.execute(ctx); - return new FragmentResources(gatedReader, engine, stream); + return new FragmentResources(gatedReader, engine, stream, trackerCleanup); } catch (Exception e) { try { - new FragmentResources(gatedReader, engine, stream).close(); + new FragmentResources(gatedReader, engine, stream, trackerCleanup).close(); } catch (Exception suppressed) { e.addSuppressed(suppressed); } @@ -287,47 +226,64 @@ private ShardScanExecutionContext buildContext( return ctx; } - FragmentExecutionResponse collectResponse(EngineResultStream stream) { - return collectResponse(stream, null); - } - - FragmentExecutionResponse collectResponse(EngineResultStream stream, @Nullable AnalyticsShardTask task) { - // Serialize incoming Arrow batches as an Arrow IPC stream: one schema header - // followed by one record-batch message per incoming batch. Arrow's own - // serializer handles every Arrow type — no per-type Java code path. - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - WriteChannel channel = new WriteChannel(Channels.newChannel(baos)); - Schema schema = null; - int totalRows = 0; - Iterator it = stream.iterator(); + /** + * QTF fetch phase: retrieves specific rows by global row ID. + */ + public org.opensearch.analytics.exec.action.FetchByRowIdsResponse executeFetchByRowIds( + org.opensearch.analytics.exec.action.FetchByRowIdsRequest request, + IndexShard shard, + AnalyticsShardTask task + ) { try { - while (it.hasNext()) { - if (task != null && task.isCancelled()) { - throw new TaskCancelledException("task cancelled: " + task.getReasonCancelled()); + IndexReaderProvider readerProvider = shard.getReaderProvider(); + if (readerProvider == null) { + throw new IllegalStateException("No ReaderProvider on " + shard.shardId()); + } + try (GatedCloseable gatedReader = readerProvider.acquireReader()) { + long[] rowIds = request.getRowIds(); + org.apache.arrow.vector.BigIntVector rowIdVector = new org.apache.arrow.vector.BigIntVector("__row_id__", allocator); + rowIdVector.allocateNew(rowIds.length); + for (int i = 0; i < rowIds.length; i++) { + rowIdVector.set(i, rowIds[i]); } - EngineResultBatch batch = it.next(); - VectorSchemaRoot root = batch.getArrowRoot(); - try { - if (schema == null) { - schema = root.getSchema(); - MessageSerializer.serialize(channel, schema); - } - try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch()) { - MessageSerializer.serialize(channel, recordBatch); + rowIdVector.setValueCount(rowIds.length); + + AnalyticsSearchBackendPlugin backend = backends.values().iterator().next(); + EngineResultStream stream = backend.fetchByRowIds(gatedReader.get(), rowIdVector, request.getColumns(), allocator); + + // Serialize stream to Arrow IPC bytes + java.io.ByteArrayOutputStream baos = new java.io.ByteArrayOutputStream(); + org.apache.arrow.vector.ipc.WriteChannel channel = + new org.apache.arrow.vector.ipc.WriteChannel(java.nio.channels.Channels.newChannel(baos)); + org.apache.arrow.vector.types.pojo.Schema schema = null; + int totalRows = 0; + java.util.Iterator it = stream.iterator(); + while (it.hasNext()) { + org.opensearch.analytics.backend.EngineResultBatch batch = it.next(); + org.apache.arrow.vector.VectorSchemaRoot root = batch.getArrowRoot(); + try { + if (schema == null) { + schema = root.getSchema(); + org.apache.arrow.vector.ipc.message.MessageSerializer.serialize(channel, schema); + } + try (org.apache.arrow.vector.ipc.message.ArrowRecordBatch rb = + new org.apache.arrow.vector.VectorUnloader(root).getRecordBatch()) { + org.apache.arrow.vector.ipc.message.MessageSerializer.serialize(channel, rb); + } + totalRows += root.getRowCount(); + } finally { + root.close(); } - totalRows += root.getRowCount(); - } finally { - root.close(); } + if (schema != null) { + org.apache.arrow.vector.ipc.ArrowStreamWriter.writeEndOfStream( + channel, org.apache.arrow.vector.ipc.message.IpcOption.DEFAULT); + } + rowIdVector.close(); + return new org.opensearch.analytics.exec.action.FetchByRowIdsResponse(baos.toByteArray(), totalRows); } - if (schema != null) { - // Write the end-of-stream marker so the reader sees a clean EOS - // instead of hitting end-of-input mid-message. - ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT); - } - } catch (IOException e) { - throw new IllegalStateException("Failed to serialize fragment output as Arrow IPC stream", e); + } catch (Exception e) { + throw new RuntimeException("Failed to execute fetch-by-row-ids on " + shard.shardId(), e); } - return new FragmentExecutionResponse(baos.toByteArray(), totalRows); } } diff --git a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchTransportService.java b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchTransportService.java index e6204789593ed..05bd44e2bed3e 100644 --- a/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchTransportService.java +++ b/sandbox/plugins/analytics-engine/src/main/java/org/opensearch/analytics/exec/AnalyticsSearchTransportService.java @@ -15,27 +15,23 @@ import org.opensearch.analytics.exec.action.FragmentExecutionAction; import org.opensearch.analytics.exec.action.FragmentExecutionArrowResponse; import org.opensearch.analytics.exec.action.FragmentExecutionRequest; -import org.opensearch.analytics.exec.action.FragmentExecutionResponse; import org.opensearch.analytics.exec.task.AnalyticsShardTask; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.Nullable; import org.opensearch.common.inject.Inject; import org.opensearch.common.inject.Singleton; -import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.index.shard.IndexShard; import org.opensearch.indices.IndicesService; import org.opensearch.ratelimitting.admissioncontrol.enums.AdmissionControlActionType; import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.StreamTransportService; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportResponseHandler; -import org.opensearch.transport.TransportService; import org.opensearch.transport.stream.StreamErrorCode; import org.opensearch.transport.stream.StreamException; import org.opensearch.transport.stream.StreamTransportResponse; @@ -44,67 +40,43 @@ import java.util.Iterator; /** - * Stateless transport dispatch component for fragment requests. Owns - * {@link TransportService} (or {@link StreamTransportService}) and + * Stateless transport dispatch component for fragment requests. Owns the + * {@link StreamTransportService} (analytics-engine is streaming-only) and * connection lookup. * - *

Does NOT track per-query or per-node concurrency - * state — callers provide their own {@link PendingExecutions} instance - * to gate dispatch concurrency. + *

Does NOT track per-query or per-node concurrency state — callers provide + * their own {@link PendingExecutions} instance to gate dispatch concurrency. * * @opensearch.internal */ @Singleton public class AnalyticsSearchTransportService { - private final TransportService transportService; + private final StreamTransportService transportService; private final ClusterService clusterService; - private final boolean streamingEnabled; @Inject public AnalyticsSearchTransportService( - TransportService transportService, - @Nullable StreamTransportService streamTransportService, + StreamTransportService streamTransportService, ClusterService clusterService, AnalyticsSearchService searchService, - IndicesService indicesService + IndicesService indicesService, + TaskResourceTrackingService taskResourceTrackingService ) { - this.streamingEnabled = streamTransportService != null; - this.transportService = this.streamingEnabled ? streamTransportService : transportService; - this.clusterService = clusterService; - if (this.streamingEnabled) { - registerStreamingFragmentHandler(this.transportService, searchService, indicesService); - } else { - registerFragmentHandler(this.transportService, searchService, indicesService); + if (streamTransportService == null) { + throw new IllegalStateException( + "analytics-engine requires the STREAM_TRANSPORT feature flag to be enabled " + + "(opensearch.experimental.feature.stream_transport.enabled=true)" + ); } + searchService.setTaskResourceTrackingService(taskResourceTrackingService); + this.transportService = streamTransportService; + this.clusterService = clusterService; + registerStreamingFragmentHandler(this.transportService, searchService, indicesService); registerFetchHandler(this.transportService, searchService, indicesService); } - public boolean isStreamingEnabled() { - return streamingEnabled; - } - - private static void registerFragmentHandler( - TransportService transportService, - AnalyticsSearchService searchService, - IndicesService indicesService - ) { - transportService.registerRequestHandler( - FragmentExecutionAction.NAME, - ThreadPool.Names.SAME, - false, - true, - AdmissionControlActionType.SEARCH, - FragmentExecutionRequest::new, - (request, channel, task) -> { - IndexShard shard = indicesService.indexServiceSafe(request.getShardId().getIndex()).getShard(request.getShardId().id()); - FragmentExecutionResponse response = searchService.executeFragment(request, shard, (AnalyticsShardTask) task); - channel.sendResponse(response); - } - ); - } - private static void registerStreamingFragmentHandler( - TransportService transportService, + StreamTransportService transportService, AnalyticsSearchService searchService, IndicesService indicesService ) { @@ -136,214 +108,6 @@ private static void registerStreamingFragmentHandler( ); } - private static void registerFetchHandler( - TransportService transportService, - AnalyticsSearchService searchService, - IndicesService indicesService - ) { - if (transportService instanceof StreamTransportService) { - // Streaming path: send Arrow batches directly - transportService.registerRequestHandler( - FetchByRowIdsAction.NAME, - ThreadPool.Names.SAME, - false, - true, - AdmissionControlActionType.SEARCH, - FetchByRowIdsRequest::new, - (request, channel, task) -> { - IndexShard shard = indicesService.indexServiceSafe(request.getShardId().getIndex()).getShard(request.getShardId().id()); - try { - org.opensearch.analytics.backend.EngineResultStream stream = searchService.executeFetchStreaming( - request, - shard, - (AnalyticsShardTask) task - ); - Iterator it = stream.iterator(); - while (it.hasNext()) { - org.opensearch.analytics.backend.EngineResultBatch batch = it.next(); - channel.sendResponseBatch( - new org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse(batch.getArrowRoot()) - ); - } - channel.completeStream(); - } catch (StreamException e) { - if (e.getErrorCode() != StreamErrorCode.CANCELLED) { - channel.sendResponse(e); - } - } catch (Exception e) { - channel.sendResponse(e); - } - } - ); - } else { - // Non-streaming path: serialize to IPC bytes - transportService.registerRequestHandler( - FetchByRowIdsAction.NAME, - ThreadPool.Names.SAME, - false, - true, - AdmissionControlActionType.SEARCH, - FetchByRowIdsRequest::new, - (request, channel, task) -> { - IndexShard shard = indicesService.indexServiceSafe(request.getShardId().getIndex()).getShard(request.getShardId().id()); - FetchByRowIdsResponse response = searchService.executeFetchByRowIds(request, shard, (AnalyticsShardTask) task); - channel.sendResponse(response); - } - ); - } - } - - public void dispatchFetch( - FetchByRowIdsRequest request, - DiscoveryNode targetNode, - StreamingResponseListener listener, - Task parentTask - ) { - if (streamingEnabled) { - dispatchFetchStreaming(request, targetNode, listener, parentTask); - } else { - dispatchFetchNonStreaming(request, targetNode, listener, parentTask); - } - } - - private void dispatchFetchStreaming( - FetchByRowIdsRequest request, - DiscoveryNode targetNode, - StreamingResponseListener listener, - Task parentTask - ) { - try { - Transport.Connection connection = getConnection(null, targetNode.getId()); - transportService.sendChildRequest( - connection, - FetchByRowIdsAction.NAME, - request, - parentTask, - TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(), - new TransportResponseHandler() { - @Override - public org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse read(StreamInput in) throws IOException { - return new org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse(in); - } - - @Override - public boolean skipsDeserialization() { - return true; - } - - @Override - public String executor() { - return ThreadPool.Names.SAME; - } - - @Override - public void handleStreamResponse( - StreamTransportResponse stream - ) { - try { - org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse current; - org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse last = null; - while ((current = stream.nextResponse()) != null) { - if (last != null) { - listener.onStreamResponse(wrapArrowAsResponse(last), false); - } - last = current; - } - if (last != null) { - listener.onStreamResponse(wrapArrowAsResponse(last), true); - } - } catch (Exception e) { - listener.onFailure(e); - } finally { - try { - stream.close(); - } catch (Exception ignore) {} - } - } - - @Override - public void handleResponse(org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse response) { - listener.onStreamResponse(wrapArrowAsResponse(response), true); - } - - @Override - public void handleException(TransportException e) { - listener.onFailure(e); - } - } - ); - } catch (Exception e) { - listener.onFailure(e); - } - } - - private void dispatchFetchNonStreaming( - FetchByRowIdsRequest request, - DiscoveryNode targetNode, - StreamingResponseListener listener, - Task parentTask - ) { - try { - Transport.Connection connection = getConnection(null, targetNode.getId()); - transportService.sendChildRequest( - connection, - FetchByRowIdsAction.NAME, - request, - parentTask, - TransportRequestOptions.EMPTY, - new TransportResponseHandler() { - @Override - public FetchByRowIdsResponse read(StreamInput in) throws IOException { - return new FetchByRowIdsResponse(in); - } - - @Override - public String executor() { - return ThreadPool.Names.SAME; - } - - @Override - public void handleResponse(FetchByRowIdsResponse response) { - listener.onStreamResponse(response, true); - } - - @Override - public void handleException(TransportException e) { - listener.onFailure(e); - } - } - ); - } catch (Exception e) { - listener.onFailure(e); - } - } - - private static FetchByRowIdsResponse wrapArrowAsResponse(org.opensearch.analytics.exec.action.FetchByRowIdsArrowResponse arrowResp) { - // For the streaming path, wrap the Arrow batch as IPC bytes for uniform handling - // in LateMaterializationStageExecution's assembler. TODO: pass VectorSchemaRoot directly. - org.apache.arrow.vector.VectorSchemaRoot root = arrowResp.getRoot(); - if (root == null) return new FetchByRowIdsResponse(new byte[0], 0); - try { - java.io.ByteArrayOutputStream baos = new java.io.ByteArrayOutputStream(); - org.apache.arrow.vector.ipc.WriteChannel channel = new org.apache.arrow.vector.ipc.WriteChannel( - java.nio.channels.Channels.newChannel(baos) - ); - org.apache.arrow.vector.ipc.message.MessageSerializer.serialize(channel, root.getSchema()); - try ( - org.apache.arrow.vector.ipc.message.ArrowRecordBatch batch = new org.apache.arrow.vector.VectorUnloader(root) - .getRecordBatch() - ) { - org.apache.arrow.vector.ipc.message.MessageSerializer.serialize(channel, batch); - } - org.apache.arrow.vector.ipc.ArrowStreamWriter.writeEndOfStream(channel, org.apache.arrow.vector.ipc.message.IpcOption.DEFAULT); - return new FetchByRowIdsResponse(baos.toByteArray(), root.getRowCount()); - } catch (Exception e) { - throw new RuntimeException("Failed to serialize Arrow batch to IPC", e); - } finally { - root.close(); - } - } - Transport.Connection getConnection(String clusterAlias, String nodeId) { DiscoveryNode node = clusterService.state().nodes().get(nodeId); return transportService.getConnection(node); @@ -356,56 +120,15 @@ public void dispatchFragmentStreaming( Task parentTask, PendingExecutions pending ) { - dispatchFragment( - request, - targetNode, - listener, - parentTask, - pending, - in -> new FragmentExecutionArrowResponse(in), - TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(), - true - ); - } - - public void dispatchFragment( - FragmentExecutionRequest request, - DiscoveryNode targetNode, - StreamingResponseListener listener, - Task parentTask, - PendingExecutions pending - ) { - dispatchFragment( - request, - targetNode, - listener, - parentTask, - pending, - in -> new FragmentExecutionResponse(in), - TransportRequestOptions.EMPTY, - false - ); - } - - private void dispatchFragment( - FragmentExecutionRequest request, - DiscoveryNode targetNode, - StreamingResponseListener listener, - Task parentTask, - PendingExecutions pending, - Writeable.Reader reader, - TransportRequestOptions options, - boolean skipsDeserialization - ) { - TransportResponseHandler handler = new TransportResponseHandler<>() { + TransportResponseHandler handler = new TransportResponseHandler<>() { @Override - public T read(StreamInput in) throws IOException { - return reader.read(in); + public FragmentExecutionArrowResponse read(StreamInput in) throws IOException { + return new FragmentExecutionArrowResponse(in); } @Override public boolean skipsDeserialization() { - return skipsDeserialization; + return true; } @Override @@ -414,10 +137,10 @@ public String executor() { } @Override - public void handleStreamResponse(StreamTransportResponse stream) { + public void handleStreamResponse(StreamTransportResponse stream) { try { - T current; - T last = null; + FragmentExecutionArrowResponse current; + FragmentExecutionArrowResponse last = null; while ((current = stream.nextResponse()) != null) { if (last != null) { listener.onStreamResponse(last, false); @@ -438,9 +161,12 @@ public void handleStreamResponse(StreamTransportResponse stream) { } @Override - public void handleResponse(T response) { - listener.onStreamResponse(response, true); - pending.finishAndRunNext(); + public void handleResponse(FragmentExecutionArrowResponse response) { + try { + listener.onStreamResponse(response, true); + } finally { + pending.finishAndRunNext(); + } } @Override @@ -453,6 +179,7 @@ public void handleException(TransportException e) { } }; + TransportRequestOptions options = TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(); pending.tryRun(() -> { try { Transport.Connection connection = getConnection(null, targetNode.getId()); @@ -466,4 +193,67 @@ public void handleException(TransportException e) { } }); } + + // ── QTF Fetch ──────────────────────────────────────────────────────────────── + + private static void registerFetchHandler( + StreamTransportService transportService, + AnalyticsSearchService searchService, + IndicesService indicesService + ) { + transportService.registerRequestHandler( + FetchByRowIdsAction.NAME, + ThreadPool.Names.SAME, + false, + true, + AdmissionControlActionType.SEARCH, + FetchByRowIdsRequest::new, + (request, channel, task) -> { + IndexShard shard = indicesService.indexServiceSafe(request.getShardId().getIndex()).getShard(request.getShardId().id()); + FetchByRowIdsResponse response = searchService.executeFetchByRowIds(request, shard, (AnalyticsShardTask) task); + channel.sendResponse(response); + } + ); + } + + public void dispatchFetch( + FetchByRowIdsRequest request, + DiscoveryNode targetNode, + StreamingResponseListener listener, + Task parentTask + ) { + try { + Transport.Connection connection = getConnection(null, targetNode.getId()); + transportService.sendChildRequest( + connection, + FetchByRowIdsAction.NAME, + request, + parentTask, + TransportRequestOptions.EMPTY, + new TransportResponseHandler() { + @Override + public FetchByRowIdsResponse read(StreamInput in) throws IOException { + return new FetchByRowIdsResponse(in); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public void handleResponse(FetchByRowIdsResponse response) { + listener.onStreamResponse(response, true); + } + + @Override + public void handleException(TransportException e) { + listener.onFailure(e); + } + } + ); + } catch (Exception e) { + listener.onFailure(e); + } + } }