Skip to content

Commit 7ee01a8

Browse files
author
Peng Ren
committed
Add function support in projection
1 parent e9c9cfc commit 7ee01a8

15 files changed

+3979
-2376
lines changed

pymongosql/result_set.py

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,63 @@ def _build_description(self) -> None:
7575
self._description = None
7676
return
7777

78-
# Build description from projection (now in MongoDB format {field: 1})
78+
# Build description from projection output if available
7979
description = []
8080
column_aliases = getattr(self._execution_plan, "column_aliases", {})
81+
projection_functions = getattr(self._execution_plan, "projection_functions", {})
82+
projection_output = getattr(self._execution_plan, "projection_output", None)
83+
84+
if projection_output:
85+
for item in projection_output:
86+
output_name = item.get("output_name")
87+
func_info = item.get("function")
88+
89+
display_name = output_name
90+
type_code = str
91+
92+
if func_info:
93+
func_name = func_info.get("name")
94+
if func_name:
95+
from .sql.projection_functions import ProjectionFunctionRegistry
96+
97+
registry = ProjectionFunctionRegistry()
98+
for func in registry.get_all_functions():
99+
if func.function_name.upper() == str(func_name).upper():
100+
type_code = func.get_type_code()
101+
break
102+
103+
description.append((display_name, type_code, None, None, None, None, None))
104+
105+
self._description = description
106+
return
81107

82108
for field_name, include_flag in self._execution_plan.projection_stage.items():
83109
# SQL cursor description format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
84110
if include_flag == 1: # Field is included in projection
85111
# Use alias if available, otherwise use field name
86112
display_name = column_aliases.get(field_name, field_name)
87-
description.append((display_name, str, None, None, None, None, None))
113+
114+
# Determine type code based on projection function if present
115+
type_code = str
116+
if field_name in projection_functions:
117+
func_info = projection_functions[field_name]
118+
func_name = None
119+
if isinstance(func_info, dict):
120+
func_name = func_info.get("name")
121+
elif isinstance(func_info, (list, tuple)) and len(func_info) > 0:
122+
func_name = func_info[0]
123+
124+
if func_name:
125+
from .sql.projection_functions import ProjectionFunctionRegistry
126+
127+
registry = ProjectionFunctionRegistry()
128+
# Find function by name
129+
for func in registry.get_all_functions():
130+
if func.function_name.upper() == str(func_name).upper():
131+
type_code = func.get_type_code()
132+
break
133+
134+
description.append((display_name, type_code, None, None, None, None, None))
88135

89136
self._description = description
90137

@@ -129,17 +176,73 @@ def _ensure_results_available(self, count: int = 1) -> None:
129176
self._cache_exhausted = True
130177

131178
def _process_document(self, doc: Dict[str, Any]) -> Dict[str, Any]:
132-
"""Process a MongoDB document according to projection mapping"""
179+
"""Process a MongoDB document according to projection mapping and apply projection functions"""
133180
if not self._execution_plan.projection_stage:
134181
# No projection, return document as-is (including _id)
135182
return dict(doc)
136183

137184
# Apply projection mapping (now using MongoDB format {field: 1})
138185
processed = {}
186+
projection_functions = getattr(self._execution_plan, "projection_functions", {})
187+
projection_output = getattr(self._execution_plan, "projection_output", None)
188+
189+
from .sql.projection_functions import ProjectionFunctionRegistry
190+
191+
registry = ProjectionFunctionRegistry()
192+
193+
if projection_output:
194+
for item in projection_output:
195+
output_name = item.get("output_name")
196+
source_field = item.get("source_field")
197+
func_info = item.get("function")
198+
199+
value = self._get_nested_value(doc, source_field)
200+
201+
if func_info:
202+
func_name = func_info.get("name") if isinstance(func_info, dict) else None
203+
format_param = func_info.get("format_param") if isinstance(func_info, dict) else None
204+
if func_name:
205+
func_handler = registry.find_function(f"{func_name}(x)")
206+
if func_handler:
207+
value = func_handler.convert_value(value, format_param)
208+
209+
if output_name == source_field:
210+
display_key = self._mongo_to_bracket_key(source_field)
211+
else:
212+
display_key = output_name
213+
214+
processed[display_key] = value
215+
216+
return processed
217+
139218
for field_name, include_flag in self._execution_plan.projection_stage.items():
140219
if include_flag == 1: # Field is included in projection
141220
# Extract value using jmespath-compatible field path (convert numeric dot indexes to bracket form)
142221
value = self._get_nested_value(doc, field_name)
222+
223+
# Apply projection function if present
224+
if field_name in projection_functions:
225+
func_info = projection_functions[field_name]
226+
func_name = None
227+
format_param = None
228+
if isinstance(func_info, dict):
229+
func_name = func_info.get("name")
230+
format_param = func_info.get("format_param")
231+
elif isinstance(func_info, (list, tuple)):
232+
if len(func_info) >= 2:
233+
func_name = func_info[0]
234+
if len(func_info) == 2:
235+
format_param = func_info[1]
236+
elif len(func_info) > 2:
237+
extra_params = func_info[2:]
238+
if extra_params:
239+
format_param = ",".join([str(p) for p in extra_params])
240+
241+
if func_name:
242+
func_handler = registry.find_function(f"{func_name}(x)")
243+
if func_handler:
244+
value = func_handler.convert_value(value, format_param)
245+
143246
# Convert the projection key back to bracket notation for client-facing results
144247
display_key = self._mongo_to_bracket_key(field_name)
145248
processed[display_key] = value

pymongosql/sql/builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ def _build_query_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
124124
builder._execution_plan.aggregate_pipeline = parse_result.aggregate_pipeline
125125
builder._execution_plan.aggregate_options = parse_result.aggregate_options
126126

127+
# Set projection functions if present
128+
if hasattr(parse_result, "projection_functions"):
129+
builder._execution_plan.projection_functions = parse_result.projection_functions
130+
131+
# Set ordered projection outputs if present
132+
if hasattr(parse_result, "projection_output"):
133+
builder._execution_plan.projection_output = parse_result.projection_output
134+
127135
# Now build and validate
128136
plan = builder.build()
129137
return plan

pymongosql/sql/partiql/PartiQLLexer.g4

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP';
7777
CURRENT_USER: 'CURRENT_USER';
7878
CURSOR: 'CURSOR';
7979
DATE: 'DATE';
80+
DATETIME: 'DATETIME';
8081
DEALLOCATE: 'DEALLOCATE';
8182
DEC: 'DEC';
8283
DECIMAL: 'DECIMAL';
@@ -168,6 +169,7 @@ NULL: 'NULL';
168169
NULLS: 'NULLS';
169170
NULLIF: 'NULLIF';
170171
NUMERIC: 'NUMERIC';
172+
NUMBER: 'NUMBER';
171173
OCTET_LENGTH: 'OCTET_LENGTH';
172174
OF: 'OF';
173175
ON: 'ON';

0 commit comments

Comments
 (0)