@@ -317,7 +317,9 @@ def execute_query(self, query_string: str) -> Any:
317317 else :
318318 res = table .select_all (limit = limit )
319319
320- # Apply column projection
320+ # Apply column projection or aggregates
321+ if self ._is_aggregate_query (columns ):
322+ return self ._apply_aggregates (res , columns , table )
321323 return table .project_columns (res , columns )
322324
323325 if cmd_type == 'DELETE' :
@@ -461,7 +463,7 @@ def execute_query(self, query_string: str) -> Any:
461463 }
462464
463465
464- except (DBError , TypeError ) as e :
466+ except (DBError , TypeError , ValueError ) as e :
465467 return f"Error: { e } "
466468 except Exception as e :
467469 return f"Unexpected Error: { e } "
@@ -474,6 +476,110 @@ def get_tables(self) -> List[str]:
474476 """
475477 return list (self .tables .keys ())
476478
479+ def _is_aggregate_query (self , columns_str : str ) -> bool :
480+ """Determines if a SELECT clause contains aggregate functions.
481+
482+ Args:
483+ columns_str: The columns portion of the SQL query.
484+
485+ Returns:
486+ bool: True if an aggregate function is detected.
487+ """
488+ aggr_funcs = ['COUNT' , 'SUM' , 'AVG' , 'MIN' , 'MAX' ]
489+ up_cols = columns_str .upper ()
490+ return any (func + '(' in up_cols for func in aggr_funcs )
491+
492+ def _apply_aggregates (self , rows : List [Dict [str , Any ]], columns_str : str , table_obj : 'Table' ) -> List [Dict [str , Any ]]:
493+ """Calculates SQL aggregates (SUM, AVG, etc.) in a single pass.
494+
495+ Args:
496+ rows: The result set after filtering.
497+ columns_str: The aggregate column specifications.
498+ table_obj: The target Table object for type verification.
499+
500+ Returns:
501+ List[Dict[str, Any]]: A list containing a single row with the aggregate results.
502+
503+ Raises:
504+ ValueError: If an aggregate is applied to an incompatible column type.
505+ """
506+ import re
507+ # Find all patterns like AGGR_FUNC(column)
508+ pattern = re .compile (r'(\w+)\(([\w\*]+)\)' , re .IGNORECASE )
509+ aggr_specs = pattern .findall (columns_str )
510+
511+ if not aggr_specs :
512+ return []
513+
514+ # Initialize accumulators
515+ accumulators = {}
516+ for func_raw , col in aggr_specs :
517+ func = func_raw .upper ()
518+ label = f"{ func } ({ col } )"
519+
520+ # Validation: SUM and AVG require numeric columns
521+ if func in ['SUM' , 'AVG' ]:
522+ col_type = table_obj .column_types .get (col )
523+ if col_type == 'str' :
524+ raise ValueError (f"Cannot compute { func } on non-numeric column '{ col } ' (type: STR)" )
525+
526+ if func == 'COUNT' :
527+ accumulators [label ] = {'sum' : 0 }
528+ elif func == 'SUM' :
529+ accumulators [label ] = {'sum' : 0 }
530+ elif func == 'AVG' :
531+ accumulators [label ] = {'sum' : 0 , 'count' : 0 }
532+ elif func == 'MIN' :
533+ accumulators [label ] = {'min' : float ('inf' )}
534+ elif func == 'MAX' :
535+ accumulators [label ] = {'max' : float ('-inf' )}
536+
537+ # Single pass execution
538+ total_filtered_rows = len (rows )
539+ for row in rows :
540+ for func_raw , col in aggr_specs :
541+ func = func_raw .upper ()
542+ label = f"{ func } ({ col } )"
543+
544+ if func == 'COUNT' :
545+ if col == '*' or row .get (col ) is not None :
546+ accumulators [label ]['sum' ] += 1
547+ continue
548+
549+ val = row .get (col )
550+ if val is not None :
551+ if func == 'SUM' :
552+ accumulators [label ]['sum' ] += val
553+ elif func == 'AVG' :
554+ accumulators [label ]['sum' ] += val
555+ accumulators [label ]['count' ] += 1
556+ elif func == 'MIN' :
557+ if val < accumulators [label ]['min' ]:
558+ accumulators [label ]['min' ] = val
559+ elif func == 'MAX' :
560+ if val > accumulators [label ]['max' ]:
561+ accumulators [label ]['max' ] = val
562+
563+ # Finalize results
564+ result_row = {}
565+ for func_raw , col in aggr_specs :
566+ func = func_raw .upper ()
567+ label = f"{ func } ({ col } )"
568+ acc = accumulators [label ]
569+
570+ if func == 'COUNT' :
571+ result_row [label ] = acc ['sum' ]
572+ elif func == 'SUM' :
573+ result_row [label ] = acc ['sum' ] if total_filtered_rows > 0 else 0
574+ elif func == 'AVG' :
575+ result_row [label ] = acc ['sum' ] / acc ['count' ] if acc .get ('count' , 0 ) > 0 else None
576+ elif func == 'MIN' :
577+ result_row [label ] = acc ['min' ] if acc ['min' ] != float ('inf' ) else None
578+ elif func == 'MAX' :
579+ result_row [label ] = acc ['max' ] if acc ['max' ] != float ('-inf' ) else None
580+
581+ return [result_row ]
582+
477583 def _nested_loop_join (self , left_rows : List [Dict [str , Any ]], right_rows : List [Dict [str , Any ]],
478584 left_on : Tuple [str , str ], right_on : Tuple [str , str ]) -> List [Dict [str , Any ]]:
479585 """Performs a simple Nested Loop Join. Complexity: O(N*M).
0 commit comments