From ac5994ef8194a15c38f26db1337da5ae41248588 Mon Sep 17 00:00:00 2001 From: Songkan Tang Date: Mon, 19 Jan 2026 15:16:43 +0800 Subject: [PATCH 1/9] General UDAF pushdown as scripts Signed-off-by: Songkan Tang --- .../patterns/PatternAggregationHelpers.java | 493 ++++++++++++++++++ .../sql/calcite/CalciteRelNodeVisitor.java | 128 +++-- .../udf/udaf/LogPatternAggFunction.java | 22 +- .../utils/UserDefinedFunctionUtils.java | 45 +- .../function/BuiltinFunctionName.java | 5 + .../function/PPLBuiltinOperators.java | 48 ++ .../expression/function/PPLFuncImpTable.java | 9 + .../function/PatternParserFunctionImpl.java | 87 +++- .../calcite/remote/CalcitePPLPatternsIT.java | 183 +++++++ .../org/opensearch/sql/ppl/ExplainIT.java | 1 - .../explain_patterns_brain_agg_push.yaml | 8 +- .../value/OpenSearchExprValueFactory.java | 8 + .../opensearch/request/AggregateAnalyzer.java | 12 + .../request/PatternScriptedMetricUDAF.java | 147 ++++++ .../request/ScriptedMetricUDAF.java | 235 +++++++++ .../request/ScriptedMetricUDAFRegistry.java | 84 +++ .../response/agg/ScriptedMetricParser.java | 53 ++ .../storage/script/CalciteScriptEngine.java | 25 +- ...iteScriptedMetricCombineScriptFactory.java | 63 +++ ...alciteScriptedMetricInitScriptFactory.java | 71 +++ ...CalciteScriptedMetricMapScriptFactory.java | 100 ++++ ...citeScriptedMetricReduceScriptFactory.java | 63 +++ .../ScriptedMetricDataContext.java | 181 +++++++ .../storage/serde/ScriptParameterHelper.java | 14 + .../ppl/calcite/CalcitePPLPatternsTest.java | 41 +- 25 files changed, 2033 insertions(+), 93 deletions(-) create mode 100644 common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/request/PatternScriptedMetricUDAF.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAF.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAFRegistry.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/ScriptedMetricParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricCombineScriptFactory.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricInitScriptFactory.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricMapScriptFactory.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricReduceScriptFactory.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricDataContext.java diff --git a/common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java b/common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java new file mode 100644 index 00000000000..fae4b61c1eb --- /dev/null +++ b/common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java @@ -0,0 +1,493 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.patterns; + +import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Static helper methods for pattern aggregation operations. These methods wrap the complex logic in + * BrainLogParser and PatternUtils to be callable from UDFs in scripted metric aggregations. + */ +public final class PatternAggregationHelpers { + + private PatternAggregationHelpers() { + // Utility class + } + + /** + * Initialize pattern accumulator state. + * + * @return Empty accumulator map with logMessages buffer and patternGroupMap + */ + public static Map initPatternAccumulator() { + Map acc = new HashMap<>(); + acc.put("logMessages", new ArrayList()); + acc.put("patternGroupMap", new HashMap>()); + return acc; + } + + /** + * Initialize pattern accumulator state in-place. This method is designed for OpenSearch scripted + * metric aggregation's init_script phase, where the state map is provided by OpenSearch and must + * be modified in-place rather than replaced. + * + * @param state The mutable state map provided by OpenSearch (will be modified in-place) + * @return The same state map (for chaining/return value) + */ + @SuppressWarnings("unchecked") + public static Map initPatternState(Object state) { + Map stateMap = (Map) state; + stateMap.put("logMessages", new ArrayList()); + stateMap.put("patternGroupMap", new HashMap>()); + return stateMap; + } + + /** + * Add a log message to the accumulator (overload for Object acc and int thresholdPercentage). + * This overload handles the case when the accumulator is passed as a generic Object and + * thresholdPercentage is passed as an integer at runtime (from the script engine). + * + * @param acc Current accumulator state (as Object, will be cast to Map) + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as int) + * @return Updated accumulator + */ + @SuppressWarnings("unchecked") + public static Map addLogToPattern( + Object acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + int thresholdPercentage) { + return addLogToPattern( + (Map) acc, + logMessage, + maxSampleCount, + bufferLimit, + variableCountThreshold, + (double) thresholdPercentage); + } + + /** + * Add a log message to the accumulator (overload for Object acc and BigDecimal + * thresholdPercentage). This overload handles the case when the accumulator is passed as a + * generic Object and thresholdPercentage is passed as BigDecimal at runtime. + * + * @param acc Current accumulator state (as Object, will be cast to Map) + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as BigDecimal) + * @return Updated accumulator + */ + @SuppressWarnings("unchecked") + public static Map addLogToPattern( + Object acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + java.math.BigDecimal thresholdPercentage) { + return addLogToPattern( + (Map) acc, + logMessage, + maxSampleCount, + bufferLimit, + variableCountThreshold, + thresholdPercentage != null + ? thresholdPercentage.doubleValue() + : BrainLogParser.DEFAULT_FREQUENCY_THRESHOLD_PERCENTAGE); + } + + /** + * Add a log message to the accumulator (overload for Object acc and double thresholdPercentage). + * This overload handles the case when the accumulator is passed as a generic Object and + * thresholdPercentage is passed as double at runtime. + * + * @param acc Current accumulator state (as Object, will be cast to Map) + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as double) + * @return Updated accumulator + */ + @SuppressWarnings("unchecked") + public static Map addLogToPattern( + Object acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + double thresholdPercentage) { + return addLogToPattern( + (Map) acc, + logMessage, + maxSampleCount, + bufferLimit, + variableCountThreshold, + thresholdPercentage); + } + + /** + * Add a log message to the accumulator (overload for int thresholdPercentage). This overload + * handles the case when thresholdPercentage is passed as an integer at runtime. + * + * @param acc Current accumulator state + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as int) + * @return Updated accumulator + */ + public static Map addLogToPattern( + Map acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + int thresholdPercentage) { + return addLogToPattern( + acc, + logMessage, + maxSampleCount, + bufferLimit, + variableCountThreshold, + (double) thresholdPercentage); + } + + /** + * Add a log message to the accumulator (overload for BigDecimal thresholdPercentage). This + * overload handles the case when thresholdPercentage is passed as BigDecimal at runtime. + * + * @param acc Current accumulator state + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as BigDecimal) + * @return Updated accumulator + */ + public static Map addLogToPattern( + Map acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + java.math.BigDecimal thresholdPercentage) { + return addLogToPattern( + acc, + logMessage, + maxSampleCount, + bufferLimit, + variableCountThreshold, + thresholdPercentage != null + ? thresholdPercentage.doubleValue() + : BrainLogParser.DEFAULT_FREQUENCY_THRESHOLD_PERCENTAGE); + } + + /** + * Add a log message to the accumulator and trigger partial merge if buffer is full. + * + * @param acc Current accumulator state + * @param logMessage The log message to process + * @param maxSampleCount Maximum samples to keep per pattern + * @param bufferLimit Maximum buffer size before triggering partial merge + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage + * @return Updated accumulator + */ + @SuppressWarnings("unchecked") + public static Map addLogToPattern( + Map acc, + String logMessage, + int maxSampleCount, + int bufferLimit, + int variableCountThreshold, + double thresholdPercentage) { + + if (logMessage == null) { + return acc; + } + + List logMessages = (List) acc.get("logMessages"); + logMessages.add(logMessage); + + // Trigger partial merge when buffer reaches limit + if (bufferLimit > 0 && logMessages.size() >= bufferLimit) { + Map> patternGroupMap = + (Map>) acc.get("patternGroupMap"); + + BrainLogParser parser = + new BrainLogParser(variableCountThreshold, (float) thresholdPercentage); + Map> partialPatterns = + parser.parseAllLogPatterns(logMessages, maxSampleCount); + + patternGroupMap = + PatternUtils.mergePatternGroups(patternGroupMap, partialPatterns, maxSampleCount); + + acc.put("patternGroupMap", patternGroupMap); + logMessages.clear(); + } + + return acc; + } + + /** + * Combine two accumulators (for combine_script phase). + * + * @param acc1 First accumulator + * @param acc2 Second accumulator + * @param maxSampleCount Maximum samples to keep per pattern + * @return Merged accumulator + */ + @SuppressWarnings("unchecked") + public static Map combinePatternAccumulators( + Map acc1, Map acc2, int maxSampleCount) { + + Map> patterns1 = + (Map>) acc1.get("patternGroupMap"); + Map> patterns2 = + (Map>) acc2.get("patternGroupMap"); + + Map> merged = + PatternUtils.mergePatternGroups(patterns1, patterns2, maxSampleCount); + + Map result = new HashMap<>(); + result.put("logMessages", new ArrayList<>()); + result.put("patternGroupMap", merged); + return result; + } + + /** + * Produce final pattern result (for reduce_script phase). + * + * @param acc Accumulator state + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + @SuppressWarnings("unchecked") + public static List> producePatternResult( + Map acc, + int maxSampleCount, + int variableCountThreshold, + double thresholdPercentage, + boolean showNumberedToken) { + + // Process any remaining logs in buffer + List logMessages = (List) acc.get("logMessages"); + Map> patternGroupMap = + (Map>) acc.get("patternGroupMap"); + + if (logMessages != null && !logMessages.isEmpty()) { + BrainLogParser parser = + new BrainLogParser(variableCountThreshold, (float) thresholdPercentage); + Map> partialPatterns = + parser.parseAllLogPatterns(logMessages, maxSampleCount); + patternGroupMap = + PatternUtils.mergePatternGroups(patternGroupMap, partialPatterns, maxSampleCount); + } + + // Format and sort final output by pattern count + return patternGroupMap.values().stream() + .sorted( + Comparator.comparing( + m -> (Long) m.get(PatternUtils.PATTERN_COUNT), + Comparator.nullsLast(Comparator.reverseOrder()))) + .map(m -> formatPatternOutput(m, showNumberedToken)) + .collect(Collectors.toList()); + } + + /** + * Produce final pattern result from states array (overload for int thresholdPercentage). This + * overload handles the case when thresholdPercentage is passed as an integer at runtime. + * + * @param states List of shard-level accumulator states + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as int) + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + public static List> producePatternResultFromStates( + List states, + int maxSampleCount, + int variableCountThreshold, + int thresholdPercentage, + boolean showNumberedToken) { + return producePatternResultFromStates( + states, + maxSampleCount, + variableCountThreshold, + (double) thresholdPercentage, + showNumberedToken); + } + + /** + * Produce final pattern result from states array (overload for Object states). This overload + * handles the case when states is passed as a generic Object at runtime due to type erasure. + * + * @param states List of shard-level accumulator states (as Object, will be cast to List) + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + @SuppressWarnings("unchecked") + public static List> producePatternResultFromStates( + Object states, + int maxSampleCount, + int variableCountThreshold, + double thresholdPercentage, + boolean showNumberedToken) { + return producePatternResultFromStates( + (List) states, + maxSampleCount, + variableCountThreshold, + thresholdPercentage, + showNumberedToken); + } + + /** + * Produce final pattern result from states array (overload for Object states with BigDecimal). + * This overload handles the case when states is Object and thresholdPercentage is BigDecimal. + * + * @param states List of shard-level accumulator states (as Object, will be cast to List) + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as BigDecimal) + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + @SuppressWarnings("unchecked") + public static List> producePatternResultFromStates( + Object states, + int maxSampleCount, + int variableCountThreshold, + java.math.BigDecimal thresholdPercentage, + boolean showNumberedToken) { + return producePatternResultFromStates( + (List) states, + maxSampleCount, + variableCountThreshold, + thresholdPercentage != null + ? thresholdPercentage.doubleValue() + : BrainLogParser.DEFAULT_FREQUENCY_THRESHOLD_PERCENTAGE, + showNumberedToken); + } + + /** + * Produce final pattern result from states array (overload for BigDecimal thresholdPercentage). + * This overload handles the case when thresholdPercentage is passed as BigDecimal at runtime. + * + * @param states List of shard-level accumulator states + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage (as BigDecimal) + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + public static List> producePatternResultFromStates( + List states, + int maxSampleCount, + int variableCountThreshold, + java.math.BigDecimal thresholdPercentage, + boolean showNumberedToken) { + return producePatternResultFromStates( + states, + maxSampleCount, + variableCountThreshold, + thresholdPercentage != null + ? thresholdPercentage.doubleValue() + : BrainLogParser.DEFAULT_FREQUENCY_THRESHOLD_PERCENTAGE, + showNumberedToken); + } + + /** + * Produce final pattern result from states array (for reduce_script phase). This method combines + * all shard-level states and produces the final aggregated result. + * + * @param states List of shard-level accumulator states + * @param maxSampleCount Maximum samples per pattern + * @param variableCountThreshold Brain parser variable count threshold + * @param thresholdPercentage Brain parser frequency threshold percentage + * @param showNumberedToken Whether to show numbered tokens in output + * @return List of pattern result objects sorted by count + */ + @SuppressWarnings("unchecked") + public static List> producePatternResultFromStates( + List states, + int maxSampleCount, + int variableCountThreshold, + double thresholdPercentage, + boolean showNumberedToken) { + + if (states == null || states.isEmpty()) { + return new ArrayList<>(); + } + + // Combine all states into a single accumulator + Map combined = (Map) states.get(0); + for (int i = 1; i < states.size(); i++) { + Map state = (Map) states.get(i); + combined = combinePatternAccumulators(combined, state, maxSampleCount); + } + + // Produce final result from combined state + return producePatternResult( + combined, maxSampleCount, variableCountThreshold, thresholdPercentage, showNumberedToken); + } + + /** + * Format a single pattern result for output. + * + *

Note: Token extraction is NOT done here. The pattern is returned with wildcards (e.g., + * {@code <*>}) and token extraction is performed later by {@code + * PatternParserFunctionImpl.evalAggSamples()} after the data returns from OpenSearch. This + * approach avoids the XContent serialization issue where nested {@code Map>} + * structures are not properly serialized. + * + * @param patternInfo Pattern information map + * @param showNumberedToken Whether numbered tokens should be shown (determines output format) + * @return Formatted pattern output with pattern, count, and sample_logs + */ + @SuppressWarnings("unchecked") + private static Map formatPatternOutput( + Map patternInfo, boolean showNumberedToken) { + + String pattern = (String) patternInfo.get(PatternUtils.PATTERN); + Long count = (Long) patternInfo.get(PatternUtils.PATTERN_COUNT); + List sampleLogs = (List) patternInfo.get(PatternUtils.SAMPLE_LOGS); + + // For UDAF pushdown, we don't compute tokens here. + // Tokens will be computed by PatternParserFunctionImpl.evalAggSamples() after data returns + // from OpenSearch. This avoids XContent serialization issues with nested Map structures. + // The showNumberedToken flag is passed through to indicate the expected output format. + return ImmutableMap.of( + PatternUtils.PATTERN, + pattern, + PatternUtils.PATTERN_COUNT, + count, + PatternUtils.SAMPLE_LOGS, + sampleLogs); + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index f1bc5fd6a0d..4b6e5860e47 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -3267,68 +3267,114 @@ private void flattenParsedPattern( Boolean showNumberedToken) { List fattenedNodes = new ArrayList<>(); List projectNames = new ArrayList<>(); - // Flatten map struct fields + + // For aggregation mode with numbered tokens, we need to compute tokens locally + // using evalAggSamples. The UDAF returns pattern with wildcards and sample_logs, + // but NOT tokens (to avoid XContent serialization issues with nested Maps). + RexNode parsedPatternResult = null; + if (flattenPatternAggResult && showNumberedToken) { + // Extract pattern string (with wildcards) from UDAF result + RexNode patternStr = + PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_ITEM, + parsedNode, + context.rexBuilder.makeLiteral(PatternUtils.PATTERN)); + // Extract sample_logs from UDAF result + RexNode sampleLogs = + PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_ITEM, + explicitMapType(context, parsedNode, SqlTypeName.VARCHAR), + context.rexBuilder.makeLiteral(PatternUtils.SAMPLE_LOGS)); + RexNode showNumberedTokenLiteral = context.rexBuilder.makeLiteral(true); + + // Call evalAggSamples to transform pattern (wildcards -> numbered tokens) and compute tokens + parsedPatternResult = + PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_PATTERN_PARSER, + patternStr, + sampleLogs, + showNumberedTokenLiteral); + } + + // Flatten map struct fields - pattern + RelDataType varcharType = + context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR); + RexNode patternSource = parsedPatternResult != null ? parsedPatternResult : parsedNode; RexNode patternExpr = - context.rexBuilder.makeCast( - context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), - PPLFuncImpTable.INSTANCE.resolve( - context.rexBuilder, - BuiltinFunctionName.INTERNAL_ITEM, - parsedNode, - context.rexBuilder.makeLiteral(PatternUtils.PATTERN)), - true, - true); + extractAndCastMapField(context, patternSource, PatternUtils.PATTERN, varcharType); fattenedNodes.add(context.relBuilder.alias(patternExpr, originalPatternResultAlias)); projectNames.add(originalPatternResultAlias); + if (flattenPatternAggResult) { + RelDataType bigintType = + context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT); RexNode patternCountExpr = - context.rexBuilder.makeCast( - context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT), - PPLFuncImpTable.INSTANCE.resolve( - context.rexBuilder, - BuiltinFunctionName.INTERNAL_ITEM, - parsedNode, - context.rexBuilder.makeLiteral(PatternUtils.PATTERN_COUNT)), - true, - true); + extractAndCastMapField(context, parsedNode, PatternUtils.PATTERN_COUNT, bigintType); fattenedNodes.add(context.relBuilder.alias(patternCountExpr, PatternUtils.PATTERN_COUNT)); projectNames.add(PatternUtils.PATTERN_COUNT); } + if (showNumberedToken) { + // Create MAP> type for tokens + RelDataType tokensType = + context + .rexBuilder + .getTypeFactory() + .createMapType( + varcharType, + context.rexBuilder.getTypeFactory().createArrayType(varcharType, -1)); + RexNode tokensSource = parsedPatternResult != null ? parsedPatternResult : parsedNode; RexNode tokensExpr = - context.rexBuilder.makeCast( - UserDefinedFunctionUtils.tokensMap, - PPLFuncImpTable.INSTANCE.resolve( - context.rexBuilder, - BuiltinFunctionName.INTERNAL_ITEM, - parsedNode, - context.rexBuilder.makeLiteral(PatternUtils.TOKENS)), - true, - true); + extractAndCastMapField(context, tokensSource, PatternUtils.TOKENS, tokensType); fattenedNodes.add(context.relBuilder.alias(tokensExpr, PatternUtils.TOKENS)); projectNames.add(PatternUtils.TOKENS); } + if (flattenPatternAggResult) { + RelDataType sampleLogsArrayType = + context + .rexBuilder + .getTypeFactory() + .createArrayType( + context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), -1); RexNode sampleLogsExpr = - context.rexBuilder.makeCast( - context - .rexBuilder - .getTypeFactory() - .createArrayType( - context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), -1), - PPLFuncImpTable.INSTANCE.resolve( - context.rexBuilder, - BuiltinFunctionName.INTERNAL_ITEM, - explicitMapType(context, parsedNode, SqlTypeName.VARCHAR), - context.rexBuilder.makeLiteral(PatternUtils.SAMPLE_LOGS)), - true, - true); + extractAndCastMapField( + context, + explicitMapType(context, parsedNode, SqlTypeName.VARCHAR), + PatternUtils.SAMPLE_LOGS, + sampleLogsArrayType); fattenedNodes.add(context.relBuilder.alias(sampleLogsExpr, PatternUtils.SAMPLE_LOGS)); projectNames.add(PatternUtils.SAMPLE_LOGS); } projectPlusOverriding(fattenedNodes, projectNames, context); } + /** + * Helper method to extract a field from a map and cast it to the specified type. Creates a + * SAFE_CAST (makeCast with safe=true) around an INTERNAL_ITEM call. + * + * @param context The Calcite plan context + * @param source The source RexNode containing the map + * @param fieldName The name of the field to extract from the map + * @param targetType The target type to cast to + * @return A RexNode representing SAFE_CAST(INTERNAL_ITEM(source, fieldName)) + */ + private RexNode extractAndCastMapField( + CalcitePlanContext context, RexNode source, String fieldName, RelDataType targetType) { + return context.rexBuilder.makeCast( + targetType, + PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_ITEM, + source, + context.rexBuilder.makeLiteral(fieldName)), + true, + true); + } + private void buildExpandRelNode( RexInputRef arrayFieldRex, String arrayFieldName, String alias, CalcitePlanContext context) { // 3. Capture the outer row in a CorrelationId diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java index f93a0e7c49d..2808b4face5 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java +++ b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java @@ -20,7 +20,6 @@ import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.patterns.BrainLogParser; import org.opensearch.sql.common.patterns.PatternUtils; -import org.opensearch.sql.common.patterns.PatternUtils.ParseResult; public class LogPatternAggFunction implements UserDefinedAggFunction { private int bufferLimit = 100000; @@ -190,7 +189,6 @@ public Object value(Object... argList) { partialMerge(argList); clearBuffer(); - Boolean showToken = (Boolean) argList[3]; return patternGroupMap.values().stream() .sorted( Comparator.comparing( @@ -201,24 +199,18 @@ public Object value(Object... argList) { String pattern = (String) m.get(PatternUtils.PATTERN); Long count = (Long) m.get(PatternUtils.PATTERN_COUNT); List sampleLogs = (List) m.get(PatternUtils.SAMPLE_LOGS); - Map> tokensMap = new HashMap<>(); - ParseResult parseResult = null; - if (showToken) { - parseResult = PatternUtils.parsePattern(pattern, PatternUtils.WILDCARD_PATTERN); - for (String sampleLog : sampleLogs) { - PatternUtils.extractVariables( - parseResult, sampleLog, tokensMap, PatternUtils.WILDCARD_PREFIX); - } - } + // For aggregation mode, always return pattern with wildcards (<*>, <*IP*>). + // The transformation to numbered tokens (, ) and token + // extraction is done downstream by evalAggSamples in flattenParsedPattern. + // This ensures consistent behavior between UDAF pushdown and regular + // aggregation paths. return ImmutableMap.of( PatternUtils.PATTERN, - showToken - ? parseResult.toTokenOrderString(PatternUtils.WILDCARD_PREFIX) - : pattern, + pattern, // Always return original wildcard format PatternUtils.PATTERN_COUNT, count, PatternUtils.TOKENS, - showToken ? tokensMap : Collections.EMPTY_MAP, + Collections.EMPTY_MAP, // Tokens computed downstream by evalAggSamples PatternUtils.SAMPLE_LOGS, sampleLogs); }) diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java index f619d966cc8..ff9bc23bc54 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java @@ -11,6 +11,7 @@ import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.*; import com.google.common.collect.ImmutableSet; +import java.lang.reflect.Type; import java.time.Instant; import java.time.ZoneId; import java.time.ZoneOffset; @@ -209,7 +210,7 @@ public static List convertToExprValues( * @return an adapted ImplementorUDF with the expr method, which is a UserDefinedFunctionBuilder */ public static ImplementorUDF adaptExprMethodToUDF( - java.lang.reflect.Type type, + Type type, String methodName, SqlReturnTypeInference returnTypeInference, NullPolicy nullPolicy, @@ -240,7 +241,7 @@ public UDFOperandMetadata getOperandMetadata() { * FunctionProperties} at the beginning to a Calcite-compatible UserDefinedFunctionBuilder. */ public static ImplementorUDF adaptExprMethodWithPropertiesToUDF( - java.lang.reflect.Type type, + Type type, String methodName, SqlReturnTypeInference returnTypeInference, NullPolicy nullPolicy, @@ -317,4 +318,44 @@ public static List prependFunctionProperties( operandsWithProperties.addFirst(properties); return Collections.unmodifiableList(operandsWithProperties); } + + /** + * Adapt a static method from any class to a UserDefinedFunctionBuilder. This is a general-purpose + * adapter that can wrap static helper methods (e.g., PatternAggregationHelpers methods) as UDFs + * for use in scripted metrics. + * + * @param type the class containing the static method + * @param methodName the name of the static method to be invoked + * @param returnTypeInference the return type inference of the UDF + * @param nullPolicy the null policy of the UDF + * @param operandMetadata type checker for operands + * @return an adapted ImplementorUDF wrapping the static method + */ + public static ImplementorUDF adaptStaticMethodToUDF( + Type type, + String methodName, + SqlReturnTypeInference returnTypeInference, + NullPolicy nullPolicy, + @Nullable UDFOperandMetadata operandMetadata) { + + NotNullImplementor implementor = + (translator, call, translatedOperands) -> { + // For static methods that work with generic objects (Map, List, etc.), + // we don't need type conversion like adaptMathFunctionToUDF + // Just pass the operands directly to the static method + return Expressions.call(type, methodName, translatedOperands); + }; + + return new ImplementorUDF(implementor, nullPolicy) { + @Override + public SqlReturnTypeInference getReturnTypeInference() { + return returnTypeInference; + } + + @Override + public UDFOperandMetadata getOperandMetadata() { + return operandMetadata; + } + }; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 50f88d47baf..6ee6a229d74 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -346,6 +346,11 @@ public enum BuiltinFunctionName { INTERNAL_PATTERN_PARSER(FunctionName.of("pattern_parser")), INTERNAL_PATTERN(FunctionName.of("pattern")), INTERNAL_UNCOLLECT_PATTERNS(FunctionName.of("uncollect_patterns")), + // Pattern aggregation UDFs for scripted metric pushdown + PATTERN_INIT_UDF(FunctionName.of("pattern_init_udf"), true), + PATTERN_ADD_UDF(FunctionName.of("pattern_add_udf"), true), + PATTERN_COMBINE_UDF(FunctionName.of("pattern_combine_udf"), true), + PATTERN_RESULT_UDF(FunctionName.of("pattern_result_udf"), true), INTERNAL_GROK(FunctionName.of("grok"), true), INTERNAL_PARSE(FunctionName.of("parse"), true), INTERNAL_REGEXP_REPLACE_PG_4(FunctionName.of("regexp_replace_pg_4"), true), diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java index 3810352cbfd..29c04625bea 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java @@ -21,12 +21,16 @@ import org.apache.calcite.adapter.enumerable.RexToLixTranslator; import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexCall; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeTransforms; +import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable; import org.apache.calcite.util.BuiltInMethod; import org.opensearch.sql.calcite.udf.udaf.FirstAggFunction; @@ -40,6 +44,7 @@ import org.opensearch.sql.calcite.utils.PPLOperandTypes; import org.opensearch.sql.calcite.utils.PPLReturnTypes; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; +import org.opensearch.sql.common.patterns.PatternAggregationHelpers; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.datetime.DateTimeFunctions; import org.opensearch.sql.expression.function.CollectionUDF.AppendFunctionImpl; @@ -482,6 +487,49 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { PPLReturnTypes.STRING_ARRAY, PPLOperandTypes.ANY_SCALAR_OPTIONAL_INTEGER); + // Pattern aggregation helper UDFs for scripted metric pushdown + // This UDF takes state as parameter and modifies it in-place (for OpenSearch scripted metric) + public static final SqlOperator PATTERN_INIT_UDF = + UserDefinedFunctionUtils.adaptStaticMethodToUDF( + PatternAggregationHelpers.class, + "initPatternState", + ReturnTypes.explicit(SqlTypeName.ANY), // Returns Map + NullPolicy.ANY, + null) // Takes state as parameter + .toUDF("PATTERN_INIT_UDF"); + + public static final SqlOperator PATTERN_ADD_UDF = + UserDefinedFunctionUtils.adaptStaticMethodToUDF( + PatternAggregationHelpers.class, + "addLogToPattern", + ReturnTypes.explicit(SqlTypeName.ANY), // Returns Map + NullPolicy.ANY, + null) // TODO: Add proper operand type checking + .toUDF("PATTERN_ADD_UDF"); + + public static final SqlOperator PATTERN_COMBINE_UDF = + UserDefinedFunctionUtils.adaptStaticMethodToUDF( + PatternAggregationHelpers.class, + "combinePatternAccumulators", + ReturnTypes.explicit(SqlTypeName.ANY), // Returns Map + NullPolicy.ANY, + null) // TODO: Add proper operand type checking + .toUDF("PATTERN_COMBINE_UDF"); + + public static final SqlOperator PATTERN_RESULT_UDF = + UserDefinedFunctionUtils.adaptStaticMethodToUDF( + PatternAggregationHelpers.class, + "producePatternResultFromStates", + opBinding -> { + // Returns List> - represented as ARRAY + RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + RelDataType anyType = typeFactory.createSqlType(SqlTypeName.ANY); + return SqlTypeUtil.createArrayType(typeFactory, anyType, true); + }, + NullPolicy.ANY, + null) // TODO: Add proper operand type checking + .toUDF("PATTERN_RESULT_UDF"); + public static final SqlOperator ENHANCED_COALESCE = new EnhancedCoalesceFunction().toUDF("COALESCE"); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 2d594c48f55..4d4dbbda954 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -163,6 +163,10 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOW; import static org.opensearch.sql.expression.function.BuiltinFunctionName.NULLIF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.OR; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_ADD_UDF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_COMBINE_UDF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_INIT_UDF; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.PATTERN_RESULT_UDF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERCENTILE_APPROX; import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERIOD_ADD; import static org.opensearch.sql.expression.function.BuiltinFunctionName.PERIOD_DIFF; @@ -981,6 +985,11 @@ void populate() { registerOperator(WEEKOFYEAR, PPLBuiltinOperators.WEEK); registerOperator(INTERNAL_PATTERN_PARSER, PPLBuiltinOperators.PATTERN_PARSER); + // Register pattern aggregation helper UDFs for scripted metric pushdown + registerOperator(PATTERN_INIT_UDF, PPLBuiltinOperators.PATTERN_INIT_UDF); + registerOperator(PATTERN_ADD_UDF, PPLBuiltinOperators.PATTERN_ADD_UDF); + registerOperator(PATTERN_COMBINE_UDF, PPLBuiltinOperators.PATTERN_COMBINE_UDF); + registerOperator(PATTERN_RESULT_UDF, PPLBuiltinOperators.PATTERN_RESULT_UDF); registerOperator(TONUMBER, PPLBuiltinOperators.TONUMBER); registerOperator(TOSTRING, PPLBuiltinOperators.TOSTRING); register( diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java index e4f7f1f9d1c..2c0d9c7e653 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java @@ -69,26 +69,35 @@ public Expression implement( : "PATTERN_PARSER should have 2 or 3 arguments"; RelDataType inputType = call.getOperands().get(1).getType(); - Method method = resolveEvaluationMethod(inputType); + Method method = resolveEvaluationMethod(inputType, operandCount); ScalarFunctionImpl function = (ScalarFunctionImpl) ScalarFunctionImpl.create(method); return function.getImplementor().implement(translator, call, RexImpTable.NullAs.NULL); } - private Method resolveEvaluationMethod(RelDataType inputType) { + private Method resolveEvaluationMethod(RelDataType inputType, int operandCount) { if (inputType.getSqlTypeName() == SqlTypeName.VARCHAR) { return getMethod(String.class, "evalField"); } RelDataType componentType = inputType.getComponentType(); - return (componentType.getSqlTypeName() == SqlTypeName.MAP) - ? Types.lookupMethod( - PatternParserFunctionImpl.class, - "evalAgg", - String.class, - Objects.class, - Boolean.class) - : getMethod(List.class, "evalSamples"); + if (componentType.getSqlTypeName() == SqlTypeName.MAP) { + // evalAgg: for label mode with aggregation results (array of maps) + return Types.lookupMethod( + PatternParserFunctionImpl.class, "evalAgg", String.class, Objects.class, Boolean.class); + } else if (operandCount == 3) { + // evalAggSamples: for UDAF pushdown aggregation mode + // Takes pattern (String), sample_logs (List), showNumberedToken (Boolean) + return Types.lookupMethod( + PatternParserFunctionImpl.class, + "evalAggSamples", + String.class, + List.class, + Boolean.class); + } else { + // evalSamples: for simple pattern with sample logs (2 arguments) + return getMethod(List.class, "evalSamples"); + } } private Method getMethod(Class paramType, String methodName) { @@ -126,13 +135,23 @@ public static Object evalAgg( if (bestCandidate != null) { String bestCandidatePattern = String.join(" ", bestCandidate); Map> tokensMap = new HashMap<>(); + String outputPattern = bestCandidatePattern; // Default: return as-is + if (showNumberedToken) { + // Parse pattern with wildcard format (<*>, <*IP*>, etc.) + // LogPatternAggFunction.value() returns patterns in wildcard format ParseResult parseResult = - PatternUtils.parsePattern(bestCandidatePattern, PatternUtils.TOKEN_PATTERN); + PatternUtils.parsePattern(bestCandidatePattern, PatternUtils.WILDCARD_PATTERN); + + // Transform pattern from wildcards to numbered tokens (, , etc.) + outputPattern = parseResult.toTokenOrderString(PatternUtils.TOKEN_PREFIX); + + // Extract token values from the field PatternUtils.extractVariables(parseResult, field, tokensMap, PatternUtils.TOKEN_PREFIX); } + return ImmutableMap.of( - PatternUtils.PATTERN, bestCandidatePattern, + PatternUtils.PATTERN, outputPattern, PatternUtils.TOKENS, tokensMap); } else { return ImmutableMap.of(); @@ -174,6 +193,47 @@ public static Object evalSamples( tokensMap); } + /** + * Extract tokens from aggregated pattern and sample logs for UDAF pushdown. Transforms the + * pattern from wildcard format (e.g., <*>) to numbered token format (e.g., <token1>, + * <token2>) when showNumberedToken is true. + * + *

This method is designed to be called after UDAF pushdown returns from OpenSearch. The UDAF + * returns patterns with wildcards, and this method transforms them to numbered tokens and + * extracts token values from sample logs. + * + * @param pattern The pattern string with wildcards (e.g., <*>, <*IP*>) + * @param sampleLogs List of sample log messages + * @param showNumberedToken Whether to transform to numbered tokens and extract token values + * @return Map containing pattern (possibly transformed) and tokens (if showNumberedToken is true) + */ + public static Object evalAggSamples( + @Parameter(name = "pattern") String pattern, + @Parameter(name = "sample_logs") List sampleLogs, + @Parameter(name = "showNumberedToken") Boolean showNumberedToken) { + if (Strings.isBlank(pattern)) { + return EMPTY_RESULT; + } + + Map> tokensMap = new HashMap<>(); + String outputPattern = pattern; // Default: return pattern as-is (with wildcards) + + if (Boolean.TRUE.equals(showNumberedToken)) { + // Parse pattern with wildcard format (<*>, <*IP*>, etc.) + ParseResult parseResult = PatternUtils.parsePattern(pattern, PatternUtils.WILDCARD_PATTERN); + + // Transform pattern from wildcards to numbered tokens (, , etc.) + outputPattern = parseResult.toTokenOrderString(PatternUtils.TOKEN_PREFIX); + + // Extract token values from sample logs + for (String sampleLog : sampleLogs) { + PatternUtils.extractVariables(parseResult, sampleLog, tokensMap, PatternUtils.TOKEN_PREFIX); + } + } + + return ImmutableMap.of(PatternUtils.PATTERN, outputPattern, PatternUtils.TOKENS, tokensMap); + } + private static List findBestCandidate( List> candidates, List tokens) { return candidates.stream() @@ -188,7 +248,8 @@ private static float calculateScore(List tokens, List candidate) String candidateToken = candidate.get(i); if (Objects.equals(preprocessedToken, candidateToken)) { score += 1; - } else if (preprocessedToken.startsWith("<*") && candidateToken.startsWith(" vs <*IP*>) score += 1; } } diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java index 46df914e611..b298e945856 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java @@ -530,4 +530,187 @@ public void testBrainParseWithUUID_ShowNumberedToken() throws IOException { "[PlaceOrder] user_id= user_currency=USD", ImmutableMap.of("", ImmutableList.of("d664d7be-77d8-11f0-8880-0242f00b101d")))); } + + @Test + public void testBrainAggregationMode_UDAFPushdown_NotShowNumberedToken() throws IOException { + // Test UDAF pushdown for patterns BRAIN aggregation mode + // This verifies that the query is pushed down to OpenSearch as a scripted metric aggregation + JSONObject result = + executeQuery( + String.format( + "source=%s | patterns content method=brain mode=aggregation" + + " variable_count_threshold=5", + TEST_INDEX_HDFS_LOGS)); + System.out.println(result.toString()); + + // Verify schema matches expected output + verifySchema( + result, + schema("patterns_field", "string"), + schema("pattern_count", "bigint"), + schema("sample_logs", "array")); + + // Verify data rows - should match the non-pushdown results + verifyDataRows( + result, + rows( + "Verification succeeded <*> blk_<*>", + 2, + ImmutableList.of( + "Verification succeeded for blk_-1547954353065580372", + "Verification succeeded for blk_6996194389878584395")), + rows( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: <*IP*> is added to blk_<*>" + + " size <*>", + 2, + ImmutableList.of( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.31.85:50010 is added to" + + " blk_-7017553867379051457 size 67108864", + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.107.19:50010 is added" + + " to blk_-3249711809227781266 size 67108864")), + rows( + "<*> NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_<*>_<*>_r_<*>_<*>/part<*>" + + " blk_<*>", + 2, + ImmutableList.of( + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000296_0/part-00296." + + " blk_-6620182933895093708", + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000318_0/part-00318." + + " blk_2096692261399680562")), + rows( + "PacketResponder failed <*> blk_<*>", + 2, + ImmutableList.of( + "PacketResponder failed for blk_6996194389878584395", + "PacketResponder failed for blk_-1547954353065580372"))); + } + + // TODO: Re-enable this test once explain plan output format is validated + // The functional tests (testBrainAggregationMode_UDAFPushdown_NotShowNumberedToken and + // testBrainAggregationMode_UDAFPushdown_ShowNumberedToken) prove that UDAF pushdown works + // correctly. This test verifies the explain plan format, which may need adjustment. + @Test + public void testBrainAggregationMode_UDAFPushdown_VerifyPlan() throws IOException { + // Verify that UDAF pushdown is happening by checking the explain plan + String query = + String.format( + "source=%s | patterns content method=brain mode=aggregation" + + " variable_count_threshold=5", + TEST_INDEX_HDFS_LOGS); + + // Get the explain plan + String explainResult = explainQueryYaml(query); + System.out.println(explainResult); + + // Verify the plan contains evidence of UDAF pushdown + // When UDAF is pushed down, the plan should show: + // 1. CalciteLogicalIndexScan with AGGREGATION pushdown type + // 2. No LogicalAggregate node in the physical plan (it's pushed down) + assertTrue( + "Expected plan to contain CalciteLogicalIndexScan", + explainResult.contains("CalciteLogicalIndexScan")); + assertTrue( + "Expected plan to show AGGREGATION pushdown", + explainResult.contains("AGGREGATION") || explainResult.contains("aggregation")); + + // The plan should NOT contain a separate Aggregate node above the scan + // since it's pushed down to OpenSearch + assertFalse( + "Expected no separate LogicalAggregate node when UDAF is pushed down", + explainResult.contains("LogicalAggregate") + && explainResult.indexOf("LogicalAggregate") + > explainResult.indexOf("CalciteLogicalIndexScan")); + } + + @Test + public void testBrainAggregationMode_UDAFPushdown_ShowNumberedToken() throws IOException { + // Test UDAF pushdown for patterns BRAIN aggregation mode with numbered tokens + JSONObject result = + executeQuery( + String.format( + "source=%s | patterns content method=brain mode=aggregation" + + " show_numbered_token=true variable_count_threshold=5", + TEST_INDEX_HDFS_LOGS)); + System.out.println(result.toString()); + + // Verify schema includes tokens field + verifySchema( + result, + schema("patterns_field", "string"), + schema("pattern_count", "bigint"), + schema("tokens", "struct"), + schema("sample_logs", "array")); + + // Verify data rows with tokens + verifyDataRows( + result, + rows( + "Verification succeeded blk_", + 2, + ImmutableMap.of( + "", + ImmutableList.of("for", "for"), + "", + ImmutableList.of("-1547954353065580372", "6996194389878584395")), + ImmutableList.of( + "Verification succeeded for blk_-1547954353065580372", + "Verification succeeded for blk_6996194389878584395")), + rows( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: is added to blk_" + + " size ", + 2, + ImmutableMap.of( + "", + ImmutableList.of("10.251.31.85:50010", "10.251.107.19:50010"), + "", + ImmutableList.of("67108864", "67108864"), + "", + ImmutableList.of("-7017553867379051457", "-3249711809227781266")), + ImmutableList.of( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.31.85:50010 is added to" + + " blk_-7017553867379051457 size 67108864", + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.107.19:50010 is added" + + " to blk_-3249711809227781266 size 67108864")), + rows( + " NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task___r__/part" + + " blk_", + 2, + ImmutableMap.of( + "", + ImmutableList.of("0", "0"), + "", + ImmutableList.of("000296", "000318"), + "", + ImmutableList.of("-6620182933895093708", "2096692261399680562"), + "", + ImmutableList.of("-00296.", "-00318."), + "", + ImmutableList.of("BLOCK*", "BLOCK*"), + "", + ImmutableList.of("0002", "0002"), + "", + ImmutableList.of("200811092030", "200811092030")), + ImmutableList.of( + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000296_0/part-00296." + + " blk_-6620182933895093708", + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000318_0/part-00318." + + " blk_2096692261399680562")), + rows( + "PacketResponder failed blk_", + 2, + ImmutableMap.of( + "", + ImmutableList.of("for", "for"), + "", + ImmutableList.of("6996194389878584395", "-1547954353065580372")), + ImmutableList.of( + "PacketResponder failed for blk_6996194389878584395", + "PacketResponder failed for blk_-1547954353065580372"))); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index 62eadd7ef5e..7e8c56f2d9b 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -433,7 +433,6 @@ public void testPatternsSimplePatternMethodWithAggPushDownExplain() throws IOExc @Test public void testPatternsBrainMethodWithAggPushDownExplain() throws IOException { - // TODO: Correct calcite expected result once pushdown is supported String expected = loadExpectedPlan("explain_patterns_brain_agg_push.yaml"); assertYamlEqualsIgnoreId( expected, diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_push.yaml index 0b2d4584804..ed98865ce43 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_push.yaml +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_push.yaml @@ -1,7 +1,7 @@ calcite: logical: | LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) - LogicalProject(patterns_field=[SAFE_CAST(ITEM($1, 'pattern'))], pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))], tokens=[SAFE_CAST(ITEM($1, 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1, 'sample_logs'))]) + LogicalProject(patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1, 'pattern')), ITEM($1, 'sample_logs'), true), 'pattern'))], pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))], tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1, 'pattern')), ITEM($1, 'sample_logs'), true), 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1, 'sample_logs'))]) LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) LogicalAggregate(group=[{}], patterns_field=[pattern($0, $1, $2, $3)]) LogicalProject(email=[$9], $f17=[10], $f18=[100000], $f19=[true]) @@ -11,11 +11,9 @@ calcite: LogicalValues(tuples=[[{ 0 }]]) physical: | EnumerableLimit(fetch=[10000]) - EnumerableCalc(expr#0..1=[{inputs}], expr#2=['pattern'], expr#3=[ITEM($t1, $t2)], expr#4=[SAFE_CAST($t3)], expr#5=['pattern_count'], expr#6=[ITEM($t1, $t5)], expr#7=[SAFE_CAST($t6)], expr#8=['tokens'], expr#9=[ITEM($t1, $t8)], expr#10=[SAFE_CAST($t9)], expr#11=['sample_logs'], expr#12=[ITEM($t1, $t11)], expr#13=[SAFE_CAST($t12)], patterns_field=[$t4], pattern_count=[$t7], tokens=[$t10], sample_logs=[$t13]) + EnumerableCalc(expr#0..1=[{inputs}], expr#2=['pattern'], expr#3=[ITEM($t1, $t2)], expr#4=[SAFE_CAST($t3)], expr#5=['sample_logs'], expr#6=[ITEM($t1, $t5)], expr#7=[true], expr#8=[PATTERN_PARSER($t4, $t6, $t7)], expr#9=[ITEM($t8, $t2)], expr#10=[SAFE_CAST($t9)], expr#11=['pattern_count'], expr#12=[ITEM($t1, $t11)], expr#13=[SAFE_CAST($t12)], expr#14=['tokens'], expr#15=[ITEM($t8, $t14)], expr#16=[SAFE_CAST($t15)], expr#17=[SAFE_CAST($t6)], patterns_field=[$t10], pattern_count=[$t13], tokens=[$t16], sample_logs=[$t17]) EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) - EnumerableAggregate(group=[{}], patterns_field=[pattern($0, $1, $2, $3)]) - EnumerableCalc(expr#0=[{inputs}], expr#1=[10], expr#2=[100000], expr#3=[true], proj#0..3=[{exprs}]) - CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[email]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","_source":{"includes":["email"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={},patterns_field=pattern($0, $1, $2, $3))], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"patterns_field":{"scripted_metric":{"init_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQCCnsKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX0lOSVRfVURGIiwKICAgICJraW5kIjogIk9USEVSX0ZVTkNUSU9OIiwKICAgICJzeW50YXgiOiAiRlVOQ1RJT04iCiAgfSwKICAib3BlcmFuZHMiOiBbCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAwLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJBTlkiLAogICAgICAgICJudWxsYWJsZSI6IGZhbHNlLAogICAgICAgICJwcmVjaXNpb24iOiAtMSwKICAgICAgICAic2NhbGUiOiAtMjE0NzQ4MzY0OAogICAgICB9CiAgICB9CiAgXSwKICAiY2xhc3MiOiAib3JnLm9wZW5zZWFyY2guc3FsLmV4cHJlc3Npb24uZnVuY3Rpb24uVXNlckRlZmluZWRGdW5jdGlvbkJ1aWxkZXIkMSIsCiAgInR5cGUiOiB7CiAgICAidHlwZSI6ICJBTlkiLAogICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAicHJlY2lzaW9uIjogLTEsCiAgICAic2NhbGUiOiAtMjE0NzQ4MzY0OAogIH0sCiAgImRldGVybWluaXN0aWMiOiB0cnVlLAogICJkeW5hbWljIjogZmFsc2UKfQ==\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3],"DIGESTS":["state"]}},"map_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQEW3sKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX0FERF9VREYiLAogICAgImtpbmQiOiAiT1RIRVJfRlVOQ1RJT04iLAogICAgInN5bnRheCI6ICJGVU5DVElPTiIKICB9LAogICJvcGVyYW5kcyI6IFsKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDAsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAgICAgInByZWNpc2lvbiI6IC0xLAogICAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAxLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJWQVJDSEFSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlLAogICAgICAgICJwcmVjaXNpb24iOiAtMQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogMiwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogMywKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogNCwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogNSwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiRE9VQkxFIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0KICBdLAogICJjbGFzcyI6ICJvcmcub3BlbnNlYXJjaC5zcWwuZXhwcmVzc2lvbi5mdW5jdGlvbi5Vc2VyRGVmaW5lZEZ1bmN0aW9uQnVpbGRlciQxIiwKICAidHlwZSI6IHsKICAgICJ0eXBlIjogIkFOWSIsCiAgICAibnVsbGFibGUiOiBmYWxzZSwKICAgICJwcmVjaXNpb24iOiAtMSwKICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgfSwKICAiZGV0ZXJtaW5pc3RpYyI6IHRydWUsCiAgImR5bmFtaWMiOiBmYWxzZQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3,0,2,2,2,2],"DIGESTS":["state","email.keyword",10,100000,5,0.3]}},"combine_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQAgHsKICAiZHluYW1pY1BhcmFtIjogMCwKICAidHlwZSI6IHsKICAgICJ0eXBlIjogIkFOWSIsCiAgICAibnVsbGFibGUiOiBmYWxzZSwKICAgICJwcmVjaXNpb24iOiAtMSwKICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgfQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3],"DIGESTS":["state"]}},"reduce_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQEH3sKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX1JFU1VMVF9VREYiLAogICAgImtpbmQiOiAiT1RIRVJfRlVOQ1RJT04iLAogICAgInN5bnRheCI6ICJGVU5DVElPTiIKICB9LAogICJvcGVyYW5kcyI6IFsKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDAsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAgICAgInByZWNpc2lvbiI6IC0xLAogICAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAxLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJJTlRFR0VSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAyLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJJTlRFR0VSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAzLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJET1VCTEUiLAogICAgICAgICJudWxsYWJsZSI6IHRydWUKICAgICAgfQogICAgfSwKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDQsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkJPT0xFQU4iLAogICAgICAgICJudWxsYWJsZSI6IHRydWUKICAgICAgfQogICAgfQogIF0sCiAgImNsYXNzIjogIm9yZy5vcGVuc2VhcmNoLnNxbC5leHByZXNzaW9uLmZ1bmN0aW9uLlVzZXJEZWZpbmVkRnVuY3Rpb25CdWlsZGVyJDEiLAogICJ0eXBlIjogewogICAgInR5cGUiOiAiQVJSQVkiLAogICAgIm51bGxhYmxlIjogdHJ1ZSwKICAgICJjb21wb25lbnQiOiB7CiAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICJudWxsYWJsZSI6IGZhbHNlLAogICAgICAicHJlY2lzaW9uIjogLTEsCiAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICB9CiAgfSwKICAiZGV0ZXJtaW5pc3RpYyI6IHRydWUsCiAgImR5bmFtaWMiOiBmYWxzZQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3,2,2,2,2],"DIGESTS":["states",10,5,0.3,true]}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) EnumerableUncollect EnumerableCalc(expr#0=[{inputs}], expr#1=[$cor0], expr#2=[$t1.patterns_field], patterns_field=[$t2]) EnumerableValues(tuples=[[{ 0 }]]) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java index d772b3e603b..dd39c794e1a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java @@ -195,6 +195,14 @@ private ExprValue parse( return ExprNullValue.of(); } + // Check for arrays first, even if field type is not defined in mapping. + // This handles nested arrays in aggregation results where inner fields + // (like sample_logs in pattern aggregation) may not have type mappings. + if (content.isArray() && (fieldType.isEmpty() || supportArrays)) { + ExprType type = fieldType.orElse(ARRAY); + return parseArray(content, field, type, supportArrays); + } + // Field type may be not defined in mapping if users have disabled dynamic mapping. // Then try to parse content directly based on the value itself // Besides, sub-fields of generated objects are also of type UNDEFINED. We parse the content diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java index 247f40b3733..71c53d3c472 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java @@ -613,6 +613,18 @@ yield switch (functionName) { !args.isEmpty() ? args.getFirst().getKey() : null, AggregationBuilders.cardinality(aggName)), new SingleValueParser(aggName)); + case INTERNAL_PATTERN -> + ScriptedMetricUDAFRegistry.INSTANCE + .lookup(functionName) + .map( + udaf -> + udaf.buildAggregation( + args, aggName, helper.cluster, helper.rowType, helper.fieldTypes)) + .orElseThrow( + () -> + new AggregateAnalyzerException( + String.format( + "No scripted metric UDAF registered for %s", functionName))); default -> throw new AggregateAnalyzer.AggregateAnalyzerException( String.format("Unsupported push-down aggregator %s", aggCall.getAggregation())); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PatternScriptedMetricUDAF.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PatternScriptedMetricUDAF.java new file mode 100644 index 00000000000..c4983e9f4f2 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PatternScriptedMetricUDAF.java @@ -0,0 +1,147 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.type.SqlTypeName; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.PPLBuiltinOperators; + +/** + * Scripted metric UDAF implementation for the Pattern (BRAIN) aggregation function. + * + *

This implementation handles the pushdown of the pattern detection algorithm to OpenSearch, + * using the BrainLogParser for log pattern analysis. The four script phases are: + * + *

    + *
  • init_script: Initializes state with logMessages buffer and patternGroupMap + *
  • map_script: Adds log messages to accumulator, triggers partial merge when buffer is + * full + *
  • combine_script: Returns shard-level state for the reduce phase + *
  • reduce_script: Combines all shard states and produces final pattern results + *
+ */ +public class PatternScriptedMetricUDAF implements ScriptedMetricUDAF { + + // Default parameter values for pattern UDAF + private static final int DEFAULT_MAX_SAMPLE_COUNT = 10; + private static final int DEFAULT_BUFFER_LIMIT = 100000; + private static final int DEFAULT_VARIABLE_COUNT_THRESHOLD = 5; + private static final BigDecimal DEFAULT_THRESHOLD_PERCENTAGE = BigDecimal.valueOf(0.3); + + /** Singleton instance */ + public static final PatternScriptedMetricUDAF INSTANCE = new PatternScriptedMetricUDAF(); + + private PatternScriptedMetricUDAF() {} + + @Override + public BuiltinFunctionName getFunctionName() { + return BuiltinFunctionName.INTERNAL_PATTERN; + } + + @Override + public RexNode buildInitScript(ScriptContext context) { + RexBuilder rexBuilder = context.getRexBuilder(); + RexNode stateRef = context.addSpecialVariableRef("state", SqlTypeName.ANY); + return rexBuilder.makeCall(PPLBuiltinOperators.PATTERN_INIT_UDF, List.of(stateRef)); + } + + @Override + public RexNode buildMapScript(ScriptContext context, List args) { + RexBuilder rexBuilder = context.getRexBuilder(); + List mapArgs = new ArrayList<>(); + + // Add state variable reference + RexNode stateRef = context.addSpecialVariableRef("state", SqlTypeName.ANY); + mapArgs.add(stateRef); + + // Add field reference (first argument) + if (!args.isEmpty()) { + mapArgs.add(args.get(0)); + } + + // Add parameters with defaults: + // args[1] = maxSampleCount + // args[2] = bufferLimit + // args[3] = showNumberedToken (not used in map script) + // args[4] = thresholdPercentage (optional) + // args[5] = variableCountThreshold (optional) + mapArgs.add(getArgOrDefault(args, 1, makeIntLiteral(rexBuilder, DEFAULT_MAX_SAMPLE_COUNT))); + mapArgs.add(getArgOrDefault(args, 2, makeIntLiteral(rexBuilder, DEFAULT_BUFFER_LIMIT))); + mapArgs.add( + getArgOrDefault(args, 5, makeIntLiteral(rexBuilder, DEFAULT_VARIABLE_COUNT_THRESHOLD))); + mapArgs.add( + getArgOrDefault(args, 4, makeDoubleLiteral(rexBuilder, DEFAULT_THRESHOLD_PERCENTAGE))); + + return rexBuilder.makeCall(PPLBuiltinOperators.PATTERN_ADD_UDF, mapArgs); + } + + @Override + public RexNode buildCombineScript(ScriptContext context) { + // Combine script simply returns the shard-level state + return context.addSpecialVariableRef("state", SqlTypeName.ANY); + } + + @Override + public RexNode buildReduceScript(ScriptContext context, List args) { + RexBuilder rexBuilder = context.getRexBuilder(); + RexNode statesRef = context.addSpecialVariableRef("states", SqlTypeName.ANY); + + List reduceArgs = new ArrayList<>(); + reduceArgs.add(statesRef); + + // maxSampleCount + reduceArgs.add(getArgOrDefault(args, 1, makeIntLiteral(rexBuilder, DEFAULT_MAX_SAMPLE_COUNT))); + + // Determine variableCountThreshold and thresholdPercentage + RexNode variableCountThreshold = makeIntLiteral(rexBuilder, DEFAULT_VARIABLE_COUNT_THRESHOLD); + RexNode thresholdPercentage = makeDoubleLiteral(rexBuilder, DEFAULT_THRESHOLD_PERCENTAGE); + + if (args.size() > 5) { + thresholdPercentage = args.get(4); + variableCountThreshold = args.get(5); + } else if (args.size() > 4) { + RexNode arg4 = args.get(4); + SqlTypeName arg4Type = arg4.getType().getSqlTypeName(); + if (arg4Type == SqlTypeName.DOUBLE + || arg4Type == SqlTypeName.DECIMAL + || arg4Type == SqlTypeName.FLOAT) { + thresholdPercentage = arg4; + } else { + variableCountThreshold = arg4; + } + } + + reduceArgs.add(variableCountThreshold); + reduceArgs.add(thresholdPercentage); + + // showNumberedToken (default false) + reduceArgs.add(getArgOrDefault(args, 3, rexBuilder.makeLiteral(false))); + + return rexBuilder.makeCall(PPLBuiltinOperators.PATTERN_RESULT_UDF, reduceArgs); + } + + /** Get argument from list or return default value. */ + private static RexNode getArgOrDefault(List args, int index, RexNode defaultValue) { + return args.size() > index ? args.get(index) : defaultValue; + } + + /** Create integer literal for pattern UDAF parameters. */ + private static RexNode makeIntLiteral(RexBuilder rexBuilder, int value) { + return rexBuilder.makeLiteral( + value, rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER), true); + } + + /** Create double literal for pattern UDAF parameters. */ + private static RexNode makeDoubleLiteral(RexBuilder rexBuilder, BigDecimal value) { + return rexBuilder.makeLiteral( + value, rexBuilder.getTypeFactory().createSqlType(SqlTypeName.DOUBLE), true); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAF.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAF.java new file mode 100644 index 00000000000..6e057861fe4 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAF.java @@ -0,0 +1,235 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +import java.util.List; +import java.util.Map; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.script.Script; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.opensearch.response.agg.MetricParser; +import org.opensearch.sql.opensearch.response.agg.ScriptedMetricParser; +import org.opensearch.sql.opensearch.storage.script.CompoundedScriptEngine; +import org.opensearch.sql.opensearch.storage.serde.RelJsonSerializer; +import org.opensearch.sql.opensearch.storage.serde.ScriptParameterHelper; +import org.opensearch.sql.opensearch.storage.serde.SerializationWrapper; + +/** + * Interface for User-Defined Aggregate Functions (UDAFs) that can be pushed down to OpenSearch as + * scripted metric aggregations. + * + *

A scripted metric aggregation has four phases: + * + *

    + *
  • init_script: Initializes the accumulator state on each shard + *
  • map_script: Processes each document, updating the accumulator + *
  • combine_script: Combines shard-level states (runs on each shard) + *
  • reduce_script: Produces final result from all shard states (runs on coordinator) + *
+ * + *

Implementations should encapsulate all domain-specific logic for a particular UDAF, keeping + * the AggregateAnalyzer generic and reusable. + */ +public interface ScriptedMetricUDAF { + + /** + * Returns the function name this UDAF handles. + * + * @return The BuiltinFunctionName that this UDAF implements + */ + BuiltinFunctionName getFunctionName(); + + /** + * Build the init_script RexNode for initializing accumulator state. + * + * @param context The script context containing builders and utilities + * @return RexNode representing the init script expression + */ + RexNode buildInitScript(ScriptContext context); + + /** + * Build the map_script RexNode for processing each document. + * + * @param context The script context containing builders and utilities + * @param args The arguments from the aggregate call + * @return RexNode representing the map script expression + */ + RexNode buildMapScript(ScriptContext context, List args); + + /** + * Build the combine_script RexNode for combining shard-level states. + * + * @param context The script context containing builders and utilities + * @return RexNode representing the combine script expression + */ + RexNode buildCombineScript(ScriptContext context); + + /** + * Build the reduce_script RexNode for producing final result. + * + * @param context The script context containing builders and utilities + * @param args The arguments from the aggregate call + * @return RexNode representing the reduce script expression + */ + RexNode buildReduceScript(ScriptContext context, List args); + + /** + * Context object providing utilities for script generation. Each script phase gets its own + * context with isolated parameter helpers. + */ + class ScriptContext { + private final RexBuilder rexBuilder; + private final ScriptParameterHelper paramHelper; + private final RelOptCluster cluster; + private final RelDataType rowType; + private final Map fieldTypes; + + public ScriptContext( + RexBuilder rexBuilder, + ScriptParameterHelper paramHelper, + RelOptCluster cluster, + RelDataType rowType, + Map fieldTypes) { + this.rexBuilder = rexBuilder; + this.paramHelper = paramHelper; + this.cluster = cluster; + this.rowType = rowType; + this.fieldTypes = fieldTypes; + } + + public RexBuilder getRexBuilder() { + return rexBuilder; + } + + public ScriptParameterHelper getParamHelper() { + return paramHelper; + } + + public RelOptCluster getCluster() { + return cluster; + } + + public RelDataType getRowType() { + return rowType; + } + + public Map getFieldTypes() { + return fieldTypes; + } + + /** + * Add a special variable (like 'state' or 'states') and return its dynamic param reference. + * + * @param varName The variable name + * @param type The SQL type for the parameter + * @return RexNode representing the dynamic parameter reference + */ + public RexNode addSpecialVariableRef( + String varName, org.apache.calcite.sql.type.SqlTypeName type) { + int index = paramHelper.addSpecialVariable(varName); + return rexBuilder.makeDynamicParam(rexBuilder.getTypeFactory().createSqlType(type), index); + } + } + + /** + * Build the complete scripted metric aggregation. + * + *

This is the main entry point that creates all four scripts and assembles them into an + * OpenSearch aggregation builder. The default implementation handles the common boilerplate. + * + * @param args The arguments from the aggregate call + * @param aggName The name of the aggregation + * @param cluster The RelOptCluster for creating builders + * @param rowType The row type containing field information + * @param fieldTypes Map of field names to expression types + * @return Pair of aggregation builder and metric parser + */ + default Pair buildAggregation( + List> args, + String aggName, + RelOptCluster cluster, + RelDataType rowType, + Map fieldTypes) { + + RelJsonSerializer serializer = new RelJsonSerializer(cluster); + RexBuilder rexBuilder = cluster.getRexBuilder(); + List fieldList = rowType.getFieldList(); + + // Create parameter helpers for each script phase + ScriptParameterHelper initParamHelper = + new ScriptParameterHelper(fieldList, fieldTypes, rexBuilder); + ScriptParameterHelper mapParamHelper = + new ScriptParameterHelper(fieldList, fieldTypes, rexBuilder); + ScriptParameterHelper combineParamHelper = + new ScriptParameterHelper(fieldList, fieldTypes, rexBuilder); + ScriptParameterHelper reduceParamHelper = + new ScriptParameterHelper(fieldList, fieldTypes, rexBuilder); + + // Create contexts for each phase + ScriptContext initContext = + new ScriptContext(rexBuilder, initParamHelper, cluster, rowType, fieldTypes); + ScriptContext mapContext = + new ScriptContext(rexBuilder, mapParamHelper, cluster, rowType, fieldTypes); + ScriptContext combineContext = + new ScriptContext(rexBuilder, combineParamHelper, cluster, rowType, fieldTypes); + ScriptContext reduceContext = + new ScriptContext(rexBuilder, reduceParamHelper, cluster, rowType, fieldTypes); + + // Extract RexNodes from args + List argRefs = args.stream().map(Pair::getKey).toList(); + + // Build scripts + RexNode initRex = buildInitScript(initContext); + RexNode mapRex = buildMapScript(mapContext, argRefs); + RexNode combineRex = buildCombineScript(combineContext); + RexNode reduceRex = buildReduceScript(reduceContext, argRefs); + + // Create Script objects + Script initScript = createScript(serializer, initRex, initParamHelper); + Script mapScript = createScript(serializer, mapRex, mapParamHelper); + Script combineScript = createScript(serializer, combineRex, combineParamHelper); + Script reduceScript = createScript(serializer, reduceRex, reduceParamHelper); + + // Build scripted metric aggregation + AggregationBuilder aggBuilder = + AggregationBuilders.scriptedMetric(aggName) + .initScript(initScript) + .mapScript(mapScript) + .combineScript(combineScript) + .reduceScript(reduceScript); + + return Pair.of(aggBuilder, new ScriptedMetricParser(aggName)); + } + + /** + * Create a Script object from a RexNode expression. + * + * @param serializer The JSON serializer for RexNode + * @param rexNode The expression to serialize + * @param paramHelper The parameter helper containing script parameters + * @return Script object ready for OpenSearch + */ + private static Script createScript( + RelJsonSerializer serializer, RexNode rexNode, ScriptParameterHelper paramHelper) { + String serializedCode = serializer.serialize(rexNode, paramHelper); + String wrappedCode = + SerializationWrapper.wrapWithLangType( + CompoundedScriptEngine.ScriptEngineType.CALCITE, serializedCode); + return new Script( + Script.DEFAULT_SCRIPT_TYPE, + CompoundedScriptEngine.COMPOUNDED_LANG_NAME, + wrappedCode, + paramHelper.getParameters()); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAFRegistry.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAFRegistry.java new file mode 100644 index 00000000000..dc2012100a3 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAFRegistry.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** + * Registry for ScriptedMetricUDAF implementations. + * + *

This registry provides a lookup mechanism for finding UDAF implementations that can be pushed + * down to OpenSearch as scripted metric aggregations. Each UDAF implementation is registered by its + * function name. + * + *

To add a new UDAF pushdown: + * + *

    + *
  1. Create a class implementing {@link ScriptedMetricUDAF} + *
  2. Register it in this registry by calling {@link #register(ScriptedMetricUDAF)} + *
+ */ +public final class ScriptedMetricUDAFRegistry { + + /** Singleton instance */ + public static final ScriptedMetricUDAFRegistry INSTANCE = new ScriptedMetricUDAFRegistry(); + + private final Map udafMap; + + private ScriptedMetricUDAFRegistry() { + this.udafMap = new HashMap<>(); + registerBuiltinUDAFs(); + } + + /** Register all built-in scripted metric UDAFs. */ + private void registerBuiltinUDAFs() { + // Register Pattern (BRAIN) UDAF + register(PatternScriptedMetricUDAF.INSTANCE); + } + + /** + * Register a ScriptedMetricUDAF implementation. + * + * @param udaf The UDAF implementation to register + */ + public void register(ScriptedMetricUDAF udaf) { + udafMap.put(udaf.getFunctionName(), udaf); + } + + /** + * Look up a ScriptedMetricUDAF by function name. + * + * @param functionName The function name to look up + * @return Optional containing the UDAF if found, empty otherwise + */ + public Optional lookup(BuiltinFunctionName functionName) { + return Optional.ofNullable(udafMap.get(functionName)); + } + + /** + * Look up a ScriptedMetricUDAF by function name string. + * + * @param functionName The function name string to look up + * @return Optional containing the UDAF if found, empty otherwise + */ + public Optional lookup(String functionName) { + return BuiltinFunctionName.ofAggregation(functionName) + .flatMap(name -> Optional.ofNullable(udafMap.get(name))); + } + + /** + * Check if a function name has a registered ScriptedMetricUDAF. + * + * @param functionName The function name to check + * @return true if a UDAF is registered for this function + */ + public boolean hasUDAF(BuiltinFunctionName functionName) { + return udafMap.containsKey(functionName); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/ScriptedMetricParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/ScriptedMetricParser.java new file mode 100644 index 00000000000..a06399711e6 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/ScriptedMetricParser.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.response.agg; + +import java.util.List; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.RequiredArgsConstructor; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.metrics.ScriptedMetric; + +/** + * Parser for scripted metric aggregation responses. Extracts the final result from the reduce phase + * of a scripted metric aggregation. + */ +@EqualsAndHashCode +@RequiredArgsConstructor +public class ScriptedMetricParser implements MetricParser { + + private final String name; + + @Override + public String getName() { + return name; + } + + @Override + @SuppressWarnings("unchecked") + public List> parse(Aggregation agg) { + if (agg instanceof ScriptedMetric scriptedMetric) { + // Extract the final result from the reduce script + Object result = scriptedMetric.aggregation(); + // The reduce script for UDAF aggregation returns List> + // which represents the array of results. We wrap this in a single Map with + // the aggregation field name as key, so the response is 1 row containing + // the array that can be expanded by Uncollect in the query plan. + if (result instanceof List) { + return List.of(Map.of(name, result)); + } + throw new IllegalArgumentException( + String.format( + "Expected List> from scripted metric but got %s", + result == null ? "null" : result.getClass().getSimpleName())); + } + throw new IllegalArgumentException( + String.format( + "Expected ScriptedMetric aggregation but got %s", + agg == null ? "null" : agg.getClass().getSimpleName())); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java index 224d7019ec2..dc9a3f87180 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java @@ -75,12 +75,17 @@ import org.opensearch.script.NumberSortScript; import org.opensearch.script.ScriptContext; import org.opensearch.script.ScriptEngine; +import org.opensearch.script.ScriptedMetricAggContexts; import org.opensearch.script.StringSortScript; import org.opensearch.search.lookup.SourceLookup; import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.opensearch.storage.script.aggregation.CalciteAggregationScriptFactory; import org.opensearch.sql.opensearch.storage.script.field.CalciteFieldScriptFactory; import org.opensearch.sql.opensearch.storage.script.filter.CalciteFilterScriptFactory; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.CalciteScriptedMetricCombineScriptFactory; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.CalciteScriptedMetricInitScriptFactory; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.CalciteScriptedMetricMapScriptFactory; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.CalciteScriptedMetricReduceScriptFactory; import org.opensearch.sql.opensearch.storage.script.sort.CalciteNumberSortScriptFactory; import org.opensearch.sql.opensearch.storage.script.sort.CalciteStringSortScriptFactory; import org.opensearch.sql.opensearch.storage.serde.RelJsonSerializer; @@ -113,6 +118,18 @@ public CalciteScriptEngine(RelOptCluster relOptCluster) { .put(NumberSortScript.CONTEXT, CalciteNumberSortScriptFactory::new) .put(StringSortScript.CONTEXT, CalciteStringSortScriptFactory::new) .put(FieldScript.CONTEXT, CalciteFieldScriptFactory::new) + .put( + ScriptedMetricAggContexts.InitScript.CONTEXT, + CalciteScriptedMetricInitScriptFactory::new) + .put( + ScriptedMetricAggContexts.MapScript.CONTEXT, + CalciteScriptedMetricMapScriptFactory::new) + .put( + ScriptedMetricAggContexts.CombineScript.CONTEXT, + CalciteScriptedMetricCombineScriptFactory::new) + .put( + ScriptedMetricAggContexts.ReduceScript.CONTEXT, + CalciteScriptedMetricReduceScriptFactory::new) .build(); @Override @@ -214,6 +231,11 @@ public Object get(String name) { case DOC_VALUE -> getFromDocValue((String) digests.get(index)); case SOURCE -> getFromSource((String) digests.get(index)); case LITERAL -> digests.get(index); + case SPECIAL_VARIABLE -> + // Special variables (state, states) are not in this context + // They should be handled by ScriptedMetricDataContext + throw new IllegalStateException( + "SPECIAL_VARIABLE " + digests.get(index) + " not supported in this context"); }; } catch (Exception e) { throw new IllegalStateException("Failed to get value for parameter " + name); @@ -245,7 +267,8 @@ public Object getFromSource(String name) { public enum Source { DOC_VALUE(0), SOURCE(1), - LITERAL(2); + LITERAL(2), + SPECIAL_VARIABLE(3); // For scripted metric state/states variables private final int value; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricCombineScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricCombineScriptFactory.java new file mode 100644 index 00000000000..9887c9d21d4 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricCombineScriptFactory.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.apache.calcite.DataContext; +import org.apache.calcite.linq4j.function.Function1; +import org.apache.calcite.rel.type.RelDataType; +import org.opensearch.script.ScriptedMetricAggContexts; + +/** + * Factory for Calcite-based CombineScript in scripted metric aggregations. Combines shard-level + * accumulators using RexNode expressions. + */ +@RequiredArgsConstructor +public class CalciteScriptedMetricCombineScriptFactory + implements ScriptedMetricAggContexts.CombineScript.Factory { + + private final Function1 function; + private final RelDataType outputType; + + @Override + public ScriptedMetricAggContexts.CombineScript newInstance( + Map params, Map state) { + return new CalciteScriptedMetricCombineScript(function, outputType, params, state); + } + + /** CombineScript that executes compiled RexNode expression. */ + private static class CalciteScriptedMetricCombineScript + extends ScriptedMetricAggContexts.CombineScript { + + private final Function1 function; + private final RelDataType outputType; + + public CalciteScriptedMetricCombineScript( + Function1 function, + RelDataType outputType, + Map params, + Map state) { + super(params, state); + this.function = function; + this.outputType = outputType; + } + + @Override + public Object execute() { + // Create data context for combine script + @SuppressWarnings("unchecked") + Map state = (Map) getState(); + DataContext dataContext = new ScriptedMetricDataContext.CombineContext(getParams(), state); + + // Execute the compiled RexNode expression + Object[] result = function.apply(dataContext); + + // Return the combined result + return (result != null && result.length > 0) ? result[0] : getState(); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricInitScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricInitScriptFactory.java new file mode 100644 index 00000000000..537faebe446 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricInitScriptFactory.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.apache.calcite.DataContext; +import org.apache.calcite.linq4j.function.Function1; +import org.apache.calcite.rel.type.RelDataType; +import org.opensearch.script.ScriptedMetricAggContexts; + +/** + * Factory for Calcite-based InitScript in scripted metric aggregations. Executes RexNode + * expressions compiled to Java code via CalciteScriptEngine. + */ +@RequiredArgsConstructor +public class CalciteScriptedMetricInitScriptFactory + implements ScriptedMetricAggContexts.InitScript.Factory { + + private final Function1 function; + private final RelDataType outputType; + + @Override + public ScriptedMetricAggContexts.InitScript newInstance( + Map params, Map state) { + return new CalciteScriptedMetricInitScript(function, outputType, params, state); + } + + /** InitScript that executes compiled RexNode expression. */ + private static class CalciteScriptedMetricInitScript + extends ScriptedMetricAggContexts.InitScript { + + private final Function1 function; + private final RelDataType outputType; + + public CalciteScriptedMetricInitScript( + Function1 function, + RelDataType outputType, + Map params, + Map state) { + super(params, state); + this.function = function; + this.outputType = outputType; + } + + @Override + public void execute() { + // Create data context for init script (no document access, only params) + @SuppressWarnings("unchecked") + Map state = (Map) getState(); + DataContext dataContext = new ScriptedMetricDataContext.InitContext(getParams(), state); + + // Execute the compiled RexNode expression + Object[] result = function.apply(dataContext); + + // Store result in state + if (result != null && result.length > 0) { + // The init script typically initializes the state + // Result should be the initialized accumulator + if (result[0] instanceof Map) { + ((Map) getState()).putAll((Map) result[0]); + } else { + ((Map) getState()).put("accumulator", result[0]); + } + } + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricMapScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricMapScriptFactory.java new file mode 100644 index 00000000000..954221fbef9 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricMapScriptFactory.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.apache.calcite.DataContext; +import org.apache.calcite.linq4j.function.Function1; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.lucene.index.LeafReaderContext; +import org.opensearch.script.ScriptedMetricAggContexts; +import org.opensearch.search.lookup.SearchLookup; + +/** + * Factory for Calcite-based MapScript in scripted metric aggregations. Executes RexNode expressions + * compiled to Java code with document field access. + */ +@RequiredArgsConstructor +public class CalciteScriptedMetricMapScriptFactory + implements ScriptedMetricAggContexts.MapScript.Factory { + + private final Function1 function; + private final RelDataType outputType; + + @Override + public ScriptedMetricAggContexts.MapScript.LeafFactory newFactory( + Map params, Map state, SearchLookup lookup) { + return new CalciteMapScriptLeafFactory(function, outputType, params, state, lookup); + } + + /** Leaf factory that creates MapScript instances for each segment. */ + @RequiredArgsConstructor + private static class CalciteMapScriptLeafFactory + implements ScriptedMetricAggContexts.MapScript.LeafFactory { + + private final Function1 function; + private final RelDataType outputType; + private final Map params; + private final Map state; + private final SearchLookup lookup; + + @Override + public ScriptedMetricAggContexts.MapScript newInstance(LeafReaderContext ctx) { + return new CalciteScriptedMetricMapScript(function, outputType, params, state, lookup, ctx); + } + } + + /** + * MapScript that executes compiled RexNode expression for each document. + * + *

The DataContext is created once in the constructor and reused for all documents to avoid + * object allocation overhead per document. This is safe because: + * + *

    + *
  • params, state references don't change between documents + *
  • doc and sourceLookup are updated internally by OpenSearch before each execute() call + *
  • sources and digests (derived from params) are the same for all documents + *
+ */ + private static class CalciteScriptedMetricMapScript extends ScriptedMetricAggContexts.MapScript { + + private final Function1 function; + private final DataContext dataContext; + + public CalciteScriptedMetricMapScript( + Function1 function, + RelDataType outputType, + Map params, + Map state, + SearchLookup lookup, + LeafReaderContext leafContext) { + super(params, state, lookup, leafContext); + this.function = function; + // Create DataContext once and reuse for all documents in this segment. + // OpenSearch updates doc values and source lookup internally before each execute(). + this.dataContext = + new ScriptedMetricDataContext.MapContext( + params, state, getDoc(), lookup.getLeafSearchLookup(leafContext).source()); + } + + @Override + public void execute() { + // Execute the compiled RexNode expression (reusing the same DataContext) + Object[] result = function.apply(dataContext); + + // Update state with result + if (result != null && result.length > 0) { + // The map script typically updates the accumulator + if (result[0] instanceof Map) { + ((Map) getState()).putAll((Map) result[0]); + } else { + ((Map) getState()).put("accumulator", result[0]); + } + } + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricReduceScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricReduceScriptFactory.java new file mode 100644 index 00000000000..9a0d8797fe0 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricReduceScriptFactory.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import java.util.List; +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.apache.calcite.DataContext; +import org.apache.calcite.linq4j.function.Function1; +import org.apache.calcite.rel.type.RelDataType; +import org.opensearch.script.ScriptedMetricAggContexts; + +/** + * Factory for Calcite-based ReduceScript in scripted metric aggregations. Produces final result + * from all shard-level combined results using RexNode expressions. + */ +@RequiredArgsConstructor +public class CalciteScriptedMetricReduceScriptFactory + implements ScriptedMetricAggContexts.ReduceScript.Factory { + + private final Function1 function; + private final RelDataType outputType; + + @Override + public ScriptedMetricAggContexts.ReduceScript newInstance( + Map params, List states) { + return new CalciteScriptedMetricReduceScript(function, outputType, params, states); + } + + /** ReduceScript that executes compiled RexNode expression. */ + private static class CalciteScriptedMetricReduceScript + extends ScriptedMetricAggContexts.ReduceScript { + + private final Function1 function; + private final RelDataType outputType; + + public CalciteScriptedMetricReduceScript( + Function1 function, + RelDataType outputType, + Map params, + List states) { + super(params, states); + this.function = function; + this.outputType = outputType; + } + + @Override + public Object execute() { + // Create data context for reduce script + DataContext dataContext = + new ScriptedMetricDataContext.ReduceContext(getParams(), getStates()); + + // Execute the compiled RexNode expression + Object[] result = function.apply(dataContext); + + // Return the final result + return (result != null && result.length > 0) ? result[0] : getStates(); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricDataContext.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricDataContext.java new file mode 100644 index 00000000000..b8306696b34 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricDataContext.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; + +import static org.opensearch.sql.opensearch.storage.serde.ScriptParameterHelper.DIGESTS; +import static org.opensearch.sql.opensearch.storage.serde.ScriptParameterHelper.SOURCES; + +import java.util.List; +import java.util.Map; +import org.apache.calcite.DataContext; +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.linq4j.QueryProvider; +import org.apache.calcite.schema.SchemaPlus; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.opensearch.index.fielddata.ScriptDocValues; +import org.opensearch.search.lookup.SourceLookup; +import org.opensearch.sql.opensearch.storage.script.CalciteScriptEngine.Source; + +/** + * DataContext implementations for scripted metric aggregation script phases. Provides access to + * params, state/states variables, and document fields depending on the phase. + * + *

Each script phase has its own context: + * + *

    + *
  • {@link InitContext} - init_script: params and state + *
  • {@link MapContext} - map_script: params, state, doc values, and source lookup + *
  • {@link CombineContext} - combine_script: params and state + *
  • {@link ReduceContext} - reduce_script: params and states (array from all shards) + *
+ */ +public abstract class ScriptedMetricDataContext implements DataContext { + + protected final Map params; + protected final List sources; + protected final List digests; + + protected ScriptedMetricDataContext(Map params) { + this.params = params; + this.sources = ((List) params.get(SOURCES)).stream().map(Source::fromValue).toList(); + this.digests = (List) params.get(DIGESTS); + } + + @Override + public @Nullable SchemaPlus getRootSchema() { + return null; + } + + @Override + public JavaTypeFactory getTypeFactory() { + return null; + } + + @Override + public QueryProvider getQueryProvider() { + return null; + } + + /** + * Parse dynamic parameter index from name pattern "?N". + * + * @param name The parameter name (expected format: "?0", "?1", etc.) + * @return The parameter index + * @throws IllegalArgumentException if name doesn't match expected pattern + */ + protected int parseDynamicParamIndex(String name) { + if (!name.startsWith("?")) { + throw new IllegalArgumentException( + "Unexpected parameter name format: " + name + ". Expected '?N' pattern."); + } + int index = Integer.parseInt(name.substring(1)); + if (index >= sources.size()) { + throw new IllegalArgumentException( + "Parameter index " + index + " out of bounds. Sources size: " + sources.size()); + } + return index; + } + + /** + * Base class for init and combine phases that share identical get() logic. Both phases only have + * access to params and state (no doc values). + */ + protected abstract static class StateOnlyContext extends ScriptedMetricDataContext { + protected final Map state; + + protected StateOnlyContext(Map params, Map state) { + super(params); + this.state = state; + } + + @Override + public Object get(String name) { + int index = parseDynamicParamIndex(name); + return switch (sources.get(index)) { + case SPECIAL_VARIABLE -> state; + case LITERAL -> digests.get(index); + default -> + throw new IllegalStateException( + "Unexpected source type " + sources.get(index) + " in StateOnlyContext"); + }; + } + } + + /** DataContext for InitScript phase - provides params and state. */ + public static class InitContext extends StateOnlyContext { + public InitContext(Map params, Map state) { + super(params, state); + } + } + + /** DataContext for CombineScript phase - provides params and state. */ + public static class CombineContext extends StateOnlyContext { + public CombineContext(Map params, Map state) { + super(params, state); + } + } + + /** DataContext for MapScript phase - provides params, state, doc values, and source lookup. */ + public static class MapContext extends ScriptedMetricDataContext { + private final Map state; + private final Map> doc; + private final SourceLookup sourceLookup; + + public MapContext( + Map params, + Map state, + Map> doc, + SourceLookup sourceLookup) { + super(params); + this.state = state; + this.doc = doc; + this.sourceLookup = sourceLookup; + } + + @Override + public Object get(String name) { + int index = parseDynamicParamIndex(name); + return switch (sources.get(index)) { + case SPECIAL_VARIABLE -> state; + case LITERAL -> digests.get(index); + case DOC_VALUE -> getDocValue((String) digests.get(index)); + case SOURCE -> sourceLookup != null ? sourceLookup.get((String) digests.get(index)) : null; + }; + } + + private Object getDocValue(String fieldName) { + if (doc != null && doc.containsKey(fieldName)) { + ScriptDocValues docValue = doc.get(fieldName); + if (docValue != null && !docValue.isEmpty()) { + return docValue.get(0); + } + } + return null; + } + } + + /** DataContext for ReduceScript phase - provides params and states array from all shards. */ + public static class ReduceContext extends ScriptedMetricDataContext { + private final List states; + + public ReduceContext(Map params, List states) { + super(params); + this.states = states; + } + + @Override + public Object get(String name) { + int index = parseDynamicParamIndex(name); + return switch (sources.get(index)) { + case SPECIAL_VARIABLE -> states; + case LITERAL -> digests.get(index); + default -> + throw new IllegalStateException( + "Unexpected source type " + sources.get(index) + " in ReduceContext"); + }; + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java index 1916ab6c2c3..4915cd63827 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java @@ -94,4 +94,18 @@ public Map getParameters() { } }; } + + /** + * Adds a special variable reference (like state or states in scripted metric aggregations) and + * returns the index. + * + * @param variableName The name of the special variable (e.g., "state", "states") + * @return The index in the sources/digests lists + */ + public int addSpecialVariable(String variableName) { + int index = sources.size(); + sources.add(3); // SPECIAL_VARIABLE = 3 + digests.add(variableName); + return index; + } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java index c272453b829..11d16dd2914 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLPatternsTest.java @@ -395,9 +395,12 @@ public void testPatternsAggregationMode_ShowNumberedToken_ForBrainMethod() { RelNode root = getRelNode(ppl); String expectedLogical = - "LogicalProject(patterns_field=[SAFE_CAST(ITEM($1, 'pattern'))]," - + " pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))], tokens=[SAFE_CAST(ITEM($1," - + " 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1, 'sample_logs'))])\n" + "LogicalProject(patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1," + + " 'pattern')), ITEM($1, 'sample_logs'), true), 'pattern'))]," + + " pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))]," + + " tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1, 'pattern')), ITEM($1," + + " 'sample_logs'), true), 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1," + + " 'sample_logs'))])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])\n" + " LogicalAggregate(group=[{}], patterns_field=[pattern($0, $1, $2, $3)])\n" + " LogicalProject(ENAME=[$1], $f8=[10], $f9=[100000], $f10=[true])\n" @@ -408,11 +411,13 @@ public void testPatternsAggregationMode_ShowNumberedToken_ForBrainMethod() { verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING) `patterns_field`," - + " TRY_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT) `pattern_count`," - + " TRY_CAST(`t20`.`patterns_field`['tokens'] AS MAP< VARCHAR, VARCHAR ARRAY >)" - + " `tokens`, TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY< STRING >)" - + " `sample_logs`\n" + "SELECT TRY_CAST(PATTERN_PARSER(TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING)," + + " `t20`.`patterns_field`['sample_logs'], TRUE)['pattern'] AS STRING)" + + " `patterns_field`, TRY_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT)" + + " `pattern_count`, TRY_CAST(PATTERN_PARSER(TRY_CAST(`t20`.`patterns_field`['pattern']" + + " AS STRING), `t20`.`patterns_field`['sample_logs'], TRUE)['tokens'] AS MAP< VARCHAR," + + " VARCHAR ARRAY >) `tokens`, TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY<" + + " STRING >) `sample_logs`\n" + "FROM (SELECT `pattern`(`ENAME`, 10, 100000, TRUE) `patterns_field`\n" + "FROM `scott`.`EMP`) `$cor0`,\n" + "LATERAL UNNEST((SELECT `$cor0`.`patterns_field`\n" @@ -460,9 +465,12 @@ public void testPatternsAggregationModeWithGroupBy_ShowNumberedToken_ForBrainMet RelNode root = getRelNode(ppl); String expectedLogical = - "LogicalProject(DEPTNO=[$0], patterns_field=[SAFE_CAST(ITEM($2, 'pattern'))]," - + " pattern_count=[SAFE_CAST(ITEM($2, 'pattern_count'))], tokens=[SAFE_CAST(ITEM($2," - + " 'tokens'))], sample_logs=[SAFE_CAST(ITEM($2, 'sample_logs'))])\n" + "LogicalProject(DEPTNO=[$0]," + + " patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2, 'pattern'))," + + " ITEM($2, 'sample_logs'), true), 'pattern'))], pattern_count=[SAFE_CAST(ITEM($2," + + " 'pattern_count'))], tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2," + + " 'pattern')), ITEM($2, 'sample_logs'), true), 'tokens'))]," + + " sample_logs=[SAFE_CAST(ITEM($2, 'sample_logs'))])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{1}])\n" + " LogicalAggregate(group=[{1}], patterns_field=[pattern($0, $2, $3, $4)])\n" + " LogicalProject(ENAME=[$1], DEPTNO=[$7], $f8=[10], $f9=[100000], $f10=[true])\n" @@ -473,11 +481,14 @@ public void testPatternsAggregationModeWithGroupBy_ShowNumberedToken_ForBrainMet verifyLogical(root, expectedLogical); String expectedSparkSql = - "SELECT `$cor0`.`DEPTNO`, TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING)" + "SELECT `$cor0`.`DEPTNO`," + + " TRY_CAST(PATTERN_PARSER(TRY_CAST(`t20`.`patterns_field`['pattern'] AS STRING)," + + " `t20`.`patterns_field`['sample_logs'], TRUE)['pattern'] AS STRING)" + " `patterns_field`, TRY_CAST(`t20`.`patterns_field`['pattern_count'] AS BIGINT)" - + " `pattern_count`, TRY_CAST(`t20`.`patterns_field`['tokens'] AS MAP< VARCHAR," - + " VARCHAR ARRAY >) `tokens`, TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS" - + " ARRAY< STRING >) `sample_logs`\n" + + " `pattern_count`, TRY_CAST(PATTERN_PARSER(TRY_CAST(`t20`.`patterns_field`['pattern']" + + " AS STRING), `t20`.`patterns_field`['sample_logs'], TRUE)['tokens'] AS MAP< VARCHAR," + + " VARCHAR ARRAY >) `tokens`, TRY_CAST(`t20`.`patterns_field`['sample_logs'] AS ARRAY<" + + " STRING >) `sample_logs`\n" + "FROM (SELECT `DEPTNO`, `pattern`(`ENAME`, 10, 100000, TRUE) `patterns_field`\n" + "FROM `scott`.`EMP`\n" + "GROUP BY `DEPTNO`) `$cor0`,\n" From 4015f27f02ce35c97ee1e0e308b2934afb369770 Mon Sep 17 00:00:00 2001 From: Songkan Tang Date: Thu, 22 Jan 2026 18:11:00 +0800 Subject: [PATCH 2/9] Refactor to reduce some duplicate logic Signed-off-by: Songkan Tang --- .../sql/calcite/CalciteRelNodeVisitor.java | 182 +++++++++++------- .../udf/udaf/LogPatternAggFunction.java | 118 ++++-------- .../request/ScriptedMetricUDAF.java | 3 +- 3 files changed, 151 insertions(+), 152 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 4b6e5860e47..60d23eef4ee 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -3259,120 +3259,156 @@ private RexNode explicitMapType( return new RexInputRef(((RexInputRef) origin).getIndex(), newMapType); } + /** + * Flattens the parsed pattern result into individual fields for projection. + * + *

This method handles two scenarios: + * + *

    + *
  • Label mode: extracts pattern (and optionally tokens) from parsedNode + *
  • Aggregation mode: extracts pattern, pattern_count, tokens (optional), and sample_logs + *
+ * + *

When both flattenPatternAggResult and showNumberedToken are true, the pattern and tokens + * need transformation via evalAggSamples (converting wildcards to numbered tokens). + * + * @param originalPatternResultAlias alias for the pattern field + * @param parsedNode the source RexNode containing parsed pattern data + * @param context the Calcite plan context + * @param flattenPatternAggResult true if in aggregation mode (includes pattern_count, + * sample_logs) + * @param showNumberedToken true if tokens should be extracted and pattern transformed + */ private void flattenParsedPattern( String originalPatternResultAlias, RexNode parsedNode, CalcitePlanContext context, boolean flattenPatternAggResult, - Boolean showNumberedToken) { - List fattenedNodes = new ArrayList<>(); + boolean showNumberedToken) { + List flattenedNodes = new ArrayList<>(); List projectNames = new ArrayList<>(); + RelDataType varcharType = + context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR); + // For aggregation mode with numbered tokens, we need to compute tokens locally // using evalAggSamples. The UDAF returns pattern with wildcards and sample_logs, // but NOT tokens (to avoid XContent serialization issues with nested Maps). - RexNode parsedPatternResult = null; + // The transformed result contains: pattern (with numbered tokens) and tokens map. + RexNode transformedPatternResult = null; if (flattenPatternAggResult && showNumberedToken) { - // Extract pattern string (with wildcards) from UDAF result - RexNode patternStr = - PPLFuncImpTable.INSTANCE.resolve( - context.rexBuilder, - BuiltinFunctionName.INTERNAL_ITEM, - parsedNode, - context.rexBuilder.makeLiteral(PatternUtils.PATTERN)); - // Extract sample_logs from UDAF result - RexNode sampleLogs = - PPLFuncImpTable.INSTANCE.resolve( - context.rexBuilder, - BuiltinFunctionName.INTERNAL_ITEM, - explicitMapType(context, parsedNode, SqlTypeName.VARCHAR), - context.rexBuilder.makeLiteral(PatternUtils.SAMPLE_LOGS)); - RexNode showNumberedTokenLiteral = context.rexBuilder.makeLiteral(true); - - // Call evalAggSamples to transform pattern (wildcards -> numbered tokens) and compute tokens - parsedPatternResult = - PPLFuncImpTable.INSTANCE.resolve( - context.rexBuilder, - BuiltinFunctionName.INTERNAL_PATTERN_PARSER, - patternStr, - sampleLogs, - showNumberedTokenLiteral); + transformedPatternResult = buildEvalAggSamplesCall(parsedNode, context); } - // Flatten map struct fields - pattern - RelDataType varcharType = - context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR); - RexNode patternSource = parsedPatternResult != null ? parsedPatternResult : parsedNode; + // Determine source for pattern and tokens: + // - When transformedPatternResult exists, use it (pattern/tokens need transformation) + // - pattern_count and sample_logs always come from the original parsedNode + RexNode patternAndTokensSource = + transformedPatternResult != null ? transformedPatternResult : parsedNode; + + // 1. Always add pattern field RexNode patternExpr = - extractAndCastMapField(context, patternSource, PatternUtils.PATTERN, varcharType); - fattenedNodes.add(context.relBuilder.alias(patternExpr, originalPatternResultAlias)); + context.rexBuilder.makeCast( + varcharType, + PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_ITEM, + patternAndTokensSource, + context.rexBuilder.makeLiteral(PatternUtils.PATTERN)), + true, + true); + flattenedNodes.add(context.relBuilder.alias(patternExpr, originalPatternResultAlias)); projectNames.add(originalPatternResultAlias); + // 2. Add pattern_count when in aggregation mode (from original parsedNode) if (flattenPatternAggResult) { - RelDataType bigintType = - context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT); + RelDataType bigintType = context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT); RexNode patternCountExpr = - extractAndCastMapField(context, parsedNode, PatternUtils.PATTERN_COUNT, bigintType); - fattenedNodes.add(context.relBuilder.alias(patternCountExpr, PatternUtils.PATTERN_COUNT)); + context.rexBuilder.makeCast( + bigintType, + PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_ITEM, + parsedNode, + context.rexBuilder.makeLiteral(PatternUtils.PATTERN_COUNT)), + true, + true); + flattenedNodes.add(context.relBuilder.alias(patternCountExpr, PatternUtils.PATTERN_COUNT)); projectNames.add(PatternUtils.PATTERN_COUNT); } + // 3. Add tokens when showNumberedToken is enabled if (showNumberedToken) { - // Create MAP> type for tokens RelDataType tokensType = - context - .rexBuilder - .getTypeFactory() - .createMapType( - varcharType, - context.rexBuilder.getTypeFactory().createArrayType(varcharType, -1)); - RexNode tokensSource = parsedPatternResult != null ? parsedPatternResult : parsedNode; + context.rexBuilder.getTypeFactory().createMapType( + varcharType, context.rexBuilder.getTypeFactory().createArrayType(varcharType, -1)); RexNode tokensExpr = - extractAndCastMapField(context, tokensSource, PatternUtils.TOKENS, tokensType); - fattenedNodes.add(context.relBuilder.alias(tokensExpr, PatternUtils.TOKENS)); + context.rexBuilder.makeCast( + tokensType, + PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_ITEM, + patternAndTokensSource, + context.rexBuilder.makeLiteral(PatternUtils.TOKENS)), + true, + true); + flattenedNodes.add(context.relBuilder.alias(tokensExpr, PatternUtils.TOKENS)); projectNames.add(PatternUtils.TOKENS); } + // 4. Add sample_logs when in aggregation mode (from original parsedNode) if (flattenPatternAggResult) { RelDataType sampleLogsArrayType = - context - .rexBuilder - .getTypeFactory() - .createArrayType( - context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), -1); + context.rexBuilder.getTypeFactory().createArrayType(varcharType, -1); RexNode sampleLogsExpr = - extractAndCastMapField( - context, - explicitMapType(context, parsedNode, SqlTypeName.VARCHAR), - PatternUtils.SAMPLE_LOGS, - sampleLogsArrayType); - fattenedNodes.add(context.relBuilder.alias(sampleLogsExpr, PatternUtils.SAMPLE_LOGS)); + context.rexBuilder.makeCast( + sampleLogsArrayType, + PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_ITEM, + explicitMapType(context, parsedNode, SqlTypeName.VARCHAR), + context.rexBuilder.makeLiteral(PatternUtils.SAMPLE_LOGS)), + true, + true); + flattenedNodes.add(context.relBuilder.alias(sampleLogsExpr, PatternUtils.SAMPLE_LOGS)); projectNames.add(PatternUtils.SAMPLE_LOGS); } - projectPlusOverriding(fattenedNodes, projectNames, context); + + projectPlusOverriding(flattenedNodes, projectNames, context); } /** - * Helper method to extract a field from a map and cast it to the specified type. Creates a - * SAFE_CAST (makeCast with safe=true) around an INTERNAL_ITEM call. + * Builds the evalAggSamples call to transform pattern with wildcards to numbered tokens and + * compute the tokens map from sample logs. * + * @param parsedNode The UDAF result containing pattern and sample_logs * @param context The Calcite plan context - * @param source The source RexNode containing the map - * @param fieldName The name of the field to extract from the map - * @param targetType The target type to cast to - * @return A RexNode representing SAFE_CAST(INTERNAL_ITEM(source, fieldName)) + * @return RexNode representing the evalAggSamples call result */ - private RexNode extractAndCastMapField( - CalcitePlanContext context, RexNode source, String fieldName, RelDataType targetType) { - return context.rexBuilder.makeCast( - targetType, + private RexNode buildEvalAggSamplesCall(RexNode parsedNode, CalcitePlanContext context) { + // Extract pattern string (with wildcards) from UDAF result + RexNode patternStr = + PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_ITEM, + parsedNode, + context.rexBuilder.makeLiteral(PatternUtils.PATTERN)); + + // Extract sample_logs from UDAF result + RexNode sampleLogs = PPLFuncImpTable.INSTANCE.resolve( context.rexBuilder, BuiltinFunctionName.INTERNAL_ITEM, - source, - context.rexBuilder.makeLiteral(fieldName)), - true, - true); + explicitMapType(context, parsedNode, SqlTypeName.VARCHAR), + context.rexBuilder.makeLiteral(PatternUtils.SAMPLE_LOGS)); + + // Call evalAggSamples to transform pattern (wildcards -> numbered tokens) and compute tokens + return PPLFuncImpTable.INSTANCE.resolve( + context.rexBuilder, + BuiltinFunctionName.INTERNAL_PATTERN_PARSER, + patternStr, + sampleLogs, + context.rexBuilder.makeLiteral(true)); } private void buildExpandRelNode( diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java index 2808b4face5..438151b7da3 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java +++ b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java @@ -5,22 +5,24 @@ package org.opensearch.sql.calcite.udf.udaf; -import com.google.common.collect.ImmutableMap; import java.math.BigDecimal; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; import org.opensearch.sql.calcite.udf.UserDefinedAggFunction; import org.opensearch.sql.calcite.udf.udaf.LogPatternAggFunction.LogParserAccumulator; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.patterns.BrainLogParser; -import org.opensearch.sql.common.patterns.PatternUtils; - +import org.opensearch.sql.common.patterns.PatternAggregationHelpers; + +/** + * User-defined aggregate function for log pattern extraction using the Brain algorithm. This UDAF + * is used for in-memory pattern aggregation in Calcite. For OpenSearch scripted metric pushdown, + * see {@link PatternAggregationHelpers} which provides the same logic with Map-based state. + * + *

Both implementations share the same underlying logic through {@link PatternAggregationHelpers} + * to ensure consistency. + */ public class LogPatternAggFunction implements UserDefinedAggFunction { private int bufferLimit = 100000; private int maxSampleCount = 10; @@ -35,12 +37,11 @@ public LogParserAccumulator init() { @Override public Object result(LogParserAccumulator acc) { - if (acc.size() == 0 && acc.logSize() == 0) { + if (acc.isEmpty()) { return null; } - - return acc.value( - maxSampleCount, variableCountThreshold, thresholdPercentage, showNumberedToken); + return PatternAggregationHelpers.producePatternResult( + acc.state, maxSampleCount, variableCountThreshold, thresholdPercentage, showNumberedToken); } @Override @@ -82,17 +83,21 @@ public LogParserAccumulator add( if (Objects.isNull(field)) { return acc; } + // Store parameters for result() phase this.bufferLimit = bufferLimit; this.maxSampleCount = maxSampleCount; this.showNumberedToken = showNumberedToken; this.variableCountThreshold = variableCountThreshold; this.thresholdPercentage = thresholdPercentage; - acc.evaluate(field); - if (bufferLimit > 0 && acc.logSize() == bufferLimit) { - acc.partialMerge( - maxSampleCount, variableCountThreshold, thresholdPercentage, showNumberedToken); - acc.clearBuffer(); - } + + // Delegate to shared helper logic + PatternAggregationHelpers.addLogToPattern( + acc.state, + field, + maxSampleCount, + bufferLimit, + variableCountThreshold, + thresholdPercentage); return acc; } @@ -146,75 +151,32 @@ public LogParserAccumulator add( this.variableCountThreshold); } + /** + * Accumulator for log pattern aggregation. This is a thin wrapper around the Map-based state used + * by {@link PatternAggregationHelpers}, providing type safety for Calcite UDAF while reusing the + * same underlying logic. + */ public static class LogParserAccumulator implements Accumulator { - private final List logMessages; - public Map> patternGroupMap = new HashMap<>(); - - public int size() { - return patternGroupMap.size(); - } - - public int logSize() { - return logMessages.size(); - } + /** The underlying state map, compatible with PatternAggregationHelpers */ + final Map state; public LogParserAccumulator() { - this.logMessages = new ArrayList<>(); - } - - public void evaluate(String value) { - logMessages.add(value); - } - - public void clearBuffer() { - logMessages.clear(); + this.state = PatternAggregationHelpers.initPatternAccumulator(); } - public void partialMerge(Object... argList) { - if (logMessages.isEmpty()) { - return; - } - assert argList.length == 4 : "partialMerge of LogParserAccumulator requires 4 parameters"; - int maxSampleCount = (int) argList[0]; - BrainLogParser logParser = - new BrainLogParser((int) argList[1], ((Double) argList[2]).floatValue()); - Map> partialPatternGroupMap = - logParser.parseAllLogPatterns(logMessages, maxSampleCount); - patternGroupMap = - PatternUtils.mergePatternGroups(patternGroupMap, partialPatternGroupMap, maxSampleCount); + @SuppressWarnings("unchecked") + public boolean isEmpty() { + List logMessages = (List) state.get("logMessages"); + Map patternGroupMap = (Map) state.get("patternGroupMap"); + return (logMessages == null || logMessages.isEmpty()) + && (patternGroupMap == null || patternGroupMap.isEmpty()); } @Override public Object value(Object... argList) { - partialMerge(argList); - clearBuffer(); - - return patternGroupMap.values().stream() - .sorted( - Comparator.comparing( - m -> (Long) m.get(PatternUtils.PATTERN_COUNT), - Comparator.nullsLast(Comparator.reverseOrder()))) - .map( - m -> { - String pattern = (String) m.get(PatternUtils.PATTERN); - Long count = (Long) m.get(PatternUtils.PATTERN_COUNT); - List sampleLogs = (List) m.get(PatternUtils.SAMPLE_LOGS); - // For aggregation mode, always return pattern with wildcards (<*>, <*IP*>). - // The transformation to numbered tokens (, ) and token - // extraction is done downstream by evalAggSamples in flattenParsedPattern. - // This ensures consistent behavior between UDAF pushdown and regular - // aggregation paths. - return ImmutableMap.of( - PatternUtils.PATTERN, - pattern, // Always return original wildcard format - PatternUtils.PATTERN_COUNT, - count, - PatternUtils.TOKENS, - Collections.EMPTY_MAP, // Tokens computed downstream by evalAggSamples - PatternUtils.SAMPLE_LOGS, - sampleLogs); - }) - .collect(Collectors.toList()); + // This method is not used directly - result() in LogPatternAggFunction handles this + throw new UnsupportedOperationException( + "Use LogPatternAggFunction.result() instead of direct value() call"); } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAF.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAF.java index 6e057861fe4..0e61c7feef2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAF.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAF.java @@ -12,6 +12,7 @@ import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.script.Script; import org.opensearch.search.aggregations.AggregationBuilder; @@ -136,7 +137,7 @@ public Map getFieldTypes() { * @return RexNode representing the dynamic parameter reference */ public RexNode addSpecialVariableRef( - String varName, org.apache.calcite.sql.type.SqlTypeName type) { + String varName, SqlTypeName type) { int index = paramHelper.addSpecialVariable(varName); return rexBuilder.makeDynamicParam(rexBuilder.getTypeFactory().createSqlType(type), index); } From 85cac778b1c03372b4e0fac9c19770efb4c1ec9d Mon Sep 17 00:00:00 2001 From: Songkan Tang Date: Fri, 23 Jan 2026 11:12:50 +0800 Subject: [PATCH 3/9] Organize files Signed-off-by: Songkan Tang --- .../opensearch/sql/calcite/CalciteRelNodeVisitor.java | 11 ++++++++--- .../sql/calcite/udf/udaf/LogPatternAggFunction.java | 7 +------ .../sql/opensearch/request/AggregateAnalyzer.java | 1 + .../script/scriptedmetric}/ScriptedMetricUDAF.java | 5 ++--- .../scriptedmetric}/ScriptedMetricUDAFRegistry.java | 3 ++- .../udaf}/PatternScriptedMetricUDAF.java | 3 ++- 6 files changed, 16 insertions(+), 14 deletions(-) rename opensearch/src/main/java/org/opensearch/sql/opensearch/{request => storage/script/scriptedmetric}/ScriptedMetricUDAF.java (98%) rename opensearch/src/main/java/org/opensearch/sql/opensearch/{request => storage/script/scriptedmetric}/ScriptedMetricUDAFRegistry.java (93%) rename opensearch/src/main/java/org/opensearch/sql/opensearch/{request => storage/script/scriptedmetric/udaf}/PatternScriptedMetricUDAF.java (97%) diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 60d23eef4ee..6b9ca4c52f3 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -3322,7 +3322,8 @@ private void flattenParsedPattern( // 2. Add pattern_count when in aggregation mode (from original parsedNode) if (flattenPatternAggResult) { - RelDataType bigintType = context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT); + RelDataType bigintType = + context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT); RexNode patternCountExpr = context.rexBuilder.makeCast( bigintType, @@ -3340,8 +3341,12 @@ private void flattenParsedPattern( // 3. Add tokens when showNumberedToken is enabled if (showNumberedToken) { RelDataType tokensType = - context.rexBuilder.getTypeFactory().createMapType( - varcharType, context.rexBuilder.getTypeFactory().createArrayType(varcharType, -1)); + context + .rexBuilder + .getTypeFactory() + .createMapType( + varcharType, + context.rexBuilder.getTypeFactory().createArrayType(varcharType, -1)); RexNode tokensExpr = context.rexBuilder.makeCast( tokensType, diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java index 438151b7da3..dc28fd296b1 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java +++ b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java @@ -92,12 +92,7 @@ public LogParserAccumulator add( // Delegate to shared helper logic PatternAggregationHelpers.addLogToPattern( - acc.state, - field, - maxSampleCount, - bufferLimit, - variableCountThreshold, - thresholdPercentage); + acc.state, field, maxSampleCount, bufferLimit, variableCountThreshold, thresholdPercentage); return acc; } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java index 71c53d3c472..1b29034c8c6 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java @@ -98,6 +98,7 @@ import org.opensearch.sql.opensearch.response.agg.StatsParser; import org.opensearch.sql.opensearch.response.agg.TopHitsParser; import org.opensearch.sql.opensearch.storage.script.aggregation.dsl.CompositeAggregationBuilder; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.ScriptedMetricUDAFRegistry; import org.opensearch.sql.utils.Utils; /** diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAF.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAF.java similarity index 98% rename from opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAF.java rename to opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAF.java index 0e61c7feef2..8bec1db3520 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAF.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAF.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.opensearch.request; +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; import java.util.List; import java.util.Map; @@ -136,8 +136,7 @@ public Map getFieldTypes() { * @param type The SQL type for the parameter * @return RexNode representing the dynamic parameter reference */ - public RexNode addSpecialVariableRef( - String varName, SqlTypeName type) { + public RexNode addSpecialVariableRef(String varName, SqlTypeName type) { int index = paramHelper.addSpecialVariable(varName); return rexBuilder.makeDynamicParam(rexBuilder.getTypeFactory().createSqlType(type), index); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAFRegistry.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAFRegistry.java similarity index 93% rename from opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAFRegistry.java rename to opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAFRegistry.java index dc2012100a3..db181afd68b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ScriptedMetricUDAFRegistry.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricUDAFRegistry.java @@ -3,12 +3,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.opensearch.request; +package org.opensearch.sql.opensearch.storage.script.scriptedmetric; import java.util.HashMap; import java.util.Map; import java.util.Optional; import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.udaf.PatternScriptedMetricUDAF; /** * Registry for ScriptedMetricUDAF implementations. diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PatternScriptedMetricUDAF.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/udaf/PatternScriptedMetricUDAF.java similarity index 97% rename from opensearch/src/main/java/org/opensearch/sql/opensearch/request/PatternScriptedMetricUDAF.java rename to opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/udaf/PatternScriptedMetricUDAF.java index c4983e9f4f2..06caca43c7b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PatternScriptedMetricUDAF.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/udaf/PatternScriptedMetricUDAF.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.opensearch.request; +package org.opensearch.sql.opensearch.storage.script.scriptedmetric.udaf; import java.math.BigDecimal; import java.util.ArrayList; @@ -13,6 +13,7 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.PPLBuiltinOperators; +import org.opensearch.sql.opensearch.storage.script.scriptedmetric.ScriptedMetricUDAF; /** * Scripted metric UDAF implementation for the Pattern (BRAIN) aggregation function. From a881d0eb9b5c446faf54e1253a43534e18aef026 Mon Sep 17 00:00:00 2001 From: Songkan Tang Date: Fri, 23 Jan 2026 12:56:57 +0800 Subject: [PATCH 4/9] Fix tests Signed-off-by: Songkan Tang --- .../calcite/remote/CalcitePPLPatternsIT.java | 37 ------------------- .../org/opensearch/sql/ppl/ExplainIT.java | 12 ++++++ ...lain_patterns_brain_agg_group_by_push.yaml | 19 ++++++++++ ...lain_patterns_brain_agg_group_by_push.yaml | 21 +++++++++++ .../explain_patterns_brain_agg_push.yaml | 4 +- .../value/OpenSearchExprValueFactory.java | 7 +++- 6 files changed, 60 insertions(+), 40 deletions(-) create mode 100644 integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_group_by_push.yaml create mode 100644 integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_group_by_push.yaml diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java index b298e945856..481ae89dfbe 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java @@ -588,43 +588,6 @@ public void testBrainAggregationMode_UDAFPushdown_NotShowNumberedToken() throws "PacketResponder failed for blk_-1547954353065580372"))); } - // TODO: Re-enable this test once explain plan output format is validated - // The functional tests (testBrainAggregationMode_UDAFPushdown_NotShowNumberedToken and - // testBrainAggregationMode_UDAFPushdown_ShowNumberedToken) prove that UDAF pushdown works - // correctly. This test verifies the explain plan format, which may need adjustment. - @Test - public void testBrainAggregationMode_UDAFPushdown_VerifyPlan() throws IOException { - // Verify that UDAF pushdown is happening by checking the explain plan - String query = - String.format( - "source=%s | patterns content method=brain mode=aggregation" - + " variable_count_threshold=5", - TEST_INDEX_HDFS_LOGS); - - // Get the explain plan - String explainResult = explainQueryYaml(query); - System.out.println(explainResult); - - // Verify the plan contains evidence of UDAF pushdown - // When UDAF is pushed down, the plan should show: - // 1. CalciteLogicalIndexScan with AGGREGATION pushdown type - // 2. No LogicalAggregate node in the physical plan (it's pushed down) - assertTrue( - "Expected plan to contain CalciteLogicalIndexScan", - explainResult.contains("CalciteLogicalIndexScan")); - assertTrue( - "Expected plan to show AGGREGATION pushdown", - explainResult.contains("AGGREGATION") || explainResult.contains("aggregation")); - - // The plan should NOT contain a separate Aggregate node above the scan - // since it's pushed down to OpenSearch - assertFalse( - "Expected no separate LogicalAggregate node when UDAF is pushed down", - explainResult.contains("LogicalAggregate") - && explainResult.indexOf("LogicalAggregate") - > explainResult.indexOf("CalciteLogicalIndexScan")); - } - @Test public void testBrainAggregationMode_UDAFPushdown_ShowNumberedToken() throws IOException { // Test UDAF pushdown for patterns BRAIN aggregation mode with numbered tokens diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index 7e8c56f2d9b..bf50250c4db 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -441,6 +441,18 @@ public void testPatternsBrainMethodWithAggPushDownExplain() throws IOException { + "| patterns email method=brain mode=aggregation show_numbered_token=true")); } + @Test + public void testPatternsBrainMethodWithAggGroupByPushDownExplain() throws IOException { + // Patterns with group by is only supported in Calcite mode + Assume.assumeTrue(isCalciteEnabled()); + String expected = loadExpectedPlan("explain_patterns_brain_agg_group_by_push.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account" + + "| patterns email by gender method=brain mode=aggregation show_numbered_token=true")); + } + @Test public void testStatsBySpan() throws IOException { String expected = loadExpectedPlan("explain_stats_by_span.json"); diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_group_by_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_group_by_push.yaml new file mode 100644 index 00000000000..007f93e365c --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_patterns_brain_agg_group_by_push.yaml @@ -0,0 +1,19 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(gender=[$0], patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2, 'pattern')), ITEM($2, 'sample_logs'), true), 'pattern'))], pattern_count=[SAFE_CAST(ITEM($2, 'pattern_count'))], tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2, 'pattern')), ITEM($2, 'sample_logs'), true), 'tokens'))], sample_logs=[SAFE_CAST(ITEM($2, 'sample_logs'))]) + LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{1}]) + LogicalAggregate(group=[{0}], patterns_field=[pattern($1, $2, $3, $4)]) + LogicalProject(gender=[$4], email=[$9], $f17=[10], $f18=[100000], $f19=[true]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + Uncollect + LogicalProject(patterns_field=[$cor0.patterns_field]) + LogicalValues(tuples=[[{ 0 }]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..2=[{inputs}], expr#3=['pattern'], expr#4=[ITEM($t2, $t3)], expr#5=[SAFE_CAST($t4)], expr#6=['sample_logs'], expr#7=[ITEM($t2, $t6)], expr#8=[true], expr#9=[PATTERN_PARSER($t5, $t7, $t8)], expr#10=[ITEM($t9, $t3)], expr#11=[SAFE_CAST($t10)], expr#12=['pattern_count'], expr#13=[ITEM($t2, $t12)], expr#14=[SAFE_CAST($t13)], expr#15=['tokens'], expr#16=[ITEM($t9, $t15)], expr#17=[SAFE_CAST($t16)], expr#18=[SAFE_CAST($t7)], gender=[$t0], patterns_field=[$t11], pattern_count=[$t14], tokens=[$t17], sample_logs=[$t18]) + EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{1}]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={0},patterns_field=pattern($1, $2, $3, $4))], OpenSearchRequestBuilder(sourceBuilder={"from":0,"size":0,"timeout":"1m","aggregations":{"composite_buckets":{"composite":{"size":1000,"sources":[{"gender":{"terms":{"field":"gender.keyword","missing_bucket":true,"missing_order":"first","order":"asc"}}}]},"aggregations":{"patterns_field":{"scripted_metric":{"init_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQCCnsKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX0lOSVRfVURGIiwKICAgICJraW5kIjogIk9USEVSX0ZVTkNUSU9OIiwKICAgICJzeW50YXgiOiAiRlVOQ1RJT04iCiAgfSwKICAib3BlcmFuZHMiOiBbCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAwLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJBTlkiLAogICAgICAgICJudWxsYWJsZSI6IGZhbHNlLAogICAgICAgICJwcmVjaXNpb24iOiAtMSwKICAgICAgICAic2NhbGUiOiAtMjE0NzQ4MzY0OAogICAgICB9CiAgICB9CiAgXSwKICAiY2xhc3MiOiAib3JnLm9wZW5zZWFyY2guc3FsLmV4cHJlc3Npb24uZnVuY3Rpb24uVXNlckRlZmluZWRGdW5jdGlvbkJ1aWxkZXIkMSIsCiAgInR5cGUiOiB7CiAgICAidHlwZSI6ICJBTlkiLAogICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAicHJlY2lzaW9uIjogLTEsCiAgICAic2NhbGUiOiAtMjE0NzQ4MzY0OAogIH0sCiAgImRldGVybWluaXN0aWMiOiB0cnVlLAogICJkeW5hbWljIjogZmFsc2UKfQ==\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3],"DIGESTS":["state"]}},"map_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQEW3sKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX0FERF9VREYiLAogICAgImtpbmQiOiAiT1RIRVJfRlVOQ1RJT04iLAogICAgInN5bnRheCI6ICJGVU5DVElPTiIKICB9LAogICJvcGVyYW5kcyI6IFsKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDAsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAgICAgInByZWNpc2lvbiI6IC0xLAogICAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAxLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJWQVJDSEFSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlLAogICAgICAgICJwcmVjaXNpb24iOiAtMQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogMiwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogMywKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogNCwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiSU5URUdFUiIsCiAgICAgICAgIm51bGxhYmxlIjogdHJ1ZQogICAgICB9CiAgICB9LAogICAgewogICAgICAiZHluYW1pY1BhcmFtIjogNSwKICAgICAgInR5cGUiOiB7CiAgICAgICAgInR5cGUiOiAiRE9VQkxFIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0KICBdLAogICJjbGFzcyI6ICJvcmcub3BlbnNlYXJjaC5zcWwuZXhwcmVzc2lvbi5mdW5jdGlvbi5Vc2VyRGVmaW5lZEZ1bmN0aW9uQnVpbGRlciQxIiwKICAidHlwZSI6IHsKICAgICJ0eXBlIjogIkFOWSIsCiAgICAibnVsbGFibGUiOiBmYWxzZSwKICAgICJwcmVjaXNpb24iOiAtMSwKICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgfSwKICAiZGV0ZXJtaW5pc3RpYyI6IHRydWUsCiAgImR5bmFtaWMiOiBmYWxzZQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3,0,2,2,2,2],"DIGESTS":["state","email.keyword",10,100000,5,0.3]}},"combine_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQAgHsKICAiZHluYW1pY1BhcmFtIjogMCwKICAidHlwZSI6IHsKICAgICJ0eXBlIjogIkFOWSIsCiAgICAibnVsbGFibGUiOiBmYWxzZSwKICAgICJwcmVjaXNpb24iOiAtMSwKICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgfQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3],"DIGESTS":["state"]}},"reduce_script":{"source":"{\"langType\":\"calcite\",\"script\":\"rO0ABXQEH3sKICAib3AiOiB7CiAgICAibmFtZSI6ICJQQVRURVJOX1JFU1VMVF9VREYiLAogICAgImtpbmQiOiAiT1RIRVJfRlVOQ1RJT04iLAogICAgInN5bnRheCI6ICJGVU5DVElPTiIKICB9LAogICJvcGVyYW5kcyI6IFsKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDAsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICAgIm51bGxhYmxlIjogZmFsc2UsCiAgICAgICAgInByZWNpc2lvbiI6IC0xLAogICAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAxLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJJTlRFR0VSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAyLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJJTlRFR0VSIiwKICAgICAgICAibnVsbGFibGUiOiB0cnVlCiAgICAgIH0KICAgIH0sCiAgICB7CiAgICAgICJkeW5hbWljUGFyYW0iOiAzLAogICAgICAidHlwZSI6IHsKICAgICAgICAidHlwZSI6ICJET1VCTEUiLAogICAgICAgICJudWxsYWJsZSI6IHRydWUKICAgICAgfQogICAgfSwKICAgIHsKICAgICAgImR5bmFtaWNQYXJhbSI6IDQsCiAgICAgICJ0eXBlIjogewogICAgICAgICJ0eXBlIjogIkJPT0xFQU4iLAogICAgICAgICJudWxsYWJsZSI6IHRydWUKICAgICAgfQogICAgfQogIF0sCiAgImNsYXNzIjogIm9yZy5vcGVuc2VhcmNoLnNxbC5leHByZXNzaW9uLmZ1bmN0aW9uLlVzZXJEZWZpbmVkRnVuY3Rpb25CdWlsZGVyJDEiLAogICJ0eXBlIjogewogICAgInR5cGUiOiAiQVJSQVkiLAogICAgIm51bGxhYmxlIjogdHJ1ZSwKICAgICJjb21wb25lbnQiOiB7CiAgICAgICJ0eXBlIjogIkFOWSIsCiAgICAgICJudWxsYWJsZSI6IGZhbHNlLAogICAgICAicHJlY2lzaW9uIjogLTEsCiAgICAgICJzY2FsZSI6IC0yMTQ3NDgzNjQ4CiAgICB9CiAgfSwKICAiZGV0ZXJtaW5pc3RpYyI6IHRydWUsCiAgImR5bmFtaWMiOiBmYWxzZQp9\"}","lang":"opensearch_compounded_script","params":{"utcTimestamp": 0,"SOURCES":[3,2,2,2,2],"DIGESTS":["states",10,5,0.3,true]}}}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) + EnumerableUncollect + EnumerableCalc(expr#0=[{inputs}], expr#1=[$cor0], expr#2=[$t1.patterns_field], patterns_field=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_group_by_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_group_by_push.yaml new file mode 100644 index 00000000000..acb0a0d0113 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_group_by_push.yaml @@ -0,0 +1,21 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(gender=[$0], patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2, 'pattern')), ITEM($2, 'sample_logs'), true), 'pattern'))], pattern_count=[SAFE_CAST(ITEM($2, 'pattern_count'))], tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($2, 'pattern')), ITEM($2, 'sample_logs'), true), 'tokens'))], sample_logs=[SAFE_CAST(ITEM($2, 'sample_logs'))]) + LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{1}]) + LogicalAggregate(group=[{0}], patterns_field=[pattern($1, $2, $3, $4)]) + LogicalProject(gender=[$4], email=[$9], $f17=[10], $f18=[100000], $f19=[true]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + Uncollect + LogicalProject(patterns_field=[$cor0.patterns_field]) + LogicalValues(tuples=[[{ 0 }]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..2=[{inputs}], expr#3=['pattern'], expr#4=[ITEM($t2, $t3)], expr#5=[SAFE_CAST($t4)], expr#6=['sample_logs'], expr#7=[ITEM($t2, $t6)], expr#8=[true], expr#9=[PATTERN_PARSER($t5, $t7, $t8)], expr#10=[ITEM($t9, $t3)], expr#11=[SAFE_CAST($t10)], expr#12=['pattern_count'], expr#13=[ITEM($t2, $t12)], expr#14=[SAFE_CAST($t13)], expr#15=['tokens'], expr#16=[ITEM($t9, $t15)], expr#17=[SAFE_CAST($t16)], expr#18=[SAFE_CAST($t7)], gender=[$t0], patterns_field=[$t11], pattern_count=[$t14], tokens=[$t17], sample_logs=[$t18]) + EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{1}]) + EnumerableAggregate(group=[{0}], patterns_field=[pattern($1, $2, $3, $4)]) + EnumerableCalc(expr#0..16=[{inputs}], expr#17=[10], expr#18=[100000], expr#19=[true], gender=[$t4], email=[$t9], $f17=[$t17], $f18=[$t18], $f19=[$t19]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + EnumerableUncollect + EnumerableCalc(expr#0=[{inputs}], expr#1=[$cor0], expr#2=[$t1.patterns_field], patterns_field=[$t2]) + EnumerableValues(tuples=[[{ 0 }]]) diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_push.yaml b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_push.yaml index bc9bc027e34..d58ef2abc1d 100644 --- a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_push.yaml +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_patterns_brain_agg_push.yaml @@ -1,7 +1,7 @@ calcite: logical: | LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) - LogicalProject(patterns_field=[SAFE_CAST(ITEM($1, 'pattern'))], pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))], tokens=[SAFE_CAST(ITEM($1, 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1, 'sample_logs'))]) + LogicalProject(patterns_field=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1, 'pattern')), ITEM($1, 'sample_logs'), true), 'pattern'))], pattern_count=[SAFE_CAST(ITEM($1, 'pattern_count'))], tokens=[SAFE_CAST(ITEM(PATTERN_PARSER(SAFE_CAST(ITEM($1, 'pattern')), ITEM($1, 'sample_logs'), true), 'tokens'))], sample_logs=[SAFE_CAST(ITEM($1, 'sample_logs'))]) LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) LogicalAggregate(group=[{}], patterns_field=[pattern($0, $1, $2, $3)]) LogicalProject(email=[$9], $f17=[10], $f18=[100000], $f19=[true]) @@ -11,7 +11,7 @@ calcite: LogicalValues(tuples=[[{ 0 }]]) physical: | EnumerableLimit(fetch=[10000]) - EnumerableCalc(expr#0..1=[{inputs}], expr#2=['pattern'], expr#3=[ITEM($t1, $t2)], expr#4=[SAFE_CAST($t3)], expr#5=['pattern_count'], expr#6=[ITEM($t1, $t5)], expr#7=[SAFE_CAST($t6)], expr#8=['tokens'], expr#9=[ITEM($t1, $t8)], expr#10=[SAFE_CAST($t9)], expr#11=['sample_logs'], expr#12=[ITEM($t1, $t11)], expr#13=[SAFE_CAST($t12)], patterns_field=[$t4], pattern_count=[$t7], tokens=[$t10], sample_logs=[$t13]) + EnumerableCalc(expr#0..1=[{inputs}], expr#2=['pattern'], expr#3=[ITEM($t1, $t2)], expr#4=[SAFE_CAST($t3)], expr#5=['sample_logs'], expr#6=[ITEM($t1, $t5)], expr#7=[true], expr#8=[PATTERN_PARSER($t4, $t6, $t7)], expr#9=[ITEM($t8, $t2)], expr#10=[SAFE_CAST($t9)], expr#11=['pattern_count'], expr#12=[ITEM($t1, $t11)], expr#13=[SAFE_CAST($t12)], expr#14=['tokens'], expr#15=[ITEM($t8, $t14)], expr#16=[SAFE_CAST($t15)], expr#17=[SAFE_CAST($t6)], patterns_field=[$t10], pattern_count=[$t13], tokens=[$t16], sample_logs=[$t17]) EnumerableCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) EnumerableAggregate(group=[{}], patterns_field=[pattern($0, $1, $2, $3)]) EnumerableCalc(expr#0..16=[{inputs}], expr#17=[10], expr#18=[100000], expr#19=[true], email=[$t9], $f17=[$t17], $f18=[$t18], $f19=[$t19]) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java index dd39c794e1a..e66efd090af 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java @@ -198,7 +198,12 @@ private ExprValue parse( // Check for arrays first, even if field type is not defined in mapping. // This handles nested arrays in aggregation results where inner fields // (like sample_logs in pattern aggregation) may not have type mappings. - if (content.isArray() && (fieldType.isEmpty() || supportArrays)) { + // Exclude GeoPoint types as they have special array handling (e.g., [lon, lat] format). + if (content.isArray() + && (fieldType.isEmpty() || supportArrays) + && !fieldType + .map(t -> t.equals(OpenSearchDataType.of(OpenSearchDataType.MappingType.GeoPoint))) + .orElse(false)) { ExprType type = fieldType.orElse(ARRAY); return parseArray(content, field, type, supportArrays); } From 595be1caf0c4b4df6214aeb531a081b3e82134ac Mon Sep 17 00:00:00 2001 From: Songkan Tang Date: Fri, 23 Jan 2026 13:27:25 +0800 Subject: [PATCH 5/9] Add udaf pushdown enabled setting Signed-off-by: Songkan Tang --- .../sql/common/setting/Settings.java | 1 + .../calcite/remote/CalcitePPLPatternsIT.java | 283 ++++++++++-------- .../org/opensearch/sql/ppl/ExplainIT.java | 46 ++- .../opensearch/request/AggregateAnalyzer.java | 31 +- .../setting/OpenSearchSettings.java | 14 + .../storage/scan/CalciteLogicalIndexScan.java | 9 +- 6 files changed, 227 insertions(+), 157 deletions(-) diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index 96fe2e04eea..9e766c0aa37 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -41,6 +41,7 @@ public enum Key { CALCITE_ENGINE_ENABLED("plugins.calcite.enabled"), CALCITE_FALLBACK_ALLOWED("plugins.calcite.fallback.allowed"), CALCITE_PUSHDOWN_ENABLED("plugins.calcite.pushdown.enabled"), + CALCITE_UDAF_PUSHDOWN_ENABLED("plugins.calcite.udaf_pushdown.enabled"), CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR( "plugins.calcite.pushdown.rowcount.estimation.factor"), CALCITE_SUPPORT_ALL_JOIN_TYPES("plugins.calcite.all_join_types.allowed"), diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java index 481ae89dfbe..981302aadfe 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java @@ -19,6 +19,7 @@ import java.io.IOException; import org.json.JSONObject; import org.junit.Test; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.ppl.PPLIntegTestCase; public class CalcitePPLPatternsIT extends PPLIntegTestCase { @@ -535,145 +536,165 @@ public void testBrainParseWithUUID_ShowNumberedToken() throws IOException { public void testBrainAggregationMode_UDAFPushdown_NotShowNumberedToken() throws IOException { // Test UDAF pushdown for patterns BRAIN aggregation mode // This verifies that the query is pushed down to OpenSearch as a scripted metric aggregation - JSONObject result = - executeQuery( - String.format( - "source=%s | patterns content method=brain mode=aggregation" - + " variable_count_threshold=5", - TEST_INDEX_HDFS_LOGS)); - System.out.println(result.toString()); + // UDAF pushdown is disabled by default, enable it for this test + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "true")); + try { + JSONObject result = + executeQuery( + String.format( + "source=%s | patterns content method=brain mode=aggregation" + + " variable_count_threshold=5", + TEST_INDEX_HDFS_LOGS)); + System.out.println(result.toString()); - // Verify schema matches expected output - verifySchema( - result, - schema("patterns_field", "string"), - schema("pattern_count", "bigint"), - schema("sample_logs", "array")); + // Verify schema matches expected output + verifySchema( + result, + schema("patterns_field", "string"), + schema("pattern_count", "bigint"), + schema("sample_logs", "array")); - // Verify data rows - should match the non-pushdown results - verifyDataRows( - result, - rows( - "Verification succeeded <*> blk_<*>", - 2, - ImmutableList.of( - "Verification succeeded for blk_-1547954353065580372", - "Verification succeeded for blk_6996194389878584395")), - rows( - "BLOCK* NameSystem.addStoredBlock: blockMap updated: <*IP*> is added to blk_<*>" - + " size <*>", - 2, - ImmutableList.of( - "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.31.85:50010 is added to" - + " blk_-7017553867379051457 size 67108864", - "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.107.19:50010 is added" - + " to blk_-3249711809227781266 size 67108864")), - rows( - "<*> NameSystem.allocateBlock:" - + " /user/root/sortrand/_temporary/_task_<*>_<*>_r_<*>_<*>/part<*>" - + " blk_<*>", - 2, - ImmutableList.of( - "BLOCK* NameSystem.allocateBlock:" - + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000296_0/part-00296." - + " blk_-6620182933895093708", - "BLOCK* NameSystem.allocateBlock:" - + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000318_0/part-00318." - + " blk_2096692261399680562")), - rows( - "PacketResponder failed <*> blk_<*>", - 2, - ImmutableList.of( - "PacketResponder failed for blk_6996194389878584395", - "PacketResponder failed for blk_-1547954353065580372"))); + // Verify data rows - should match the non-pushdown results + verifyDataRows( + result, + rows( + "Verification succeeded <*> blk_<*>", + 2, + ImmutableList.of( + "Verification succeeded for blk_-1547954353065580372", + "Verification succeeded for blk_6996194389878584395")), + rows( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: <*IP*> is added to blk_<*>" + + " size <*>", + 2, + ImmutableList.of( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.31.85:50010 is added" + + " to blk_-7017553867379051457 size 67108864", + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.107.19:50010 is added" + + " to blk_-3249711809227781266 size 67108864")), + rows( + "<*> NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_<*>_<*>_r_<*>_<*>/part<*>" + + " blk_<*>", + 2, + ImmutableList.of( + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000296_0/part-00296." + + " blk_-6620182933895093708", + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000318_0/part-00318." + + " blk_2096692261399680562")), + rows( + "PacketResponder failed <*> blk_<*>", + 2, + ImmutableList.of( + "PacketResponder failed for blk_6996194389878584395", + "PacketResponder failed for blk_-1547954353065580372"))); + } finally { + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "false")); + } } @Test public void testBrainAggregationMode_UDAFPushdown_ShowNumberedToken() throws IOException { // Test UDAF pushdown for patterns BRAIN aggregation mode with numbered tokens - JSONObject result = - executeQuery( - String.format( - "source=%s | patterns content method=brain mode=aggregation" - + " show_numbered_token=true variable_count_threshold=5", - TEST_INDEX_HDFS_LOGS)); - System.out.println(result.toString()); + // UDAF pushdown is disabled by default, enable it for this test + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "true")); + try { + JSONObject result = + executeQuery( + String.format( + "source=%s | patterns content method=brain mode=aggregation" + + " show_numbered_token=true variable_count_threshold=5", + TEST_INDEX_HDFS_LOGS)); + System.out.println(result.toString()); - // Verify schema includes tokens field - verifySchema( - result, - schema("patterns_field", "string"), - schema("pattern_count", "bigint"), - schema("tokens", "struct"), - schema("sample_logs", "array")); + // Verify schema includes tokens field + verifySchema( + result, + schema("patterns_field", "string"), + schema("pattern_count", "bigint"), + schema("tokens", "struct"), + schema("sample_logs", "array")); - // Verify data rows with tokens - verifyDataRows( - result, - rows( - "Verification succeeded blk_", - 2, - ImmutableMap.of( - "", - ImmutableList.of("for", "for"), - "", - ImmutableList.of("-1547954353065580372", "6996194389878584395")), - ImmutableList.of( - "Verification succeeded for blk_-1547954353065580372", - "Verification succeeded for blk_6996194389878584395")), - rows( - "BLOCK* NameSystem.addStoredBlock: blockMap updated: is added to blk_" - + " size ", - 2, - ImmutableMap.of( - "", - ImmutableList.of("10.251.31.85:50010", "10.251.107.19:50010"), - "", - ImmutableList.of("67108864", "67108864"), - "", - ImmutableList.of("-7017553867379051457", "-3249711809227781266")), - ImmutableList.of( - "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.31.85:50010 is added to" - + " blk_-7017553867379051457 size 67108864", - "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.107.19:50010 is added" - + " to blk_-3249711809227781266 size 67108864")), - rows( - " NameSystem.allocateBlock:" - + " /user/root/sortrand/_temporary/_task___r__/part" - + " blk_", - 2, - ImmutableMap.of( - "", - ImmutableList.of("0", "0"), - "", - ImmutableList.of("000296", "000318"), - "", - ImmutableList.of("-6620182933895093708", "2096692261399680562"), - "", - ImmutableList.of("-00296.", "-00318."), - "", - ImmutableList.of("BLOCK*", "BLOCK*"), - "", - ImmutableList.of("0002", "0002"), - "", - ImmutableList.of("200811092030", "200811092030")), - ImmutableList.of( - "BLOCK* NameSystem.allocateBlock:" - + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000296_0/part-00296." - + " blk_-6620182933895093708", - "BLOCK* NameSystem.allocateBlock:" - + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000318_0/part-00318." - + " blk_2096692261399680562")), - rows( - "PacketResponder failed blk_", - 2, - ImmutableMap.of( - "", - ImmutableList.of("for", "for"), - "", - ImmutableList.of("6996194389878584395", "-1547954353065580372")), - ImmutableList.of( - "PacketResponder failed for blk_6996194389878584395", - "PacketResponder failed for blk_-1547954353065580372"))); + // Verify data rows with tokens + verifyDataRows( + result, + rows( + "Verification succeeded blk_", + 2, + ImmutableMap.of( + "", + ImmutableList.of("for", "for"), + "", + ImmutableList.of("-1547954353065580372", "6996194389878584395")), + ImmutableList.of( + "Verification succeeded for blk_-1547954353065580372", + "Verification succeeded for blk_6996194389878584395")), + rows( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: is added to" + + " blk_ size ", + 2, + ImmutableMap.of( + "", + ImmutableList.of("10.251.31.85:50010", "10.251.107.19:50010"), + "", + ImmutableList.of("67108864", "67108864"), + "", + ImmutableList.of("-7017553867379051457", "-3249711809227781266")), + ImmutableList.of( + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.31.85:50010 is added" + + " to blk_-7017553867379051457 size 67108864", + "BLOCK* NameSystem.addStoredBlock: blockMap updated: 10.251.107.19:50010 is added" + + " to blk_-3249711809227781266 size 67108864")), + rows( + " NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task___r__/part" + + " blk_", + 2, + ImmutableMap.of( + "", + ImmutableList.of("0", "0"), + "", + ImmutableList.of("000296", "000318"), + "", + ImmutableList.of("-6620182933895093708", "2096692261399680562"), + "", + ImmutableList.of("-00296.", "-00318."), + "", + ImmutableList.of("BLOCK*", "BLOCK*"), + "", + ImmutableList.of("0002", "0002"), + "", + ImmutableList.of("200811092030", "200811092030")), + ImmutableList.of( + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000296_0/part-00296." + + " blk_-6620182933895093708", + "BLOCK* NameSystem.allocateBlock:" + + " /user/root/sortrand/_temporary/_task_200811092030_0002_r_000318_0/part-00318." + + " blk_2096692261399680562")), + rows( + "PacketResponder failed blk_", + 2, + ImmutableMap.of( + "", + ImmutableList.of("for", "for"), + "", + ImmutableList.of("6996194389878584395", "-1547954353065580372")), + ImmutableList.of( + "PacketResponder failed for blk_6996194389878584395", + "PacketResponder failed for blk_-1547954353065580372"))); + } finally { + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "false")); + } } } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index bf50250c4db..7d192317564 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -433,24 +433,44 @@ public void testPatternsSimplePatternMethodWithAggPushDownExplain() throws IOExc @Test public void testPatternsBrainMethodWithAggPushDownExplain() throws IOException { - String expected = loadExpectedPlan("explain_patterns_brain_agg_push.yaml"); - assertYamlEqualsIgnoreId( - expected, - explainQueryYaml( - "source=opensearch-sql_test_index_account" - + "| patterns email method=brain mode=aggregation show_numbered_token=true")); + // UDAF pushdown is disabled by default, enable it for this test + Assume.assumeTrue(isCalciteEnabled()); + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "true")); + try { + String expected = loadExpectedPlan("explain_patterns_brain_agg_push.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account" + + "| patterns email method=brain mode=aggregation show_numbered_token=true")); + } finally { + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "false")); + } } @Test public void testPatternsBrainMethodWithAggGroupByPushDownExplain() throws IOException { - // Patterns with group by is only supported in Calcite mode + // Patterns with group by is only supported in Calcite mode with UDAF pushdown enabled Assume.assumeTrue(isCalciteEnabled()); - String expected = loadExpectedPlan("explain_patterns_brain_agg_group_by_push.yaml"); - assertYamlEqualsIgnoreId( - expected, - explainQueryYaml( - "source=opensearch-sql_test_index_account" - + "| patterns email by gender method=brain mode=aggregation show_numbered_token=true")); + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "true")); + try { + String expected = loadExpectedPlan("explain_patterns_brain_agg_group_by_push.yaml"); + assertYamlEqualsIgnoreId( + expected, + explainQueryYaml( + "source=opensearch-sql_test_index_account| patterns email by gender method=brain" + + " mode=aggregation show_numbered_token=true")); + } finally { + updateClusterSettings( + new ClusterSetting( + "persistent", Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), "false")); + } } @Test diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java index 1b29034c8c6..e4149c41ddd 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java @@ -142,6 +142,7 @@ public static class AggregateBuilderHelper { final RelOptCluster cluster; final boolean bucketNullable; final int queryBucketSize; + final boolean udafPushdownEnabled; > T build(RexNode node, T aggBuilder) { return build(node, aggBuilder::field, aggBuilder::script); @@ -614,18 +615,24 @@ yield switch (functionName) { !args.isEmpty() ? args.getFirst().getKey() : null, AggregationBuilders.cardinality(aggName)), new SingleValueParser(aggName)); - case INTERNAL_PATTERN -> - ScriptedMetricUDAFRegistry.INSTANCE - .lookup(functionName) - .map( - udaf -> - udaf.buildAggregation( - args, aggName, helper.cluster, helper.rowType, helper.fieldTypes)) - .orElseThrow( - () -> - new AggregateAnalyzerException( - String.format( - "No scripted metric UDAF registered for %s", functionName))); + case INTERNAL_PATTERN -> { + if (!helper.udafPushdownEnabled) { + throw new UnsupportedOperationException( + "UDAF pushdown is disabled. Enable it via cluster setting" + + " 'plugins.calcite.udaf_pushdown.enabled'"); + } + yield ScriptedMetricUDAFRegistry.INSTANCE + .lookup(functionName) + .map( + udaf -> + udaf.buildAggregation( + args, aggName, helper.cluster, helper.rowType, helper.fieldTypes)) + .orElseThrow( + () -> + new AggregateAnalyzerException( + String.format( + "No scripted metric UDAF registered for %s", functionName))); + } default -> throw new AggregateAnalyzer.AggregateAnalyzerException( String.format("Unsupported push-down aggregator %s", aggCall.getAggregation())); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index bd8001f589d..31ffb28924b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -172,6 +172,13 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting CALCITE_UDAF_PUSHDOWN_ENABLED_SETTING = + Setting.boolSetting( + Key.CALCITE_UDAF_PUSHDOWN_ENABLED.getKeyValue(), + false, + Setting.Property.NodeScope, + Setting.Property.Dynamic); + public static final Setting CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR_SETTING = Setting.doubleSetting( Key.CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR.getKeyValue(), @@ -455,6 +462,12 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.CALCITE_PUSHDOWN_ENABLED, CALCITE_PUSHDOWN_ENABLED_SETTING, new Updater(Key.CALCITE_PUSHDOWN_ENABLED)); + register( + settingBuilder, + clusterSettings, + Key.CALCITE_UDAF_PUSHDOWN_ENABLED, + CALCITE_UDAF_PUSHDOWN_ENABLED_SETTING, + new Updater(Key.CALCITE_UDAF_PUSHDOWN_ENABLED)); register( settingBuilder, clusterSettings, @@ -656,6 +669,7 @@ public static List> pluginSettings() { .add(CALCITE_ENGINE_ENABLED_SETTING) .add(CALCITE_FALLBACK_ALLOWED_SETTING) .add(CALCITE_PUSHDOWN_ENABLED_SETTING) + .add(CALCITE_UDAF_PUSHDOWN_ENABLED_SETTING) .add(CALCITE_PUSHDOWN_ROWCOUNT_ESTIMATION_FACTOR_SETTING) .add(CALCITE_SUPPORT_ALL_JOIN_TYPES_SETTING) .add(DEFAULT_PATTERN_METHOD_SETTING) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java index dbe8306d4b2..a2b714d6360 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java @@ -386,9 +386,16 @@ public AbstractRelNode pushDownAggregate(Aggregate aggregate, @Nullable Project } int queryBucketSize = osIndex.getQueryBucketSize(); boolean bucketNullable = !PPLHintUtils.ignoreNullBucket(aggregate); + boolean udafPushdownEnabled = + osIndex.getSettings().getSettingValue(Settings.Key.CALCITE_UDAF_PUSHDOWN_ENABLED); AggregateAnalyzer.AggregateBuilderHelper helper = new AggregateAnalyzer.AggregateBuilderHelper( - getRowType(), fieldTypes, getCluster(), bucketNullable, queryBucketSize); + getRowType(), + fieldTypes, + getCluster(), + bucketNullable, + queryBucketSize, + udafPushdownEnabled); final Pair, OpenSearchAggregationResponseParser> builderAndParser = AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); Map extendedTypeMapping = From 006b9fd0362bb1704fc347b76fddc36c6a2f98e3 Mon Sep 17 00:00:00 2001 From: Songkan Tang Date: Fri, 23 Jan 2026 13:38:27 +0800 Subject: [PATCH 6/9] Fix compileTestJava failure Signed-off-by: Songkan Tang --- .../request/AggregateAnalyzerTest.java | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java index 2e79da953b1..de66ddd8338 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java @@ -153,7 +153,8 @@ void analyze_aggCall_simple() throws ExpressionNotAnalyzableException { List.of(countCall, avgCall, sumCall, minCall, maxCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(0)); AggregateAnalyzer.AggregateBuilderHelper helper = - new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, null, true, BUCKET_SIZE, false); Pair, OpenSearchAggregationResponseParser> result = AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); assertEquals( @@ -236,7 +237,8 @@ void analyze_aggCall_extended() throws ExpressionNotAnalyzableException { List.of(varSampCall, varPopCall, stddevSampCall, stddevPopCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(0)); AggregateAnalyzer.AggregateBuilderHelper helper = - new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, null, true, BUCKET_SIZE, false); Pair, OpenSearchAggregationResponseParser> result = AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); assertEquals( @@ -277,7 +279,8 @@ void analyze_groupBy() throws ExpressionNotAnalyzableException { Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of(0, 1)); Project project = createMockProject(List.of(0, 1)); AggregateAnalyzer.AggregateBuilderHelper helper = - new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, null, true, BUCKET_SIZE, false); Pair, OpenSearchAggregationResponseParser> result = AggregateAnalyzer.analyze(aggregate, project, outputFields, helper); @@ -318,7 +321,8 @@ void analyze_aggCall_TextWithoutKeyword() { Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of()); Project project = createMockProject(List.of(2)); AggregateAnalyzer.AggregateBuilderHelper helper = - new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, null, true, BUCKET_SIZE, false); ExpressionNotAnalyzableException exception = assertThrows( ExpressionNotAnalyzableException.class, @@ -345,7 +349,8 @@ void analyze_groupBy_TextWithoutKeyword() { Aggregate aggregate = createMockAggregate(List.of(aggCall), ImmutableBitSet.of(0)); Project project = createMockProject(List.of(2)); AggregateAnalyzer.AggregateBuilderHelper helper = - new AggregateAnalyzer.AggregateBuilderHelper(rowType, fieldTypes, null, true, BUCKET_SIZE); + new AggregateAnalyzer.AggregateBuilderHelper( + rowType, fieldTypes, null, true, BUCKET_SIZE, false); ExpressionNotAnalyzableException exception = assertThrows( ExpressionNotAnalyzableException.class, @@ -699,7 +704,7 @@ void verify() throws ExpressionNotAnalyzableException { } AggregateAnalyzer.AggregateBuilderHelper helper = new AggregateAnalyzer.AggregateBuilderHelper( - rowType, fieldTypes, agg.getCluster(), true, BUCKET_SIZE); + rowType, fieldTypes, agg.getCluster(), true, BUCKET_SIZE, false); Pair, OpenSearchAggregationResponseParser> result = AggregateAnalyzer.analyze(agg, project, outputFields, helper); From 222b56c00dc57a5c148e12a5e56da41e87d1a7f0 Mon Sep 17 00:00:00 2001 From: Songkan Tang Date: Fri, 23 Jan 2026 13:51:18 +0800 Subject: [PATCH 7/9] Update patterns.md doc with new udaf_pushdown setting explanation Signed-off-by: Songkan Tang --- docs/user/ppl/cmd/patterns.md | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/docs/user/ppl/cmd/patterns.md b/docs/user/ppl/cmd/patterns.md index 6941efbe4f1..8e19b1f18fc 100644 --- a/docs/user/ppl/cmd/patterns.md +++ b/docs/user/ppl/cmd/patterns.md @@ -67,7 +67,7 @@ The `brain` method accepts the following parameters. By default, the Apache Calcite engine labels variables using the `<*>` placeholder. If the `show_numbered_token` option is enabled, the Calcite engine's `label` mode not only labels the text pattern but also assigns numbered placeholders to variable tokens. In `aggregation` mode, it outputs both the labeled pattern and the variable tokens for each pattern. In this case, variable placeholders use the format `` instead of `<*>`. -## Changing the default pattern method +## Changing the default pattern method To override default pattern parameters, run the following command: @@ -83,7 +83,26 @@ PUT _cluster/settings } } ``` - + +## Enabling UDAF pushdown for patterns aggregation + +When using the `patterns` command with `mode=aggregation` and `method=brain`, the aggregation can optionally be pushed down to OpenSearch as a scripted metric aggregation for parallel execution across data nodes. This can improve performance for large datasets but uses scripted metric aggregations which lack circuit breaker protection. + +By default, UDAF pushdown is **disabled**. To enable it, run the following command: + +```bash ignore +PUT _cluster/settings +{ + "persistent": { + "plugins.calcite.udaf_pushdown.enabled": true + } +} +``` + +> **Warning**: Enabling UDAF pushdown executes user-defined aggregation functions as scripted metric aggregations on OpenSearch data nodes. This bypasses certain memory circuit breakers and may cause out-of-memory errors on nodes when processing very large datasets. Use with caution and monitor cluster resource usage. + +When UDAF pushdown is disabled (the default), the pattern aggregation runs locally on the coordinator node after fetching the data from OpenSearch. + ## Simple pattern examples The following are examples of using the `simple_pattern` method. From 5c9fdb9468780e75af366f9b0114266e72ab5f88 Mon Sep 17 00:00:00 2001 From: Songkan Tang Date: Fri, 23 Jan 2026 15:12:23 +0800 Subject: [PATCH 8/9] Minor fixes and refactoring Signed-off-by: Songkan Tang --- .../patterns/PatternAggregationHelpers.java | 13 +++++++- .../function/PatternParserFunctionImpl.java | 2 +- docs/user/ppl/cmd/patterns.md | 2 +- .../calcite/remote/CalcitePPLPatternsIT.java | 2 -- .../opensearch/request/AggregateAnalyzer.java | 2 +- .../storage/script/CalciteScriptEngine.java | 8 ++--- ...iteScriptedMetricCombineScriptFactory.java | 7 +--- ...alciteScriptedMetricInitScriptFactory.java | 21 ++---------- ...CalciteScriptedMetricMapScriptFactory.java | 20 +++--------- ...citeScriptedMetricReduceScriptFactory.java | 7 +--- .../ScriptedMetricDataContext.java | 32 +++++++++++++++++-- .../storage/serde/ScriptParameterHelper.java | 5 ++- 12 files changed, 62 insertions(+), 59 deletions(-) diff --git a/common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java b/common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java index fae4b61c1eb..8ede948649f 100644 --- a/common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java +++ b/common/src/main/java/org/opensearch/sql/common/patterns/PatternAggregationHelpers.java @@ -268,8 +268,19 @@ public static Map combinePatternAccumulators( Map> merged = PatternUtils.mergePatternGroups(patterns1, patterns2, maxSampleCount); + // Merge logMessages from both accumulators to preserve buffered messages + List logMessages1 = (List) acc1.get("logMessages"); + List logMessages2 = (List) acc2.get("logMessages"); + List mergedLogMessages = new ArrayList<>(); + if (logMessages1 != null) { + mergedLogMessages.addAll(logMessages1); + } + if (logMessages2 != null) { + mergedLogMessages.addAll(logMessages2); + } + Map result = new HashMap<>(); - result.put("logMessages", new ArrayList<>()); + result.put("logMessages", mergedLogMessages); result.put("patternGroupMap", merged); return result; } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java b/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java index 2c0d9c7e653..1e7aa80617b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PatternParserFunctionImpl.java @@ -137,7 +137,7 @@ public static Object evalAgg( Map> tokensMap = new HashMap<>(); String outputPattern = bestCandidatePattern; // Default: return as-is - if (showNumberedToken) { + if (Boolean.TRUE.equals(showNumberedToken)) { // Parse pattern with wildcard format (<*>, <*IP*>, etc.) // LogPatternAggFunction.value() returns patterns in wildcard format ParseResult parseResult = diff --git a/docs/user/ppl/cmd/patterns.md b/docs/user/ppl/cmd/patterns.md index 8e19b1f18fc..c880b08baf1 100644 --- a/docs/user/ppl/cmd/patterns.md +++ b/docs/user/ppl/cmd/patterns.md @@ -13,7 +13,7 @@ The `patterns` command supports the following modes: The command identifies variable parts of log messages (such as timestamps, numbers, IP addresses, and unique identifiers) and replaces them with `<*>` placeholders to create reusable patterns. For example, email addresses like `amberduke@pyrami.com` and `hattiebond@netagy.com` are replaced with the pattern `<*>@<*>.<*>`. -> **Note**: The `patterns` command is not executed on OpenSearch data nodes. It only groups log patterns from log messages that have been returned to the coordinator node. +> **Note**: By default, the `patterns` command is not executed on OpenSearch data nodes. It only groups log patterns from log messages that have been returned to the coordinator node. However, when using `mode=aggregation` with `method=brain` and the `plugins.calcite.udaf_pushdown.enabled` cluster setting is set to `true`, the aggregation may be pushed down and executed on data nodes as a scripted metric aggregation for improved performance. See [Enabling UDAF pushdown for patterns aggregation](#enabling-udaf-pushdown-for-patterns-aggregation) for more details. ## Syntax diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java index 981302aadfe..74e3f8a9958 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLPatternsIT.java @@ -547,7 +547,6 @@ public void testBrainAggregationMode_UDAFPushdown_NotShowNumberedToken() throws "source=%s | patterns content method=brain mode=aggregation" + " variable_count_threshold=5", TEST_INDEX_HDFS_LOGS)); - System.out.println(result.toString()); // Verify schema matches expected output verifySchema( @@ -613,7 +612,6 @@ public void testBrainAggregationMode_UDAFPushdown_ShowNumberedToken() throws IOE "source=%s | patterns content method=brain mode=aggregation" + " show_numbered_token=true variable_count_threshold=5", TEST_INDEX_HDFS_LOGS)); - System.out.println(result.toString()); // Verify schema includes tokens field verifySchema( diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java index e4149c41ddd..d5c31418bd5 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java @@ -617,7 +617,7 @@ yield switch (functionName) { new SingleValueParser(aggName)); case INTERNAL_PATTERN -> { if (!helper.udafPushdownEnabled) { - throw new UnsupportedOperationException( + throw new AggregateAnalyzerException( "UDAF pushdown is disabled. Enable it via cluster setting" + " 'plugins.calcite.udaf_pushdown.enabled'"); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java index dc9a3f87180..7e2c3231e74 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/CalciteScriptEngine.java @@ -120,16 +120,16 @@ public CalciteScriptEngine(RelOptCluster relOptCluster) { .put(FieldScript.CONTEXT, CalciteFieldScriptFactory::new) .put( ScriptedMetricAggContexts.InitScript.CONTEXT, - CalciteScriptedMetricInitScriptFactory::new) + (func, type) -> new CalciteScriptedMetricInitScriptFactory(func)) .put( ScriptedMetricAggContexts.MapScript.CONTEXT, - CalciteScriptedMetricMapScriptFactory::new) + (func, type) -> new CalciteScriptedMetricMapScriptFactory(func)) .put( ScriptedMetricAggContexts.CombineScript.CONTEXT, - CalciteScriptedMetricCombineScriptFactory::new) + (func, type) -> new CalciteScriptedMetricCombineScriptFactory(func)) .put( ScriptedMetricAggContexts.ReduceScript.CONTEXT, - CalciteScriptedMetricReduceScriptFactory::new) + (func, type) -> new CalciteScriptedMetricReduceScriptFactory(func)) .build(); @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricCombineScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricCombineScriptFactory.java index 9887c9d21d4..48a9300bcb9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricCombineScriptFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricCombineScriptFactory.java @@ -9,7 +9,6 @@ import lombok.RequiredArgsConstructor; import org.apache.calcite.DataContext; import org.apache.calcite.linq4j.function.Function1; -import org.apache.calcite.rel.type.RelDataType; import org.opensearch.script.ScriptedMetricAggContexts; /** @@ -21,12 +20,11 @@ public class CalciteScriptedMetricCombineScriptFactory implements ScriptedMetricAggContexts.CombineScript.Factory { private final Function1 function; - private final RelDataType outputType; @Override public ScriptedMetricAggContexts.CombineScript newInstance( Map params, Map state) { - return new CalciteScriptedMetricCombineScript(function, outputType, params, state); + return new CalciteScriptedMetricCombineScript(function, params, state); } /** CombineScript that executes compiled RexNode expression. */ @@ -34,16 +32,13 @@ private static class CalciteScriptedMetricCombineScript extends ScriptedMetricAggContexts.CombineScript { private final Function1 function; - private final RelDataType outputType; public CalciteScriptedMetricCombineScript( Function1 function, - RelDataType outputType, Map params, Map state) { super(params, state); this.function = function; - this.outputType = outputType; } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricInitScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricInitScriptFactory.java index 537faebe446..13d9dcbbbbf 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricInitScriptFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricInitScriptFactory.java @@ -9,7 +9,6 @@ import lombok.RequiredArgsConstructor; import org.apache.calcite.DataContext; import org.apache.calcite.linq4j.function.Function1; -import org.apache.calcite.rel.type.RelDataType; import org.opensearch.script.ScriptedMetricAggContexts; /** @@ -21,12 +20,11 @@ public class CalciteScriptedMetricInitScriptFactory implements ScriptedMetricAggContexts.InitScript.Factory { private final Function1 function; - private final RelDataType outputType; @Override public ScriptedMetricAggContexts.InitScript newInstance( Map params, Map state) { - return new CalciteScriptedMetricInitScript(function, outputType, params, state); + return new CalciteScriptedMetricInitScript(function, params, state); } /** InitScript that executes compiled RexNode expression. */ @@ -34,16 +32,13 @@ private static class CalciteScriptedMetricInitScript extends ScriptedMetricAggContexts.InitScript { private final Function1 function; - private final RelDataType outputType; public CalciteScriptedMetricInitScript( Function1 function, - RelDataType outputType, Map params, Map state) { super(params, state); this.function = function; - this.outputType = outputType; } @Override @@ -53,19 +48,9 @@ public void execute() { Map state = (Map) getState(); DataContext dataContext = new ScriptedMetricDataContext.InitContext(getParams(), state); - // Execute the compiled RexNode expression + // Execute the compiled RexNode expression and merge result into state Object[] result = function.apply(dataContext); - - // Store result in state - if (result != null && result.length > 0) { - // The init script typically initializes the state - // Result should be the initialized accumulator - if (result[0] instanceof Map) { - ((Map) getState()).putAll((Map) result[0]); - } else { - ((Map) getState()).put("accumulator", result[0]); - } - } + ScriptedMetricDataContext.mergeResultIntoState(result, state); } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricMapScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricMapScriptFactory.java index 954221fbef9..70128748a8f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricMapScriptFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricMapScriptFactory.java @@ -9,7 +9,6 @@ import lombok.RequiredArgsConstructor; import org.apache.calcite.DataContext; import org.apache.calcite.linq4j.function.Function1; -import org.apache.calcite.rel.type.RelDataType; import org.apache.lucene.index.LeafReaderContext; import org.opensearch.script.ScriptedMetricAggContexts; import org.opensearch.search.lookup.SearchLookup; @@ -23,12 +22,11 @@ public class CalciteScriptedMetricMapScriptFactory implements ScriptedMetricAggContexts.MapScript.Factory { private final Function1 function; - private final RelDataType outputType; @Override public ScriptedMetricAggContexts.MapScript.LeafFactory newFactory( Map params, Map state, SearchLookup lookup) { - return new CalciteMapScriptLeafFactory(function, outputType, params, state, lookup); + return new CalciteMapScriptLeafFactory(function, params, state, lookup); } /** Leaf factory that creates MapScript instances for each segment. */ @@ -37,14 +35,13 @@ private static class CalciteMapScriptLeafFactory implements ScriptedMetricAggContexts.MapScript.LeafFactory { private final Function1 function; - private final RelDataType outputType; private final Map params; private final Map state; private final SearchLookup lookup; @Override public ScriptedMetricAggContexts.MapScript newInstance(LeafReaderContext ctx) { - return new CalciteScriptedMetricMapScript(function, outputType, params, state, lookup, ctx); + return new CalciteScriptedMetricMapScript(function, params, state, lookup, ctx); } } @@ -67,7 +64,6 @@ private static class CalciteScriptedMetricMapScript extends ScriptedMetricAggCon public CalciteScriptedMetricMapScript( Function1 function, - RelDataType outputType, Map params, Map state, SearchLookup lookup, @@ -82,19 +78,11 @@ public CalciteScriptedMetricMapScript( } @Override + @SuppressWarnings("unchecked") public void execute() { // Execute the compiled RexNode expression (reusing the same DataContext) Object[] result = function.apply(dataContext); - - // Update state with result - if (result != null && result.length > 0) { - // The map script typically updates the accumulator - if (result[0] instanceof Map) { - ((Map) getState()).putAll((Map) result[0]); - } else { - ((Map) getState()).put("accumulator", result[0]); - } - } + ScriptedMetricDataContext.mergeResultIntoState(result, (Map) getState()); } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricReduceScriptFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricReduceScriptFactory.java index 9a0d8797fe0..cb8a93cdecf 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricReduceScriptFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/CalciteScriptedMetricReduceScriptFactory.java @@ -10,7 +10,6 @@ import lombok.RequiredArgsConstructor; import org.apache.calcite.DataContext; import org.apache.calcite.linq4j.function.Function1; -import org.apache.calcite.rel.type.RelDataType; import org.opensearch.script.ScriptedMetricAggContexts; /** @@ -22,12 +21,11 @@ public class CalciteScriptedMetricReduceScriptFactory implements ScriptedMetricAggContexts.ReduceScript.Factory { private final Function1 function; - private final RelDataType outputType; @Override public ScriptedMetricAggContexts.ReduceScript newInstance( Map params, List states) { - return new CalciteScriptedMetricReduceScript(function, outputType, params, states); + return new CalciteScriptedMetricReduceScript(function, params, states); } /** ReduceScript that executes compiled RexNode expression. */ @@ -35,16 +33,13 @@ private static class CalciteScriptedMetricReduceScript extends ScriptedMetricAggContexts.ReduceScript { private final Function1 function; - private final RelDataType outputType; public CalciteScriptedMetricReduceScript( Function1 function, - RelDataType outputType, Map params, List states) { super(params, states); this.function = function; - this.outputType = outputType; } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricDataContext.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricDataContext.java index b8306696b34..f95feff6f65 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricDataContext.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/scriptedmetric/ScriptedMetricDataContext.java @@ -59,19 +59,47 @@ public QueryProvider getQueryProvider() { return null; } + /** + * Merges the execution result into the state map. This is a common operation used in init_script + * and map_script phases to update the accumulator state. + * + *

If the result is a Map, its entries are merged into the state. Otherwise, the result is + * stored under the "accumulator" key. + * + * @param result The result array from function execution (may be null or empty) + * @param state The state map to update + */ + @SuppressWarnings("unchecked") + public static void mergeResultIntoState(Object[] result, Map state) { + if (result != null && result.length > 0) { + if (result[0] instanceof Map) { + state.putAll((Map) result[0]); + } else { + state.put("accumulator", result[0]); + } + } + } + /** * Parse dynamic parameter index from name pattern "?N". * * @param name The parameter name (expected format: "?0", "?1", etc.) * @return The parameter index - * @throws IllegalArgumentException if name doesn't match expected pattern + * @throws IllegalArgumentException if name doesn't match expected pattern or is malformed */ protected int parseDynamicParamIndex(String name) { if (!name.startsWith("?")) { throw new IllegalArgumentException( "Unexpected parameter name format: " + name + ". Expected '?N' pattern."); } - int index = Integer.parseInt(name.substring(1)); + int index; + try { + index = Integer.parseInt(name.substring(1)); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Malformed parameter name '" + name + "'. Expected '?N' pattern where N is an integer.", + e); + } if (index >= sources.size()) { throw new IllegalArgumentException( "Parameter index " + index + " out of bounds. Sources size: " + sources.size()); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java index 4915cd63827..aa6d2de0f4d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/serde/ScriptParameterHelper.java @@ -42,9 +42,12 @@ public class ScriptParameterHelper { * *

0 stands for DOC_VALUE * - *

1 stand for SOURCE + *

1 stands for SOURCE * *

2 stands for LITERAL + * + *

3 stands for SPECIAL_VARIABLE - retrieves value from special context variables (e.g., state, + * states in scripted metric aggregations) */ List sources; From 00049de6120192d629ca7dc501d93736211a3e3a Mon Sep 17 00:00:00 2001 From: Songkan Tang Date: Fri, 23 Jan 2026 16:13:45 +0800 Subject: [PATCH 9/9] Address minor feedbacks Signed-off-by: Songkan Tang --- .../udf/udaf/LogPatternAggFunction.java | 28 ++++++++++++++--- .../value/OpenSearchExprValueFactory.java | 31 ++++++++++++------- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java index dc28fd296b1..ff207912e35 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java +++ b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/LogPatternAggFunction.java @@ -40,8 +40,8 @@ public Object result(LogParserAccumulator acc) { if (acc.isEmpty()) { return null; } - return PatternAggregationHelpers.producePatternResult( - acc.state, maxSampleCount, variableCountThreshold, thresholdPercentage, showNumberedToken); + return acc.value( + maxSampleCount, variableCountThreshold, thresholdPercentage, showNumberedToken); } @Override @@ -169,9 +169,27 @@ public boolean isEmpty() { @Override public Object value(Object... argList) { - // This method is not used directly - result() in LogPatternAggFunction handles this - throw new UnsupportedOperationException( - "Use LogPatternAggFunction.result() instead of direct value() call"); + // Return the current state for use by LogPatternAggFunction.result() + // The argList contains [maxSampleCount, variableCountThreshold, thresholdPercentage, + // showNumberedToken] + if (isEmpty()) { + return null; + } + int maxSampleCount = + argList.length > 0 && argList[0] != null ? ((Number) argList[0]).intValue() : 10; + int variableCountThreshold = + argList.length > 1 && argList[1] != null + ? ((Number) argList[1]).intValue() + : BrainLogParser.DEFAULT_VARIABLE_COUNT_THRESHOLD; + double thresholdPercentage = + argList.length > 2 && argList[2] != null + ? ((Number) argList[2]).doubleValue() + : BrainLogParser.DEFAULT_FREQUENCY_THRESHOLD_PERCENTAGE; + boolean showNumberedToken = + argList.length > 3 && argList[3] != null && Boolean.TRUE.equals(argList[3]); + + return PatternAggregationHelpers.producePatternResult( + state, maxSampleCount, variableCountThreshold, thresholdPercentage, showNumberedToken); } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java index e66efd090af..35ea858cf9c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/value/OpenSearchExprValueFactory.java @@ -494,24 +494,31 @@ public JsonPath getChildPath() { */ private ExprValue parseArray( Content content, String prefix, ExprType type, boolean supportArrays) { - List result = new ArrayList<>(); - // ARRAY is mapped to nested but can take the json structure of an Object. if (content.objectValue() instanceof ObjectNode) { + List result = new ArrayList<>(); result.add(parseStruct(content, prefix, supportArrays)); - // non-object type arrays are only supported when parsing inner_hits of OS response. - } else if (!(type instanceof OpenSearchDataType + return new ExprCollectionValue(result); + } + + // Get the array iterator once and reuse it + var arrayIterator = content.array(); + + // Handle empty arrays early + if (!arrayIterator.hasNext()) { + return supportArrays ? new ExprCollectionValue(List.of()) : ExprNullValue.of(); + } + + // non-object type arrays are only supported when parsing inner_hits of OS response. + if (!(type instanceof OpenSearchDataType && ((OpenSearchDataType) type).getExprType().equals(ARRAY)) && !supportArrays) { - return parseInnerArrayValue(content.array().next(), prefix, type, supportArrays); - } else { - content - .array() - .forEachRemaining( - v -> { - result.add(parseInnerArrayValue(v, prefix, type, supportArrays)); - }); + return parseInnerArrayValue(arrayIterator.next(), prefix, type, supportArrays); } + + List result = new ArrayList<>(); + arrayIterator.forEachRemaining( + v -> result.add(parseInnerArrayValue(v, prefix, type, supportArrays))); return new ExprCollectionValue(result); }