diff --git a/grammar/Expr.g4 b/grammar/Expr.g4 index 072bf52e..c49a5a02 100644 --- a/grammar/Expr.g4 +++ b/grammar/Expr.g4 @@ -26,10 +26,9 @@ expr | expr op=('+' | '-') expr # addsub | expr COMPARISON expr # comparison | IDENTIFIER '(' expr ')' # function - | IDENTIFIER '[' shift (',' shift)* ']' # timeShift - | IDENTIFIER '[' expr (',' expr )* ']' # timeIndex - | IDENTIFIER '[' shift1=shift '..' shift2=shift ']' # timeShiftRange - | IDENTIFIER '[' expr '..' expr ']' # timeRange + | IDENTIFIER '[' shift ']' # timeShift + | IDENTIFIER '[' expr ']' # timeIndex + | TIME_SUM '(' ((expr | shift | timeShiftRange | timeRange) ',')? IDENTIFIER ')' #timeSum ; atom @@ -37,6 +36,9 @@ atom | IDENTIFIER # identifier ; +timeShiftRange: shift1=shift '..' shift2=shift; +timeRange: expr1=expr '..' expr2=expr; + // a shift is required to be either "t" or "t + ..." or "t - ..." // Note: simply defining it as "shift: TIME ('+' | '-') expr" won't work // because the minus sign will not have the expected precedence: @@ -74,5 +76,6 @@ NUMBER : DIGIT+ ('.' DIGIT+)?; TIME : 't'; IDENTIFIER : CHAR CHAR_OR_DIGIT*; COMPARISON : ( '=' | '>=' | '<=' ); +TIME_SUM : 'sum'; WS: (' ' | '\t' | '\r'| '\n') -> skip; diff --git a/requirements-dev.txt b/requirements-dev.txt index 647280ff..4f090b4c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,3 +8,4 @@ types-PyYAML~=6.0.12.12 antlr4-tools~=0.2.1 pandas~=2.0.3 pandas-stubs<=2.0.3 +types-PyYAML~=6.0.12 diff --git a/src/andromede/expression/__init__.py b/src/andromede/expression/__init__.py index 2fe9b94d..f0723175 100644 --- a/src/andromede/expression/__init__.py +++ b/src/andromede/expression/__init__.py @@ -11,29 +11,28 @@ # This file is part of the Antares project. from .copy import CopyVisitor, copy_expression -from .degree import ExpressionDegreeVisitor, compute_degree -from .evaluate import EvaluationContext, EvaluationVisitor, ValueProvider, evaluate -from .evaluate_parameters import ( - ParameterResolver, - ParameterValueProvider, - resolve_parameters, -) +from .evaluate_parameters import ValueProvider from .expression import ( AdditionNode, - Comparator, ComparisonNode, + ComponentParameterNode, DivisionNode, ExpressionNode, + ExpressionRange, + InstancesTimeIndex, LiteralNode, MultiplicationNode, NegationNode, ParameterNode, + PortFieldAggregatorNode, + PortFieldNode, + ScenarioOperatorName, + ScenarioOperatorNode, SubstractionNode, - VariableNode, - literal, - param, - sum_expressions, - var, + TimeAggregatorName, + TimeAggregatorNode, + TimeOperatorName, + TimeOperatorNode, ) from .print import PrinterVisitor, print_expr from .visitor import ExpressionVisitor, visit diff --git a/src/andromede/expression/context_adder.py b/src/andromede/expression/context_adder.py index 812e95f7..5480a632 100644 --- a/src/andromede/expression/context_adder.py +++ b/src/andromede/expression/context_adder.py @@ -13,13 +13,7 @@ from dataclasses import dataclass from . import CopyVisitor -from .expression import ( - ComponentParameterNode, - ComponentVariableNode, - ExpressionNode, - ParameterNode, - VariableNode, -) +from .expression import ComponentParameterNode, ExpressionNode, ParameterNode from .visitor import visit @@ -32,21 +26,10 @@ class ContextAdder(CopyVisitor): component_id: str - def variable(self, node: VariableNode) -> ExpressionNode: - return ComponentVariableNode(self.component_id, node.name) - def parameter(self, node: ParameterNode) -> ExpressionNode: return ComponentParameterNode(self.component_id, node.name) - def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode: - raise ValueError( - "This expression has already been associated to another component." - ) - - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: - raise ValueError( - "This expression has already been associated to another component." - ) + # Nothing is done is a component parameter node is encountered as it may have been generated from port resolution def add_component_context(id: str, expression: ExpressionNode) -> ExpressionNode: diff --git a/src/andromede/expression/copy.py b/src/andromede/expression/copy.py index c135ee59..677aaf15 100644 --- a/src/andromede/expression/copy.py +++ b/src/andromede/expression/copy.py @@ -11,30 +11,23 @@ # This file is part of the Antares project. from dataclasses import dataclass -from typing import List, Union, cast +from typing import List, cast from .expression import ( - AdditionNode, ComparisonNode, ComponentParameterNode, - ComponentVariableNode, - DivisionNode, ExpressionNode, ExpressionRange, InstancesTimeIndex, LiteralNode, - MultiplicationNode, - NegationNode, ParameterNode, PortFieldAggregatorNode, PortFieldNode, ScenarioOperatorNode, - SubstractionNode, TimeAggregatorNode, TimeOperatorNode, - VariableNode, ) -from .visitor import ExpressionVisitor, ExpressionVisitorOperations, T, visit +from .visitor import ExpressionVisitorOperations, visit @dataclass(frozen=True) @@ -51,15 +44,9 @@ def comparison(self, node: ComparisonNode) -> ExpressionNode: visit(node.left, self), visit(node.right, self), node.comparator ) - def variable(self, node: VariableNode) -> ExpressionNode: - return VariableNode(node.name) - def parameter(self, node: ParameterNode) -> ExpressionNode: return ParameterNode(node.name) - def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode: - return ComponentVariableNode(node.component_id, node.name) - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: return ComponentParameterNode(node.component_id, node.name) @@ -69,9 +56,11 @@ def copy_expression_range( return ExpressionRange( start=visit(expression_range.start, self), stop=visit(expression_range.stop, self), - step=visit(expression_range.step, self) - if expression_range.step is not None - else None, + step=( + visit(expression_range.step, self) + if expression_range.step is not None + else None + ), ) def copy_instances_index( diff --git a/src/andromede/expression/degree.py b/src/andromede/expression/degree.py deleted file mode 100644 index cfd175cd..00000000 --- a/src/andromede/expression/degree.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -import andromede.expression.scenario_operator -from andromede.expression.expression import ( - ComponentParameterNode, - ComponentVariableNode, - PortFieldAggregatorNode, - PortFieldNode, - TimeOperatorNode, -) - -from .expression import ( - AdditionNode, - ComparisonNode, - DivisionNode, - ExpressionNode, - LiteralNode, - MultiplicationNode, - NegationNode, - ParameterNode, - ScenarioOperatorNode, - SubstractionNode, - TimeAggregatorNode, - VariableNode, -) -from .visitor import ExpressionVisitor, T, visit - - -class ExpressionDegreeVisitor(ExpressionVisitor[int]): - """ - Computes degree of expression with respect to variables. - """ - - def literal(self, node: LiteralNode) -> int: - return 0 - - def negation(self, node: NegationNode) -> int: - return visit(node.operand, self) - - # TODO: Take into account simplification that can occur with literal coefficient for add, sub, mult, div - def addition(self, node: AdditionNode) -> int: - return max(visit(node.left, self), visit(node.right, self)) - - def substraction(self, node: SubstractionNode) -> int: - return max(visit(node.left, self), visit(node.right, self)) - - def multiplication(self, node: MultiplicationNode) -> int: - return visit(node.left, self) + visit(node.right, self) - - def division(self, node: DivisionNode) -> int: - right_degree = visit(node.right, self) - if right_degree != 0: - raise ValueError("Degree computation not implemented for divisions.") - return visit(node.left, self) - - def comparison(self, node: ComparisonNode) -> int: - return max(visit(node.left, self), visit(node.right, self)) - - def variable(self, node: VariableNode) -> int: - return 1 - - def parameter(self, node: ParameterNode) -> int: - return 0 - - def comp_variable(self, node: ComponentVariableNode) -> int: - return 1 - - def comp_parameter(self, node: ComponentParameterNode) -> int: - return 0 - - def time_operator(self, node: TimeOperatorNode) -> int: - if node.name in ["TimeShift", "TimeEvaluation"]: - return visit(node.operand, self) - else: - return NotImplemented - - def time_aggregator(self, node: TimeAggregatorNode) -> int: - if node.name in ["TimeSum"]: - return visit(node.operand, self) - else: - return NotImplemented - - def scenario_operator(self, node: ScenarioOperatorNode) -> int: - scenario_operator_cls = getattr( - andromede.expression.scenario_operator, node.name - ) - # TODO: Carefully check if this formula is correct - return scenario_operator_cls.degree() * visit(node.operand, self) - - def port_field(self, node: PortFieldNode) -> int: - return 1 - - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> int: - return visit(node.operand, self) - - -def compute_degree(expression: ExpressionNode) -> int: - return visit(expression, ExpressionDegreeVisitor()) - - -def is_constant(expr: ExpressionNode) -> bool: - """ - True if the expression has no variable. - """ - return compute_degree(expr) == 0 - - -def is_linear(expr: ExpressionNode) -> bool: - """ - True if the expression is linear with respect to variables. - """ - return compute_degree(expr) <= 1 diff --git a/src/andromede/expression/equality.py b/src/andromede/expression/equality.py index a2deeb27..b6efebc0 100644 --- a/src/andromede/expression/equality.py +++ b/src/andromede/expression/equality.py @@ -14,25 +14,23 @@ from dataclasses import dataclass from typing import Optional -from andromede.expression import ( +from andromede.expression.expression import ( AdditionNode, + BinaryOperatorNode, ComparisonNode, + ComponentParameterNode, DivisionNode, ExpressionNode, + ExpressionRange, + InstancesTimeIndex, LiteralNode, MultiplicationNode, NegationNode, ParameterNode, - SubstractionNode, - VariableNode, -) -from andromede.expression.expression import ( - BinaryOperatorNode, - ExpressionRange, - InstancesTimeIndex, PortFieldAggregatorNode, PortFieldNode, ScenarioOperatorNode, + SubstractionNode, TimeAggregatorNode, TimeOperatorNode, ) @@ -72,10 +70,12 @@ def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool: return self.multiplication(left, right) if isinstance(left, ComparisonNode) and isinstance(right, ComparisonNode): return self.comparison(left, right) - if isinstance(left, VariableNode) and isinstance(right, VariableNode): - return self.variable(left, right) if isinstance(left, ParameterNode) and isinstance(right, ParameterNode): return self.parameter(left, right) + if isinstance(left, ComponentParameterNode) and isinstance( + right, ComponentParameterNode + ): + return self.comp_parameter(left, right) if isinstance(left, TimeOperatorNode) and isinstance(right, TimeOperatorNode): return self.time_operator(left, right) if isinstance(left, TimeAggregatorNode) and isinstance( @@ -124,12 +124,14 @@ def division(self, left: DivisionNode, right: DivisionNode) -> bool: def comparison(self, left: ComparisonNode, right: ComparisonNode) -> bool: return left.comparator == right.comparator and self._visit_operands(left, right) - def variable(self, left: VariableNode, right: VariableNode) -> bool: - return left.name == right.name - def parameter(self, left: ParameterNode, right: ParameterNode) -> bool: return left.name == right.name + def comp_parameter( + self, left: ComponentParameterNode, right: ComponentParameterNode + ) -> bool: + return left.component_id == right.component_id and left.name == right.name + def expression_range(self, left: ExpressionRange, right: ExpressionRange) -> bool: if not self.visit(left.start, right.start): return False @@ -183,7 +185,10 @@ def port_field_aggregator( def expressions_equal( - left: ExpressionNode, right: ExpressionNode, abs_tol: float = 0, rel_tol: float = 0 + left: ExpressionNode, + right: ExpressionNode, + abs_tol: float = 0, + rel_tol: float = 0, ) -> bool: """ True if both expression nodes are equal. Literal values may be compared with absolute or relative tolerance. diff --git a/src/andromede/expression/evaluate.py b/src/andromede/expression/evaluate.py deleted file mode 100644 index b51c0e86..00000000 --- a/src/andromede/expression/evaluate.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Dict - -from andromede.expression.expression import ( - ComponentParameterNode, - ComponentVariableNode, - PortFieldAggregatorNode, - PortFieldNode, - TimeOperatorNode, -) - -from .expression import ( - AdditionNode, - ComparisonNode, - DivisionNode, - ExpressionNode, - LiteralNode, - MultiplicationNode, - NegationNode, - ParameterNode, - ScenarioOperatorNode, - SubstractionNode, - TimeAggregatorNode, - VariableNode, -) -from .visitor import ExpressionVisitor, ExpressionVisitorOperations, T, visit - - -class ValueProvider(ABC): - """ - Implementations are in charge of mapping parameters and variables to their values. - Depending on the implementation, evaluation may require a component id or not. - """ - - @abstractmethod - def get_variable_value(self, name: str) -> float: - ... - - @abstractmethod - def get_parameter_value(self, name: str) -> float: - ... - - @abstractmethod - def get_component_variable_value(self, component_id: str, name: str) -> float: - ... - - @abstractmethod - def get_component_parameter_value(self, component_id: str, name: str) -> float: - ... - - # TODO: Should this really be an abstract method ? Or maybe, only the Provider in _make_value_provider should implement it. And the context attribute in the InstancesIndexVisitor is a ValueProvider that implements the parameter_is_constant_over_time method. Maybe create a child class of ValueProvider like TimeValueProvider ? - @abstractmethod - def parameter_is_constant_over_time(self, name: str) -> bool: - ... - - -@dataclass(frozen=True) -class EvaluationContext(ValueProvider): - """ - Simple value provider relying on dictionaries. - Does not support component variables/parameters. - """ - - variables: Dict[str, float] = field(default_factory=dict) - parameters: Dict[str, float] = field(default_factory=dict) - - def get_variable_value(self, name: str) -> float: - return self.variables[name] - - def get_parameter_value(self, name: str) -> float: - return self.parameters[name] - - def get_component_variable_value(self, component_id: str, name: str) -> float: - raise NotImplementedError() - - def get_component_parameter_value(self, component_id: str, name: str) -> float: - raise NotImplementedError() - - def parameter_is_constant_over_time(self, name: str) -> bool: - raise NotImplementedError() - - -@dataclass(frozen=True) -class EvaluationVisitor(ExpressionVisitorOperations[float]): - """ - Evaluates the expression with respect to the provided context - (variables and parameters values). - """ - - context: ValueProvider - - def literal(self, node: LiteralNode) -> float: - return node.value - - def comparison(self, node: ComparisonNode) -> float: - raise ValueError("Cannot evaluate comparison operator.") - - def variable(self, node: VariableNode) -> float: - return self.context.get_variable_value(node.name) - - def parameter(self, node: ParameterNode) -> float: - return self.context.get_parameter_value(node.name) - - def comp_parameter(self, node: ComponentParameterNode) -> float: - return self.context.get_component_parameter_value(node.component_id, node.name) - - def comp_variable(self, node: ComponentVariableNode) -> float: - return self.context.get_component_variable_value(node.component_id, node.name) - - def time_operator(self, node: TimeOperatorNode) -> float: - raise NotImplementedError() - - def time_aggregator(self, node: TimeAggregatorNode) -> float: - raise NotImplementedError() - - def scenario_operator(self, node: ScenarioOperatorNode) -> float: - raise NotImplementedError() - - def port_field(self, node: PortFieldNode) -> float: - raise NotImplementedError() - - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> float: - raise NotImplementedError() - - -def evaluate(expression: ExpressionNode, value_provider: ValueProvider) -> float: - return visit(expression, EvaluationVisitor(value_provider)) - - -@dataclass(frozen=True) -class InstancesIndexVisitor(EvaluationVisitor): - """ - Evaluates an expression given as instances index which should have no variable and constant parameter values. - """ - - def variable(self, node: VariableNode) -> float: - raise ValueError("An instance index expression cannot contain variable") - - def parameter(self, node: ParameterNode) -> float: - if not self.context.parameter_is_constant_over_time(node.name): - raise ValueError( - "Parameter given in an instance index expression must be constant over time" - ) - return self.context.get_parameter_value(node.name) - - def time_operator(self, node: TimeOperatorNode) -> float: - raise ValueError("An instance index expression cannot contain time operator") - - def time_aggregator(self, node: TimeAggregatorNode) -> float: - raise ValueError("An instance index expression cannot contain time aggregator") diff --git a/src/andromede/expression/evaluate_parameters.py b/src/andromede/expression/evaluate_parameters.py index 7c734260..d202ee2d 100644 --- a/src/andromede/expression/evaluate_parameters.py +++ b/src/andromede/expression/evaluate_parameters.py @@ -10,57 +10,247 @@ # # This file is part of the Antares project. -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import List +import operator +from dataclasses import dataclass, field +from typing import Callable, Dict, List -from andromede.expression.evaluate import InstancesIndexVisitor, ValueProvider - -from .copy import CopyVisitor from .expression import ( + AdditionNode, + ComparisonNode, ComponentParameterNode, + DivisionNode, ExpressionNode, ExpressionRange, InstancesTimeIndex, LiteralNode, + MultiplicationNode, + NegationNode, ParameterNode, + PortFieldAggregatorNode, + PortFieldNode, + ScenarioOperatorName, + ScenarioOperatorNode, + SubstractionNode, + TimeAggregatorName, + TimeAggregatorNode, + TimeOperatorName, + TimeOperatorNode, ) -from .visitor import visit +from .indexing_structure import RowIndex +from .value_provider import TimeScenarioIndex, TimeScenarioIndices, ValueProvider +from .visitor import ExpressionVisitor, visit + + +# TODO: (almost) same function as in linear_expression _merge_dicts +def _merge_dicts( + lhs: Dict[TimeScenarioIndex, float], + rhs: Dict[TimeScenarioIndex, float], + merge_func: Callable[[float, float], float], + neutral: float, +) -> Dict[TimeScenarioIndex, float]: + res = {} + for k, v in lhs.items(): + res[k] = merge_func(v, rhs.get(k, neutral)) + for k, v in rhs.items(): + if k not in lhs: + res[k] = merge_func(neutral, v) + return res -class ParameterValueProvider(ABC): - @abstractmethod - def get_parameter_value(self, name: str) -> float: - ... +# It is better to return a list of float than a single float to minimize the number of calls to the visit function. We access values of the parameters at different time steps with a single visit of the tree +@dataclass(frozen=True) +class ParameterEvaluationVisitor(ExpressionVisitor[Dict[TimeScenarioIndex, float]]): + """ + Evaluates the expression with respect to the provided context + (variables and parameters values). + """ + + context: ValueProvider + # Useful to keep track of row id for time shift + row_id: RowIndex # TODO to be included in ValueProvider ? + time_scenario_indices: TimeScenarioIndices = field(init=False) + + def __post_init__(self) -> None: + object.__setattr__( + self, + "time_scenario_indices", + TimeScenarioIndices([self.row_id.time], [self.row_id.scenario]), + ) - @abstractmethod - def get_component_parameter_value(self, component_id: str, name: str) -> float: - ... + # Redefine common operations so that it works as expected on lists (i.e. summing element wise rather than appending to it) + def negation(self, node: NegationNode) -> Dict[TimeScenarioIndex, float]: + return {k: -v for k, v in visit(node.operand, self).items()} + + def addition(self, node: AdditionNode) -> Dict[TimeScenarioIndex, float]: + left_value = visit(node.left, self) + right_value = visit(node.right, self) + result = _merge_dicts(left_value, right_value, operator.add, 0) + return result + + def substraction(self, node: SubstractionNode) -> Dict[TimeScenarioIndex, float]: + left_value = visit(node.left, self) + right_value = visit(node.right, self) + result = _merge_dicts(left_value, right_value, operator.sub, 0) + return result + + def multiplication( + self, node: MultiplicationNode + ) -> Dict[TimeScenarioIndex, float]: + left_value = visit(node.left, self) + right_value = visit(node.right, self) + result = _merge_dicts(left_value, right_value, operator.mul, 1) + return result + + def division(self, node: DivisionNode) -> Dict[TimeScenarioIndex, float]: + left_value = visit(node.left, self) + right_value = visit(node.right, self) + result = _merge_dicts(left_value, right_value, operator.truediv, 1) + return result + + def literal(self, node: LiteralNode) -> Dict[TimeScenarioIndex, float]: + result = {} + for time in self.time_scenario_indices.time_indices: + for scenario in self.time_scenario_indices.scenario_indices: + result[TimeScenarioIndex(time, scenario)] = node.value + return result + + def comparison(self, node: ComparisonNode) -> Dict[TimeScenarioIndex, float]: + raise ValueError("Cannot evaluate comparison operator.") + + def parameter(self, node: ParameterNode) -> Dict[TimeScenarioIndex, float]: + return self.context.get_parameter_value(node.name, self.time_scenario_indices) + + def comp_parameter( + self, node: ComponentParameterNode + ) -> Dict[TimeScenarioIndex, float]: + return self.context.get_component_parameter_value( + node.component_id, node.name, self.time_scenario_indices + ) + + def time_operator(self, node: TimeOperatorNode) -> Dict[TimeScenarioIndex, float]: + self.time_scenario_indices.time_indices = get_time_ids_from_instances_index( + node.instances_index, self.context, self.row_id + ) + if node.name == TimeOperatorName.SHIFT: + self.time_scenario_indices.time_indices = [ + self.row_id.time + op_id + for op_id in self.time_scenario_indices.time_indices + ] + elif node.name != TimeOperatorName.EVALUATION: + return NotImplemented + return visit(node.operand, self) + + def time_aggregator( + self, node: TimeAggregatorNode + ) -> Dict[TimeScenarioIndex, float]: + if node.name == TimeAggregatorName.TIME_SUM: + if not isinstance(node.operand, TimeOperatorNode): + # Sum over all time block + self.time_scenario_indices.time_indices = list( + range(self.context.block_length()) + ) + # Time indices for the case where node.operand is a TimeOperatorNode are handled in time_operator function directly + operand_dict = visit(node.operand, self) + result = {} + for scenario in self.time_scenario_indices.scenario_indices: + result[TimeScenarioIndex(self.row_id.time, scenario)] = sum( + operand_dict[k] + for k in operand_dict.keys() + if k.scenario == scenario + ) + # As the sum aggregates on time, time indices on which to evaluate parent expression collapses on row_id.time + self.time_scenario_indices.time_indices = [self.row_id.time] + return result + else: + return NotImplemented + + def scenario_operator( + self, node: ScenarioOperatorNode + ) -> Dict[TimeScenarioIndex, float]: + if node.name == ScenarioOperatorName.EXPECTATION: + self.time_scenario_indices.scenario_indices = list( + range(self.context.scenarios()) + ) + operand_dict = visit(node.operand, self) + result = {} + for time in self.time_scenario_indices.time_indices: + # TODO: Make this more general to consider weighted expectations + result[TimeScenarioIndex(time, self.row_id.scenario)] = ( + 1 + / self.context.scenarios() + * sum( + operand_dict[k] for k in operand_dict.keys() if k.time == time + ) + ) + # As the expectation aggregates on scenario, scenario indices on which to evaluate parent expression collapses on row_id.scenario + self.time_scenario_indices.scenario_indices = [self.row_id.scenario] + return result + + else: + return NotImplemented + + def port_field(self, node: PortFieldNode) -> Dict[TimeScenarioIndex, float]: + raise ValueError("Port fields must be resolved before evaluating parameters") + + def port_field_aggregator( + self, node: PortFieldAggregatorNode + ) -> Dict[TimeScenarioIndex, float]: + raise ValueError("Port fields must be resolved before evaluating parameters") + + +def check_resolved_expr( + resolved_expr: Dict[TimeScenarioIndex, float], row_id: RowIndex +) -> None: + # Check that the resolved expression has been correctly time and scenario aggregated so that only a float is left + if len(resolved_expr) != 1: + raise ValueError("Evaluation of expression cannot be reduced to a float value") + if TimeScenarioIndex(row_id.time, row_id.scenario) not in resolved_expr: + raise ValueError( + "Expression has a time operator but not time aggregator, maybe you are missing a sum(), necessary even on one element" + ) + + +def resolve_coefficient( + expression: ExpressionNode, value_provider: ValueProvider, row_id: RowIndex +) -> float: + result = visit(expression, ParameterEvaluationVisitor(value_provider, row_id)) + check_resolved_expr(result, row_id) + return result[TimeScenarioIndex(row_id.time, row_id.scenario)] @dataclass(frozen=True) -class ParameterResolver(CopyVisitor): +class InstancesIndexVisitor(ParameterEvaluationVisitor): """ - Duplicates the AST with replacement of parameter nodes by literal nodes. + Evaluates an expression given as instances index which should have no variable and constant parameter values. """ - context: ParameterValueProvider + # Probably useless as parameter nodes should have already be replaced by component parameter nodes ? + def parameter(self, node: ParameterNode) -> Dict[TimeScenarioIndex, float]: + if not self.context.parameter_is_constant_over_time(node.name): + raise ValueError( + "Parameter given in an instance index expression must be constant over time" + ) - def parameter(self, node: ParameterNode) -> ExpressionNode: - value: float = self.context.get_parameter_value(node.name) - return LiteralNode(value) + return self.context.get_parameter_value(node.name, self.time_scenario_indices) - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: - value: float = self.context.get_component_parameter_value( - node.component_id, node.name + def comp_parameter( + self, node: ComponentParameterNode + ) -> Dict[TimeScenarioIndex, float]: + if not self.context.parameter_is_constant_over_time(node.name): + raise ValueError( + "Parameter given in an instance index expression must be constant over time" + ) + return self.context.get_component_parameter_value( + node.component_id, node.name, self.time_scenario_indices ) - return LiteralNode(value) + def time_operator(self, node: TimeOperatorNode) -> Dict[TimeScenarioIndex, float]: + raise ValueError("An instance index expression cannot contain time operator") -def resolve_parameters( - expression: ExpressionNode, parameter_provider: ParameterValueProvider -) -> ExpressionNode: - return visit(expression, ParameterResolver(parameter_provider)) + def time_aggregator( + self, node: TimeAggregatorNode + ) -> Dict[TimeScenarioIndex, float]: + raise ValueError("An instance index expression cannot contain time aggregator") def float_to_int(value: float) -> int: @@ -70,29 +260,40 @@ def float_to_int(value: float) -> int: raise ValueError(f"{value} is not an integer.") -def evaluate_time_id(expr: ExpressionNode, value_provider: ValueProvider) -> int: - float_time_id = visit(expr, InstancesIndexVisitor(value_provider)) +def evaluate_time_id( + expr: ExpressionNode, value_provider: ValueProvider, row_id: RowIndex +) -> int: + float_time_id_in_list = visit(expr, InstancesIndexVisitor(value_provider, row_id)) + check_resolved_expr(float_time_id_in_list, row_id) try: - time_id = float_to_int(float_time_id) + time_id = float_to_int( + float_time_id_in_list[TimeScenarioIndex(row_id.time, row_id.scenario)] + ) except ValueError: print(f"{expr} does not represent an integer time index.") return time_id def get_time_ids_from_instances_index( - instances_index: InstancesTimeIndex, value_provider: ValueProvider + instances_index: InstancesTimeIndex, value_provider: ValueProvider, row_id: RowIndex ) -> List[int]: time_ids = [] if isinstance(instances_index.expressions, list): # List[ExpressionNode] for expr in instances_index.expressions: - time_ids.append(evaluate_time_id(expr, value_provider)) + time_ids.append(evaluate_time_id(expr, value_provider, row_id)) elif isinstance(instances_index.expressions, ExpressionRange): # ExpressionRange - start_id = evaluate_time_id(instances_index.expressions.start, value_provider) - stop_id = evaluate_time_id(instances_index.expressions.stop, value_provider) + start_id = evaluate_time_id( + instances_index.expressions.start, value_provider, row_id + ) + stop_id = evaluate_time_id( + instances_index.expressions.stop, value_provider, row_id + ) step_id = 1 if instances_index.expressions.step is not None: - step_id = evaluate_time_id(instances_index.expressions.step, value_provider) + step_id = evaluate_time_id( + instances_index.expressions.step, value_provider, row_id + ) # ExpressionRange includes stop_id whereas range excludes it time_ids = list(range(start_id, stop_id + 1, step_id)) diff --git a/src/andromede/expression/expression.py b/src/andromede/expression/expression.py index eb03dd5b..d2fee655 100644 --- a/src/andromede/expression/expression.py +++ b/src/andromede/expression/expression.py @@ -14,18 +14,11 @@ Defines the model for generic expressions. """ import enum -import inspect -from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Sequence, Union +import math +from dataclasses import dataclass +from typing import Any, Callable, List, Optional, Union -import andromede.expression.port_operator -import andromede.expression.scenario_operator -import andromede.expression.time_operator - - -class Instances(enum.Enum): - SIMPLE = "SIMPLE" - MULTIPLE = "MULTIPLE" +EPS = 10 ** (-16) @dataclass(frozen=True) @@ -40,34 +33,32 @@ class ExpressionNode: >>> expr = -var('x') + 5 / param('p') """ - instances: Instances = field(init=False, default=Instances.SIMPLE) - def __neg__(self) -> "ExpressionNode": - return NegationNode(self) + return _negate_node(self) def __add__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: AdditionNode(self, x)) + return _apply_if_node(rhs, lambda x: _add_node(self, x)) def __radd__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: AdditionNode(x, self)) + return _apply_if_node(lhs, lambda x: _add_node(x, self)) def __sub__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: SubstractionNode(self, x)) + return _apply_if_node(rhs, lambda x: _substract_node(self, x)) def __rsub__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: SubstractionNode(x, self)) + return _apply_if_node(lhs, lambda x: _substract_node(x, self)) def __mul__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: MultiplicationNode(self, x)) + return _apply_if_node(rhs, lambda x: _multiply_node(self, x)) def __rmul__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: MultiplicationNode(x, self)) + return _apply_if_node(lhs, lambda x: _multiply_node(x, self)) def __truediv__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: DivisionNode(self, x)) + return _apply_if_node(rhs, lambda x: _divide_node(self, x)) def __rtruediv__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: DivisionNode(x, self)) + return _apply_if_node(lhs, lambda x: _divide_node(x, self)) def __le__(self, rhs: Any) -> "ExpressionNode": return _apply_if_node( @@ -84,15 +75,20 @@ def __eq__(self, rhs: Any) -> "ExpressionNode": # type: ignore def sum(self) -> "ExpressionNode": if isinstance(self, TimeOperatorNode): - return TimeAggregatorNode(self, "TimeSum", stay_roll=True) + return TimeAggregatorNode(self, TimeAggregatorName.TIME_SUM, stay_roll=True) else: return _apply_if_node( - self, lambda x: TimeAggregatorNode(x, "TimeSum", stay_roll=False) + self, + lambda x: TimeAggregatorNode( + x, TimeAggregatorName.TIME_SUM, stay_roll=False + ), ) def sum_connections(self) -> "ExpressionNode": if isinstance(self, PortFieldNode): - return PortFieldAggregatorNode(self, aggregator="PortSum") + return PortFieldAggregatorNode( + self, aggregator=PortFieldAggregatorName.PORT_SUM + ) raise ValueError( f"sum_connections() applies only for PortFieldNode, whereas the current node is of type {type(self)}." ) @@ -100,58 +96,224 @@ def sum_connections(self) -> "ExpressionNode": def shift( self, expressions: Union[ - int, "ExpressionNode", List["ExpressionNode"], "ExpressionRange" + int, + "ExpressionNode", + List["ExpressionNode"], + "ExpressionRange", ], ) -> "ExpressionNode": return _apply_if_node( self, - lambda x: TimeOperatorNode(x, "TimeShift", InstancesTimeIndex(expressions)), + lambda x: TimeOperatorNode( + x, TimeOperatorName.SHIFT, InstancesTimeIndex(expressions) + ), ) def eval( self, expressions: Union[ - int, "ExpressionNode", List["ExpressionNode"], "ExpressionRange" + int, + "ExpressionNode", + List["ExpressionNode"], + "ExpressionRange", ], ) -> "ExpressionNode": return _apply_if_node( self, lambda x: TimeOperatorNode( - x, "TimeEvaluation", InstancesTimeIndex(expressions) + x, TimeOperatorName.EVALUATION, InstancesTimeIndex(expressions) ), ) def expec(self) -> "ExpressionNode": - return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) + return _apply_if_node( + self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.EXPECTATION) + ) def variance(self) -> "ExpressionNode": - return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) + return _apply_if_node( + self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.VARIANCE) + ) -def _wrap_in_node(obj: Any) -> ExpressionNode: +def wrap_in_node(obj: Any) -> ExpressionNode: if isinstance(obj, ExpressionNode): return obj elif isinstance(obj, float) or isinstance(obj, int): return LiteralNode(float(obj)) - raise TypeError(f"Unable to wrap {obj} into an expression node") + # Do not raise excpetion so that we can return NotImplemented in _apply_if_node + # raise TypeError(f"Unable to wrap {obj} into an expression node") def _apply_if_node( obj: Any, func: Callable[["ExpressionNode"], "ExpressionNode"] ) -> "ExpressionNode": - if as_node := _wrap_in_node(obj): + if as_node := wrap_in_node(obj): return func(as_node) else: return NotImplemented -@dataclass(frozen=True, eq=False) -class VariableNode(ExpressionNode): - name: str +def is_zero(node: ExpressionNode) -> bool: + # Faster implementation than expressions equal for this particular cases + return isinstance(node, LiteralNode) and math.isclose(node.value, 0, abs_tol=EPS) + + +def is_one(node: ExpressionNode) -> bool: + # Faster implementation than expressions equal for this particular cases + return isinstance(node, LiteralNode) and math.isclose(node.value, 1) + + +def is_minus_one(node: ExpressionNode) -> bool: + # Faster implementation than expressions equal for this particular cases + return isinstance(node, LiteralNode) and math.isclose(node.value, -1) + + +def _negate_node(node: ExpressionNode) -> ExpressionNode: + if isinstance(node, LiteralNode): + return LiteralNode(-node.value) + elif isinstance(node, NegationNode): + return node.operand + else: + return NegationNode(node) -def var(name: str) -> VariableNode: - return VariableNode(name) +def _add_node(lhs: ExpressionNode, rhs: ExpressionNode) -> ExpressionNode: + if is_zero(lhs): + return rhs + if is_zero(rhs): + return lhs + # TODO: How can we use the equality visitor here (simple import -> circular import) -> equality visitor code is placed at the bottom of this file... + if expressions_equal(lhs, -rhs): + return LiteralNode(0) + if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): + return LiteralNode(lhs.value + rhs.value) + if _are_parameter_nodes_equal(lhs, rhs): + return MultiplicationNode(LiteralNode(2), lhs) + if (lhs_is_param := isinstance(lhs, ParameterNode)) or ( + rhs_is_param := isinstance(rhs, ParameterNode) + ): + if lhs_is_param: + param_node = lhs + other = rhs + elif rhs_is_param: + param_node = rhs + other = lhs + + if isinstance(other, MultiplicationNode): + if _are_parameter_nodes_equal(param_node, other.left): + return MultiplicationNode( + _add_node(LiteralNode(1), other.right), param_node + ) + elif _are_parameter_nodes_equal(param_node, other.right): + return MultiplicationNode( + _add_node(LiteralNode(1), other.left), param_node + ) + + if isinstance(lhs, MultiplicationNode) and isinstance(rhs, MultiplicationNode): + if _are_parameter_nodes_equal(lhs.left, rhs.left): + return MultiplicationNode(_add_node(lhs.right, rhs.right), lhs.left) + elif _are_parameter_nodes_equal(lhs.left, rhs.right): + return MultiplicationNode(_add_node(lhs.right, rhs.left), lhs.left) + elif _are_parameter_nodes_equal(lhs.right, rhs.left): + return MultiplicationNode(_add_node(lhs.left, rhs.right), lhs.right) + elif _are_parameter_nodes_equal(lhs.right, rhs.right): + return MultiplicationNode(_add_node(lhs.left, rhs.left), lhs.right) + else: + return AdditionNode(lhs, rhs) + + +# Better if we could use equality visitor +def _are_parameter_nodes_equal(lhs: ExpressionNode, rhs: ExpressionNode) -> bool: + return ( + isinstance(lhs, ParameterNode) + and isinstance(rhs, ParameterNode) + and lhs.name == rhs.name + ) + + +def _substract_node(lhs: ExpressionNode, rhs: ExpressionNode) -> ExpressionNode: + if is_zero(lhs): + return -rhs + if is_zero(rhs): + return lhs + if expressions_equal(lhs, rhs): + return LiteralNode(0) + if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): + return LiteralNode(lhs.value - rhs.value) + if _are_parameter_nodes_equal(lhs, -rhs): + return MultiplicationNode(LiteralNode(2), lhs) + if (lhs_is_param := isinstance(lhs, ParameterNode)) or ( + rhs_is_param := isinstance(rhs, ParameterNode) + ): + if lhs_is_param: + param_node = lhs + other = rhs + elif rhs_is_param: + param_node = rhs + other = lhs + + if isinstance(other, MultiplicationNode): + if _are_parameter_nodes_equal(param_node, other.left): + if lhs_is_param: + return MultiplicationNode( + _substract_node(LiteralNode(1), other.right), param_node + ) + elif rhs_is_param: + return MultiplicationNode( + _substract_node(other.right, LiteralNode(1)), param_node + ) + elif _are_parameter_nodes_equal(param_node, other.right): + if lhs_is_param: + return MultiplicationNode( + _substract_node(LiteralNode(1), other.left), param_node + ) + elif rhs_is_param: + return MultiplicationNode( + _substract_node(other.left, LiteralNode(1)), param_node + ) + + if isinstance(lhs, MultiplicationNode) and isinstance(rhs, MultiplicationNode): + if _are_parameter_nodes_equal(lhs.left, rhs.left): + return MultiplicationNode(_substract_node(lhs.right, rhs.right), lhs.left) + elif _are_parameter_nodes_equal(lhs.left, rhs.right): + return MultiplicationNode(_substract_node(lhs.right, rhs.left), lhs.left) + elif _are_parameter_nodes_equal(lhs.right, rhs.left): + return MultiplicationNode(_substract_node(lhs.left, rhs.right), lhs.right) + elif _are_parameter_nodes_equal(lhs.right, rhs.right): + return MultiplicationNode(_substract_node(lhs.left, rhs.left), lhs.right) + else: + return SubstractionNode(lhs, rhs) + + +def _multiply_node(lhs: ExpressionNode, rhs: ExpressionNode) -> ExpressionNode: + if is_zero(lhs) or is_zero(rhs): + return LiteralNode(0) + if is_one(lhs): + return rhs + if is_one(rhs): + return lhs + if is_minus_one(lhs): + return -rhs + if is_minus_one(rhs): + return -lhs + if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): + return LiteralNode(lhs.value * rhs.value) + else: + return MultiplicationNode(lhs, rhs) + + +def _divide_node(lhs: ExpressionNode, rhs: ExpressionNode) -> ExpressionNode: + if is_one(rhs): + return lhs + if is_minus_one(rhs): + return -lhs + if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): + # This could raise division by 0 error + return LiteralNode(lhs.value / rhs.value) + + else: + return DivisionNode(lhs, rhs) @dataclass(frozen=True, eq=False) @@ -164,19 +326,11 @@ class PortFieldNode(ExpressionNode): field_name: str -def port_field(port_name: str, field_name: str) -> PortFieldNode: - return PortFieldNode(port_name, field_name) - - @dataclass(frozen=True, eq=False) class ParameterNode(ExpressionNode): name: str -def param(name: str) -> ParameterNode: - return ParameterNode(name) - - @dataclass(frozen=True, eq=False) class ComponentParameterNode(ExpressionNode): """ @@ -191,26 +345,12 @@ class ComponentParameterNode(ExpressionNode): name: str -def comp_param(component_id: str, name: str) -> ComponentParameterNode: - return ComponentParameterNode(component_id, name) - - -@dataclass(frozen=True, eq=False) -class ComponentVariableNode(ExpressionNode): - """ - Represents one variable of one component. - - When building actual equations for a system, - we need to associated each variable to its - actual component, at some point. - """ - - component_id: str - name: str +def param(name: str) -> ExpressionNode: + return ParameterNode(name) -def comp_var(component_id: str, name: str) -> ComponentVariableNode: - return ComponentVariableNode(component_id, name) +def comp_param(component_id: str, name: str) -> ExpressionNode: + return ComponentParameterNode(component_id, name) @dataclass(frozen=True, eq=False) @@ -218,33 +358,32 @@ class LiteralNode(ExpressionNode): value: float -def literal(value: float) -> LiteralNode: +def literal(value: float) -> ExpressionNode: return LiteralNode(value) +def is_unbound(expr: ExpressionNode) -> bool: + return isinstance(expr, LiteralNode) and (abs(expr.value) == float("inf")) + + @dataclass(frozen=True, eq=False) class UnaryOperatorNode(ExpressionNode): operand: ExpressionNode - def __post_init__(self) -> None: - object.__setattr__(self, "instances", self.operand.instances) + +class PortFieldAggregatorName(enum.Enum): + # String value of enum must match the name of the PortAggregator class in port_operator.py + PORT_SUM = "PortSum" @dataclass(frozen=True, eq=False) class PortFieldAggregatorNode(UnaryOperatorNode): - aggregator: str + aggregator: PortFieldAggregatorName def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.port_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.port_operator.PortAggregator) - ] - if self.aggregator not in valid_names: - raise NotImplementedError( - f"{self.aggregator} is not a valid port aggregator, valid port aggregators are {valid_names}" + if not isinstance(self.aggregator, PortFieldAggregatorName): + raise TypeError( + f"PortFieldAggregatorNode.name should of class PortFieldAggregatorName, but {self.aggregator} of type {type(self.aggregator)} was given" ) @@ -258,18 +397,6 @@ class BinaryOperatorNode(ExpressionNode): left: ExpressionNode right: ExpressionNode - def __post_init__(self) -> None: - binary_operator_post_init(self, "apply binary operation with") - - -def binary_operator_post_init(node: BinaryOperatorNode, operation: str) -> None: - if node.left.instances != node.right.instances: - raise ValueError( - f"Cannot {operation} {node.left} and {node.right} as they do not have the same number of instances." - ) - else: - object.__setattr__(node, "instances", node.left.instances) - class Comparator(enum.Enum): LESS_THAN = "LESS_THAN" @@ -281,35 +408,28 @@ class Comparator(enum.Enum): class ComparisonNode(BinaryOperatorNode): comparator: Comparator - def __post_init__(self) -> None: - binary_operator_post_init(self, "compare") - @dataclass(frozen=True, eq=False) class AdditionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "add") + pass @dataclass(frozen=True, eq=False) class SubstractionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "substract") + pass @dataclass(frozen=True, eq=False) class MultiplicationNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "multiply") + pass @dataclass(frozen=True, eq=False) class DivisionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "divide") + pass -@dataclass(frozen=True, eq=False) +@dataclass(frozen=True) class ExpressionRange: start: ExpressionNode stop: ExpressionNode @@ -319,9 +439,17 @@ def __post_init__(self) -> None: for attribute in self.__dict__: value = getattr(self, attribute) object.__setattr__( - self, attribute, _wrap_in_node(value) if value is not None else value + self, attribute, wrap_in_node(value) if value is not None else value ) + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, ExpressionRange) + and expressions_equal(self.start, other.start) + and expressions_equal(self.stop, other.stop) + and expressions_equal_if_present(self.step, other.step) + ) + IntOrExpr = Union[int, ExpressionNode] @@ -330,13 +458,13 @@ def expression_range( start: IntOrExpr, stop: IntOrExpr, step: Optional[IntOrExpr] = None ) -> ExpressionRange: return ExpressionRange( - start=_wrap_in_node(start), - stop=_wrap_in_node(stop), - step=None if step is None else _wrap_in_node(step), + start=wrap_in_node(start), + stop=wrap_in_node(stop), + step=None if step is None else wrap_in_node(step), ) -@dataclass +@dataclass(frozen=True) class InstancesTimeIndex: """ Defines a set of time indices on which a time operator operates. @@ -365,9 +493,36 @@ def __init__( ) if isinstance(expressions, (int, ExpressionNode)): - self.expressions = [_wrap_in_node(expressions)] + object.__setattr__(self, "expressions", [wrap_in_node(expressions)]) else: - self.expressions = expressions + object.__setattr__(self, "expressions", expressions) + + def __hash__(self) -> int: + # Maybe if/else not needed and always using the tuple works ? + if isinstance(self.expressions, list): + return hash(tuple(self.expressions)) + else: + return hash(self.expressions) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, InstancesTimeIndex): + if isinstance(self.expressions, list) and all( + isinstance(x, ExpressionNode) for x in self.expressions + ): + return ( + isinstance(other.expressions, list) + and all(isinstance(x, ExpressionNode) for x in other.expressions) + and all( + expressions_equal(left_expr, right_expr) + for left_expr, right_expr in zip( + self.expressions, other.expressions + ) + ) + ) + elif isinstance(self.expressions, ExpressionRange): + return self.expressions == other.expressions + else: + return False def is_simple(self) -> bool: if isinstance(self.expressions, list): @@ -377,76 +532,225 @@ def is_simple(self) -> bool: return False +class TimeOperatorName(enum.Enum): + # String value of enum must match the name of the TimeOperator class in time_operator.py + SHIFT = "TimeShift" + EVALUATION = "TimeEvaluation" + + +class TimeAggregatorName(enum.Enum): + # String value of enum must match the name of the TimeAggregator class in time_operator.py + TIME_SUM = "TimeSum" + + @dataclass(frozen=True, eq=False) class TimeOperatorNode(UnaryOperatorNode): - name: str + name: TimeOperatorName instances_index: InstancesTimeIndex def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.time_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.time_operator.TimeOperator) - ] - if self.name not in valid_names: - raise ValueError( - f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" - ) - if self.operand.instances == Instances.SIMPLE: - if self.instances_index.is_simple(): - object.__setattr__(self, "instances", Instances.SIMPLE) - else: - object.__setattr__(self, "instances", Instances.MULTIPLE) - else: - raise ValueError( - "Cannot apply time operator on an expression that already represents multiple instances" + if not isinstance(self.name, TimeOperatorName): + raise TypeError( + f"TimeOperatorNode.name should of class TimeOperatorName, but {self.name} of type {type(self.name)} was given" ) @dataclass(frozen=True, eq=False) class TimeAggregatorNode(UnaryOperatorNode): - name: str - stay_roll: bool + name: TimeAggregatorName + stay_roll: bool # TODO: Is it still useful ? def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.time_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.time_operator.TimeAggregator) - ] - if self.name not in valid_names: - raise ValueError( - f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" + if not isinstance(self.name, TimeAggregatorName): + raise TypeError( + f"TimeAggregatorNode.name should of class TimeAggregatorName, but {self.name} of type {type(self.name)} was given" ) - object.__setattr__(self, "instances", Instances.SIMPLE) + + +class ScenarioOperatorName(enum.Enum): + # String value of enum must match the name of the ScenarioOperator class in scenario_operator.py + EXPECTATION = "Expectation" + VARIANCE = "Variance" @dataclass(frozen=True, eq=False) class ScenarioOperatorNode(UnaryOperatorNode): - name: str + name: ScenarioOperatorName def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.scenario_operator, inspect.isclass + if not isinstance(self.name, ScenarioOperatorName): + raise TypeError( + f"ScenarioOperatorNode.name should of class ScenarioOperatorName, but {self.name} of type {type(self.name)} was given" + ) + + +@dataclass(frozen=True) +class EqualityVisitor: + abs_tol: float = 0 + rel_tol: float = 0 + + def __post_init__(self) -> None: + if self.abs_tol < 0: + raise ValueError( + f"Absolute comparison tolerance must be >= 0, got {self.abs_tol}" ) - if issubclass(cls, andromede.expression.scenario_operator.ScenarioOperator) - ] - if self.name not in valid_names: + if self.rel_tol < 0: raise ValueError( - f"{self.name} is not a valid scenario operator, valid scenario operators are {valid_names}" + f"Relative comparison tolerance must be >= 0, got {self.rel_tol}" ) - object.__setattr__(self, "instances", Instances.SIMPLE) + def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool: + if left.__class__ != right.__class__: + return False + if isinstance(left, LiteralNode) and isinstance(right, LiteralNode): + return self.literal(left, right) + if isinstance(left, NegationNode) and isinstance(right, NegationNode): + return self.negation(left, right) + if isinstance(left, AdditionNode) and isinstance(right, AdditionNode): + return self.addition(left, right) + if isinstance(left, SubstractionNode) and isinstance(right, SubstractionNode): + return self.substraction(left, right) + if isinstance(left, DivisionNode) and isinstance(right, DivisionNode): + return self.division(left, right) + if isinstance(left, MultiplicationNode) and isinstance( + right, MultiplicationNode + ): + return self.multiplication(left, right) + if isinstance(left, ComparisonNode) and isinstance(right, ComparisonNode): + return self.comparison(left, right) + if isinstance(left, ParameterNode) and isinstance(right, ParameterNode): + return self.parameter(left, right) + if isinstance(left, ComponentParameterNode) and isinstance( + right, ComponentParameterNode + ): + return self.comp_parameter(left, right) + if isinstance(left, TimeOperatorNode) and isinstance(right, TimeOperatorNode): + return self.time_operator(left, right) + if isinstance(left, TimeAggregatorNode) and isinstance( + right, TimeAggregatorNode + ): + return self.time_aggregator(left, right) + if isinstance(left, ScenarioOperatorNode) and isinstance( + right, ScenarioOperatorNode + ): + return self.scenario_operator(left, right) + if isinstance(left, PortFieldNode) and isinstance(right, PortFieldNode): + return self.port_field(left, right) + if isinstance(left, PortFieldAggregatorNode) and isinstance( + right, PortFieldAggregatorNode + ): + return self.port_field_aggregator(left, right) + raise NotImplementedError(f"Equality not implemented for {left.__class__}") -def sum_expressions(expressions: Sequence[ExpressionNode]) -> ExpressionNode: - if len(expressions) == 0: - return LiteralNode(0) - if len(expressions) == 1: - return expressions[0] - return expressions[0] + sum_expressions(expressions[1:]) + def literal(self, left: LiteralNode, right: LiteralNode) -> bool: + return math.isclose( + left.value, right.value, abs_tol=self.abs_tol, rel_tol=self.rel_tol + ) + + def _visit_operands( + self, left: BinaryOperatorNode, right: BinaryOperatorNode + ) -> bool: + return self.visit(left.left, right.left) and self.visit(left.right, right.right) + + def negation(self, left: NegationNode, right: NegationNode) -> bool: + return self.visit(left.operand, right.operand) + + def addition(self, left: AdditionNode, right: AdditionNode) -> bool: + # TODO: Commutativty ??? Cannot detect that a+b == b+a ? Do we want to do this ? + return self._visit_operands(left, right) + + def substraction(self, left: SubstractionNode, right: SubstractionNode) -> bool: + return self._visit_operands(left, right) + + def multiplication( + self, left: MultiplicationNode, right: MultiplicationNode + ) -> bool: + return self._visit_operands(left, right) + + def division(self, left: DivisionNode, right: DivisionNode) -> bool: + return self._visit_operands(left, right) + + def comparison(self, left: ComparisonNode, right: ComparisonNode) -> bool: + return left.comparator == right.comparator and self._visit_operands(left, right) + + def parameter(self, left: ParameterNode, right: ParameterNode) -> bool: + return left.name == right.name + + def comp_parameter( + self, left: ComponentParameterNode, right: ComponentParameterNode + ) -> bool: + return left.component_id == right.component_id and left.name == right.name + + def expression_range(self, left: ExpressionRange, right: ExpressionRange) -> bool: + if not self.visit(left.start, right.start): + return False + if not self.visit(left.stop, right.stop): + return False + if left.step is not None and right.step is not None: + return self.visit(left.step, right.step) + return left.step is None and right.step is None + + def instances_index(self, lhs: InstancesTimeIndex, rhs: InstancesTimeIndex) -> bool: + if isinstance(lhs.expressions, ExpressionRange) and isinstance( + rhs.expressions, ExpressionRange + ): + return self.expression_range(lhs.expressions, rhs.expressions) + if isinstance(lhs.expressions, list) and isinstance(rhs.expressions, list): + return len(lhs.expressions) == len(rhs.expressions) and all( + self.visit(l, r) for l, r in zip(lhs.expressions, rhs.expressions) + ) + return False + + def time_operator(self, left: TimeOperatorNode, right: TimeOperatorNode) -> bool: + return ( + left.name == right.name + and self.instances_index(left.instances_index, right.instances_index) + and self.visit(left.operand, right.operand) + ) + + def time_aggregator( + self, left: TimeAggregatorNode, right: TimeAggregatorNode + ) -> bool: + return ( + left.name == right.name + and left.stay_roll == right.stay_roll + and self.visit(left.operand, right.operand) + ) + + def scenario_operator( + self, left: ScenarioOperatorNode, right: ScenarioOperatorNode + ) -> bool: + return left.name == right.name and self.visit(left.operand, right.operand) + + def port_field(self, left: PortFieldNode, right: PortFieldNode) -> bool: + return left.port_name == right.port_name and left.field_name == right.field_name + + def port_field_aggregator( + self, left: PortFieldAggregatorNode, right: PortFieldAggregatorNode + ) -> bool: + return left.aggregator == right.aggregator and self.visit( + left.operand, right.operand + ) + + +def expressions_equal( + left: ExpressionNode, + right: ExpressionNode, + abs_tol: float = 0, + rel_tol: float = 0, +) -> bool: + """ + True if both expression nodes are equal. Literal values may be compared with absolute or relative tolerance. + """ + return EqualityVisitor(abs_tol, rel_tol).visit(left, right) + + +def expressions_equal_if_present( + lhs: Optional[ExpressionNode], rhs: Optional[ExpressionNode] +) -> bool: + if lhs is None and rhs is None: + return True + elif lhs is None or rhs is None: + return False + else: + return expressions_equal(lhs, rhs) diff --git a/src/andromede/expression/indexing.py b/src/andromede/expression/indexing.py index 11051dd5..aaad0881 100644 --- a/src/andromede/expression/indexing.py +++ b/src/andromede/expression/indexing.py @@ -11,31 +11,8 @@ # This file is part of the Antares project. from abc import ABC, abstractmethod -from dataclasses import dataclass -import andromede.expression.time_operator -from andromede.expression.indexing_structure import IndexingStructure - -from .expression import ( - AdditionNode, - ComparisonNode, - ComponentParameterNode, - ComponentVariableNode, - DivisionNode, - ExpressionNode, - LiteralNode, - MultiplicationNode, - NegationNode, - ParameterNode, - PortFieldAggregatorNode, - PortFieldNode, - ScenarioOperatorNode, - SubstractionNode, - TimeAggregatorNode, - TimeOperatorNode, - VariableNode, -) -from .visitor import ExpressionVisitor, T, visit +from .indexing_structure import IndexingStructure class IndexingStructureProvider(ABC): @@ -58,85 +35,3 @@ def get_component_parameter_structure( self, component_id: str, name: str ) -> IndexingStructure: ... - - -@dataclass(frozen=True) -class TimeScenarioIndexingVisitor(ExpressionVisitor[IndexingStructure]): - """ - Determines if the expression represents a single expression or an expression that should be instantiated for all time steps. - """ - - context: IndexingStructureProvider - - def literal(self, node: LiteralNode) -> IndexingStructure: - return IndexingStructure(False, False) - - def negation(self, node: NegationNode) -> IndexingStructure: - return visit(node.operand, self) - - def addition(self, node: AdditionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def substraction(self, node: SubstractionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def multiplication(self, node: MultiplicationNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def division(self, node: DivisionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def comparison(self, node: ComparisonNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def variable(self, node: VariableNode) -> IndexingStructure: - time = self.context.get_variable_structure(node.name).time == True - scenario = self.context.get_variable_structure(node.name).scenario == True - return IndexingStructure(time, scenario) - - def parameter(self, node: ParameterNode) -> IndexingStructure: - time = self.context.get_parameter_structure(node.name).time == True - scenario = self.context.get_parameter_structure(node.name).scenario == True - return IndexingStructure(time, scenario) - - def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure: - return self.context.get_component_variable_structure( - node.component_id, node.name - ) - - def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure: - return self.context.get_component_parameter_structure( - node.component_id, node.name - ) - - def time_operator(self, node: TimeOperatorNode) -> IndexingStructure: - time_operator_cls = getattr(andromede.expression.time_operator, node.name) - if time_operator_cls.rolling(): - return visit(node.operand, self) - else: - return IndexingStructure(False, visit(node.operand, self).scenario) - - def time_aggregator(self, node: TimeAggregatorNode) -> IndexingStructure: - if node.stay_roll: - return visit(node.operand, self) - else: - return IndexingStructure(False, visit(node.operand, self).scenario) - - def scenario_operator(self, node: ScenarioOperatorNode) -> IndexingStructure: - return IndexingStructure(visit(node.operand, self).time, False) - - def port_field(self, node: PortFieldNode) -> IndexingStructure: - raise ValueError( - "Port fields must be resolved before computing indexing structure." - ) - - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> IndexingStructure: - raise ValueError( - "Port fields aggregators must be resolved before computing indexing structure." - ) - - -def compute_indexation( - expression: ExpressionNode, provider: IndexingStructureProvider -) -> IndexingStructure: - return visit(expression, TimeScenarioIndexingVisitor(provider)) diff --git a/src/andromede/expression/indexing_structure.py b/src/andromede/expression/indexing_structure.py index 746b07ff..2c81e708 100644 --- a/src/andromede/expression/indexing_structure.py +++ b/src/andromede/expression/indexing_structure.py @@ -38,3 +38,17 @@ def is_time_scenario_varying(self) -> bool: def is_constant(self) -> bool: return (not self.is_time_varying()) and (not self.is_scenario_varying()) + + +# Contrary to IndexingStructure, time and scenario are integers to "count/identify" the constraint whereas IndexingStructure is used used to know whether or not an expression is indexed by time or scenario. +@dataclass(frozen=True) +class RowIndex: + """ + Indexing of rows in a problem. + """ + + time: int + scenario: int + + def __str__(self) -> str: + return f"t{self.time}_s{self.scenario}" diff --git a/src/andromede/expression/linear_expression.py b/src/andromede/expression/linear_expression.py new file mode 100644 index 00000000..5050888f --- /dev/null +++ b/src/andromede/expression/linear_expression.py @@ -0,0 +1,1090 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +""" +Specific modelling for "instantiated" linear expressions, +with only variables and literal coefficients. +""" +import dataclasses +from dataclasses import dataclass, field +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Mapping, + Optional, + Sequence, + TypeVar, + Union, + cast, + overload, +) + +from .context_adder import add_component_context +from .equality import expressions_equal +from .evaluate_parameters import check_resolved_expr, resolve_coefficient +from .expression import ( + ExpressionNode, + ExpressionRange, + InstancesTimeIndex, + LiteralNode, + ScenarioOperatorName, + ScenarioOperatorNode, + TimeAggregatorName, + TimeAggregatorNode, + TimeOperatorName, + TimeOperatorNode, + is_minus_one, + is_one, + is_unbound, + is_zero, + literal, + wrap_in_node, +) +from .indexing import IndexingStructureProvider +from .indexing_structure import IndexingStructure, RowIndex +from .port_operator import PortAggregator, PortSum +from .print import print_expr +from .scenario_operator import Expectation, ScenarioAggregator +from .time_operator import ( + TimeAggregator, + TimeEvaluation, + TimeOperator, + TimeShift, + TimeSum, +) +from .value_provider import TimeScenarioIndex, TimeScenarioIndices, ValueProvider + + +@dataclass(frozen=True) +class TermKey: + """ + Utility class to provide key for a term that contains all term information except coefficient + """ + + component_id: str + variable_name: str + time_operator: Optional[TimeOperator] + time_aggregator: Optional[TimeAggregator] + scenario_aggregator: Optional[ScenarioAggregator] + + # Used for test_expression_parsing + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, TermKey) + and self.component_id == other.component_id + and self.variable_name == other.variable_name + and self.time_operator == other.time_operator + and self.time_aggregator == other.time_aggregator + and self.scenario_aggregator == other.scenario_aggregator + ) + + +@dataclass(frozen=True) +class Term: + """ + One term in a linear expression: for example the "10x" par in "10x + 5y + 5" + + Args: + coefficient: the coefficient for that term, for example "10" in "10x" + variable_name: the name of the variable, for example "x" in "10x" + """ + + coefficient: ExpressionNode + component_id: str + variable_name: str + structure: IndexingStructure = field( + default=IndexingStructure(time=True, scenario=True) + ) + time_operator: Optional[TimeOperator] = None + time_aggregator: Optional[TimeAggregator] = None + scenario_aggregator: Optional[ScenarioAggregator] = None + + def __post_init__(self) -> None: + object.__setattr__(self, "coefficient", wrap_in_node(self.coefficient)) + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, Term) + and expressions_equal(self.coefficient, other.coefficient) + and self.component_id == other.component_id + and self.variable_name == other.variable_name + and self.structure == other.structure + and self.time_operator == other.time_operator + and self.time_aggregator == other.time_aggregator + and self.scenario_aggregator == other.scenario_aggregator + ) + + def is_zero(self) -> bool: + return is_zero(self.coefficient) + + def str_for_coeff(self) -> str: + str_for_coeff = "" + if is_one(self.coefficient): + str_for_coeff = "+" + elif is_minus_one(self.coefficient): + str_for_coeff = "-" + else: + str_for_coeff = print_expr(self.coefficient) + return str_for_coeff + + def __str__(self) -> str: + # Useful for debugging tests + result = self.str_for_coeff() + str(self.variable_name) + if self.time_operator is not None: + result += f".{str(self.time_operator)}" + if self.time_aggregator is not None: + result += f".{str(self.time_aggregator)}" + if self.scenario_aggregator is not None: + result += f".{str(self.scenario_aggregator)}" + return result + + def evaluate(self, context: ValueProvider, time_scenario_index: RowIndex) -> float: + # TODO: Take care of component variables, multiple time scenarios, operators, etc + time_scenario_indices = TimeScenarioIndices( + [time_scenario_index.time], [time_scenario_index.scenario] + ) + # Probably very error prone + if self.component_id: + variable_value = context.get_component_variable_value( + self.component_id, self.variable_name, time_scenario_indices + ) + else: + variable_value = context.get_variable_value( + self.variable_name, time_scenario_indices + ) + check_resolved_expr(variable_value, time_scenario_index) + return ( + resolve_coefficient(self.coefficient, context, time_scenario_index) + * variable_value[ + TimeScenarioIndex( + time_scenario_index.time, time_scenario_index.scenario + ) + ] + ) + + def compute_indexation( + self, provider: IndexingStructureProvider + ) -> IndexingStructure: + return IndexingStructure( + self._compute_time_indexing(provider), + self._compute_scenario_indexing(provider), + ) + + def _compute_time_indexing(self, provider: IndexingStructureProvider) -> bool: + if (self.time_aggregator and not self.time_aggregator.stay_roll) or ( + self.time_operator and not self.time_operator.rolling() + ): + time = False + else: + if self.component_id: + time = provider.get_component_variable_structure( + self.component_id, self.variable_name + ).time + else: + time = provider.get_variable_structure(self.variable_name).time + return time + + def _compute_scenario_indexing(self, provider: IndexingStructureProvider) -> bool: + if self.scenario_aggregator: + scenario = False + else: + # TODO: Improve this if/else structure, probably simplify IndexingStructureProvider + if self.component_id: + scenario = provider.get_component_variable_structure( + self.component_id, self.variable_name + ).scenario + + else: + scenario = provider.get_variable_structure(self.variable_name).scenario + return scenario + + def sum( + self, + shift: Union[ + int, + "ExpressionNode", + List["ExpressionNode"], + "ExpressionRange", + None, + ] = None, + eval: Union[ + int, + "ExpressionNode", + List["ExpressionNode"], + "ExpressionRange", + None, + ] = None, + ) -> "Term": + if shift is not None and eval is not None: + raise ValueError("Only shift or eval arguments should specified, not both.") + + # The shift or eval operators applies on the variable, then it will define at which time step the term coefficient * variable will be evaluated + + if shift is not None: + return dataclasses.replace( + self, + time_operator=TimeShift(InstancesTimeIndex(shift)), + time_aggregator=TimeSum(stay_roll=True), + ) + elif eval is not None: + return dataclasses.replace( + self, + time_operator=TimeEvaluation(InstancesTimeIndex(eval)), + time_aggregator=TimeSum(stay_roll=True), + ) + else: # x.sum() -> Sum over all time block + return dataclasses.replace(self, time_aggregator=TimeSum(stay_roll=False)) + + def shift( + self, + expressions: Union[ + int, + "ExpressionNode", + List["ExpressionNode"], + "ExpressionRange", + ], + ) -> "Term": + """ + Shorthand for shift on a single time step + + To refer to x[t-1], it is more natural to write x.shift(-1) than x.sum(shift=-1). + + This function provides the shorthand x.sum(shift=expr), valid only in the case when expr refers to a single time step. + + """ + + # The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a shift operator on a whole expression, rather than just on the variables of an expression + + # Example : (param("p") * var("x")).shift(1) + # Previous behavior : p[t]x[t-1] + # New behavior : p[t-1]x[t-1] + + if not InstancesTimeIndex(expressions).is_simple(): + raise ValueError( + "The shift operator can only be applied on expressions refering to a single time step. To apply a shifting sum on multiple time indices on an expression x, you should use x.sum(shift=...)" + ) + + else: + return self.sum(shift=expressions) + + def eval( + self, + expressions: Union[ + int, + "ExpressionNode", + List["ExpressionNode"], + "ExpressionRange", + ], + ) -> "Term": + """ + Shorthand for eval on a single time step + + To refer to x[1], it is more natural to write x.eval(1) than x.sum(eval=1). + + This function provides the shorthand x.sum(eval=expr), valid only in the case when expr refers to a single time step. + + """ + + # The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a eval operator on a whole expression, rather than just on the variables of an expression + + # Example : (param("p") * var("x")).eval(1) + # Previous behavior : p[t]x[1] + # New behavior : p[1]x[1] + + if not InstancesTimeIndex(expressions).is_simple(): + raise ValueError( + "The eval operator can only be applied on expressions refering to a single time step. To apply a evaluating sum on multiple time indices on an expression x, you should use x.sum(eval=...)" + ) + + else: + return self.sum(eval=expressions) + + def expec(self) -> "Term": + # TODO: Do we need checks, in case a scenario operator is already specified ? + return dataclasses.replace(self, scenario_aggregator=Expectation()) + + +def generate_key(term: Term) -> TermKey: + return TermKey( + term.component_id, + term.variable_name, + term.time_operator, + term.time_aggregator, + term.scenario_aggregator, + ) + + +@dataclass(frozen=True) +class PortFieldId: + port_name: str + field_name: str + + +@dataclass(eq=True, frozen=True) +class PortFieldKey: + """ + Identifies the expression node for one component and one port variable. + """ + + component_id: str + port_variable_id: PortFieldId + + +@dataclass(frozen=True) +class PortFieldTerm: + coefficient: ExpressionNode + port_name: str + field_name: str + aggregator: Optional[PortAggregator] = None + + def __str__(self) -> str: + result = f"{self.port_name}.{self.field_name}" + if self.aggregator is not None: + result += f".{str(self.aggregator)}" + return result + + def sum_connections(self) -> "PortFieldTerm": + if self.aggregator is not None: + raise ValueError(f"Port field {str(self)} already has a port aggregator") + return dataclasses.replace(self, aggregator=PortSum()) + + +T_val = TypeVar("T_val", bound=Union[Term, PortFieldTerm]) + + +def _get_neutral_term(term: T_val, neutral: float) -> T_val: + return dataclasses.replace(term, coefficient=wrap_in_node(neutral)) + + +@overload +def _merge_dicts( + lhs: Dict[TermKey, Term], + rhs: Dict[TermKey, Term], + merge_func: Callable[[Term, Term], Term], + neutral: float, +) -> Dict[TermKey, Term]: + ... + + +@overload +def _merge_dicts( + lhs: Dict[PortFieldId, PortFieldTerm], + rhs: Dict[PortFieldId, PortFieldTerm], + merge_func: Callable[[PortFieldTerm, PortFieldTerm], PortFieldTerm], + neutral: float, +) -> Dict[PortFieldId, PortFieldTerm]: + ... + + +def _merge_dicts(lhs, rhs, merge_func, neutral): + res = {} + for k, v in lhs.items(): + res[k] = merge_func(v, rhs.get(k, _get_neutral_term(v, neutral))) + for k, v in rhs.items(): + if k not in lhs: + res[k] = merge_func(_get_neutral_term(v, neutral), v) + return res + + +def _merge_is_possible(lhs: T_val, rhs: T_val) -> None: + if isinstance(lhs, Term) and isinstance(rhs, Term): + _merge_term_is_possible(lhs, rhs) + elif isinstance(lhs, PortFieldTerm) and isinstance(rhs, PortFieldTerm): + _merge_port_terms_is_possible(lhs, rhs) + else: + raise TypeError("Cannot merge terms of different types") + + +def _merge_term_is_possible(lhs: Term, rhs: Term) -> None: + if lhs.component_id != rhs.component_id or lhs.variable_name != rhs.variable_name: + raise ValueError("Cannot merge terms for different variables") + if ( + lhs.time_operator != rhs.time_operator + or lhs.time_aggregator != rhs.time_aggregator + or lhs.scenario_aggregator != rhs.scenario_aggregator + ): + raise ValueError("Cannot merge terms with different operators") + if lhs.structure != rhs.structure: + raise ValueError("Cannot merge terms with different structures") + + +def _merge_port_terms_is_possible(lhs: PortFieldTerm, rhs: PortFieldTerm) -> None: + if lhs.port_name != rhs.port_name or lhs.field_name != rhs.field_name: + raise ValueError("Cannot merge terms for different ports") + if lhs.aggregator != rhs.aggregator: + raise ValueError("Cannot merge port terms with different aggregators") + + +def _add_terms(lhs: T_val, rhs: T_val) -> T_val: + _merge_is_possible(lhs, rhs) + return dataclasses.replace(lhs, coefficient=lhs.coefficient + rhs.coefficient) + + +def _substract_terms(lhs: T_val, rhs: T_val) -> T_val: + _merge_is_possible(lhs, rhs) + return dataclasses.replace(lhs, coefficient=lhs.coefficient - rhs.coefficient) + + +class LinearExpression: + """ + Represents a linear expression with respect to variable names, for example 10x + 5y + 2. + + Operators may be used for construction. + + Args: + terms: the list of variable terms, for example 10x and 5y in "10x + 5y + 2". + constant: the constant term, for example 2 in "10x + 5y + 2" + + Examples: + Operators may be used for construction: + + >>> LinearExpression([], 10) + LinearExpression([Term + LinearExpression([Term + """ + + terms: Dict[TermKey, Term] + constant: ExpressionNode + port_field_terms: Dict[PortFieldId, PortFieldTerm] + + def __init__( + self, + terms: Optional[Union[Dict[TermKey, Term], List[Term]]] = None, + constant: Optional[Union[float, ExpressionNode]] = None, + port_field_terms: Optional[ + Union[Dict[PortFieldId, PortFieldTerm], List[PortFieldTerm]] + ] = None, + ) -> None: + if constant is None: + self.constant = LiteralNode(0) + else: + self.constant = wrap_in_node(constant) + + self.terms = {} + if terms is not None: + # Allows to give two different syntax in the constructor: + # - List[Term] is natural + # - Dict[str, Term] is useful when constructing a linear expression from the terms of another expression + if isinstance(terms, dict): + for term_key, term in terms.items(): + if not term.is_zero(): + self.terms[term_key] = term + elif isinstance(terms, list): + for term in terms: + if not term.is_zero(): + self.terms[generate_key(term)] = term + else: + raise TypeError( + f"Terms must be either of type Dict[TermKey, Term] or List[Term], whereas {terms} is of type {type(terms)}" + ) + + self.port_field_terms = {} + if port_field_terms is not None: + if isinstance(port_field_terms, dict): + for port_field_term_key, port_field_term in port_field_terms.items(): + self.port_field_terms[port_field_term_key] = port_field_term + elif isinstance(port_field_terms, list): + for port_field_term in port_field_terms: + self.port_field_terms[ + PortFieldId( + port_field_term.port_name, port_field_term.field_name + ) + ] = port_field_term + else: + raise TypeError( + f"Port field terms must be either of type Dict[PortFieldKey, PortFieldTerm] or List[PortFieldTerm], whereas {port_field_terms} is of type {type(port_field_terms)}" + ) + + def is_zero(self) -> bool: + # TODO : Contribution of portfield ? + return len(self.terms) == 0 and is_zero(self.constant) + + def str_for_constant(self) -> str: + if is_zero(self.constant): + return "" + else: + const_str = print_expr(self.constant) + if const_str.startswith("+"): + return f" + {const_str[1:]}" + elif const_str.startswith("-"): + return f" + ({const_str})" + else: + return f" + {print_expr(self.constant)}" + + def __str__(self) -> str: + # Useful for debugging tests + result = "" + if self.is_zero(): + result += "0" + else: + for term in self.terms.values(): + result += str(term) + + result += self.str_for_constant() + + return result + + def __le__(self, rhs: Any) -> "StandaloneConstraint": + return StandaloneConstraint( + expression=self - rhs, + lower_bound=wrap_in_linear_expr(literal(-float("inf"))), + upper_bound=wrap_in_linear_expr(literal(0)), + ) + + def __ge__(self, rhs: Any) -> "StandaloneConstraint": + return StandaloneConstraint( + expression=self - rhs, + lower_bound=wrap_in_linear_expr(literal(0)), + upper_bound=wrap_in_linear_expr(literal(float("inf"))), + ) + + def __eq__(self, rhs: Any) -> "StandaloneConstraint": # type: ignore + return StandaloneConstraint( + expression=self - rhs, + lower_bound=wrap_in_linear_expr(literal(0)), + upper_bound=wrap_in_linear_expr(literal(0)), + ) + + def __iadd__( + self, rhs: Union["LinearExpression", int, float] + ) -> "LinearExpression": + rhs = wrap_in_linear_expr(rhs) + self.constant += rhs.constant + + aggregated_terms = _merge_dicts(self.terms, rhs.terms, _add_terms, 0) + self.terms = aggregated_terms + + aggregated_port_terms = _merge_dicts( + self.port_field_terms, rhs.port_field_terms, _add_terms, 0 + ) + self.port_field_terms = aggregated_port_terms + + self.remove_zeros_from_terms() + return self + + def __add__(self, rhs: Union["LinearExpression", int, float]) -> "LinearExpression": + result = LinearExpression() + result += self + result += rhs + return result + + def __radd__(self, rhs: int) -> "LinearExpression": + return self.__add__(rhs) + + def __isub__( + self, rhs: Union["LinearExpression", int, float] + ) -> "LinearExpression": + rhs = wrap_in_linear_expr(rhs) + self.constant -= rhs.constant + + aggregated_terms = _merge_dicts(self.terms, rhs.terms, _substract_terms, 0) + self.terms = aggregated_terms + + aggregated_port_terms = _merge_dicts( + self.port_field_terms, rhs.port_field_terms, _substract_terms, 0 + ) + self.port_field_terms = aggregated_port_terms + + self.remove_zeros_from_terms() + return self + + def __sub__(self, rhs: Union["LinearExpression", int, float]) -> "LinearExpression": + result = LinearExpression() + result += self + result -= rhs + return result + + def __rsub__(self, rhs: int) -> "LinearExpression": + return -self + rhs + + def __neg__(self) -> "LinearExpression": + result = LinearExpression() + result -= self + return result + + def __imul__( + self, rhs: Union["LinearExpression", int, float] + ) -> "LinearExpression": + rhs = wrap_in_linear_expr(rhs) + + if not (self.is_constant() or rhs.is_constant()): + raise ValueError("Cannot multiply two non constant expression") + else: + if rhs.is_constant(): + left_expr = self + const_expr = rhs + else: # self is constant + left_expr = rhs + const_expr = self + if is_zero(const_expr.constant): + return LinearExpression() + elif is_one(const_expr.constant): + _copy_expression(left_expr, self) + else: + left_expr.constant *= const_expr.constant + for term_key, term in left_expr.terms.items(): + left_expr.terms[term_key] = dataclasses.replace( + term, coefficient=term.coefficient * const_expr.constant + ) + for port_term_key, port_term in left_expr.port_field_terms.items(): + left_expr.port_field_terms[port_term_key] = dataclasses.replace( + port_term, + coefficient=port_term.coefficient * const_expr.constant, + ) + _copy_expression(left_expr, self) + return self + + def __mul__(self, rhs: Union["LinearExpression", int, float]) -> "LinearExpression": + result = LinearExpression() + result += self + result *= rhs + return result + + def __rmul__(self, rhs: int) -> "LinearExpression": + return self.__mul__(rhs) + + def __itruediv__( + self, rhs: Union["LinearExpression", int, float] + ) -> "LinearExpression": + rhs = wrap_in_linear_expr(rhs) + + if not rhs.is_constant(): + raise ValueError("Cannot divide by a non constant expression") + else: + if is_zero(rhs.constant): + raise ZeroDivisionError("Cannot divide expression by zero") + elif is_one(rhs.constant): + return self + else: + self.constant /= rhs.constant + for term_key, term in self.terms.items(): + self.terms[term_key] = dataclasses.replace( + term, coefficient=term.coefficient / rhs.constant + ) + for port_term_key, port_term in self.port_field_terms.items(): + self.port_field_terms[port_term_key] = dataclasses.replace( + port_term, coefficient=port_term.coefficient / rhs.constant + ) + return self + + def __truediv__( + self, rhs: Union["LinearExpression", int, float] + ) -> "LinearExpression": + result = LinearExpression() + result += self + result /= rhs + + return result + + def __rtruediv__(self, rhs: Union[int, float]) -> "LinearExpression": + return self.__truediv__(rhs) + + def remove_zeros_from_terms(self) -> None: + # TODO: Not optimized, checks could be done directly when doing operations on self.linear_term to avoid copies + for term_key, term in self.terms.copy().items(): + if is_zero(term.coefficient): + del self.terms[term_key] + for port_term_key, port_term in self.port_field_terms.copy().items(): + if is_zero(port_term.coefficient): + del self.port_field_terms[port_term_key] + + # Function used only in tests... + def evaluate(self, context: ValueProvider, time_scenario_index: RowIndex) -> float: + return sum( + [ + term.evaluate(context, time_scenario_index) + for term in self.terms.values() + ] + ) + resolve_coefficient(self.constant, context, time_scenario_index) + + def is_constant(self) -> bool: + # Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree... + return not self.terms and not self.port_field_terms + + def is_unbound(self) -> bool: + return is_unbound(self.constant) + + def compute_indexation( + self, provider: IndexingStructureProvider + ) -> IndexingStructure: + """ + Computes the (time, scenario) indexing of a linear expression. + + Time and scenario indexation is driven by the indexation of variables in the expression. If a single term is indexed by time (resp. scenario), then the linear expression is indexed by time (resp. scenario). + """ + indexing = IndexingStructure(False, False) + for term in self.terms.values(): + indexing = indexing | term.compute_indexation(provider) + + return indexing + + def sum( + self, + shift: Union[ + int, + "ExpressionNode", + List["ExpressionNode"], + "ExpressionRange", + None, + ] = None, + eval: Union[ + int, + "ExpressionNode", + List["ExpressionNode"], + "ExpressionRange", + None, + ] = None, + ) -> "LinearExpression": + """ + Examples: + >>> x.sum(shift=[1, 2, 4]) represents x[t+1] + x[t+2] + x[t+4] + + No variables allowed in shift argument, but parameter trees are ok + + It is assumed that the shift operator is linear and distributes to all terms and to the constant of the linear expression on which it is applied. + + Examples: + >>> (param("a") * var("x") + param("b")).sum(shift=[1, 2, 4]) represents a[t+1]x[t+1] + b[t+1] + a[t+2]x[t+2] + b[t+2] + a[t+4]x[t+4] + b[t+4] + """ + + if shift is not None and eval is not None: + raise ValueError("Only shift or eval arguments should specified, not both.") + + if shift is not None: + sum_args = {"shift": shift} + + result_constant = TimeAggregatorNode( + TimeOperatorNode( + self.constant, + TimeOperatorName.SHIFT, + InstancesTimeIndex(shift), + ), + TimeAggregatorName.TIME_SUM, + stay_roll=True, + ) + elif eval is not None: + sum_args = {"eval": eval} + + result_constant = TimeAggregatorNode( + TimeOperatorNode( + self.constant, + TimeOperatorName.EVALUATION, + InstancesTimeIndex(eval), + ), + TimeAggregatorName.TIME_SUM, + stay_roll=True, + ) + else: # x.sum() -> Sum over all time block + sum_args = {} + + result_constant = TimeAggregatorNode( + self.constant, + TimeAggregatorName.TIME_SUM, + stay_roll=False, + ) + + return LinearExpression(self._apply_operator(sum_args), result_constant) + + def _apply_operator( + self, + sum_args: Mapping[ + str, + Union[ + int, + "ExpressionNode", + List["ExpressionNode"], + "ExpressionRange", + None, + ], + ], + ) -> Dict[TermKey, Term]: + result_terms = {} + for term in self.terms.values(): + term_with_operator = term.sum(**sum_args) + result_terms[generate_key(term_with_operator)] = term_with_operator + + return result_terms + + def shift( + self, + expressions: Union[ + int, + "ExpressionNode", + List["ExpressionNode"], + "ExpressionRange", + ], + ) -> "LinearExpression": + """ + Shorthand for shift on a single time step + + To refer to x[t-1], it is more natural to write x.shift(-1) than x.sum(shift=-1). + + This function provides the shorthand x.sum(shift=expr), valid only in the case when expr refers to a single time step. + + """ + + if not InstancesTimeIndex(expressions).is_simple(): + raise ValueError( + "The shift operator can only be applied on expressions refering to a single time step. To apply a shifting sum on multiple time indices on an expression x, you should use x.sum(shift=...)" + ) + + else: + return self.sum(shift=expressions) + + def eval( + self, + expressions: Union[ + int, + "ExpressionNode", + List["ExpressionNode"], + "ExpressionRange", + ], + ) -> "LinearExpression": + """ + Shorthand for eval on a single time step + + To refer to x[1], it is more natural to write x.eval(1) than x.sum(eval=1). + + This function provides the shorthand x.sum(eval=expr), valid only in the case when expr refers to a single time step. + + """ + + if not InstancesTimeIndex(expressions).is_simple(): + raise ValueError( + "The eval operator can only be applied on expressions refering to a single time step. To apply a evaluation sum on multiple time indices on an expression x, you should use x.sum(eval=...)" + ) + + else: + return self.sum(eval=expressions) + + def expec(self) -> "LinearExpression": + """ + Expectation of linear expression. As the operator is linear, it distributes over all terms and the constant + """ + + result_terms = {} + for term in self.terms.values(): + term_with_operator = term.expec() + result_terms[generate_key(term_with_operator)] = term_with_operator + + result_constant = ScenarioOperatorNode( + self.constant, ScenarioOperatorName.EXPECTATION + ) + result_expr = LinearExpression(result_terms, result_constant) + return result_expr + + def sum_connections(self) -> "LinearExpression": + if not self.is_zero(): + raise ValueError( + "sum_connections only after an expression created with port_field" + ) + port_field_terms = {} + for port_field_key, port_field_value in self.port_field_terms.items(): + port_field_terms[port_field_key] = port_field_value.sum_connections() + return LinearExpression(port_field_terms=port_field_terms) + + def resolve_port( + self, + component_id: str, + ports_expressions: Dict[PortFieldKey, List["LinearExpression"]], + ) -> "LinearExpression": + port_expr = LinearExpression() + for port_term in self.port_field_terms.values(): + expressions = ports_expressions.get( + PortFieldKey( + component_id, + PortFieldId(port_term.port_name, port_term.field_name), + ), + [], + ) + if port_term.aggregator is None: + if len(expressions) != 1: + raise ValueError( + f"Invalid number of expression for port : {port_term.port_name}" + ) + else: + if port_term.aggregator != PortSum(): + raise NotImplementedError("Only PortSum is supported.") + + port_expr += sum_expressions( + [port_term.coefficient * expression for expression in expressions] + ) + self_without_ports = LinearExpression(self.terms, self.constant) + return self_without_ports + port_expr + + def add_component_context(self, component_id: str) -> "LinearExpression": + result_terms = {} + for term in self.terms.values(): + # Some terms may involve variable from other component if they arise from previous port resolution + result_term = dataclasses.replace( + term, + component_id=term.component_id if term.component_id else component_id, + coefficient=add_component_context(component_id, term.coefficient), + time_operator=( + dataclasses.replace( + term.time_operator, + time_ids=_add_component_context_to_instances_index( + component_id, term.time_operator.time_ids + ), + ) + if term.time_operator + else None + ), + ) + result_terms[generate_key(result_term)] = result_term + result_constant = add_component_context(component_id, self.constant) + return LinearExpression(result_terms, result_constant, self.port_field_terms) + + +def _add_component_context_to_expression_range( + component_id: str, expression_range: ExpressionRange +) -> ExpressionRange: + return ExpressionRange( + start=add_component_context(component_id, expression_range.start), + stop=add_component_context(component_id, expression_range.stop), + step=( + add_component_context(component_id, expression_range.step) + if expression_range.step is not None + else None + ), + ) + + +def _add_component_context_to_instances_index( + component_id: str, instances_index: InstancesTimeIndex +) -> InstancesTimeIndex: + expressions = instances_index.expressions + if isinstance(expressions, ExpressionRange): + return InstancesTimeIndex( + _add_component_context_to_expression_range(component_id, expressions) + ) + if isinstance(expressions, list): + expressions_list = cast(List[ExpressionNode], expressions) + copy = [add_component_context(component_id, e) for e in expressions_list] + return InstancesTimeIndex(copy) + raise ValueError("Unexpected type in instances index") + + +def linear_expressions_equal(lhs: LinearExpression, rhs: LinearExpression) -> bool: + return ( + isinstance(lhs, LinearExpression) + and isinstance(rhs, LinearExpression) + and expressions_equal(lhs.constant, rhs.constant) + and lhs.terms == rhs.terms + ) + + +def linear_expressions_equal_if_present( + lhs: Optional[LinearExpression], rhs: Optional[LinearExpression] +) -> bool: + if lhs is None and rhs is None: + return True + elif lhs is None or rhs is None: + return False + else: + return linear_expressions_equal(lhs, rhs) + + +# TODO: Is this function useful ? Could we just rely on the sum operator overloading ? Only the case with an empty list may make the function useful +def sum_expressions( + expressions: Sequence[LinearExpression], +) -> Union[LinearExpression, Literal[0]]: + if len(expressions) == 0: + return wrap_in_linear_expr(literal(0)) + else: + return sum(expressions) + + +@dataclass +class StandaloneConstraint: + """ + A standalone constraint, with rigid initialization. + """ + + expression: LinearExpression + lower_bound: LinearExpression + upper_bound: LinearExpression + + def __post_init__( + self, + ) -> None: + for bound in [self.lower_bound, self.upper_bound]: + if not bound.is_constant(): + raise ValueError( + f"The bounds of a constraint should not contain variables, {str(bound)} was given." + ) + + def __str__(self) -> str: + return f"{str(self.lower_bound)} <= {str(self.expression)} <= {str(self.upper_bound)}" + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, StandaloneConstraint) + and linear_expressions_equal_if_present(self.expression, other.expression) + and linear_expressions_equal_if_present(self.lower_bound, other.lower_bound) + and linear_expressions_equal_if_present(self.upper_bound, other.upper_bound) + ) + + +def wrap_in_linear_expr(obj: Any) -> LinearExpression: + if isinstance(obj, LinearExpression): + return obj + elif isinstance(obj, float) or isinstance(obj, int): + return LinearExpression([], LiteralNode(float(obj))) + elif isinstance(obj, ExpressionNode): + return LinearExpression([], obj) + raise TypeError(f"Unable to wrap {obj} into a linear expression") + + +def wrap_in_linear_expr_if_present(obj: Any) -> Union[None, LinearExpression]: + if obj is None: + return None + else: + return wrap_in_linear_expr(obj) + + +def _copy_expression(src: LinearExpression, dst: LinearExpression) -> None: + dst.terms = src.terms + dst.constant = src.constant + + +def var(name: str) -> LinearExpression: + # TODO: At term build time, no information on the variable structure is known, we use a default time, scenario varying, maybe discard structure as term attribute ? + return LinearExpression( + [Term(coefficient=LiteralNode(1), component_id="", variable_name=name)], + LiteralNode(0), + ) + + +def comp_var(component_id: str, name: str) -> LinearExpression: + return LinearExpression( + [ + Term( + coefficient=LiteralNode(1), + component_id=component_id, + variable_name=name, + ) + ], + LiteralNode(0), + ) + + +def port_field(port_name: str, field_name: str) -> LinearExpression: + return LinearExpression( + port_field_terms=[PortFieldTerm(literal(1), port_name, field_name)] + ) + + +def is_linear(expr: LinearExpression) -> bool: + return True diff --git a/src/andromede/expression/parsing/antlr/Expr.interp b/src/andromede/expression/parsing/antlr/Expr.interp index bf05ae28..436305a4 100644 --- a/src/andromede/expression/parsing/antlr/Expr.interp +++ b/src/andromede/expression/parsing/antlr/Expr.interp @@ -8,13 +8,14 @@ null '*' '+' '[' -',' ']' +',' '..' null 't' null null +'sum' null token symbolic names: @@ -34,16 +35,19 @@ NUMBER TIME IDENTIFIER COMPARISON +TIME_SUM WS rule names: fullexpr expr atom +timeShiftRange +timeRange shift shift_expr right_expr atn: -[4, 1, 16, 131, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 37, 8, 1, 10, 1, 12, 1, 40, 9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 49, 8, 1, 10, 1, 12, 1, 52, 9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 70, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 81, 8, 1, 10, 1, 12, 1, 84, 9, 1, 1, 2, 1, 2, 3, 2, 88, 8, 2, 1, 3, 1, 3, 3, 3, 92, 8, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 3, 4, 102, 8, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 5, 4, 110, 8, 4, 10, 4, 12, 4, 113, 9, 4, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 3, 5, 121, 8, 5, 1, 5, 1, 5, 1, 5, 5, 5, 126, 8, 5, 10, 5, 12, 5, 129, 9, 5, 1, 5, 0, 3, 2, 8, 10, 6, 0, 2, 4, 6, 8, 10, 0, 2, 1, 0, 5, 6, 2, 0, 2, 2, 7, 7, 144, 0, 12, 1, 0, 0, 0, 2, 69, 1, 0, 0, 0, 4, 87, 1, 0, 0, 0, 6, 89, 1, 0, 0, 0, 8, 101, 1, 0, 0, 0, 10, 120, 1, 0, 0, 0, 12, 13, 3, 2, 1, 0, 13, 14, 5, 0, 0, 1, 14, 1, 1, 0, 0, 0, 15, 16, 6, 1, -1, 0, 16, 70, 3, 4, 2, 0, 17, 18, 5, 14, 0, 0, 18, 19, 5, 1, 0, 0, 19, 70, 5, 14, 0, 0, 20, 21, 5, 2, 0, 0, 21, 70, 3, 2, 1, 10, 22, 23, 5, 3, 0, 0, 23, 24, 3, 2, 1, 0, 24, 25, 5, 4, 0, 0, 25, 70, 1, 0, 0, 0, 26, 27, 5, 14, 0, 0, 27, 28, 5, 3, 0, 0, 28, 29, 3, 2, 1, 0, 29, 30, 5, 4, 0, 0, 30, 70, 1, 0, 0, 0, 31, 32, 5, 14, 0, 0, 32, 33, 5, 8, 0, 0, 33, 38, 3, 6, 3, 0, 34, 35, 5, 9, 0, 0, 35, 37, 3, 6, 3, 0, 36, 34, 1, 0, 0, 0, 37, 40, 1, 0, 0, 0, 38, 36, 1, 0, 0, 0, 38, 39, 1, 0, 0, 0, 39, 41, 1, 0, 0, 0, 40, 38, 1, 0, 0, 0, 41, 42, 5, 10, 0, 0, 42, 70, 1, 0, 0, 0, 43, 44, 5, 14, 0, 0, 44, 45, 5, 8, 0, 0, 45, 50, 3, 2, 1, 0, 46, 47, 5, 9, 0, 0, 47, 49, 3, 2, 1, 0, 48, 46, 1, 0, 0, 0, 49, 52, 1, 0, 0, 0, 50, 48, 1, 0, 0, 0, 50, 51, 1, 0, 0, 0, 51, 53, 1, 0, 0, 0, 52, 50, 1, 0, 0, 0, 53, 54, 5, 10, 0, 0, 54, 70, 1, 0, 0, 0, 55, 56, 5, 14, 0, 0, 56, 57, 5, 8, 0, 0, 57, 58, 3, 6, 3, 0, 58, 59, 5, 11, 0, 0, 59, 60, 3, 6, 3, 0, 60, 61, 5, 10, 0, 0, 61, 70, 1, 0, 0, 0, 62, 63, 5, 14, 0, 0, 63, 64, 5, 8, 0, 0, 64, 65, 3, 2, 1, 0, 65, 66, 5, 11, 0, 0, 66, 67, 3, 2, 1, 0, 67, 68, 5, 10, 0, 0, 68, 70, 1, 0, 0, 0, 69, 15, 1, 0, 0, 0, 69, 17, 1, 0, 0, 0, 69, 20, 1, 0, 0, 0, 69, 22, 1, 0, 0, 0, 69, 26, 1, 0, 0, 0, 69, 31, 1, 0, 0, 0, 69, 43, 1, 0, 0, 0, 69, 55, 1, 0, 0, 0, 69, 62, 1, 0, 0, 0, 70, 82, 1, 0, 0, 0, 71, 72, 10, 8, 0, 0, 72, 73, 7, 0, 0, 0, 73, 81, 3, 2, 1, 9, 74, 75, 10, 7, 0, 0, 75, 76, 7, 1, 0, 0, 76, 81, 3, 2, 1, 8, 77, 78, 10, 6, 0, 0, 78, 79, 5, 15, 0, 0, 79, 81, 3, 2, 1, 7, 80, 71, 1, 0, 0, 0, 80, 74, 1, 0, 0, 0, 80, 77, 1, 0, 0, 0, 81, 84, 1, 0, 0, 0, 82, 80, 1, 0, 0, 0, 82, 83, 1, 0, 0, 0, 83, 3, 1, 0, 0, 0, 84, 82, 1, 0, 0, 0, 85, 88, 5, 12, 0, 0, 86, 88, 5, 14, 0, 0, 87, 85, 1, 0, 0, 0, 87, 86, 1, 0, 0, 0, 88, 5, 1, 0, 0, 0, 89, 91, 5, 13, 0, 0, 90, 92, 3, 8, 4, 0, 91, 90, 1, 0, 0, 0, 91, 92, 1, 0, 0, 0, 92, 7, 1, 0, 0, 0, 93, 94, 6, 4, -1, 0, 94, 95, 7, 1, 0, 0, 95, 102, 3, 4, 2, 0, 96, 97, 7, 1, 0, 0, 97, 98, 5, 3, 0, 0, 98, 99, 3, 2, 1, 0, 99, 100, 5, 4, 0, 0, 100, 102, 1, 0, 0, 0, 101, 93, 1, 0, 0, 0, 101, 96, 1, 0, 0, 0, 102, 111, 1, 0, 0, 0, 103, 104, 10, 4, 0, 0, 104, 105, 7, 0, 0, 0, 105, 110, 3, 10, 5, 0, 106, 107, 10, 3, 0, 0, 107, 108, 7, 1, 0, 0, 108, 110, 3, 10, 5, 0, 109, 103, 1, 0, 0, 0, 109, 106, 1, 0, 0, 0, 110, 113, 1, 0, 0, 0, 111, 109, 1, 0, 0, 0, 111, 112, 1, 0, 0, 0, 112, 9, 1, 0, 0, 0, 113, 111, 1, 0, 0, 0, 114, 115, 6, 5, -1, 0, 115, 116, 5, 3, 0, 0, 116, 117, 3, 2, 1, 0, 117, 118, 5, 4, 0, 0, 118, 121, 1, 0, 0, 0, 119, 121, 3, 4, 2, 0, 120, 114, 1, 0, 0, 0, 120, 119, 1, 0, 0, 0, 121, 127, 1, 0, 0, 0, 122, 123, 10, 3, 0, 0, 123, 124, 7, 0, 0, 0, 124, 126, 3, 10, 5, 4, 125, 122, 1, 0, 0, 0, 126, 129, 1, 0, 0, 0, 127, 125, 1, 0, 0, 0, 127, 128, 1, 0, 0, 0, 128, 11, 1, 0, 0, 0, 129, 127, 1, 0, 0, 0, 12, 38, 50, 69, 80, 82, 87, 91, 101, 109, 111, 120, 127] \ No newline at end of file +[4, 1, 17, 129, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 52, 8, 1, 1, 1, 1, 1, 3, 1, 56, 8, 1, 1, 1, 1, 1, 3, 1, 60, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 71, 8, 1, 10, 1, 12, 1, 74, 9, 1, 1, 2, 1, 2, 3, 2, 78, 8, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 5, 1, 5, 3, 5, 90, 8, 5, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 3, 6, 100, 8, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 5, 6, 108, 8, 6, 10, 6, 12, 6, 111, 9, 6, 1, 7, 1, 7, 1, 7, 1, 7, 1, 7, 1, 7, 3, 7, 119, 8, 7, 1, 7, 1, 7, 1, 7, 5, 7, 124, 8, 7, 10, 7, 12, 7, 127, 9, 7, 1, 7, 0, 3, 2, 12, 14, 8, 0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 1, 0, 5, 6, 2, 0, 2, 2, 7, 7, 141, 0, 16, 1, 0, 0, 0, 2, 59, 1, 0, 0, 0, 4, 77, 1, 0, 0, 0, 6, 79, 1, 0, 0, 0, 8, 83, 1, 0, 0, 0, 10, 87, 1, 0, 0, 0, 12, 99, 1, 0, 0, 0, 14, 118, 1, 0, 0, 0, 16, 17, 3, 2, 1, 0, 17, 18, 5, 0, 0, 1, 18, 1, 1, 0, 0, 0, 19, 20, 6, 1, -1, 0, 20, 60, 3, 4, 2, 0, 21, 22, 5, 14, 0, 0, 22, 23, 5, 1, 0, 0, 23, 60, 5, 14, 0, 0, 24, 25, 5, 2, 0, 0, 25, 60, 3, 2, 1, 9, 26, 27, 5, 3, 0, 0, 27, 28, 3, 2, 1, 0, 28, 29, 5, 4, 0, 0, 29, 60, 1, 0, 0, 0, 30, 31, 5, 14, 0, 0, 31, 32, 5, 3, 0, 0, 32, 33, 3, 2, 1, 0, 33, 34, 5, 4, 0, 0, 34, 60, 1, 0, 0, 0, 35, 36, 5, 14, 0, 0, 36, 37, 5, 8, 0, 0, 37, 38, 3, 10, 5, 0, 38, 39, 5, 9, 0, 0, 39, 60, 1, 0, 0, 0, 40, 41, 5, 14, 0, 0, 41, 42, 5, 8, 0, 0, 42, 43, 3, 2, 1, 0, 43, 44, 5, 9, 0, 0, 44, 60, 1, 0, 0, 0, 45, 46, 5, 16, 0, 0, 46, 55, 5, 3, 0, 0, 47, 52, 3, 2, 1, 0, 48, 52, 3, 10, 5, 0, 49, 52, 3, 6, 3, 0, 50, 52, 3, 8, 4, 0, 51, 47, 1, 0, 0, 0, 51, 48, 1, 0, 0, 0, 51, 49, 1, 0, 0, 0, 51, 50, 1, 0, 0, 0, 52, 53, 1, 0, 0, 0, 53, 54, 5, 10, 0, 0, 54, 56, 1, 0, 0, 0, 55, 51, 1, 0, 0, 0, 55, 56, 1, 0, 0, 0, 56, 57, 1, 0, 0, 0, 57, 58, 5, 14, 0, 0, 58, 60, 5, 4, 0, 0, 59, 19, 1, 0, 0, 0, 59, 21, 1, 0, 0, 0, 59, 24, 1, 0, 0, 0, 59, 26, 1, 0, 0, 0, 59, 30, 1, 0, 0, 0, 59, 35, 1, 0, 0, 0, 59, 40, 1, 0, 0, 0, 59, 45, 1, 0, 0, 0, 60, 72, 1, 0, 0, 0, 61, 62, 10, 7, 0, 0, 62, 63, 7, 0, 0, 0, 63, 71, 3, 2, 1, 8, 64, 65, 10, 6, 0, 0, 65, 66, 7, 1, 0, 0, 66, 71, 3, 2, 1, 7, 67, 68, 10, 5, 0, 0, 68, 69, 5, 15, 0, 0, 69, 71, 3, 2, 1, 6, 70, 61, 1, 0, 0, 0, 70, 64, 1, 0, 0, 0, 70, 67, 1, 0, 0, 0, 71, 74, 1, 0, 0, 0, 72, 70, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 3, 1, 0, 0, 0, 74, 72, 1, 0, 0, 0, 75, 78, 5, 12, 0, 0, 76, 78, 5, 14, 0, 0, 77, 75, 1, 0, 0, 0, 77, 76, 1, 0, 0, 0, 78, 5, 1, 0, 0, 0, 79, 80, 3, 10, 5, 0, 80, 81, 5, 11, 0, 0, 81, 82, 3, 10, 5, 0, 82, 7, 1, 0, 0, 0, 83, 84, 3, 2, 1, 0, 84, 85, 5, 11, 0, 0, 85, 86, 3, 2, 1, 0, 86, 9, 1, 0, 0, 0, 87, 89, 5, 13, 0, 0, 88, 90, 3, 12, 6, 0, 89, 88, 1, 0, 0, 0, 89, 90, 1, 0, 0, 0, 90, 11, 1, 0, 0, 0, 91, 92, 6, 6, -1, 0, 92, 93, 7, 1, 0, 0, 93, 100, 3, 4, 2, 0, 94, 95, 7, 1, 0, 0, 95, 96, 5, 3, 0, 0, 96, 97, 3, 2, 1, 0, 97, 98, 5, 4, 0, 0, 98, 100, 1, 0, 0, 0, 99, 91, 1, 0, 0, 0, 99, 94, 1, 0, 0, 0, 100, 109, 1, 0, 0, 0, 101, 102, 10, 4, 0, 0, 102, 103, 7, 0, 0, 0, 103, 108, 3, 14, 7, 0, 104, 105, 10, 3, 0, 0, 105, 106, 7, 1, 0, 0, 106, 108, 3, 14, 7, 0, 107, 101, 1, 0, 0, 0, 107, 104, 1, 0, 0, 0, 108, 111, 1, 0, 0, 0, 109, 107, 1, 0, 0, 0, 109, 110, 1, 0, 0, 0, 110, 13, 1, 0, 0, 0, 111, 109, 1, 0, 0, 0, 112, 113, 6, 7, -1, 0, 113, 114, 5, 3, 0, 0, 114, 115, 3, 2, 1, 0, 115, 116, 5, 4, 0, 0, 116, 119, 1, 0, 0, 0, 117, 119, 3, 4, 2, 0, 118, 112, 1, 0, 0, 0, 118, 117, 1, 0, 0, 0, 119, 125, 1, 0, 0, 0, 120, 121, 10, 3, 0, 0, 121, 122, 7, 0, 0, 0, 122, 124, 3, 14, 7, 4, 123, 120, 1, 0, 0, 0, 124, 127, 1, 0, 0, 0, 125, 123, 1, 0, 0, 0, 125, 126, 1, 0, 0, 0, 126, 15, 1, 0, 0, 0, 127, 125, 1, 0, 0, 0, 12, 51, 55, 59, 70, 72, 77, 89, 99, 107, 109, 118, 125] \ No newline at end of file diff --git a/src/andromede/expression/parsing/antlr/Expr.tokens b/src/andromede/expression/parsing/antlr/Expr.tokens index 9401c83a..1c48b7d6 100644 --- a/src/andromede/expression/parsing/antlr/Expr.tokens +++ b/src/andromede/expression/parsing/antlr/Expr.tokens @@ -13,7 +13,8 @@ NUMBER=12 TIME=13 IDENTIFIER=14 COMPARISON=15 -WS=16 +TIME_SUM=16 +WS=17 '.'=1 '-'=2 '('=3 @@ -22,7 +23,8 @@ WS=16 '*'=6 '+'=7 '['=8 -','=9 -']'=10 +']'=9 +','=10 '..'=11 't'=13 +'sum'=16 diff --git a/src/andromede/expression/parsing/antlr/ExprLexer.interp b/src/andromede/expression/parsing/antlr/ExprLexer.interp index 2e85e1b7..c29ef905 100644 --- a/src/andromede/expression/parsing/antlr/ExprLexer.interp +++ b/src/andromede/expression/parsing/antlr/ExprLexer.interp @@ -8,13 +8,14 @@ null '*' '+' '[' -',' ']' +',' '..' null 't' null null +'sum' null token symbolic names: @@ -34,6 +35,7 @@ NUMBER TIME IDENTIFIER COMPARISON +TIME_SUM WS rule names: @@ -55,6 +57,7 @@ NUMBER TIME IDENTIFIER COMPARISON +TIME_SUM WS channel names: @@ -65,4 +68,4 @@ mode names: DEFAULT_MODE atn: -[4, 0, 16, 103, 6, -1, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2, 10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 2, 15, 7, 15, 2, 16, 7, 16, 2, 17, 7, 17, 2, 18, 7, 18, 1, 0, 1, 0, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 5, 1, 5, 1, 6, 1, 6, 1, 7, 1, 7, 1, 8, 1, 8, 1, 9, 1, 9, 1, 10, 1, 10, 1, 10, 1, 11, 1, 11, 1, 12, 1, 12, 1, 13, 1, 13, 3, 13, 69, 8, 13, 1, 14, 4, 14, 72, 8, 14, 11, 14, 12, 14, 73, 1, 14, 1, 14, 4, 14, 78, 8, 14, 11, 14, 12, 14, 79, 3, 14, 82, 8, 14, 1, 15, 1, 15, 1, 16, 1, 16, 5, 16, 88, 8, 16, 10, 16, 12, 16, 91, 9, 16, 1, 17, 1, 17, 1, 17, 1, 17, 1, 17, 3, 17, 98, 8, 17, 1, 18, 1, 18, 1, 18, 1, 18, 0, 0, 19, 1, 1, 3, 2, 5, 3, 7, 4, 9, 5, 11, 6, 13, 7, 15, 8, 17, 9, 19, 10, 21, 11, 23, 0, 25, 0, 27, 0, 29, 12, 31, 13, 33, 14, 35, 15, 37, 16, 1, 0, 3, 1, 0, 48, 57, 3, 0, 65, 90, 95, 95, 97, 122, 3, 0, 9, 10, 13, 13, 32, 32, 106, 0, 1, 1, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0, 5, 1, 0, 0, 0, 0, 7, 1, 0, 0, 0, 0, 9, 1, 0, 0, 0, 0, 11, 1, 0, 0, 0, 0, 13, 1, 0, 0, 0, 0, 15, 1, 0, 0, 0, 0, 17, 1, 0, 0, 0, 0, 19, 1, 0, 0, 0, 0, 21, 1, 0, 0, 0, 0, 29, 1, 0, 0, 0, 0, 31, 1, 0, 0, 0, 0, 33, 1, 0, 0, 0, 0, 35, 1, 0, 0, 0, 0, 37, 1, 0, 0, 0, 1, 39, 1, 0, 0, 0, 3, 41, 1, 0, 0, 0, 5, 43, 1, 0, 0, 0, 7, 45, 1, 0, 0, 0, 9, 47, 1, 0, 0, 0, 11, 49, 1, 0, 0, 0, 13, 51, 1, 0, 0, 0, 15, 53, 1, 0, 0, 0, 17, 55, 1, 0, 0, 0, 19, 57, 1, 0, 0, 0, 21, 59, 1, 0, 0, 0, 23, 62, 1, 0, 0, 0, 25, 64, 1, 0, 0, 0, 27, 68, 1, 0, 0, 0, 29, 71, 1, 0, 0, 0, 31, 83, 1, 0, 0, 0, 33, 85, 1, 0, 0, 0, 35, 97, 1, 0, 0, 0, 37, 99, 1, 0, 0, 0, 39, 40, 5, 46, 0, 0, 40, 2, 1, 0, 0, 0, 41, 42, 5, 45, 0, 0, 42, 4, 1, 0, 0, 0, 43, 44, 5, 40, 0, 0, 44, 6, 1, 0, 0, 0, 45, 46, 5, 41, 0, 0, 46, 8, 1, 0, 0, 0, 47, 48, 5, 47, 0, 0, 48, 10, 1, 0, 0, 0, 49, 50, 5, 42, 0, 0, 50, 12, 1, 0, 0, 0, 51, 52, 5, 43, 0, 0, 52, 14, 1, 0, 0, 0, 53, 54, 5, 91, 0, 0, 54, 16, 1, 0, 0, 0, 55, 56, 5, 44, 0, 0, 56, 18, 1, 0, 0, 0, 57, 58, 5, 93, 0, 0, 58, 20, 1, 0, 0, 0, 59, 60, 5, 46, 0, 0, 60, 61, 5, 46, 0, 0, 61, 22, 1, 0, 0, 0, 62, 63, 7, 0, 0, 0, 63, 24, 1, 0, 0, 0, 64, 65, 7, 1, 0, 0, 65, 26, 1, 0, 0, 0, 66, 69, 3, 25, 12, 0, 67, 69, 3, 23, 11, 0, 68, 66, 1, 0, 0, 0, 68, 67, 1, 0, 0, 0, 69, 28, 1, 0, 0, 0, 70, 72, 3, 23, 11, 0, 71, 70, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 71, 1, 0, 0, 0, 73, 74, 1, 0, 0, 0, 74, 81, 1, 0, 0, 0, 75, 77, 5, 46, 0, 0, 76, 78, 3, 23, 11, 0, 77, 76, 1, 0, 0, 0, 78, 79, 1, 0, 0, 0, 79, 77, 1, 0, 0, 0, 79, 80, 1, 0, 0, 0, 80, 82, 1, 0, 0, 0, 81, 75, 1, 0, 0, 0, 81, 82, 1, 0, 0, 0, 82, 30, 1, 0, 0, 0, 83, 84, 5, 116, 0, 0, 84, 32, 1, 0, 0, 0, 85, 89, 3, 25, 12, 0, 86, 88, 3, 27, 13, 0, 87, 86, 1, 0, 0, 0, 88, 91, 1, 0, 0, 0, 89, 87, 1, 0, 0, 0, 89, 90, 1, 0, 0, 0, 90, 34, 1, 0, 0, 0, 91, 89, 1, 0, 0, 0, 92, 98, 5, 61, 0, 0, 93, 94, 5, 62, 0, 0, 94, 98, 5, 61, 0, 0, 95, 96, 5, 60, 0, 0, 96, 98, 5, 61, 0, 0, 97, 92, 1, 0, 0, 0, 97, 93, 1, 0, 0, 0, 97, 95, 1, 0, 0, 0, 98, 36, 1, 0, 0, 0, 99, 100, 7, 2, 0, 0, 100, 101, 1, 0, 0, 0, 101, 102, 6, 18, 0, 0, 102, 38, 1, 0, 0, 0, 7, 0, 68, 73, 79, 81, 89, 97, 1, 6, 0, 0] \ No newline at end of file +[4, 0, 17, 109, 6, -1, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2, 10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 2, 15, 7, 15, 2, 16, 7, 16, 2, 17, 7, 17, 2, 18, 7, 18, 2, 19, 7, 19, 1, 0, 1, 0, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 5, 1, 5, 1, 6, 1, 6, 1, 7, 1, 7, 1, 8, 1, 8, 1, 9, 1, 9, 1, 10, 1, 10, 1, 10, 1, 11, 1, 11, 1, 12, 1, 12, 1, 13, 1, 13, 3, 13, 71, 8, 13, 1, 14, 4, 14, 74, 8, 14, 11, 14, 12, 14, 75, 1, 14, 1, 14, 4, 14, 80, 8, 14, 11, 14, 12, 14, 81, 3, 14, 84, 8, 14, 1, 15, 1, 15, 1, 16, 1, 16, 5, 16, 90, 8, 16, 10, 16, 12, 16, 93, 9, 16, 1, 17, 1, 17, 1, 17, 1, 17, 1, 17, 3, 17, 100, 8, 17, 1, 18, 1, 18, 1, 18, 1, 18, 1, 19, 1, 19, 1, 19, 1, 19, 0, 0, 20, 1, 1, 3, 2, 5, 3, 7, 4, 9, 5, 11, 6, 13, 7, 15, 8, 17, 9, 19, 10, 21, 11, 23, 0, 25, 0, 27, 0, 29, 12, 31, 13, 33, 14, 35, 15, 37, 16, 39, 17, 1, 0, 3, 1, 0, 48, 57, 3, 0, 65, 90, 95, 95, 97, 122, 3, 0, 9, 10, 13, 13, 32, 32, 112, 0, 1, 1, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0, 5, 1, 0, 0, 0, 0, 7, 1, 0, 0, 0, 0, 9, 1, 0, 0, 0, 0, 11, 1, 0, 0, 0, 0, 13, 1, 0, 0, 0, 0, 15, 1, 0, 0, 0, 0, 17, 1, 0, 0, 0, 0, 19, 1, 0, 0, 0, 0, 21, 1, 0, 0, 0, 0, 29, 1, 0, 0, 0, 0, 31, 1, 0, 0, 0, 0, 33, 1, 0, 0, 0, 0, 35, 1, 0, 0, 0, 0, 37, 1, 0, 0, 0, 0, 39, 1, 0, 0, 0, 1, 41, 1, 0, 0, 0, 3, 43, 1, 0, 0, 0, 5, 45, 1, 0, 0, 0, 7, 47, 1, 0, 0, 0, 9, 49, 1, 0, 0, 0, 11, 51, 1, 0, 0, 0, 13, 53, 1, 0, 0, 0, 15, 55, 1, 0, 0, 0, 17, 57, 1, 0, 0, 0, 19, 59, 1, 0, 0, 0, 21, 61, 1, 0, 0, 0, 23, 64, 1, 0, 0, 0, 25, 66, 1, 0, 0, 0, 27, 70, 1, 0, 0, 0, 29, 73, 1, 0, 0, 0, 31, 85, 1, 0, 0, 0, 33, 87, 1, 0, 0, 0, 35, 99, 1, 0, 0, 0, 37, 101, 1, 0, 0, 0, 39, 105, 1, 0, 0, 0, 41, 42, 5, 46, 0, 0, 42, 2, 1, 0, 0, 0, 43, 44, 5, 45, 0, 0, 44, 4, 1, 0, 0, 0, 45, 46, 5, 40, 0, 0, 46, 6, 1, 0, 0, 0, 47, 48, 5, 41, 0, 0, 48, 8, 1, 0, 0, 0, 49, 50, 5, 47, 0, 0, 50, 10, 1, 0, 0, 0, 51, 52, 5, 42, 0, 0, 52, 12, 1, 0, 0, 0, 53, 54, 5, 43, 0, 0, 54, 14, 1, 0, 0, 0, 55, 56, 5, 91, 0, 0, 56, 16, 1, 0, 0, 0, 57, 58, 5, 93, 0, 0, 58, 18, 1, 0, 0, 0, 59, 60, 5, 44, 0, 0, 60, 20, 1, 0, 0, 0, 61, 62, 5, 46, 0, 0, 62, 63, 5, 46, 0, 0, 63, 22, 1, 0, 0, 0, 64, 65, 7, 0, 0, 0, 65, 24, 1, 0, 0, 0, 66, 67, 7, 1, 0, 0, 67, 26, 1, 0, 0, 0, 68, 71, 3, 25, 12, 0, 69, 71, 3, 23, 11, 0, 70, 68, 1, 0, 0, 0, 70, 69, 1, 0, 0, 0, 71, 28, 1, 0, 0, 0, 72, 74, 3, 23, 11, 0, 73, 72, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 73, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 83, 1, 0, 0, 0, 77, 79, 5, 46, 0, 0, 78, 80, 3, 23, 11, 0, 79, 78, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 79, 1, 0, 0, 0, 81, 82, 1, 0, 0, 0, 82, 84, 1, 0, 0, 0, 83, 77, 1, 0, 0, 0, 83, 84, 1, 0, 0, 0, 84, 30, 1, 0, 0, 0, 85, 86, 5, 116, 0, 0, 86, 32, 1, 0, 0, 0, 87, 91, 3, 25, 12, 0, 88, 90, 3, 27, 13, 0, 89, 88, 1, 0, 0, 0, 90, 93, 1, 0, 0, 0, 91, 89, 1, 0, 0, 0, 91, 92, 1, 0, 0, 0, 92, 34, 1, 0, 0, 0, 93, 91, 1, 0, 0, 0, 94, 100, 5, 61, 0, 0, 95, 96, 5, 62, 0, 0, 96, 100, 5, 61, 0, 0, 97, 98, 5, 60, 0, 0, 98, 100, 5, 61, 0, 0, 99, 94, 1, 0, 0, 0, 99, 95, 1, 0, 0, 0, 99, 97, 1, 0, 0, 0, 100, 36, 1, 0, 0, 0, 101, 102, 5, 115, 0, 0, 102, 103, 5, 117, 0, 0, 103, 104, 5, 109, 0, 0, 104, 38, 1, 0, 0, 0, 105, 106, 7, 2, 0, 0, 106, 107, 1, 0, 0, 0, 107, 108, 6, 19, 0, 0, 108, 40, 1, 0, 0, 0, 7, 0, 70, 75, 81, 83, 91, 99, 1, 6, 0, 0] \ No newline at end of file diff --git a/src/andromede/expression/parsing/antlr/ExprLexer.py b/src/andromede/expression/parsing/antlr/ExprLexer.py index 1ad7f368..aa15abd0 100644 --- a/src/andromede/expression/parsing/antlr/ExprLexer.py +++ b/src/andromede/expression/parsing/antlr/ExprLexer.py @@ -1,9 +1,7 @@ -# Generated from Expr.g4 by ANTLR 4.13.1 -import sys -from io import StringIO - +# Generated from Expr.g4 by ANTLR 4.13.2 from antlr4 import * - +from io import StringIO +import sys if sys.version_info[1] > 5: from typing import TextIO else: @@ -12,945 +10,50 @@ def serializedATN(): return [ - 4, - 0, - 16, - 103, - 6, - -1, - 2, - 0, - 7, - 0, - 2, - 1, - 7, - 1, - 2, - 2, - 7, - 2, - 2, - 3, - 7, - 3, - 2, - 4, - 7, - 4, - 2, - 5, - 7, - 5, - 2, - 6, - 7, - 6, - 2, - 7, - 7, - 7, - 2, - 8, - 7, - 8, - 2, - 9, - 7, - 9, - 2, - 10, - 7, - 10, - 2, - 11, - 7, - 11, - 2, - 12, - 7, - 12, - 2, - 13, - 7, - 13, - 2, - 14, - 7, - 14, - 2, - 15, - 7, - 15, - 2, - 16, - 7, - 16, - 2, - 17, - 7, - 17, - 2, - 18, - 7, - 18, - 1, - 0, - 1, - 0, - 1, - 1, - 1, - 1, - 1, - 2, - 1, - 2, - 1, - 3, - 1, - 3, - 1, - 4, - 1, - 4, - 1, - 5, - 1, - 5, - 1, - 6, - 1, - 6, - 1, - 7, - 1, - 7, - 1, - 8, - 1, - 8, - 1, - 9, - 1, - 9, - 1, - 10, - 1, - 10, - 1, - 10, - 1, - 11, - 1, - 11, - 1, - 12, - 1, - 12, - 1, - 13, - 1, - 13, - 3, - 13, - 69, - 8, - 13, - 1, - 14, - 4, - 14, - 72, - 8, - 14, - 11, - 14, - 12, - 14, - 73, - 1, - 14, - 1, - 14, - 4, - 14, - 78, - 8, - 14, - 11, - 14, - 12, - 14, - 79, - 3, - 14, - 82, - 8, - 14, - 1, - 15, - 1, - 15, - 1, - 16, - 1, - 16, - 5, - 16, - 88, - 8, - 16, - 10, - 16, - 12, - 16, - 91, - 9, - 16, - 1, - 17, - 1, - 17, - 1, - 17, - 1, - 17, - 1, - 17, - 3, - 17, - 98, - 8, - 17, - 1, - 18, - 1, - 18, - 1, - 18, - 1, - 18, - 0, - 0, - 19, - 1, - 1, - 3, - 2, - 5, - 3, - 7, - 4, - 9, - 5, - 11, - 6, - 13, - 7, - 15, - 8, - 17, - 9, - 19, - 10, - 21, - 11, - 23, - 0, - 25, - 0, - 27, - 0, - 29, - 12, - 31, - 13, - 33, - 14, - 35, - 15, - 37, - 16, - 1, - 0, - 3, - 1, - 0, - 48, - 57, - 3, - 0, - 65, - 90, - 95, - 95, - 97, - 122, - 3, - 0, - 9, - 10, - 13, - 13, - 32, - 32, - 106, - 0, - 1, - 1, - 0, - 0, - 0, - 0, - 3, - 1, - 0, - 0, - 0, - 0, - 5, - 1, - 0, - 0, - 0, - 0, - 7, - 1, - 0, - 0, - 0, - 0, - 9, - 1, - 0, - 0, - 0, - 0, - 11, - 1, - 0, - 0, - 0, - 0, - 13, - 1, - 0, - 0, - 0, - 0, - 15, - 1, - 0, - 0, - 0, - 0, - 17, - 1, - 0, - 0, - 0, - 0, - 19, - 1, - 0, - 0, - 0, - 0, - 21, - 1, - 0, - 0, - 0, - 0, - 29, - 1, - 0, - 0, - 0, - 0, - 31, - 1, - 0, - 0, - 0, - 0, - 33, - 1, - 0, - 0, - 0, - 0, - 35, - 1, - 0, - 0, - 0, - 0, - 37, - 1, - 0, - 0, - 0, - 1, - 39, - 1, - 0, - 0, - 0, - 3, - 41, - 1, - 0, - 0, - 0, - 5, - 43, - 1, - 0, - 0, - 0, - 7, - 45, - 1, - 0, - 0, - 0, - 9, - 47, - 1, - 0, - 0, - 0, - 11, - 49, - 1, - 0, - 0, - 0, - 13, - 51, - 1, - 0, - 0, - 0, - 15, - 53, - 1, - 0, - 0, - 0, - 17, - 55, - 1, - 0, - 0, - 0, - 19, - 57, - 1, - 0, - 0, - 0, - 21, - 59, - 1, - 0, - 0, - 0, - 23, - 62, - 1, - 0, - 0, - 0, - 25, - 64, - 1, - 0, - 0, - 0, - 27, - 68, - 1, - 0, - 0, - 0, - 29, - 71, - 1, - 0, - 0, - 0, - 31, - 83, - 1, - 0, - 0, - 0, - 33, - 85, - 1, - 0, - 0, - 0, - 35, - 97, - 1, - 0, - 0, - 0, - 37, - 99, - 1, - 0, - 0, - 0, - 39, - 40, - 5, - 46, - 0, - 0, - 40, - 2, - 1, - 0, - 0, - 0, - 41, - 42, - 5, - 45, - 0, - 0, - 42, - 4, - 1, - 0, - 0, - 0, - 43, - 44, - 5, - 40, - 0, - 0, - 44, - 6, - 1, - 0, - 0, - 0, - 45, - 46, - 5, - 41, - 0, - 0, - 46, - 8, - 1, - 0, - 0, - 0, - 47, - 48, - 5, - 47, - 0, - 0, - 48, - 10, - 1, - 0, - 0, - 0, - 49, - 50, - 5, - 42, - 0, - 0, - 50, - 12, - 1, - 0, - 0, - 0, - 51, - 52, - 5, - 43, - 0, - 0, - 52, - 14, - 1, - 0, - 0, - 0, - 53, - 54, - 5, - 91, - 0, - 0, - 54, - 16, - 1, - 0, - 0, - 0, - 55, - 56, - 5, - 44, - 0, - 0, - 56, - 18, - 1, - 0, - 0, - 0, - 57, - 58, - 5, - 93, - 0, - 0, - 58, - 20, - 1, - 0, - 0, - 0, - 59, - 60, - 5, - 46, - 0, - 0, - 60, - 61, - 5, - 46, - 0, - 0, - 61, - 22, - 1, - 0, - 0, - 0, - 62, - 63, - 7, - 0, - 0, - 0, - 63, - 24, - 1, - 0, - 0, - 0, - 64, - 65, - 7, - 1, - 0, - 0, - 65, - 26, - 1, - 0, - 0, - 0, - 66, - 69, - 3, - 25, - 12, - 0, - 67, - 69, - 3, - 23, - 11, - 0, - 68, - 66, - 1, - 0, - 0, - 0, - 68, - 67, - 1, - 0, - 0, - 0, - 69, - 28, - 1, - 0, - 0, - 0, - 70, - 72, - 3, - 23, - 11, - 0, - 71, - 70, - 1, - 0, - 0, - 0, - 72, - 73, - 1, - 0, - 0, - 0, - 73, - 71, - 1, - 0, - 0, - 0, - 73, - 74, - 1, - 0, - 0, - 0, - 74, - 81, - 1, - 0, - 0, - 0, - 75, - 77, - 5, - 46, - 0, - 0, - 76, - 78, - 3, - 23, - 11, - 0, - 77, - 76, - 1, - 0, - 0, - 0, - 78, - 79, - 1, - 0, - 0, - 0, - 79, - 77, - 1, - 0, - 0, - 0, - 79, - 80, - 1, - 0, - 0, - 0, - 80, - 82, - 1, - 0, - 0, - 0, - 81, - 75, - 1, - 0, - 0, - 0, - 81, - 82, - 1, - 0, - 0, - 0, - 82, - 30, - 1, - 0, - 0, - 0, - 83, - 84, - 5, - 116, - 0, - 0, - 84, - 32, - 1, - 0, - 0, - 0, - 85, - 89, - 3, - 25, - 12, - 0, - 86, - 88, - 3, - 27, - 13, - 0, - 87, - 86, - 1, - 0, - 0, - 0, - 88, - 91, - 1, - 0, - 0, - 0, - 89, - 87, - 1, - 0, - 0, - 0, - 89, - 90, - 1, - 0, - 0, - 0, - 90, - 34, - 1, - 0, - 0, - 0, - 91, - 89, - 1, - 0, - 0, - 0, - 92, - 98, - 5, - 61, - 0, - 0, - 93, - 94, - 5, - 62, - 0, - 0, - 94, - 98, - 5, - 61, - 0, - 0, - 95, - 96, - 5, - 60, - 0, - 0, - 96, - 98, - 5, - 61, - 0, - 0, - 97, - 92, - 1, - 0, - 0, - 0, - 97, - 93, - 1, - 0, - 0, - 0, - 97, - 95, - 1, - 0, - 0, - 0, - 98, - 36, - 1, - 0, - 0, - 0, - 99, - 100, - 7, - 2, - 0, - 0, - 100, - 101, - 1, - 0, - 0, - 0, - 101, - 102, - 6, - 18, - 0, - 0, - 102, - 38, - 1, - 0, - 0, - 0, - 7, - 0, - 68, - 73, - 79, - 81, - 89, - 97, - 1, - 6, - 0, - 0, + 4,0,17,109,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5, + 2,6,7,6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2, + 13,7,13,2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7, + 19,1,0,1,0,1,1,1,1,1,2,1,2,1,3,1,3,1,4,1,4,1,5,1,5,1,6,1,6,1,7,1, + 7,1,8,1,8,1,9,1,9,1,10,1,10,1,10,1,11,1,11,1,12,1,12,1,13,1,13,3, + 13,71,8,13,1,14,4,14,74,8,14,11,14,12,14,75,1,14,1,14,4,14,80,8, + 14,11,14,12,14,81,3,14,84,8,14,1,15,1,15,1,16,1,16,5,16,90,8,16, + 10,16,12,16,93,9,16,1,17,1,17,1,17,1,17,1,17,3,17,100,8,17,1,18, + 1,18,1,18,1,18,1,19,1,19,1,19,1,19,0,0,20,1,1,3,2,5,3,7,4,9,5,11, + 6,13,7,15,8,17,9,19,10,21,11,23,0,25,0,27,0,29,12,31,13,33,14,35, + 15,37,16,39,17,1,0,3,1,0,48,57,3,0,65,90,95,95,97,122,3,0,9,10,13, + 13,32,32,112,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0,0,7,1,0,0,0,0,9, + 1,0,0,0,0,11,1,0,0,0,0,13,1,0,0,0,0,15,1,0,0,0,0,17,1,0,0,0,0,19, + 1,0,0,0,0,21,1,0,0,0,0,29,1,0,0,0,0,31,1,0,0,0,0,33,1,0,0,0,0,35, + 1,0,0,0,0,37,1,0,0,0,0,39,1,0,0,0,1,41,1,0,0,0,3,43,1,0,0,0,5,45, + 1,0,0,0,7,47,1,0,0,0,9,49,1,0,0,0,11,51,1,0,0,0,13,53,1,0,0,0,15, + 55,1,0,0,0,17,57,1,0,0,0,19,59,1,0,0,0,21,61,1,0,0,0,23,64,1,0,0, + 0,25,66,1,0,0,0,27,70,1,0,0,0,29,73,1,0,0,0,31,85,1,0,0,0,33,87, + 1,0,0,0,35,99,1,0,0,0,37,101,1,0,0,0,39,105,1,0,0,0,41,42,5,46,0, + 0,42,2,1,0,0,0,43,44,5,45,0,0,44,4,1,0,0,0,45,46,5,40,0,0,46,6,1, + 0,0,0,47,48,5,41,0,0,48,8,1,0,0,0,49,50,5,47,0,0,50,10,1,0,0,0,51, + 52,5,42,0,0,52,12,1,0,0,0,53,54,5,43,0,0,54,14,1,0,0,0,55,56,5,91, + 0,0,56,16,1,0,0,0,57,58,5,93,0,0,58,18,1,0,0,0,59,60,5,44,0,0,60, + 20,1,0,0,0,61,62,5,46,0,0,62,63,5,46,0,0,63,22,1,0,0,0,64,65,7,0, + 0,0,65,24,1,0,0,0,66,67,7,1,0,0,67,26,1,0,0,0,68,71,3,25,12,0,69, + 71,3,23,11,0,70,68,1,0,0,0,70,69,1,0,0,0,71,28,1,0,0,0,72,74,3,23, + 11,0,73,72,1,0,0,0,74,75,1,0,0,0,75,73,1,0,0,0,75,76,1,0,0,0,76, + 83,1,0,0,0,77,79,5,46,0,0,78,80,3,23,11,0,79,78,1,0,0,0,80,81,1, + 0,0,0,81,79,1,0,0,0,81,82,1,0,0,0,82,84,1,0,0,0,83,77,1,0,0,0,83, + 84,1,0,0,0,84,30,1,0,0,0,85,86,5,116,0,0,86,32,1,0,0,0,87,91,3,25, + 12,0,88,90,3,27,13,0,89,88,1,0,0,0,90,93,1,0,0,0,91,89,1,0,0,0,91, + 92,1,0,0,0,92,34,1,0,0,0,93,91,1,0,0,0,94,100,5,61,0,0,95,96,5,62, + 0,0,96,100,5,61,0,0,97,98,5,60,0,0,98,100,5,61,0,0,99,94,1,0,0,0, + 99,95,1,0,0,0,99,97,1,0,0,0,100,36,1,0,0,0,101,102,5,115,0,0,102, + 103,5,117,0,0,103,104,5,109,0,0,104,38,1,0,0,0,105,106,7,2,0,0,106, + 107,1,0,0,0,107,108,6,19,0,0,108,40,1,0,0,0,7,0,70,75,81,83,91,99, + 1,6,0,0 ] - class ExprLexer(Lexer): + atn = ATNDeserializer().deserialize(serializedATN()) - decisionsToDFA = [DFA(ds, i) for i, ds in enumerate(atn.decisionToState)] + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] T__0 = 1 T__1 = 2 @@ -967,59 +70,32 @@ class ExprLexer(Lexer): TIME = 13 IDENTIFIER = 14 COMPARISON = 15 - WS = 16 + TIME_SUM = 16 + WS = 17 - channelNames = ["DEFAULT_TOKEN_CHANNEL", "HIDDEN"] + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] - modeNames = ["DEFAULT_MODE"] + modeNames = [ "DEFAULT_MODE" ] - literalNames = [ - "", - "'.'", - "'-'", - "'('", - "')'", - "'/'", - "'*'", - "'+'", - "'['", - "','", - "']'", - "'..'", - "'t'", - ] + literalNames = [ "", + "'.'", "'-'", "'('", "')'", "'/'", "'*'", "'+'", "'['", "']'", + "','", "'..'", "'t'", "'sum'" ] - symbolicNames = ["", "NUMBER", "TIME", "IDENTIFIER", "COMPARISON", "WS"] + symbolicNames = [ "", + "NUMBER", "TIME", "IDENTIFIER", "COMPARISON", "TIME_SUM", "WS" ] - ruleNames = [ - "T__0", - "T__1", - "T__2", - "T__3", - "T__4", - "T__5", - "T__6", - "T__7", - "T__8", - "T__9", - "T__10", - "DIGIT", - "CHAR", - "CHAR_OR_DIGIT", - "NUMBER", - "TIME", - "IDENTIFIER", - "COMPARISON", - "WS", - ] + ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", + "T__7", "T__8", "T__9", "T__10", "DIGIT", "CHAR", "CHAR_OR_DIGIT", + "NUMBER", "TIME", "IDENTIFIER", "COMPARISON", "TIME_SUM", + "WS" ] grammarFileName = "Expr.g4" - def __init__(self, input=None, output: TextIO = sys.stdout): + def __init__(self, input=None, output:TextIO = sys.stdout): super().__init__(input, output) - self.checkVersion("4.13.1") - self._interp = LexerATNSimulator( - self, self.atn, self.decisionsToDFA, PredictionContextCache() - ) + self.checkVersion("4.13.2") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) self._actions = None self._predicates = None + + diff --git a/src/andromede/expression/parsing/antlr/ExprLexer.tokens b/src/andromede/expression/parsing/antlr/ExprLexer.tokens index 9401c83a..1c48b7d6 100644 --- a/src/andromede/expression/parsing/antlr/ExprLexer.tokens +++ b/src/andromede/expression/parsing/antlr/ExprLexer.tokens @@ -13,7 +13,8 @@ NUMBER=12 TIME=13 IDENTIFIER=14 COMPARISON=15 -WS=16 +TIME_SUM=16 +WS=17 '.'=1 '-'=2 '('=3 @@ -22,7 +23,8 @@ WS=16 '*'=6 '+'=7 '['=8 -','=9 -']'=10 +']'=9 +','=10 '..'=11 't'=13 +'sum'=16 diff --git a/src/andromede/expression/parsing/antlr/ExprParser.py b/src/andromede/expression/parsing/antlr/ExprParser.py index 8f312fe9..8fef9021 100644 --- a/src/andromede/expression/parsing/antlr/ExprParser.py +++ b/src/andromede/expression/parsing/antlr/ExprParser.py @@ -1,1296 +1,130 @@ -# Generated from Expr.g4 by ANTLR 4.13.1 +# Generated from Expr.g4 by ANTLR 4.13.2 # encoding: utf-8 -import sys -from io import StringIO - from antlr4 import * - +from io import StringIO +import sys if sys.version_info[1] > 5: - from typing import TextIO + from typing import TextIO else: - from typing.io import TextIO - + from typing.io import TextIO def serializedATN(): return [ - 4, - 1, - 16, - 131, - 2, - 0, - 7, - 0, - 2, - 1, - 7, - 1, - 2, - 2, - 7, - 2, - 2, - 3, - 7, - 3, - 2, - 4, - 7, - 4, - 2, - 5, - 7, - 5, - 1, - 0, - 1, - 0, - 1, - 0, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 5, - 1, - 37, - 8, - 1, - 10, - 1, - 12, - 1, - 40, - 9, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 5, - 1, - 49, - 8, - 1, - 10, - 1, - 12, - 1, - 52, - 9, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 3, - 1, - 70, - 8, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 5, - 1, - 81, - 8, - 1, - 10, - 1, - 12, - 1, - 84, - 9, - 1, - 1, - 2, - 1, - 2, - 3, - 2, - 88, - 8, - 2, - 1, - 3, - 1, - 3, - 3, - 3, - 92, - 8, - 3, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 3, - 4, - 102, - 8, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 5, - 4, - 110, - 8, - 4, - 10, - 4, - 12, - 4, - 113, - 9, - 4, - 1, - 5, - 1, - 5, - 1, - 5, - 1, - 5, - 1, - 5, - 1, - 5, - 3, - 5, - 121, - 8, - 5, - 1, - 5, - 1, - 5, - 1, - 5, - 5, - 5, - 126, - 8, - 5, - 10, - 5, - 12, - 5, - 129, - 9, - 5, - 1, - 5, - 0, - 3, - 2, - 8, - 10, - 6, - 0, - 2, - 4, - 6, - 8, - 10, - 0, - 2, - 1, - 0, - 5, - 6, - 2, - 0, - 2, - 2, - 7, - 7, - 144, - 0, - 12, - 1, - 0, - 0, - 0, - 2, - 69, - 1, - 0, - 0, - 0, - 4, - 87, - 1, - 0, - 0, - 0, - 6, - 89, - 1, - 0, - 0, - 0, - 8, - 101, - 1, - 0, - 0, - 0, - 10, - 120, - 1, - 0, - 0, - 0, - 12, - 13, - 3, - 2, - 1, - 0, - 13, - 14, - 5, - 0, - 0, - 1, - 14, - 1, - 1, - 0, - 0, - 0, - 15, - 16, - 6, - 1, - -1, - 0, - 16, - 70, - 3, - 4, - 2, - 0, - 17, - 18, - 5, - 14, - 0, - 0, - 18, - 19, - 5, - 1, - 0, - 0, - 19, - 70, - 5, - 14, - 0, - 0, - 20, - 21, - 5, - 2, - 0, - 0, - 21, - 70, - 3, - 2, - 1, - 10, - 22, - 23, - 5, - 3, - 0, - 0, - 23, - 24, - 3, - 2, - 1, - 0, - 24, - 25, - 5, - 4, - 0, - 0, - 25, - 70, - 1, - 0, - 0, - 0, - 26, - 27, - 5, - 14, - 0, - 0, - 27, - 28, - 5, - 3, - 0, - 0, - 28, - 29, - 3, - 2, - 1, - 0, - 29, - 30, - 5, - 4, - 0, - 0, - 30, - 70, - 1, - 0, - 0, - 0, - 31, - 32, - 5, - 14, - 0, - 0, - 32, - 33, - 5, - 8, - 0, - 0, - 33, - 38, - 3, - 6, - 3, - 0, - 34, - 35, - 5, - 9, - 0, - 0, - 35, - 37, - 3, - 6, - 3, - 0, - 36, - 34, - 1, - 0, - 0, - 0, - 37, - 40, - 1, - 0, - 0, - 0, - 38, - 36, - 1, - 0, - 0, - 0, - 38, - 39, - 1, - 0, - 0, - 0, - 39, - 41, - 1, - 0, - 0, - 0, - 40, - 38, - 1, - 0, - 0, - 0, - 41, - 42, - 5, - 10, - 0, - 0, - 42, - 70, - 1, - 0, - 0, - 0, - 43, - 44, - 5, - 14, - 0, - 0, - 44, - 45, - 5, - 8, - 0, - 0, - 45, - 50, - 3, - 2, - 1, - 0, - 46, - 47, - 5, - 9, - 0, - 0, - 47, - 49, - 3, - 2, - 1, - 0, - 48, - 46, - 1, - 0, - 0, - 0, - 49, - 52, - 1, - 0, - 0, - 0, - 50, - 48, - 1, - 0, - 0, - 0, - 50, - 51, - 1, - 0, - 0, - 0, - 51, - 53, - 1, - 0, - 0, - 0, - 52, - 50, - 1, - 0, - 0, - 0, - 53, - 54, - 5, - 10, - 0, - 0, - 54, - 70, - 1, - 0, - 0, - 0, - 55, - 56, - 5, - 14, - 0, - 0, - 56, - 57, - 5, - 8, - 0, - 0, - 57, - 58, - 3, - 6, - 3, - 0, - 58, - 59, - 5, - 11, - 0, - 0, - 59, - 60, - 3, - 6, - 3, - 0, - 60, - 61, - 5, - 10, - 0, - 0, - 61, - 70, - 1, - 0, - 0, - 0, - 62, - 63, - 5, - 14, - 0, - 0, - 63, - 64, - 5, - 8, - 0, - 0, - 64, - 65, - 3, - 2, - 1, - 0, - 65, - 66, - 5, - 11, - 0, - 0, - 66, - 67, - 3, - 2, - 1, - 0, - 67, - 68, - 5, - 10, - 0, - 0, - 68, - 70, - 1, - 0, - 0, - 0, - 69, - 15, - 1, - 0, - 0, - 0, - 69, - 17, - 1, - 0, - 0, - 0, - 69, - 20, - 1, - 0, - 0, - 0, - 69, - 22, - 1, - 0, - 0, - 0, - 69, - 26, - 1, - 0, - 0, - 0, - 69, - 31, - 1, - 0, - 0, - 0, - 69, - 43, - 1, - 0, - 0, - 0, - 69, - 55, - 1, - 0, - 0, - 0, - 69, - 62, - 1, - 0, - 0, - 0, - 70, - 82, - 1, - 0, - 0, - 0, - 71, - 72, - 10, - 8, - 0, - 0, - 72, - 73, - 7, - 0, - 0, - 0, - 73, - 81, - 3, - 2, - 1, - 9, - 74, - 75, - 10, - 7, - 0, - 0, - 75, - 76, - 7, - 1, - 0, - 0, - 76, - 81, - 3, - 2, - 1, - 8, - 77, - 78, - 10, - 6, - 0, - 0, - 78, - 79, - 5, - 15, - 0, - 0, - 79, - 81, - 3, - 2, - 1, - 7, - 80, - 71, - 1, - 0, - 0, - 0, - 80, - 74, - 1, - 0, - 0, - 0, - 80, - 77, - 1, - 0, - 0, - 0, - 81, - 84, - 1, - 0, - 0, - 0, - 82, - 80, - 1, - 0, - 0, - 0, - 82, - 83, - 1, - 0, - 0, - 0, - 83, - 3, - 1, - 0, - 0, - 0, - 84, - 82, - 1, - 0, - 0, - 0, - 85, - 88, - 5, - 12, - 0, - 0, - 86, - 88, - 5, - 14, - 0, - 0, - 87, - 85, - 1, - 0, - 0, - 0, - 87, - 86, - 1, - 0, - 0, - 0, - 88, - 5, - 1, - 0, - 0, - 0, - 89, - 91, - 5, - 13, - 0, - 0, - 90, - 92, - 3, - 8, - 4, - 0, - 91, - 90, - 1, - 0, - 0, - 0, - 91, - 92, - 1, - 0, - 0, - 0, - 92, - 7, - 1, - 0, - 0, - 0, - 93, - 94, - 6, - 4, - -1, - 0, - 94, - 95, - 7, - 1, - 0, - 0, - 95, - 102, - 3, - 4, - 2, - 0, - 96, - 97, - 7, - 1, - 0, - 0, - 97, - 98, - 5, - 3, - 0, - 0, - 98, - 99, - 3, - 2, - 1, - 0, - 99, - 100, - 5, - 4, - 0, - 0, - 100, - 102, - 1, - 0, - 0, - 0, - 101, - 93, - 1, - 0, - 0, - 0, - 101, - 96, - 1, - 0, - 0, - 0, - 102, - 111, - 1, - 0, - 0, - 0, - 103, - 104, - 10, - 4, - 0, - 0, - 104, - 105, - 7, - 0, - 0, - 0, - 105, - 110, - 3, - 10, - 5, - 0, - 106, - 107, - 10, - 3, - 0, - 0, - 107, - 108, - 7, - 1, - 0, - 0, - 108, - 110, - 3, - 10, - 5, - 0, - 109, - 103, - 1, - 0, - 0, - 0, - 109, - 106, - 1, - 0, - 0, - 0, - 110, - 113, - 1, - 0, - 0, - 0, - 111, - 109, - 1, - 0, - 0, - 0, - 111, - 112, - 1, - 0, - 0, - 0, - 112, - 9, - 1, - 0, - 0, - 0, - 113, - 111, - 1, - 0, - 0, - 0, - 114, - 115, - 6, - 5, - -1, - 0, - 115, - 116, - 5, - 3, - 0, - 0, - 116, - 117, - 3, - 2, - 1, - 0, - 117, - 118, - 5, - 4, - 0, - 0, - 118, - 121, - 1, - 0, - 0, - 0, - 119, - 121, - 3, - 4, - 2, - 0, - 120, - 114, - 1, - 0, - 0, - 0, - 120, - 119, - 1, - 0, - 0, - 0, - 121, - 127, - 1, - 0, - 0, - 0, - 122, - 123, - 10, - 3, - 0, - 0, - 123, - 124, - 7, - 0, - 0, - 0, - 124, - 126, - 3, - 10, - 5, - 4, - 125, - 122, - 1, - 0, - 0, - 0, - 126, - 129, - 1, - 0, - 0, - 0, - 127, - 125, - 1, - 0, - 0, - 0, - 127, - 128, - 1, - 0, - 0, - 0, - 128, - 11, - 1, - 0, - 0, - 0, - 129, - 127, - 1, - 0, - 0, - 0, - 12, - 38, - 50, - 69, - 80, - 82, - 87, - 91, - 101, - 109, - 111, - 120, - 127, + 4,1,17,129,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, + 6,2,7,7,7,1,0,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,3,1,52,8,1,1,1,1,1,3,1,56,8,1,1,1,1,1,3,1, + 60,8,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,5,1,71,8,1,10,1,12,1, + 74,9,1,1,2,1,2,3,2,78,8,2,1,3,1,3,1,3,1,3,1,4,1,4,1,4,1,4,1,5,1, + 5,3,5,90,8,5,1,6,1,6,1,6,1,6,1,6,1,6,1,6,1,6,3,6,100,8,6,1,6,1,6, + 1,6,1,6,1,6,1,6,5,6,108,8,6,10,6,12,6,111,9,6,1,7,1,7,1,7,1,7,1, + 7,1,7,3,7,119,8,7,1,7,1,7,1,7,5,7,124,8,7,10,7,12,7,127,9,7,1,7, + 0,3,2,12,14,8,0,2,4,6,8,10,12,14,0,2,1,0,5,6,2,0,2,2,7,7,141,0,16, + 1,0,0,0,2,59,1,0,0,0,4,77,1,0,0,0,6,79,1,0,0,0,8,83,1,0,0,0,10,87, + 1,0,0,0,12,99,1,0,0,0,14,118,1,0,0,0,16,17,3,2,1,0,17,18,5,0,0,1, + 18,1,1,0,0,0,19,20,6,1,-1,0,20,60,3,4,2,0,21,22,5,14,0,0,22,23,5, + 1,0,0,23,60,5,14,0,0,24,25,5,2,0,0,25,60,3,2,1,9,26,27,5,3,0,0,27, + 28,3,2,1,0,28,29,5,4,0,0,29,60,1,0,0,0,30,31,5,14,0,0,31,32,5,3, + 0,0,32,33,3,2,1,0,33,34,5,4,0,0,34,60,1,0,0,0,35,36,5,14,0,0,36, + 37,5,8,0,0,37,38,3,10,5,0,38,39,5,9,0,0,39,60,1,0,0,0,40,41,5,14, + 0,0,41,42,5,8,0,0,42,43,3,2,1,0,43,44,5,9,0,0,44,60,1,0,0,0,45,46, + 5,16,0,0,46,55,5,3,0,0,47,52,3,2,1,0,48,52,3,10,5,0,49,52,3,6,3, + 0,50,52,3,8,4,0,51,47,1,0,0,0,51,48,1,0,0,0,51,49,1,0,0,0,51,50, + 1,0,0,0,52,53,1,0,0,0,53,54,5,10,0,0,54,56,1,0,0,0,55,51,1,0,0,0, + 55,56,1,0,0,0,56,57,1,0,0,0,57,58,5,14,0,0,58,60,5,4,0,0,59,19,1, + 0,0,0,59,21,1,0,0,0,59,24,1,0,0,0,59,26,1,0,0,0,59,30,1,0,0,0,59, + 35,1,0,0,0,59,40,1,0,0,0,59,45,1,0,0,0,60,72,1,0,0,0,61,62,10,7, + 0,0,62,63,7,0,0,0,63,71,3,2,1,8,64,65,10,6,0,0,65,66,7,1,0,0,66, + 71,3,2,1,7,67,68,10,5,0,0,68,69,5,15,0,0,69,71,3,2,1,6,70,61,1,0, + 0,0,70,64,1,0,0,0,70,67,1,0,0,0,71,74,1,0,0,0,72,70,1,0,0,0,72,73, + 1,0,0,0,73,3,1,0,0,0,74,72,1,0,0,0,75,78,5,12,0,0,76,78,5,14,0,0, + 77,75,1,0,0,0,77,76,1,0,0,0,78,5,1,0,0,0,79,80,3,10,5,0,80,81,5, + 11,0,0,81,82,3,10,5,0,82,7,1,0,0,0,83,84,3,2,1,0,84,85,5,11,0,0, + 85,86,3,2,1,0,86,9,1,0,0,0,87,89,5,13,0,0,88,90,3,12,6,0,89,88,1, + 0,0,0,89,90,1,0,0,0,90,11,1,0,0,0,91,92,6,6,-1,0,92,93,7,1,0,0,93, + 100,3,4,2,0,94,95,7,1,0,0,95,96,5,3,0,0,96,97,3,2,1,0,97,98,5,4, + 0,0,98,100,1,0,0,0,99,91,1,0,0,0,99,94,1,0,0,0,100,109,1,0,0,0,101, + 102,10,4,0,0,102,103,7,0,0,0,103,108,3,14,7,0,104,105,10,3,0,0,105, + 106,7,1,0,0,106,108,3,14,7,0,107,101,1,0,0,0,107,104,1,0,0,0,108, + 111,1,0,0,0,109,107,1,0,0,0,109,110,1,0,0,0,110,13,1,0,0,0,111,109, + 1,0,0,0,112,113,6,7,-1,0,113,114,5,3,0,0,114,115,3,2,1,0,115,116, + 5,4,0,0,116,119,1,0,0,0,117,119,3,4,2,0,118,112,1,0,0,0,118,117, + 1,0,0,0,119,125,1,0,0,0,120,121,10,3,0,0,121,122,7,0,0,0,122,124, + 3,14,7,4,123,120,1,0,0,0,124,127,1,0,0,0,125,123,1,0,0,0,125,126, + 1,0,0,0,126,15,1,0,0,0,127,125,1,0,0,0,12,51,55,59,70,72,77,89,99, + 107,109,118,125 ] +class ExprParser ( Parser ): -class ExprParser(Parser): grammarFileName = "Expr.g4" atn = ATNDeserializer().deserialize(serializedATN()) - decisionsToDFA = [DFA(ds, i) for i, ds in enumerate(atn.decisionToState)] + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] sharedContextCache = PredictionContextCache() - literalNames = [ - "", - "'.'", - "'-'", - "'('", - "')'", - "'/'", - "'*'", - "'+'", - "'['", - "','", - "']'", - "'..'", - "", - "'t'", - ] + literalNames = [ "", "'.'", "'-'", "'('", "')'", "'/'", "'*'", + "'+'", "'['", "']'", "','", "'..'", "", "'t'", + "", "", "'sum'" ] - symbolicNames = [ - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "NUMBER", - "TIME", - "IDENTIFIER", - "COMPARISON", - "WS", - ] + symbolicNames = [ "", "", "", "", + "", "", "", "", + "", "", "", "", + "NUMBER", "TIME", "IDENTIFIER", "COMPARISON", "TIME_SUM", + "WS" ] RULE_fullexpr = 0 RULE_expr = 1 RULE_atom = 2 - RULE_shift = 3 - RULE_shift_expr = 4 - RULE_right_expr = 5 + RULE_timeShiftRange = 3 + RULE_timeRange = 4 + RULE_shift = 5 + RULE_shift_expr = 6 + RULE_right_expr = 7 - ruleNames = ["fullexpr", "expr", "atom", "shift", "shift_expr", "right_expr"] + ruleNames = [ "fullexpr", "expr", "atom", "timeShiftRange", "timeRange", + "shift", "shift_expr", "right_expr" ] EOF = Token.EOF - T__0 = 1 - T__1 = 2 - T__2 = 3 - T__3 = 4 - T__4 = 5 - T__5 = 6 - T__6 = 7 - T__7 = 8 - T__8 = 9 - T__9 = 10 - T__10 = 11 - NUMBER = 12 - TIME = 13 - IDENTIFIER = 14 - COMPARISON = 15 - WS = 16 - - def __init__(self, input: TokenStream, output: TextIO = sys.stdout): + T__0=1 + T__1=2 + T__2=3 + T__3=4 + T__4=5 + T__5=6 + T__6=7 + T__7=8 + T__8=9 + T__9=10 + T__10=11 + NUMBER=12 + TIME=13 + IDENTIFIER=14 + COMPARISON=15 + TIME_SUM=16 + WS=17 + + def __init__(self, input:TokenStream, output:TextIO = sys.stdout): super().__init__(input, output) - self.checkVersion("4.13.1") - self._interp = ParserATNSimulator( - self, self.atn, self.decisionsToDFA, self.sharedContextCache - ) + self.checkVersion("4.13.2") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) self._predicates = None + + + class FullexprContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) + def EOF(self): return self.getToken(ExprParser.EOF, 0) @@ -1298,20 +132,24 @@ def EOF(self): def getRuleIndex(self): return ExprParser.RULE_fullexpr - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitFullexpr"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitFullexpr" ): return visitor.visitFullexpr(self) else: return visitor.visitChildren(self) + + + def fullexpr(self): + localctx = ExprParser.FullexprContext(self, self._ctx, self.state) self.enterRule(localctx, 0, self.RULE_fullexpr) try: self.enterOuterAlt(localctx, 1) - self.state = 12 + self.state = 16 self.expr(0) - self.state = 13 + self.state = 17 self.match(ExprParser.EOF) except RecognitionException as re: localctx.exception = re @@ -1321,278 +159,264 @@ def fullexpr(self): self.exitRule() return localctx + class ExprContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser + def getRuleIndex(self): return ExprParser.RULE_expr - def copyFrom(self, ctx: ParserRuleContext): + + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) + + class TimeSumContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def TIME_SUM(self): + return self.getToken(ExprParser.TIME_SUM, 0) + def IDENTIFIER(self): + return self.getToken(ExprParser.IDENTIFIER, 0) + def expr(self): + return self.getTypedRuleContext(ExprParser.ExprContext,0) + + def shift(self): + return self.getTypedRuleContext(ExprParser.ShiftContext,0) + + def timeShiftRange(self): + return self.getTypedRuleContext(ExprParser.TimeShiftRangeContext,0) + + def timeRange(self): + return self.getTypedRuleContext(ExprParser.TimeRangeContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTimeSum" ): + return visitor.visitTimeSum(self) + else: + return visitor.visitChildren(self) + + class NegationContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitNegation"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitNegation" ): return visitor.visitNegation(self) else: return visitor.visitChildren(self) + class UnsignedAtomContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def atom(self): - return self.getTypedRuleContext(ExprParser.AtomContext, 0) + return self.getTypedRuleContext(ExprParser.AtomContext,0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitUnsignedAtom"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitUnsignedAtom" ): return visitor.visitUnsignedAtom(self) else: return visitor.visitChildren(self) + class ExpressionContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitExpression"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitExpression" ): return visitor.visitExpression(self) else: return visitor.visitChildren(self) + class TimeIndexContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def IDENTIFIER(self): return self.getToken(ExprParser.IDENTIFIER, 0) + def expr(self): + return self.getTypedRuleContext(ExprParser.ExprContext,0) - def expr(self, i: int = None): - if i is None: - return self.getTypedRuleContexts(ExprParser.ExprContext) - else: - return self.getTypedRuleContext(ExprParser.ExprContext, i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitTimeIndex"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTimeIndex" ): return visitor.visitTimeIndex(self) else: return visitor.visitChildren(self) + class ComparisonContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) - def expr(self, i: int = None): + def expr(self, i:int=None): if i is None: return self.getTypedRuleContexts(ExprParser.ExprContext) else: - return self.getTypedRuleContext(ExprParser.ExprContext, i) + return self.getTypedRuleContext(ExprParser.ExprContext,i) def COMPARISON(self): return self.getToken(ExprParser.COMPARISON, 0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitComparison"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitComparison" ): return visitor.visitComparison(self) else: return visitor.visitChildren(self) + class TimeShiftContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def IDENTIFIER(self): return self.getToken(ExprParser.IDENTIFIER, 0) + def shift(self): + return self.getTypedRuleContext(ExprParser.ShiftContext,0) - def shift(self, i: int = None): - if i is None: - return self.getTypedRuleContexts(ExprParser.ShiftContext) - else: - return self.getTypedRuleContext(ExprParser.ShiftContext, i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitTimeShift"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTimeShift" ): return visitor.visitTimeShift(self) else: return visitor.visitChildren(self) + class FunctionContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def IDENTIFIER(self): return self.getToken(ExprParser.IDENTIFIER, 0) - def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitFunction"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitFunction" ): return visitor.visitFunction(self) else: return visitor.visitChildren(self) + class AddsubContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) - def expr(self, i: int = None): + def expr(self, i:int=None): if i is None: return self.getTypedRuleContexts(ExprParser.ExprContext) else: - return self.getTypedRuleContext(ExprParser.ExprContext, i) + return self.getTypedRuleContext(ExprParser.ExprContext,i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitAddsub"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitAddsub" ): return visitor.visitAddsub(self) else: return visitor.visitChildren(self) - class TimeShiftRangeContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext - super().__init__(parser) - self.shift1 = None # ShiftContext - self.shift2 = None # ShiftContext - self.copyFrom(ctx) - - def IDENTIFIER(self): - return self.getToken(ExprParser.IDENTIFIER, 0) - - def shift(self, i: int = None): - if i is None: - return self.getTypedRuleContexts(ExprParser.ShiftContext) - else: - return self.getTypedRuleContext(ExprParser.ShiftContext, i) - - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitTimeShiftRange"): - return visitor.visitTimeShiftRange(self) - else: - return visitor.visitChildren(self) class PortFieldContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) - def IDENTIFIER(self, i: int = None): + def IDENTIFIER(self, i:int=None): if i is None: return self.getTokens(ExprParser.IDENTIFIER) else: return self.getToken(ExprParser.IDENTIFIER, i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitPortField"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitPortField" ): return visitor.visitPortField(self) else: return visitor.visitChildren(self) + class MuldivContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) - def expr(self, i: int = None): + def expr(self, i:int=None): if i is None: return self.getTypedRuleContexts(ExprParser.ExprContext) else: - return self.getTypedRuleContext(ExprParser.ExprContext, i) + return self.getTypedRuleContext(ExprParser.ExprContext,i) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitMuldiv"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitMuldiv" ): return visitor.visitMuldiv(self) else: return visitor.visitChildren(self) - class TimeRangeContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def IDENTIFIER(self): - return self.getToken(ExprParser.IDENTIFIER, 0) - - def expr(self, i: int = None): - if i is None: - return self.getTypedRuleContexts(ExprParser.ExprContext) - else: - return self.getTypedRuleContext(ExprParser.ExprContext, i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitTimeRange"): - return visitor.visitTimeRange(self) - else: - return visitor.visitChildren(self) - def expr(self, _p: int = 0): + def expr(self, _p:int=0): _parentctx = self._ctx _parentState = self.state localctx = ExprParser.ExprContext(self, self._ctx, _parentState) _prevctx = localctx _startState = 2 self.enterRecursionRule(localctx, 2, self.RULE_expr, _p) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 69 + self.state = 59 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 2, self._ctx) + la_ = self._interp.adaptivePredict(self._input,2,self._ctx) if la_ == 1: localctx = ExprParser.UnsignedAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 16 + self.state = 20 self.atom() pass @@ -1600,11 +424,11 @@ def expr(self, _p: int = 0): localctx = ExprParser.PortFieldContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 17 + self.state = 21 self.match(ExprParser.IDENTIFIER) - self.state = 18 + self.state = 22 self.match(ExprParser.T__0) - self.state = 19 + self.state = 23 self.match(ExprParser.IDENTIFIER) pass @@ -1612,21 +436,21 @@ def expr(self, _p: int = 0): localctx = ExprParser.NegationContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 20 + self.state = 24 self.match(ExprParser.T__1) - self.state = 21 - self.expr(10) + self.state = 25 + self.expr(9) pass elif la_ == 4: localctx = ExprParser.ExpressionContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 22 + self.state = 26 self.match(ExprParser.T__2) - self.state = 23 + self.state = 27 self.expr(0) - self.state = 24 + self.state = 28 self.match(ExprParser.T__3) pass @@ -1634,13 +458,13 @@ def expr(self, _p: int = 0): localctx = ExprParser.FunctionContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 26 + self.state = 30 self.match(ExprParser.IDENTIFIER) - self.state = 27 + self.state = 31 self.match(ExprParser.T__2) - self.state = 28 + self.state = 32 self.expr(0) - self.state = 29 + self.state = 33 self.match(ExprParser.T__3) pass @@ -1648,177 +472,144 @@ def expr(self, _p: int = 0): localctx = ExprParser.TimeShiftContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 31 + self.state = 35 self.match(ExprParser.IDENTIFIER) - self.state = 32 + self.state = 36 self.match(ExprParser.T__7) - self.state = 33 + self.state = 37 self.shift() self.state = 38 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la == 9: - self.state = 34 - self.match(ExprParser.T__8) - self.state = 35 - self.shift() - self.state = 40 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 41 - self.match(ExprParser.T__9) + self.match(ExprParser.T__8) pass elif la_ == 7: localctx = ExprParser.TimeIndexContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 43 + self.state = 40 self.match(ExprParser.IDENTIFIER) - self.state = 44 + self.state = 41 self.match(ExprParser.T__7) - self.state = 45 + self.state = 42 self.expr(0) - self.state = 50 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la == 9: - self.state = 46 - self.match(ExprParser.T__8) - self.state = 47 - self.expr(0) - self.state = 52 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 53 - self.match(ExprParser.T__9) + self.state = 43 + self.match(ExprParser.T__8) pass elif la_ == 8: - localctx = ExprParser.TimeShiftRangeContext(self, localctx) + localctx = ExprParser.TimeSumContext(self, localctx) self._ctx = localctx _prevctx = localctx + self.state = 45 + self.match(ExprParser.TIME_SUM) + self.state = 46 + self.match(ExprParser.T__2) self.state = 55 - self.match(ExprParser.IDENTIFIER) - self.state = 56 - self.match(ExprParser.T__7) + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,1,self._ctx) + if la_ == 1: + self.state = 51 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,0,self._ctx) + if la_ == 1: + self.state = 47 + self.expr(0) + pass + + elif la_ == 2: + self.state = 48 + self.shift() + pass + + elif la_ == 3: + self.state = 49 + self.timeShiftRange() + pass + + elif la_ == 4: + self.state = 50 + self.timeRange() + pass + + + self.state = 53 + self.match(ExprParser.T__9) + + self.state = 57 - localctx.shift1 = self.shift() + self.match(ExprParser.IDENTIFIER) self.state = 58 - self.match(ExprParser.T__10) - self.state = 59 - localctx.shift2 = self.shift() - self.state = 60 - self.match(ExprParser.T__9) + self.match(ExprParser.T__3) pass - elif la_ == 9: - localctx = ExprParser.TimeRangeContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 62 - self.match(ExprParser.IDENTIFIER) - self.state = 63 - self.match(ExprParser.T__7) - self.state = 64 - self.expr(0) - self.state = 65 - self.match(ExprParser.T__10) - self.state = 66 - self.expr(0) - self.state = 67 - self.match(ExprParser.T__9) - pass self._ctx.stop = self._input.LT(-1) - self.state = 82 + self.state = 72 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 4, self._ctx) - while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: - if _alt == 1: + _alt = self._interp.adaptivePredict(self._input,4,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 80 + self.state = 70 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 3, self._ctx) + la_ = self._interp.adaptivePredict(self._input,3,self._ctx) if la_ == 1: - localctx = ExprParser.MuldivContext( - self, ExprParser.ExprContext(self, _parentctx, _parentState) - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_expr - ) - self.state = 71 - if not self.precpred(self._ctx, 8): + localctx = ExprParser.MuldivContext(self, ExprParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 61 + if not self.precpred(self._ctx, 7): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 8)" - ) - self.state = 72 + raise FailedPredicateException(self, "self.precpred(self._ctx, 7)") + self.state = 62 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 5 or _la == 6): + if not(_la==5 or _la==6): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 73 - self.expr(9) + self.state = 63 + self.expr(8) pass elif la_ == 2: - localctx = ExprParser.AddsubContext( - self, ExprParser.ExprContext(self, _parentctx, _parentState) - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_expr - ) - self.state = 74 - if not self.precpred(self._ctx, 7): + localctx = ExprParser.AddsubContext(self, ExprParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 64 + if not self.precpred(self._ctx, 6): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 7)" - ) - self.state = 75 + raise FailedPredicateException(self, "self.precpred(self._ctx, 6)") + self.state = 65 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 2 or _la == 7): + if not(_la==2 or _la==7): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 76 - self.expr(8) + self.state = 66 + self.expr(7) pass elif la_ == 3: - localctx = ExprParser.ComparisonContext( - self, ExprParser.ExprContext(self, _parentctx, _parentState) - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_expr - ) - self.state = 77 - if not self.precpred(self._ctx, 6): + localctx = ExprParser.ComparisonContext(self, ExprParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 67 + if not self.precpred(self._ctx, 5): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 6)" - ) - self.state = 78 + raise FailedPredicateException(self, "self.precpred(self._ctx, 5)") + self.state = 68 self.match(ExprParser.COMPARISON) - self.state = 79 - self.expr(7) + self.state = 69 + self.expr(6) pass - self.state = 84 + + self.state = 74 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 4, self._ctx) + _alt = self._interp.adaptivePredict(self._input,4,self._ctx) except RecognitionException as re: localctx.exception = re @@ -1828,70 +619,75 @@ def expr(self, _p: int = 0): self.unrollRecursionContexts(_parentctx) return localctx + class AtomContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser + def getRuleIndex(self): return ExprParser.RULE_atom - def copyFrom(self, ctx: ParserRuleContext): + + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) + + class NumberContext(AtomContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.AtomContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.AtomContext super().__init__(parser) self.copyFrom(ctx) def NUMBER(self): return self.getToken(ExprParser.NUMBER, 0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitNumber"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitNumber" ): return visitor.visitNumber(self) else: return visitor.visitChildren(self) + class IdentifierContext(AtomContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.AtomContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.AtomContext super().__init__(parser) self.copyFrom(ctx) def IDENTIFIER(self): return self.getToken(ExprParser.IDENTIFIER, 0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitIdentifier"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIdentifier" ): return visitor.visitIdentifier(self) else: return visitor.visitChildren(self) + + def atom(self): + localctx = ExprParser.AtomContext(self, self._ctx, self.state) self.enterRule(localctx, 4, self.RULE_atom) try: - self.state = 87 + self.state = 77 self._errHandler.sync(self) token = self._input.LA(1) if token in [12]: localctx = ExprParser.NumberContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 85 + self.state = 75 self.match(ExprParser.NUMBER) pass elif token in [14]: localctx = ExprParser.IdentifierContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 86 + self.state = 76 self.match(ExprParser.IDENTIFIER) pass else: @@ -1905,12 +701,109 @@ def atom(self): self.exitRule() return localctx + + class TimeShiftRangeContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.shift1 = None # ShiftContext + self.shift2 = None # ShiftContext + + def shift(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(ExprParser.ShiftContext) + else: + return self.getTypedRuleContext(ExprParser.ShiftContext,i) + + + def getRuleIndex(self): + return ExprParser.RULE_timeShiftRange + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTimeShiftRange" ): + return visitor.visitTimeShiftRange(self) + else: + return visitor.visitChildren(self) + + + + + def timeShiftRange(self): + + localctx = ExprParser.TimeShiftRangeContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_timeShiftRange) + try: + self.enterOuterAlt(localctx, 1) + self.state = 79 + localctx.shift1 = self.shift() + self.state = 80 + self.match(ExprParser.T__10) + self.state = 81 + localctx.shift2 = self.shift() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class TimeRangeContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.expr1 = None # ExprContext + self.expr2 = None # ExprContext + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(ExprParser.ExprContext) + else: + return self.getTypedRuleContext(ExprParser.ExprContext,i) + + + def getRuleIndex(self): + return ExprParser.RULE_timeRange + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTimeRange" ): + return visitor.visitTimeRange(self) + else: + return visitor.visitChildren(self) + + + + + def timeRange(self): + + localctx = ExprParser.TimeRangeContext(self, self._ctx, self.state) + self.enterRule(localctx, 8, self.RULE_timeRange) + try: + self.enterOuterAlt(localctx, 1) + self.state = 83 + localctx.expr1 = self.expr(0) + self.state = 84 + self.match(ExprParser.T__10) + self.state = 85 + localctx.expr2 = self.expr(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + class ShiftContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser @@ -1918,32 +811,38 @@ def TIME(self): return self.getToken(ExprParser.TIME, 0) def shift_expr(self): - return self.getTypedRuleContext(ExprParser.Shift_exprContext, 0) + return self.getTypedRuleContext(ExprParser.Shift_exprContext,0) + def getRuleIndex(self): return ExprParser.RULE_shift - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitShift"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitShift" ): return visitor.visitShift(self) else: return visitor.visitChildren(self) + + + def shift(self): + localctx = ExprParser.ShiftContext(self, self._ctx, self.state) - self.enterRule(localctx, 6, self.RULE_shift) - self._la = 0 # Token type + self.enterRule(localctx, 10, self.RULE_shift) + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 89 + self.state = 87 self.match(ExprParser.TIME) - self.state = 91 + self.state = 89 self._errHandler.sync(self) _la = self._input.LA(1) - if _la == 2 or _la == 7: - self.state = 90 + if _la==2 or _la==7: + self.state = 88 self.shift_expr(0) + except RecognitionException as re: localctx.exception = re self._errHandler.reportError(self, re) @@ -1952,122 +851,129 @@ def shift(self): self.exitRule() return localctx + class Shift_exprContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser + def getRuleIndex(self): return ExprParser.RULE_shift_expr - def copyFrom(self, ctx: ParserRuleContext): + + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) + class SignedAtomContext(Shift_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Shift_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Shift_exprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) def atom(self): - return self.getTypedRuleContext(ExprParser.AtomContext, 0) + return self.getTypedRuleContext(ExprParser.AtomContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitSignedAtom"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitSignedAtom" ): return visitor.visitSignedAtom(self) else: return visitor.visitChildren(self) + class SignedExpressionContext(Shift_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Shift_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Shift_exprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitSignedExpression"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitSignedExpression" ): return visitor.visitSignedExpression(self) else: return visitor.visitChildren(self) + class ShiftMuldivContext(Shift_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Shift_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Shift_exprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) def shift_expr(self): - return self.getTypedRuleContext(ExprParser.Shift_exprContext, 0) + return self.getTypedRuleContext(ExprParser.Shift_exprContext,0) def right_expr(self): - return self.getTypedRuleContext(ExprParser.Right_exprContext, 0) + return self.getTypedRuleContext(ExprParser.Right_exprContext,0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitShiftMuldiv"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitShiftMuldiv" ): return visitor.visitShiftMuldiv(self) else: return visitor.visitChildren(self) + class ShiftAddsubContext(Shift_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Shift_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Shift_exprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) def shift_expr(self): - return self.getTypedRuleContext(ExprParser.Shift_exprContext, 0) + return self.getTypedRuleContext(ExprParser.Shift_exprContext,0) def right_expr(self): - return self.getTypedRuleContext(ExprParser.Right_exprContext, 0) + return self.getTypedRuleContext(ExprParser.Right_exprContext,0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitShiftAddsub"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitShiftAddsub" ): return visitor.visitShiftAddsub(self) else: return visitor.visitChildren(self) - def shift_expr(self, _p: int = 0): + + + def shift_expr(self, _p:int=0): _parentctx = self._ctx _parentState = self.state localctx = ExprParser.Shift_exprContext(self, self._ctx, _parentState) _prevctx = localctx - _startState = 8 - self.enterRecursionRule(localctx, 8, self.RULE_shift_expr, _p) - self._la = 0 # Token type + _startState = 12 + self.enterRecursionRule(localctx, 12, self.RULE_shift_expr, _p) + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 101 + self.state = 99 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 7, self._ctx) + la_ = self._interp.adaptivePredict(self._input,7,self._ctx) if la_ == 1: localctx = ExprParser.SignedAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 94 + self.state = 92 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 2 or _la == 7): + if not(_la==2 or _la==7): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 95 + self.state = 93 self.atom() pass @@ -2075,95 +981,77 @@ def shift_expr(self, _p: int = 0): localctx = ExprParser.SignedExpressionContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 96 + self.state = 94 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 2 or _la == 7): + if not(_la==2 or _la==7): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 97 + self.state = 95 self.match(ExprParser.T__2) - self.state = 98 + self.state = 96 self.expr(0) - self.state = 99 + self.state = 97 self.match(ExprParser.T__3) pass + self._ctx.stop = self._input.LT(-1) - self.state = 111 + self.state = 109 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 9, self._ctx) - while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: - if _alt == 1: + _alt = self._interp.adaptivePredict(self._input,9,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 109 + self.state = 107 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 8, self._ctx) + la_ = self._interp.adaptivePredict(self._input,8,self._ctx) if la_ == 1: - localctx = ExprParser.ShiftMuldivContext( - self, - ExprParser.Shift_exprContext( - self, _parentctx, _parentState - ), - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_shift_expr - ) - self.state = 103 + localctx = ExprParser.ShiftMuldivContext(self, ExprParser.Shift_exprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_shift_expr) + self.state = 101 if not self.precpred(self._ctx, 4): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 4)" - ) - self.state = 104 + raise FailedPredicateException(self, "self.precpred(self._ctx, 4)") + self.state = 102 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 5 or _la == 6): + if not(_la==5 or _la==6): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 105 + self.state = 103 self.right_expr(0) pass elif la_ == 2: - localctx = ExprParser.ShiftAddsubContext( - self, - ExprParser.Shift_exprContext( - self, _parentctx, _parentState - ), - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_shift_expr - ) - self.state = 106 + localctx = ExprParser.ShiftAddsubContext(self, ExprParser.Shift_exprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_shift_expr) + self.state = 104 if not self.precpred(self._ctx, 3): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 3)" - ) - self.state = 107 + raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") + self.state = 105 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 2 or _la == 7): + if not(_la==2 or _la==7): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 108 + self.state = 106 self.right_expr(0) pass - self.state = 113 + + self.state = 111 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 9, self._ctx) + _alt = self._interp.adaptivePredict(self._input,9,self._ctx) except RecognitionException as re: localctx.exception = re @@ -2173,84 +1061,90 @@ def shift_expr(self, _p: int = 0): self.unrollRecursionContexts(_parentctx) return localctx + class Right_exprContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser + def getRuleIndex(self): return ExprParser.RULE_right_expr - def copyFrom(self, ctx: ParserRuleContext): + + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) + class RightExpressionContext(Right_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Right_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Right_exprContext super().__init__(parser) self.copyFrom(ctx) def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitRightExpression"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitRightExpression" ): return visitor.visitRightExpression(self) else: return visitor.visitChildren(self) + class RightMuldivContext(Right_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Right_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Right_exprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) - def right_expr(self, i: int = None): + def right_expr(self, i:int=None): if i is None: return self.getTypedRuleContexts(ExprParser.Right_exprContext) else: - return self.getTypedRuleContext(ExprParser.Right_exprContext, i) + return self.getTypedRuleContext(ExprParser.Right_exprContext,i) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitRightMuldiv"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitRightMuldiv" ): return visitor.visitRightMuldiv(self) else: return visitor.visitChildren(self) + class RightAtomContext(Right_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Right_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Right_exprContext super().__init__(parser) self.copyFrom(ctx) def atom(self): - return self.getTypedRuleContext(ExprParser.AtomContext, 0) + return self.getTypedRuleContext(ExprParser.AtomContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitRightAtom"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitRightAtom" ): return visitor.visitRightAtom(self) else: return visitor.visitChildren(self) - def right_expr(self, _p: int = 0): + + + def right_expr(self, _p:int=0): _parentctx = self._ctx _parentState = self.state localctx = ExprParser.Right_exprContext(self, self._ctx, _parentState) _prevctx = localctx - _startState = 10 - self.enterRecursionRule(localctx, 10, self.RULE_right_expr, _p) - self._la = 0 # Token type + _startState = 14 + self.enterRecursionRule(localctx, 14, self.RULE_right_expr, _p) + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 120 + self.state = 118 self._errHandler.sync(self) token = self._input.LA(1) if token in [3]: @@ -2258,59 +1152,51 @@ def right_expr(self, _p: int = 0): self._ctx = localctx _prevctx = localctx - self.state = 115 + self.state = 113 self.match(ExprParser.T__2) - self.state = 116 + self.state = 114 self.expr(0) - self.state = 117 + self.state = 115 self.match(ExprParser.T__3) pass elif token in [12, 14]: localctx = ExprParser.RightAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 119 + self.state = 117 self.atom() pass else: raise NoViableAltException(self) self._ctx.stop = self._input.LT(-1) - self.state = 127 + self.state = 125 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 11, self._ctx) - while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: - if _alt == 1: + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - localctx = ExprParser.RightMuldivContext( - self, - ExprParser.Right_exprContext(self, _parentctx, _parentState), - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_right_expr - ) - self.state = 122 + localctx = ExprParser.RightMuldivContext(self, ExprParser.Right_exprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_right_expr) + self.state = 120 if not self.precpred(self._ctx, 3): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 3)" - ) - self.state = 123 + raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") + self.state = 121 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 5 or _la == 6): + if not(_la==5 or _la==6): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 124 - self.right_expr(4) - self.state = 129 + self.state = 122 + self.right_expr(4) + self.state = 127 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 11, self._ctx) + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) except RecognitionException as re: localctx.exception = re @@ -2320,35 +1206,47 @@ def right_expr(self, _p: int = 0): self.unrollRecursionContexts(_parentctx) return localctx - def sempred(self, localctx: RuleContext, ruleIndex: int, predIndex: int): + + + def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): if self._predicates == None: self._predicates = dict() self._predicates[1] = self.expr_sempred - self._predicates[4] = self.shift_expr_sempred - self._predicates[5] = self.right_expr_sempred + self._predicates[6] = self.shift_expr_sempred + self._predicates[7] = self.right_expr_sempred pred = self._predicates.get(ruleIndex, None) if pred is None: raise Exception("No predicate with index:" + str(ruleIndex)) else: return pred(localctx, predIndex) - def expr_sempred(self, localctx: ExprContext, predIndex: int): - if predIndex == 0: - return self.precpred(self._ctx, 8) + def expr_sempred(self, localctx:ExprContext, predIndex:int): + if predIndex == 0: + return self.precpred(self._ctx, 7) + + + if predIndex == 1: + return self.precpred(self._ctx, 6) + + + if predIndex == 2: + return self.precpred(self._ctx, 5) + + + def shift_expr_sempred(self, localctx:Shift_exprContext, predIndex:int): + if predIndex == 3: + return self.precpred(self._ctx, 4) + + + if predIndex == 4: + return self.precpred(self._ctx, 3) + - if predIndex == 1: - return self.precpred(self._ctx, 7) + def right_expr_sempred(self, localctx:Right_exprContext, predIndex:int): + if predIndex == 5: + return self.precpred(self._ctx, 3) + - if predIndex == 2: - return self.precpred(self._ctx, 6) - def shift_expr_sempred(self, localctx: Shift_exprContext, predIndex: int): - if predIndex == 3: - return self.precpred(self._ctx, 4) - if predIndex == 4: - return self.precpred(self._ctx, 3) - def right_expr_sempred(self, localctx: Right_exprContext, predIndex: int): - if predIndex == 5: - return self.precpred(self._ctx, 3) diff --git a/src/andromede/expression/parsing/antlr/ExprVisitor.py b/src/andromede/expression/parsing/antlr/ExprVisitor.py index 0e924349..ee405aa2 100644 --- a/src/andromede/expression/parsing/antlr/ExprVisitor.py +++ b/src/andromede/expression/parsing/antlr/ExprVisitor.py @@ -1,6 +1,5 @@ -# Generated from Expr.g4 by ANTLR 4.13.1 +# Generated from Expr.g4 by ANTLR 4.13.2 from antlr4 import * - if "." in __name__: from .ExprParser import ExprParser else: @@ -8,99 +7,127 @@ # This class defines a complete generic visitor for a parse tree produced by ExprParser. - class ExprVisitor(ParseTreeVisitor): + # Visit a parse tree produced by ExprParser#fullexpr. - def visitFullexpr(self, ctx: ExprParser.FullexprContext): + def visitFullexpr(self, ctx:ExprParser.FullexprContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by ExprParser#timeSum. + def visitTimeSum(self, ctx:ExprParser.TimeSumContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#negation. - def visitNegation(self, ctx: ExprParser.NegationContext): + def visitNegation(self, ctx:ExprParser.NegationContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#unsignedAtom. - def visitUnsignedAtom(self, ctx: ExprParser.UnsignedAtomContext): + def visitUnsignedAtom(self, ctx:ExprParser.UnsignedAtomContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#expression. - def visitExpression(self, ctx: ExprParser.ExpressionContext): + def visitExpression(self, ctx:ExprParser.ExpressionContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#timeIndex. - def visitTimeIndex(self, ctx: ExprParser.TimeIndexContext): + def visitTimeIndex(self, ctx:ExprParser.TimeIndexContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#comparison. - def visitComparison(self, ctx: ExprParser.ComparisonContext): + def visitComparison(self, ctx:ExprParser.ComparisonContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#timeShift. - def visitTimeShift(self, ctx: ExprParser.TimeShiftContext): + def visitTimeShift(self, ctx:ExprParser.TimeShiftContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#function. - def visitFunction(self, ctx: ExprParser.FunctionContext): + def visitFunction(self, ctx:ExprParser.FunctionContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#addsub. - def visitAddsub(self, ctx: ExprParser.AddsubContext): + def visitAddsub(self, ctx:ExprParser.AddsubContext): return self.visitChildren(ctx) - # Visit a parse tree produced by ExprParser#timeShiftRange. - def visitTimeShiftRange(self, ctx: ExprParser.TimeShiftRangeContext): - return self.visitChildren(ctx) # Visit a parse tree produced by ExprParser#portField. - def visitPortField(self, ctx: ExprParser.PortFieldContext): + def visitPortField(self, ctx:ExprParser.PortFieldContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#muldiv. - def visitMuldiv(self, ctx: ExprParser.MuldivContext): + def visitMuldiv(self, ctx:ExprParser.MuldivContext): return self.visitChildren(ctx) - # Visit a parse tree produced by ExprParser#timeRange. - def visitTimeRange(self, ctx: ExprParser.TimeRangeContext): - return self.visitChildren(ctx) # Visit a parse tree produced by ExprParser#number. - def visitNumber(self, ctx: ExprParser.NumberContext): + def visitNumber(self, ctx:ExprParser.NumberContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#identifier. - def visitIdentifier(self, ctx: ExprParser.IdentifierContext): + def visitIdentifier(self, ctx:ExprParser.IdentifierContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by ExprParser#timeShiftRange. + def visitTimeShiftRange(self, ctx:ExprParser.TimeShiftRangeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by ExprParser#timeRange. + def visitTimeRange(self, ctx:ExprParser.TimeRangeContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#shift. - def visitShift(self, ctx: ExprParser.ShiftContext): + def visitShift(self, ctx:ExprParser.ShiftContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#signedAtom. - def visitSignedAtom(self, ctx: ExprParser.SignedAtomContext): + def visitSignedAtom(self, ctx:ExprParser.SignedAtomContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#signedExpression. - def visitSignedExpression(self, ctx: ExprParser.SignedExpressionContext): + def visitSignedExpression(self, ctx:ExprParser.SignedExpressionContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#shiftMuldiv. - def visitShiftMuldiv(self, ctx: ExprParser.ShiftMuldivContext): + def visitShiftMuldiv(self, ctx:ExprParser.ShiftMuldivContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#shiftAddsub. - def visitShiftAddsub(self, ctx: ExprParser.ShiftAddsubContext): + def visitShiftAddsub(self, ctx:ExprParser.ShiftAddsubContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#rightExpression. - def visitRightExpression(self, ctx: ExprParser.RightExpressionContext): + def visitRightExpression(self, ctx:ExprParser.RightExpressionContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#rightMuldiv. - def visitRightMuldiv(self, ctx: ExprParser.RightMuldivContext): + def visitRightMuldiv(self, ctx:ExprParser.RightMuldivContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#rightAtom. - def visitRightAtom(self, ctx: ExprParser.RightAtomContext): + def visitRightAtom(self, ctx:ExprParser.RightAtomContext): return self.visitChildren(ctx) -del ExprParser + +del ExprParser \ No newline at end of file diff --git a/src/andromede/expression/parsing/parse_expression.py b/src/andromede/expression/parsing/parse_expression.py index e96a70f1..f4ed9142 100644 --- a/src/andromede/expression/parsing/parse_expression.py +++ b/src/andromede/expression/parsing/parse_expression.py @@ -12,16 +12,22 @@ from dataclasses import dataclass from typing import Set -from antlr4 import CommonTokenStream, DiagnosticErrorListener, InputStream +from antlr4 import CommonTokenStream, InputStream from antlr4.error.ErrorStrategy import BailErrorStrategy -from andromede.expression import ExpressionNode, literal, param, var from andromede.expression.equality import expressions_equal from andromede.expression.expression import ( Comparator, ComparisonNode, ExpressionRange, - PortFieldNode, + literal, + param, +) +from andromede.expression.linear_expression import ( + LinearExpression, + port_field, + var, + wrap_in_linear_expr, ) from andromede.expression.parsing.antlr.ExprLexer import ExprLexer from andromede.expression.parsing.antlr.ExprParser import ExprParser @@ -52,19 +58,19 @@ class ExpressionNodeBuilderVisitor(ExprVisitor): identifiers: ModelIdentifiers - def visitFullexpr(self, ctx: ExprParser.FullexprContext) -> ExpressionNode: + def visitFullexpr(self, ctx: ExprParser.FullexprContext) -> LinearExpression: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#number. - def visitNumber(self, ctx: ExprParser.NumberContext) -> ExpressionNode: + def visitNumber(self, ctx: ExprParser.NumberContext) -> LinearExpression: return literal(float(ctx.NUMBER().getText())) # type: ignore # Visit a parse tree produced by ExprParser#identifier. - def visitIdentifier(self, ctx: ExprParser.IdentifierContext) -> ExpressionNode: + def visitIdentifier(self, ctx: ExprParser.IdentifierContext) -> LinearExpression: return self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore # Visit a parse tree produced by ExprParser#division. - def visitMuldiv(self, ctx: ExprParser.MuldivContext) -> ExpressionNode: + def visitMuldiv(self, ctx: ExprParser.MuldivContext) -> LinearExpression: left = ctx.expr(0).accept(self) # type: ignore right = ctx.expr(1).accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -75,7 +81,7 @@ def visitMuldiv(self, ctx: ExprParser.MuldivContext) -> ExpressionNode: raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#subtraction. - def visitAddsub(self, ctx: ExprParser.AddsubContext) -> ExpressionNode: + def visitAddsub(self, ctx: ExprParser.AddsubContext) -> LinearExpression: left = ctx.expr(0).accept(self) # type: ignore right = ctx.expr(1).accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -86,18 +92,20 @@ def visitAddsub(self, ctx: ExprParser.AddsubContext) -> ExpressionNode: raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#negation. - def visitNegation(self, ctx: ExprParser.NegationContext) -> ExpressionNode: + def visitNegation(self, ctx: ExprParser.NegationContext) -> LinearExpression: return -ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#expression. - def visitExpression(self, ctx: ExprParser.ExpressionContext) -> ExpressionNode: + def visitExpression(self, ctx: ExprParser.ExpressionContext) -> LinearExpression: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#unsignedAtom. - def visitUnsignedAtom(self, ctx: ExprParser.UnsignedAtomContext) -> ExpressionNode: + def visitUnsignedAtom( + self, ctx: ExprParser.UnsignedAtomContext + ) -> LinearExpression: return ctx.atom().accept(self) # type: ignore - def _convert_identifier(self, identifier: str) -> ExpressionNode: + def _convert_identifier(self, identifier: str) -> LinearExpression: if self.identifiers.is_variable(identifier): return var(identifier) elif self.identifiers.is_parameter(identifier): @@ -105,70 +113,84 @@ def _convert_identifier(self, identifier: str) -> ExpressionNode: raise ValueError(f"{identifier} is not a valid variable or parameter name.") # Visit a parse tree produced by ExprParser#portField. - def visitPortField(self, ctx: ExprParser.PortFieldContext) -> ExpressionNode: - return PortFieldNode( + def visitPortField(self, ctx: ExprParser.PortFieldContext) -> LinearExpression: + return port_field( port_name=ctx.IDENTIFIER(0).getText(), # type: ignore field_name=ctx.IDENTIFIER(1).getText(), # type: ignore ) # Visit a parse tree produced by ExprParser#comparison. - def visitComparison(self, ctx: ExprParser.ComparisonContext) -> ExpressionNode: + def visitComparison(self, ctx: ExprParser.ComparisonContext) -> LinearExpression: op = ctx.COMPARISON().getText() # type: ignore exp1 = ctx.expr(0).accept(self) # type: ignore exp2 = ctx.expr(1).accept(self) # type: ignore comp = { - "=": Comparator.EQUAL, - "<=": Comparator.LESS_THAN, - ">=": Comparator.GREATER_THAN, + "=": LinearExpression.__eq__, + "<=": LinearExpression.__le__, + ">=": LinearExpression.__ge__, }[op] - return ComparisonNode(exp1, exp2, comp) + return comp(exp1, exp2) # Visit a parse tree produced by ExprParser#timeShift. - def visitTimeIndex(self, ctx: ExprParser.TimeIndexContext) -> ExpressionNode: - shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore - time_shifts = [e.accept(self) for e in ctx.expr()] # type: ignore - return shifted_expr.eval(time_shifts) - - # Visit a parse tree produced by ExprParser#rangeTimeShift. - def visitTimeRange(self, ctx: ExprParser.TimeRangeContext) -> ExpressionNode: + def visitTimeIndex(self, ctx: ExprParser.TimeIndexContext) -> LinearExpression: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore - expressions = [e.accept(self) for e in ctx.expr()] # type: ignore - return shifted_expr.eval(ExpressionRange(expressions[0], expressions[1])) + time_shift = ctx.expr().accept(self) # type: ignore + return shifted_expr.eval(time_shift) - def visitTimeShift(self, ctx: ExprParser.TimeShiftContext) -> ExpressionNode: + def visitTimeShift(self, ctx: ExprParser.TimeShiftContext) -> LinearExpression: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore - time_shifts = [s.accept(self) for s in ctx.shift()] # type: ignore + time_shift = ctx.shift().accept(self) # type: ignore # specifics for x[t] ... - if len(time_shifts) == 1 and expressions_equal(time_shifts[0], literal(0)): + if expressions_equal(time_shift, literal(0)): + # TODO: Should expression simplification be handled only in linear expression building rather than here in the parsing ? return shifted_expr - return shifted_expr.shift(time_shifts) + return shifted_expr.shift(time_shift) + + # Visit a parse tree produced by ExprParser#timeSum. + def visitTimeSum(self, ctx:ExprParser.TimeSumContext) -> LinearExpression: + summed_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) + expr = ctx.expr() + shift = ctx.shift() + time_range = ctx.timeRange() + time_shift_range = ctx.timeShiftRange() + if expr is not None: + return summed_expr.sum(eval=expr.accept(self)) + elif shift is not None: + return summed_expr.sum(shift=shift.accept(self)) + elif time_range is not None: + return summed_expr.sum(eval=time_range.accept(self)) + elif time_shift_range is not None: + return summed_expr.sum(shift=time_shift_range.accept(self)) + else: + return summed_expr.sum() + + # Visit a parse tree produced by ExprParser#timeShiftRange. + def visitTimeShiftRange(self, ctx:ExprParser.TimeShiftRangeContext): + return ExpressionRange(ctx.shift1, ctx.shift2) - def visitTimeShiftRange( - self, ctx: ExprParser.TimeShiftRangeContext - ) -> ExpressionNode: - shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore - shift1 = ctx.shift1.accept(self) # type: ignore - shift2 = ctx.shift2.accept(self) # type: ignore - return shifted_expr.shift(ExpressionRange(shift1, shift2)) + + # Visit a parse tree produced by ExprParser#timeRange. + def visitTimeRange(self, ctx:ExprParser.TimeRangeContext): + return ExpressionRange(ctx.expr1, ctx.expr2) # Visit a parse tree produced by ExprParser#function. - def visitFunction(self, ctx: ExprParser.FunctionContext) -> ExpressionNode: + def visitFunction(self, ctx: ExprParser.FunctionContext) -> LinearExpression: function_name: str = ctx.IDENTIFIER().getText() # type: ignore - operand: ExpressionNode = ctx.expr().accept(self) # type: ignore + operand: LinearExpression = ctx.expr().accept(self) # type: ignore fn = _FUNCTIONS.get(function_name, None) if fn is None: raise ValueError(f"Encountered invalid function name {function_name}") return fn(operand) # Visit a parse tree produced by ExprParser#shift. - def visitShift(self, ctx: ExprParser.ShiftContext) -> ExpressionNode: + def visitShift(self, ctx: ExprParser.ShiftContext) -> LinearExpression: if ctx.shift_expr() is None: # type: ignore return literal(0) shift = ctx.shift_expr().accept(self) # type: ignore return shift # Visit a parse tree produced by ExprParser#shiftAddsub. - def visitShiftAddsub(self, ctx: ExprParser.ShiftAddsubContext) -> ExpressionNode: + def visitShiftAddsub(self, ctx: ExprParser.ShiftAddsubContext) -> LinearExpression: left = ctx.shift_expr().accept(self) # type: ignore right = ctx.right_expr().accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -179,7 +201,7 @@ def visitShiftAddsub(self, ctx: ExprParser.ShiftAddsubContext) -> ExpressionNode raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#shiftMuldiv. - def visitShiftMuldiv(self, ctx: ExprParser.ShiftMuldivContext) -> ExpressionNode: + def visitShiftMuldiv(self, ctx: ExprParser.ShiftMuldivContext) -> LinearExpression: left = ctx.shift_expr().accept(self) # type: ignore right = ctx.right_expr().accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -192,14 +214,14 @@ def visitShiftMuldiv(self, ctx: ExprParser.ShiftMuldivContext) -> ExpressionNode # Visit a parse tree produced by ExprParser#signedExpression. def visitSignedExpression( self, ctx: ExprParser.SignedExpressionContext - ) -> ExpressionNode: + ) -> LinearExpression: if ctx.op.text == "-": # type: ignore return -ctx.expr().accept(self) # type: ignore else: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#signedAtom. - def visitSignedAtom(self, ctx: ExprParser.SignedAtomContext) -> ExpressionNode: + def visitSignedAtom(self, ctx: ExprParser.SignedAtomContext) -> LinearExpression: if ctx.op.text == "-": # type: ignore return -ctx.atom().accept(self) # type: ignore else: @@ -208,11 +230,11 @@ def visitSignedAtom(self, ctx: ExprParser.SignedAtomContext) -> ExpressionNode: # Visit a parse tree produced by ExprParser#rightExpression. def visitRightExpression( self, ctx: ExprParser.RightExpressionContext - ) -> ExpressionNode: + ) -> LinearExpression: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#rightMuldiv. - def visitRightMuldiv(self, ctx: ExprParser.RightMuldivContext) -> ExpressionNode: + def visitRightMuldiv(self, ctx: ExprParser.RightMuldivContext) -> LinearExpression: left = ctx.right_expr(0).accept(self) # type: ignore right = ctx.right_expr(1).accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -223,14 +245,14 @@ def visitRightMuldiv(self, ctx: ExprParser.RightMuldivContext) -> ExpressionNode raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#rightAtom. - def visitRightAtom(self, ctx: ExprParser.RightAtomContext) -> ExpressionNode: + def visitRightAtom(self, ctx: ExprParser.RightAtomContext) -> LinearExpression: return ctx.atom().accept(self) # type: ignore _FUNCTIONS = { - "sum": ExpressionNode.sum, - "sum_connections": ExpressionNode.sum_connections, - "expec": ExpressionNode.expec, + "sum": LinearExpression.sum, + "sum_connections": LinearExpression.sum_connections, + "expec": LinearExpression.expec, } @@ -238,7 +260,9 @@ class AntaresParseException(Exception): pass -def parse_expression(expression: str, identifiers: ModelIdentifiers) -> ExpressionNode: +def parse_expression( + expression: str, identifiers: ModelIdentifiers +) -> LinearExpression: """ Parses a string expression to create the corresponding AST representation. """ diff --git a/src/andromede/expression/port_operator.py b/src/andromede/expression/port_operator.py index 875d4f32..845ae693 100644 --- a/src/andromede/expression/port_operator.py +++ b/src/andromede/expression/port_operator.py @@ -30,4 +30,5 @@ class PortAggregator: @dataclass(frozen=True) class PortSum(PortAggregator): - pass + def __str__(self) -> str: + return "PortSum" diff --git a/src/andromede/expression/port_resolver.py b/src/andromede/expression/port_resolver.py deleted file mode 100644 index 6f333408..00000000 --- a/src/andromede/expression/port_resolver.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Dict, List - -from andromede.expression import CopyVisitor, sum_expressions, visit -from andromede.expression.expression import ( - AdditionNode, - ExpressionNode, - LiteralNode, - PortFieldAggregatorNode, - PortFieldNode, -) -from andromede.model.model import PortFieldId - - -@dataclass(eq=True, frozen=True) -class PortFieldKey: - """ - Identifies the expression node for one component and one port variable. - """ - - component_id: str - port_variable_id: PortFieldId - - -@dataclass(frozen=True) -class PortResolver(CopyVisitor): - """ - Duplicates the AST with replacement of port field nodes by - their corresponding expression. - """ - - component_id: str - ports_expressions: Dict[PortFieldKey, List[ExpressionNode]] - - def port_field(self, node: PortFieldNode) -> ExpressionNode: - expressions = self.ports_expressions[ - PortFieldKey( - self.component_id, PortFieldId(node.port_name, node.field_name) - ) - ] - if len(expressions) != 1: - raise ValueError( - f"Invalid number of expression for port : {node.port_name}" - ) - else: - return expressions[0] - - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode: - if node.aggregator != "PortSum": - raise NotImplementedError("Only PortSum is supported.") - port_field_node = node.operand - if not isinstance(port_field_node, PortFieldNode): - raise ValueError(f"Should be a portFieldNode : {port_field_node}") - - expressions = self.ports_expressions.get( - PortFieldKey( - self.component_id, - PortFieldId(port_field_node.port_name, port_field_node.field_name), - ), - [], - ) - return sum_expressions(expressions) - - -def resolve_port( - expression: ExpressionNode, - component_id: str, - ports_expressions: Dict[PortFieldKey, List[ExpressionNode]], -) -> ExpressionNode: - return visit(expression, PortResolver(component_id, ports_expressions)) diff --git a/src/andromede/expression/print.py b/src/andromede/expression/print.py index c01ae76f..6b1e1c84 100644 --- a/src/andromede/expression/print.py +++ b/src/andromede/expression/print.py @@ -13,29 +13,23 @@ from dataclasses import dataclass from typing import Dict -from andromede.expression.expression import ( - ComponentParameterNode, - ComponentVariableNode, - ExpressionNode, - PortFieldAggregatorNode, - PortFieldNode, -) -from andromede.expression.visitor import T - from .expression import ( AdditionNode, Comparator, ComparisonNode, + ComponentParameterNode, DivisionNode, + ExpressionNode, LiteralNode, MultiplicationNode, NegationNode, ParameterNode, + PortFieldAggregatorNode, + PortFieldNode, ScenarioOperatorNode, SubstractionNode, TimeAggregatorNode, TimeOperatorNode, - VariableNode, ) from .visitor import ExpressionVisitor, visit @@ -86,15 +80,9 @@ def comparison(self, node: ComparisonNode) -> str: right_value = visit(node.right, self) return f"{left_value} {op} {right_value}" - def variable(self, node: VariableNode) -> str: - return node.name - def parameter(self, node: ParameterNode) -> str: return node.name - def comp_variable(self, node: ComponentVariableNode) -> str: - return f"{node.component_id}.{node.name}" - def comp_parameter(self, node: ComponentParameterNode) -> str: return f"{node.component_id}.{node.name}" diff --git a/src/andromede/expression/scenario_operator.py b/src/andromede/expression/scenario_operator.py index 77f164e1..6eee08df 100644 --- a/src/andromede/expression/scenario_operator.py +++ b/src/andromede/expression/scenario_operator.py @@ -19,7 +19,7 @@ @dataclass(frozen=True) -class ScenarioOperator(ABC): +class ScenarioAggregator(ABC): def __str__(self) -> str: return NotImplemented @@ -30,7 +30,7 @@ def degree(cls) -> int: @dataclass(frozen=True) -class Expectation(ScenarioOperator): +class Expectation(ScenarioAggregator): def __str__(self) -> str: return "expec()" @@ -40,7 +40,7 @@ def degree(cls) -> int: @dataclass(frozen=True) -class Variance(ScenarioOperator): +class Variance(ScenarioAggregator): def __str__(self) -> str: return "variance()" diff --git a/src/andromede/expression/time_operator.py b/src/andromede/expression/time_operator.py index 63059528..3d078920 100644 --- a/src/andromede/expression/time_operator.py +++ b/src/andromede/expression/time_operator.py @@ -18,6 +18,8 @@ from dataclasses import dataclass from typing import Any, List, Tuple +from andromede.expression.expression import InstancesTimeIndex + @dataclass(frozen=True) class TimeOperator(ABC): @@ -27,24 +29,15 @@ class TimeOperator(ABC): - is_rolling: bool, if true, this means that the time_ids are to be understood relatively to the current timestep of the context AND that the represented expression will have to be instanciated for all timesteps. Otherwise, the time_ids are "absolute" times and the expression only has to be instantiated once. """ - time_ids: List[int] + time_ids: InstancesTimeIndex @classmethod @abstractmethod def rolling(cls) -> bool: raise NotImplementedError - def __post_init__(self) -> None: - if isinstance(self.time_ids, int): - object.__setattr__(self, "time_ids", [self.time_ids]) - elif isinstance(self.time_ids, range): - object.__setattr__(self, "time_ids", list(self.time_ids)) - - def key(self) -> Tuple[int, ...]: - return tuple(self.time_ids) - - def size(self) -> int: - return len(self.time_ids) + def key(self) -> InstancesTimeIndex: + return self.time_ids @dataclass(frozen=True) diff --git a/src/andromede/expression/value_provider.py b/src/andromede/expression/value_provider.py new file mode 100644 index 00000000..f15b943a --- /dev/null +++ b/src/andromede/expression/value_provider.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, List + + +@dataclass +class TimeScenarioIndices: + time_indices: List[int] + scenario_indices: List[int] + + +# TODO: Already define in study module, factorize this +@dataclass(frozen=True) +class TimeScenarioIndex: + time: int + scenario: int + + +# Given a list of time_indices and of scenario_indices, the value provider will get the parameter value for all couple (time, scenario) for time in time_indices and scenario in scenario_indices +class ValueProvider(ABC): + """ + Implementations are in charge of mapping parameters and variables to their values. + Depending on the implementation, evaluation may require a component id or not. + """ + + # TODO: To be removed, or should we keep it to evaluate solutions ? + @abstractmethod + def get_variable_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + ... + + # Need to have time_scenarios_indices as function argument as we do not want to create a Provider each time we have to get the value of a parameter at a different (time, scenario) index + @abstractmethod + def get_parameter_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + ... + + # TODO: To be removed ? + @abstractmethod + def get_component_variable_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + ... + + @abstractmethod + def get_component_parameter_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + ... + + # TODO: Should this really be an abstract method ? Or maybe, only the Provider in _make_value_provider should implement it. And the context attribute in the InstancesIndexVisitor is a ValueProvider that implements the parameter_is_constant_over_time method. Maybe create a child class of ValueProvider like TimeValueProvider ? + @abstractmethod + def parameter_is_constant_over_time(self, name: str) -> bool: + ... + + # There is probably a better place to put this.. + # Which is useful when evaluating the TimeSum operator over the whole block, to know which time steps to look for to get the parameter values + @staticmethod + @abstractmethod + def block_length() -> int: + ... + + @staticmethod + @abstractmethod + def scenarios() -> int: + ... diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index 25bbfb02..37d9f507 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -13,15 +13,13 @@ """ Defines abstract base class for visitors of expressions. """ -import typing from abc import ABC, abstractmethod from typing import Generic, Protocol, TypeVar -from andromede.expression.expression import ( +from .expression import ( AdditionNode, ComparisonNode, ComponentParameterNode, - ComponentVariableNode, DivisionNode, ExpressionNode, LiteralNode, @@ -34,7 +32,6 @@ SubstractionNode, TimeAggregatorNode, TimeOperatorNode, - VariableNode, ) T = TypeVar("T") @@ -77,10 +74,6 @@ def division(self, node: DivisionNode) -> T: def comparison(self, node: ComparisonNode) -> T: ... - @abstractmethod - def variable(self, node: VariableNode) -> T: - ... - @abstractmethod def parameter(self, node: ParameterNode) -> T: ... @@ -89,10 +82,6 @@ def parameter(self, node: ParameterNode) -> T: def comp_parameter(self, node: ComponentParameterNode) -> T: ... - @abstractmethod - def comp_variable(self, node: ComponentVariableNode) -> T: - ... - @abstractmethod def time_operator(self, node: TimeOperatorNode) -> T: ... @@ -122,14 +111,10 @@ def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T: return visitor.literal(root) elif isinstance(root, NegationNode): return visitor.negation(root) - elif isinstance(root, VariableNode): - return visitor.variable(root) elif isinstance(root, ParameterNode): return visitor.parameter(root) elif isinstance(root, ComponentParameterNode): return visitor.comp_parameter(root) - elif isinstance(root, ComponentVariableNode): - return visitor.comp_variable(root) elif isinstance(root, AdditionNode): return visitor.addition(root) elif isinstance(root, MultiplicationNode): diff --git a/src/andromede/libs/standard.py b/src/andromede/libs/standard.py index 9a6df24a..073d74d2 100644 --- a/src/andromede/libs/standard.py +++ b/src/andromede/libs/standard.py @@ -13,9 +13,10 @@ """ The standard module contains the definition of standard models. """ -from andromede.expression import literal, param, var -from andromede.expression.expression import ExpressionRange, port_field + +from andromede.expression.expression import ExpressionRange, literal, param from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression import port_field, var, wrap_in_linear_expr from andromede.model.constraint import Constraint from andromede.model.model import ModelPort, PortFieldDefinition, PortFieldId, model from andromede.model.parameter import float_parameter, int_parameter @@ -35,7 +36,7 @@ binding_constraints=[ Constraint( name="Balance", - expression=port_field("balance_port", "flow").sum_connections() + expression_init=port_field("balance_port", "flow").sum_connections() == literal(0), ) ], @@ -52,7 +53,7 @@ binding_constraints=[ Constraint( name="Balance", - expression=port_field("balance_port", "flow").sum_connections() + expression_init=port_field("balance_port", "flow").sum_connections() == var("spillage") - var("unsupplied_energy"), ) ], @@ -80,11 +81,11 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port_from", "flow"), - definition=-var("flow"), + definition_init=-var("flow"), ), PortFieldDefinition( port_field=PortFieldId("balance_port_to", "flow"), - definition=var("flow"), + definition_init=var("flow"), ), ], ) @@ -101,7 +102,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=-param("demand"), + definition_init=-param("demand"), ) ], ) @@ -120,12 +121,12 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= param("p_max") + name="Max generation", expression_init=var("generation") <= param("p_max") ), ], objective_operational_contribution=(param("cost") * var("generation")) @@ -145,17 +146,19 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= param("p_max") + name="Max generation", expression_init=var("generation") <= param("p_max") ), Constraint( name="Min generation", - expression=var("generation") - param("p_min"), - lower_bound=literal(0), + expression_init=var("generation") - param("p_min"), + lower_bound=wrap_in_linear_expr( + literal(0) + ), # wrap_in_linear_expr is not needed as done in __post_init__, it is used here only for type checking... ), # To test both ways of setting constraints ], objective_operational_contribution=(param("cost") * var("generation")) @@ -179,16 +182,16 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= param("p_max") + name="Max generation", expression_init=var("generation") <= param("p_max") ), Constraint( name="Total storage", - expression=var("generation").sum() <= param("full_storage"), + expression_init=var("generation").sum() <= param("full_storage"), ), ], objective_operational_contribution=(param("cost") * var("generation")) @@ -236,7 +239,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ @@ -254,17 +257,18 @@ ), Constraint( "Min up time", - var("nb_start") - .shift(ExpressionRange(-param("d_min_up") + 1, literal(0))) - .sum() + var("nb_start").sum( + shift=ExpressionRange(-param("d_min_up") + 1, literal(0)) + ) <= var("nb_on"), ), + # TODO : Improve API so that we are not forced to use sum() on one shifted element for ExpressionNode Constraint( "Min down time", - var("nb_stop") - .shift(ExpressionRange(-param("d_min_down") + 1, literal(0))) - .sum() - <= param("nb_units_max").shift(-param("d_min_down")) - var("nb_on"), + var("nb_stop").sum( + shift=ExpressionRange(-param("d_min_down") + 1, literal(0)) + ) + <= param("nb_units_max").shift(-param("d_min_down")).sum() - var("nb_on"), ), # It also works by writing ExpressionRange(-param("d_min_down") + 1, 0) as ExpressionRange's __post_init__ wraps integers to literal nodes. However, MyPy does not seem to infer that ExpressionRange's attributes are necessarily of ExpressionNode type and raises an error if the arguments in the constructor are integer (whereas it runs correctly), this why we specify it here with literal(0) instead of 0. ], @@ -313,7 +317,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ @@ -331,17 +335,17 @@ ), Constraint( "Min up time", - var("nb_start") - .shift(ExpressionRange(-param("d_min_up") + 1, literal(0))) - .sum() + var("nb_start").sum( + shift=ExpressionRange(-param("d_min_up") + 1, literal(0)) + ) <= var("nb_on"), ), Constraint( "Min down time", - var("nb_stop") - .shift(ExpressionRange(-param("d_min_down") + 1, literal(0))) - .sum() - <= param("nb_units_max").shift(-param("d_min_down")) - var("nb_on"), + var("nb_stop").sum( + shift=ExpressionRange(-param("d_min_down") + 1, literal(0)) + ) + <= param("nb_units_max").shift(-param("d_min_down")).sum() - var("nb_on"), ), ], objective_operational_contribution=(param("cost") * var("generation")) @@ -357,7 +361,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=-var("spillage"), + definition_init=-var("spillage"), ) ], objective_operational_contribution=(param("cost") * var("spillage")).sum().expec(), @@ -371,7 +375,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("unsupplied_energy"), + definition_init=var("unsupplied_energy"), ) ], objective_operational_contribution=(param("cost") * var("unsupplied_energy")) @@ -410,13 +414,13 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("withdrawal") - var("injection"), + definition_init=var("withdrawal") - var("injection"), ) ], constraints=[ Constraint( name="Level", - expression=var("level") + expression_init=var("level") - var("level").shift(-1) - param("efficiency") * var("injection") + var("withdrawal") diff --git a/src/andromede/libs/standard_sc.py b/src/andromede/libs/standard_sc.py index 25200638..12243bdc 100644 --- a/src/andromede/libs/standard_sc.py +++ b/src/andromede/libs/standard_sc.py @@ -10,9 +10,10 @@ # # This file is part of the Antares project. -from andromede.expression import literal, param, var -from andromede.expression.expression import port_field -from andromede.libs.standard import BALANCE_PORT_TYPE, CONSTANT, TIME_AND_SCENARIO_FREE + +from andromede.expression.expression import literal, param +from andromede.expression.linear_expression import port_field, var +from andromede.libs.standard import BALANCE_PORT_TYPE, CONSTANT from andromede.model import ( Constraint, ModelPort, @@ -40,11 +41,11 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("FlowDI", "flow"), - definition=-var("input"), + definition_init=-var("input"), ), PortFieldDefinition( port_field=PortFieldId("FlowDO", "flow"), - definition=var("input") * param("alpha"), + definition_init=var("input") * param("alpha"), ), ], ) @@ -67,15 +68,15 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("FlowDI1", "flow"), - definition=-var("input1"), + definition_init=-var("input1"), ), PortFieldDefinition( port_field=PortFieldId("FlowDI2", "flow"), - definition=-var("input2"), + definition_init=-var("input2"), ), PortFieldDefinition( port_field=PortFieldId("FlowDO", "flow"), - definition=var("input1") * param("alpha1") + definition_init=var("input1") * param("alpha1") + var("input2") * param("alpha2"), ), ], @@ -95,17 +96,17 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("FlowDI1", "flow"), - definition=var("input1"), + definition_init=var("input1"), ), PortFieldDefinition( port_field=PortFieldId("FlowDI2", "flow"), - definition=var("input2"), + definition_init=var("input2"), ), ], binding_constraints=[ Constraint( name="Conversion", - expression=var("input1") + var("input2") + expression_init=var("input1") + var("input2") == port_field("FlowDO", "flow").sum_connections(), ) ], @@ -124,13 +125,14 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("FlowDO", "flow"), - definition=var("input") * param("alpha"), + definition_init=var("input") * param("alpha"), ), ], binding_constraints=[ Constraint( name="Conversion", - expression=var("input") == port_field("FlowDI", "flow").sum_connections(), + expression_init=var("input") + == port_field("FlowDI", "flow").sum_connections(), ) ], ) @@ -163,11 +165,11 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("FlowP", "flow"), - definition=var("p"), + definition_init=var("p"), ), PortFieldDefinition( port_field=PortFieldId("OutCO2", "Q"), - definition=var("p") * param("emission_rate"), + definition_init=var("p") * param("emission_rate"), ), ], objective_operational_contribution=(param("cost") * var("p")).sum().expec(), @@ -184,7 +186,7 @@ binding_constraints=[ Constraint( name="Bound CO2", - expression=port_field("emissionCO2", "Q").sum_connections() + expression_init=port_field("emissionCO2", "Q").sum_connections() <= param("quota"), ) ], @@ -200,13 +202,13 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port_e", "flow"), - definition=var("p"), + definition_init=var("p"), ) ], binding_constraints=[ Constraint( name="Balance", - expression=var("p") + expression_init=var("p") == port_field("balance_port_n", "flow").sum_connections(), ) ], @@ -249,13 +251,13 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("withdrawal") - var("injection"), + definition_init=var("withdrawal") - var("injection"), ) ], constraints=[ Constraint( name="Level", - expression=var("level") + expression_init=var("level") - var("level").shift(-1) - param("efficiency") * var("injection") + var("withdrawal") diff --git a/src/andromede/model/common.py b/src/andromede/model/common.py index 5763f360..ffd0ffe5 100644 --- a/src/andromede/model/common.py +++ b/src/andromede/model/common.py @@ -14,6 +14,12 @@ Module for common classes used in models. """ from enum import Enum +from typing import Union + +from andromede.expression.expression import ExpressionNode +from andromede.expression.linear_expression import LinearExpression + +ValueOrExprNodeOrLinearExpr = Union[int, float, ExpressionNode, LinearExpression] class ValueType(Enum): diff --git a/src/andromede/model/constraint.py b/src/andromede/model/constraint.py index 9852046e..a4cd09a8 100644 --- a/src/andromede/model/constraint.py +++ b/src/andromede/model/constraint.py @@ -9,21 +9,16 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. -from dataclasses import dataclass -from typing import Any, Optional +from dataclasses import InitVar, dataclass, field +from typing import Any, Union -from andromede.expression.degree import is_constant -from andromede.expression.equality import ( - expressions_equal, - expressions_equal_if_present, +from andromede.expression.expression import ExpressionNode, literal +from andromede.expression.linear_expression import ( + LinearExpression, + StandaloneConstraint, + linear_expressions_equal, + wrap_in_linear_expr, ) -from andromede.expression.expression import ( - Comparator, - ComparisonNode, - ExpressionNode, - literal, -) -from andromede.expression.print import print_expr from andromede.model.common import ProblemContext @@ -36,66 +31,54 @@ class Constraint: """ name: str - expression: ExpressionNode - lower_bound: ExpressionNode - upper_bound: ExpressionNode - context: ProblemContext + # Used only for mypy type checking, we could have done the same by using only the attribute expression + expression_init: InitVar[ + Union[ExpressionNode, LinearExpression, StandaloneConstraint] + ] + expression: LinearExpression = field(init=False) + lower_bound: LinearExpression = field( + default=wrap_in_linear_expr(literal(-float("inf"))) + ) + upper_bound: LinearExpression = field( + default=wrap_in_linear_expr(literal(float("inf"))) + ) + context: ProblemContext = field(default=ProblemContext.OPERATIONAL) - def __init__( + def __post_init__( self, - name: str, - expression: ExpressionNode, - lower_bound: Optional[ExpressionNode] = None, - upper_bound: Optional[ExpressionNode] = None, - context: ProblemContext = ProblemContext.OPERATIONAL, + expression_init: Union[ExpressionNode, LinearExpression, StandaloneConstraint], ) -> None: - self.name = name - self.context = context + self.lower_bound = wrap_in_linear_expr(self.lower_bound) + self.upper_bound = wrap_in_linear_expr(self.upper_bound) - if isinstance(expression, ComparisonNode): - if lower_bound is not None or upper_bound is not None: + if isinstance(expression_init, StandaloneConstraint): + # Case where constraint is initialized with something like Constraint(var("x") <= var("y")) + if not self.lower_bound.is_unbound() or not self.upper_bound.is_unbound(): raise ValueError( "Both comparison between two expressions and a bound are specfied, set either only a comparison between expressions or a single linear expression with bounds." ) - merged_expr = expression.left - expression.right - self.expression = merged_expr + self.lower_bound = expression_init.lower_bound + self.upper_bound = expression_init.upper_bound + self.expression = expression_init.expression - if expression.comparator == Comparator.LESS_THAN: - # lhs - rhs <= 0 - self.upper_bound = literal(0) - self.lower_bound = literal(-float("inf")) - elif expression.comparator == Comparator.GREATER_THAN: - # lhs - rhs >= 0 - self.lower_bound = literal(0) - self.upper_bound = literal(float("inf")) - else: # lhs - rhs == 0 - self.lower_bound = literal(0) - self.upper_bound = literal(0) else: - for bound in [lower_bound, upper_bound]: - if bound is not None and not is_constant(bound): + self.expression = wrap_in_linear_expr(expression_init) + for bound in [self.lower_bound, self.upper_bound]: + if not bound.is_constant(): raise ValueError( - f"The bounds of a constraint should not contain variables, {print_expr(bound)} was given." + f"The bounds of a constraint should not contain variables, {str(bound)} was given." ) - self.expression = expression - if lower_bound is not None: - self.lower_bound = lower_bound - else: - self.lower_bound = literal(-float("inf")) - - if upper_bound is not None: - self.upper_bound = upper_bound - else: - self.upper_bound = literal(float("inf")) - def __eq__(self, other: Any) -> bool: if not isinstance(other, Constraint): return False return ( self.name == other.name - and expressions_equal(self.expression, other.expression) - and expressions_equal_if_present(self.lower_bound, other.lower_bound) - and expressions_equal_if_present(self.upper_bound, other.upper_bound) + and linear_expressions_equal(self.expression, other.expression) + and linear_expressions_equal(self.lower_bound, other.lower_bound) + and linear_expressions_equal(self.upper_bound, other.upper_bound) ) + + def __str__(self) -> str: + return f"{str(self.lower_bound)} <= {str(self.expression)} <= {str(self.upper_bound)}" diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index 3997f43d..93011db6 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -16,43 +16,43 @@ defining parameters, variables, and equations. """ import itertools -from dataclasses import dataclass, field -from typing import Dict, Iterable, Optional +from dataclasses import InitVar, dataclass, field +from typing import Dict, Iterable, Optional, Union -from andromede.expression import ( +from andromede.expression.expression import ( AdditionNode, + BinaryOperatorNode, ComparisonNode, + ComponentParameterNode, DivisionNode, ExpressionNode, - ExpressionVisitor, LiteralNode, MultiplicationNode, NegationNode, ParameterNode, - SubstractionNode, - VariableNode, -) -from andromede.expression.degree import is_linear -from andromede.expression.expression import ( - BinaryOperatorNode, - ComponentParameterNode, - ComponentVariableNode, PortFieldAggregatorNode, PortFieldNode, ScenarioOperatorNode, + SubstractionNode, TimeAggregatorNode, TimeOperatorNode, ) -from andromede.expression.indexing import IndexingStructureProvider, compute_indexation +from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.visitor import T, visit +from andromede.expression.linear_expression import ( + LinearExpression, + is_linear, + wrap_in_linear_expr, + wrap_in_linear_expr_if_present, +) +from andromede.expression.visitor import ExpressionVisitor, visit +from andromede.model.common import ValueOrExprNodeOrLinearExpr from andromede.model.constraint import Constraint from andromede.model.parameter import Parameter from andromede.model.port import PortType from andromede.model.variable import Variable -# TODO: Introduce bool_variable ? def _make_structure_provider(model: "Model") -> IndexingStructureProvider: class Provider(IndexingStructureProvider): def get_parameter_structure(self, name: str) -> IndexingStructure: @@ -79,19 +79,18 @@ def get_component_variable_structure( def _is_objective_contribution_valid( - model: "Model", objective_contribution: ExpressionNode + model: "Model", objective_contribution: LinearExpression ) -> bool: if not is_linear(objective_contribution): raise ValueError("Objective contribution must be a linear expression.") data_structure_provider = _make_structure_provider(model) - objective_structure = compute_indexation( - objective_contribution, data_structure_provider + objective_structure = objective_contribution.compute_indexation( + data_structure_provider ) if objective_structure != IndexingStructure(time=False, scenario=False): raise ValueError("Objective contribution should be a real-valued expression.") - # TODO: We should also check that the number of instances is equal to 1, but this would require a linearization here, do not want to do that for now... return True @@ -121,14 +120,21 @@ class PortFieldDefinition: """ port_field: PortFieldId - definition: ExpressionNode - - def __post_init__(self) -> None: + # Used only for type checking... + definition_init: InitVar[Union[ExpressionNode, LinearExpression]] + definition: LinearExpression = field(init=False) + + def __post_init__( + self, definition_init: Union[ExpressionNode, LinearExpression] + ) -> None: + object.__setattr__(self, "definition", wrap_in_linear_expr(definition_init)) _validate_port_field_expression(self) def port_field_def( - port_name: str, field_name: str, definition: ExpressionNode + port_name: str, + field_name: str, + definition: Union[ExpressionNode, LinearExpression], ) -> PortFieldDefinition: return PortFieldDefinition(PortFieldId(port_name, field_name), definition) @@ -146,8 +152,8 @@ class Model: inter_block_dyn: bool = False parameters: Dict[str, Parameter] = field(default_factory=dict) variables: Dict[str, Variable] = field(default_factory=dict) - objective_operational_contribution: Optional[ExpressionNode] = None - objective_investment_contribution: Optional[ExpressionNode] = None + objective_operational_contribution: Optional[LinearExpression] = None + objective_investment_contribution: Optional[LinearExpression] = None ports: Dict[str, ModelPort] = field(default_factory=dict) # key = port name port_fields_definitions: Dict[PortFieldId, PortFieldDefinition] = field( default_factory=dict @@ -190,8 +196,8 @@ def model( binding_constraints: Optional[Iterable[Constraint]] = None, parameters: Optional[Iterable[Parameter]] = None, variables: Optional[Iterable[Variable]] = None, - objective_operational_contribution: Optional[ExpressionNode] = None, - objective_investment_contribution: Optional[ExpressionNode] = None, + objective_operational_contribution: Optional[ValueOrExprNodeOrLinearExpr] = None, + objective_investment_contribution: Optional[ValueOrExprNodeOrLinearExpr] = None, inter_block_dyn: bool = False, ports: Optional[Iterable[ModelPort]] = None, port_fields_definitions: Optional[Iterable[PortFieldDefinition]] = None, @@ -212,25 +218,31 @@ def model( return Model( id=id, constraints={c.name: c for c in constraints} if constraints else {}, - binding_constraints={c.name: c for c in binding_constraints} - if binding_constraints - else {}, + binding_constraints=( + {c.name: c for c in binding_constraints} if binding_constraints else {} + ), parameters={p.name: p for p in parameters} if parameters else {}, variables={v.name: v for v in variables} if variables else {}, - objective_operational_contribution=objective_operational_contribution, - objective_investment_contribution=objective_investment_contribution, + objective_operational_contribution=wrap_in_linear_expr_if_present( + objective_operational_contribution + ), + objective_investment_contribution=wrap_in_linear_expr_if_present( + objective_investment_contribution + ), inter_block_dyn=inter_block_dyn, ports=existing_port_names, - port_fields_definitions={d.port_field: d for d in port_fields_definitions} - if port_fields_definitions - else {}, + port_fields_definitions=( + {d.port_field: d for d in port_fields_definitions} + if port_fields_definitions + else {} + ), ) class _PortFieldExpressionChecker(ExpressionVisitor[None]): """ Visits the whole expression to check there is no: - comparison, other port field, component-associated parametrs or variables... + comparison, other port field, component-associated parameters or variables... """ def literal(self, node: LiteralNode) -> None: @@ -258,9 +270,6 @@ def division(self, node: DivisionNode) -> None: def comparison(self, node: ComparisonNode) -> None: raise ValueError("Port definition cannot contain a comparison operator.") - def variable(self, node: VariableNode) -> None: - pass - def parameter(self, node: ParameterNode) -> None: pass @@ -269,11 +278,6 @@ def comp_parameter(self, node: ComponentParameterNode) -> None: "Port definition must not contain a parameter associated to a component." ) - def comp_variable(self, node: ComponentVariableNode) -> None: - raise ValueError( - "Port definition must not contain a variable associated to a component." - ) - def time_operator(self, node: TimeOperatorNode) -> None: visit(node.operand, self) @@ -291,4 +295,34 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> None: def _validate_port_field_expression(definition: PortFieldDefinition) -> None: - visit(definition.definition, _PortFieldExpressionChecker()) + """ + Check there is no: + comparison, other port field, component-associated parameters or variables... + """ + _check_port_field_expression_type(definition) + _check_no_reference_to_other_port_field(definition) + _check_no_component_associated_variable_or_parameter(definition) + + +def _check_port_field_expression_type(definition: PortFieldDefinition) -> None: + if not isinstance(definition.definition, LinearExpression): + raise TypeError( + f"Port field definition should be a LinearExpression, not a {type(definition.definition)}" + ) + + +def _check_no_reference_to_other_port_field(definition: PortFieldDefinition) -> None: + if definition.definition.port_field_terms: + raise ValueError("Port definition cannot reference another port field.") + + +def _check_no_component_associated_variable_or_parameter( + definition: PortFieldDefinition, +) -> None: + for term in definition.definition.terms.values(): + if term.component_id: + raise ValueError( + "Port definition must not contain a variable associated to a component." + ) + visit(term.coefficient, _PortFieldExpressionChecker()) + visit(definition.definition.constant, _PortFieldExpressionChecker()) diff --git a/src/andromede/model/probability_law.py b/src/andromede/model/probability_law.py deleted file mode 100644 index 62e3dc6a..00000000 --- a/src/andromede/model/probability_law.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -""" -Describes probability distributions used in the models -""" - -from abc import ABC -from dataclasses import dataclass -from typing import List - -import numpy as np - -from andromede.expression.expression import ExpressionNode - - -class AbstractProbabilityLaw(ABC): - def get_sample(self, size: int) -> List[float]: - return NotImplemented - - -@dataclass(frozen=True) -class Normal(AbstractProbabilityLaw): - mean: ExpressionNode - standard_deviation: ExpressionNode - - def get_sample(self, size: int) -> List[float]: - return NotImplemented - - -@dataclass(frozen=True) -class Uniform(AbstractProbabilityLaw): - lower_bound: ExpressionNode - upper_bound: ExpressionNode - - def get_sample(self, size: int) -> List[float]: - return NotImplemented - - -@dataclass(frozen=True) -class UniformIntegers(AbstractProbabilityLaw): - lower_bound: ExpressionNode - upper_bound: ExpressionNode - - def get_sample(self, size: int) -> List[float]: - return NotImplemented diff --git a/src/andromede/model/resolve_library.py b/src/andromede/model/resolve_library.py index 546acdbd..22820e2e 100644 --- a/src/andromede/model/resolve_library.py +++ b/src/andromede/model/resolve_library.py @@ -9,10 +9,16 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. -from typing import Dict, List, Optional +from typing import Dict, List, Optional, TypedDict, Union -from andromede.expression import ExpressionNode +# from andromede.expression import ExpressionNode +from andromede.expression.expression import ExpressionNode from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression import ( + LinearExpression, + StandaloneConstraint, + wrap_in_linear_expr_if_present, +) from andromede.expression.parsing.parse_expression import ( ModelIdentifiers, parse_expression, @@ -123,7 +129,7 @@ def _to_parameter(param: InputParameter) -> Parameter: def _to_expression_if_present( expr: Optional[str], identifiers: ModelIdentifiers -) -> Optional[ExpressionNode]: +) -> Optional[LinearExpression]: if not expr: return None return parse_expression(expr, identifiers) @@ -136,18 +142,39 @@ def _to_variable(var: InputVariable, identifiers: ModelIdentifiers) -> Variable: var.variable_type ], structure=IndexingStructure(var.time_dependent, var.scenario_dependent), - lower_bound=_to_expression_if_present(var.lower_bound, identifiers), - upper_bound=_to_expression_if_present(var.upper_bound, identifiers), + lower_bound=wrap_in_linear_expr_if_present( + _to_expression_if_present(var.lower_bound, identifiers) + ), + upper_bound=wrap_in_linear_expr_if_present( + _to_expression_if_present(var.upper_bound, identifiers) + ), context=ProblemContext.OPERATIONAL, ) +# Used only for mypy +class ConstraintKwargs(TypedDict, total=False): + name: str + expression_init: Union[ExpressionNode, LinearExpression, StandaloneConstraint] + lower_bound: LinearExpression + upper_bound: LinearExpression + + def _to_constraint( constraint: InputConstraint, identifiers: ModelIdentifiers ) -> Constraint: - return Constraint( - name=constraint.name, - expression=parse_expression(constraint.expression, identifiers), - lower_bound=_to_expression_if_present(constraint.lower_bound, identifiers), - upper_bound=_to_expression_if_present(constraint.upper_bound, identifiers), + kwargs: ConstraintKwargs = { + "name": constraint.name, + "expression_init": parse_expression(constraint.expression, identifiers), + } + lb = wrap_in_linear_expr_if_present( + _to_expression_if_present(constraint.lower_bound, identifiers) + ) + ub = wrap_in_linear_expr_if_present( + _to_expression_if_present(constraint.upper_bound, identifiers) ) + if lb is not None: + kwargs["lower_bound"] = lb + if ub is not None: + kwargs["upper_bound"] = ub + return Constraint(**kwargs) diff --git a/src/andromede/model/variable.py b/src/andromede/model/variable.py index 109fbd9e..a1d94447 100644 --- a/src/andromede/model/variable.py +++ b/src/andromede/model/variable.py @@ -13,11 +13,19 @@ from dataclasses import dataclass from typing import Any, Optional -from andromede.expression import ExpressionNode, literal -from andromede.expression.degree import is_constant -from andromede.expression.equality import expressions_equal_if_present +from andromede.expression.expression import literal from andromede.expression.indexing_structure import IndexingStructure -from andromede.model.common import ProblemContext, ValueType +from andromede.expression.linear_expression import ( + LinearExpression, + linear_expressions_equal_if_present, + wrap_in_linear_expr, + wrap_in_linear_expr_if_present, +) +from andromede.model.common import ( + ProblemContext, + ValueOrExprNodeOrLinearExpr, + ValueType, +) @dataclass @@ -28,15 +36,15 @@ class Variable: name: str data_type: ValueType - lower_bound: Optional[ExpressionNode] - upper_bound: Optional[ExpressionNode] + lower_bound: Optional[LinearExpression] + upper_bound: Optional[LinearExpression] structure: IndexingStructure context: ProblemContext def __post_init__(self) -> None: - if self.lower_bound and not is_constant(self.lower_bound): + if self.lower_bound and not self.lower_bound.is_constant(): raise ValueError("Lower bounds of variables must be constant") - if self.upper_bound and not is_constant(self.upper_bound): + if self.upper_bound and not self.upper_bound.is_constant(): raise ValueError("Upper bounds of variables must be constant") def __eq__(self, other: Any) -> bool: @@ -45,21 +53,26 @@ def __eq__(self, other: Any) -> bool: return ( self.name == other.name and self.data_type == other.data_type - and expressions_equal_if_present(self.lower_bound, other.lower_bound) - and expressions_equal_if_present(self.upper_bound, other.upper_bound) + and linear_expressions_equal_if_present(self.lower_bound, other.lower_bound) + and linear_expressions_equal_if_present(self.upper_bound, other.upper_bound) and self.structure == other.structure ) def int_variable( name: str, - lower_bound: Optional[ExpressionNode] = None, - upper_bound: Optional[ExpressionNode] = None, + lower_bound: Optional[ValueOrExprNodeOrLinearExpr] = None, + upper_bound: Optional[ValueOrExprNodeOrLinearExpr] = None, structure: IndexingStructure = IndexingStructure(True, True), context: ProblemContext = ProblemContext.OPERATIONAL, ) -> Variable: return Variable( - name, ValueType.INTEGER, lower_bound, upper_bound, structure, context + name, + ValueType.INTEGER, + wrap_in_linear_expr_if_present(lower_bound), + wrap_in_linear_expr_if_present(upper_bound), + structure, + context, ) @@ -68,14 +81,28 @@ def bool_var( structure: IndexingStructure = IndexingStructure(True, True), context: ProblemContext = ProblemContext.OPERATIONAL, ) -> Variable: - return Variable(name, ValueType.BOOL, literal(0), literal(1), structure, context) + return Variable( + name, + ValueType.BOOL, + wrap_in_linear_expr(literal(0)), + wrap_in_linear_expr(literal(1)), + structure, + context, + ) def float_variable( name: str, - lower_bound: Optional[ExpressionNode] = None, - upper_bound: Optional[ExpressionNode] = None, + lower_bound: Optional[ValueOrExprNodeOrLinearExpr] = None, + upper_bound: Optional[ValueOrExprNodeOrLinearExpr] = None, structure: IndexingStructure = IndexingStructure(True, True), context: ProblemContext = ProblemContext.OPERATIONAL, ) -> Variable: - return Variable(name, ValueType.FLOAT, lower_bound, upper_bound, structure, context) + return Variable( + name, + ValueType.FLOAT, + wrap_in_linear_expr_if_present(lower_bound), + wrap_in_linear_expr_if_present(upper_bound), + structure, + context, + ) diff --git a/src/andromede/simulation/__init__.py b/src/andromede/simulation/__init__.py index 23e0a855..94cac7d3 100644 --- a/src/andromede/simulation/__init__.py +++ b/src/andromede/simulation/__init__.py @@ -14,7 +14,8 @@ BendersDecomposedProblem, build_benders_decomposed_problem, ) -from .optimization import BlockBorderManagement, OptimizationProblem, build_problem +from .optimization import OptimizationProblem, build_problem +from .optimization_context import BlockBorderManagement from .output_values import BendersSolution, OutputValues from .runner import BendersRunner, MergeMPSRunner from .strategy import MergedProblemStrategy, ModelSelectionStrategy diff --git a/src/andromede/simulation/benders_decomposed.py b/src/andromede/simulation/benders_decomposed.py index 08d718f1..f5ff23a0 100644 --- a/src/andromede/simulation/benders_decomposed.py +++ b/src/andromede/simulation/benders_decomposed.py @@ -18,25 +18,20 @@ import pathlib from typing import Any, Dict, List, Optional -from andromede.simulation.optimization import ( - BlockBorderManagement, - OptimizationProblem, - build_problem, -) -from andromede.simulation.output_values import ( +from andromede.study.data import DataBase +from andromede.study.network import Network +from andromede.utils import read_json, serialize, serialize_json + +from .optimization import OptimizationProblem, build_problem +from .optimization_context import BlockBorderManagement +from .output_values import ( BendersDecomposedSolution, BendersMergedSolution, BendersSolution, ) -from andromede.simulation.runner import BendersRunner, MergeMPSRunner -from andromede.simulation.strategy import ( - InvestmentProblemStrategy, - OperationalProblemStrategy, -) -from andromede.simulation.time_block import TimeBlock -from andromede.study.data import DataBase -from andromede.study.network import Network -from andromede.utils import read_json, serialize, serialize_json +from .runner import BendersRunner, MergeMPSRunner +from .strategy import InvestmentProblemStrategy, OperationalProblemStrategy +from .time_block import TimeBlock class BendersDecomposedProblem: diff --git a/src/andromede/simulation/linear_expression.py b/src/andromede/simulation/linear_expression.py deleted file mode 100644 index 344a0988..00000000 --- a/src/andromede/simulation/linear_expression.py +++ /dev/null @@ -1,418 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -""" -Specific modelling for "instantiated" linear expressions, -with only variables and literal coefficients. -""" -from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, TypeVar, Union - -from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.scenario_operator import ScenarioOperator -from andromede.expression.time_operator import TimeAggregator, TimeOperator -from andromede.model.model import PortFieldId - -T = TypeVar("T") - -EPS = 10 ** (-16) - - -def is_close_abs(value: float, other_value: float, eps: float) -> bool: - return abs(value - other_value) < eps - - -def is_zero(value: float) -> bool: - return is_close_abs(value, 0, EPS) - - -def is_one(value: float) -> bool: - return is_close_abs(value, 1, EPS) - - -def is_minus_one(value: float) -> bool: - return is_close_abs(value, -1, EPS) - - -@dataclass(frozen=True) -class TermKey: - - """ - Utility class to provide key for a term that contains all term information except coefficient - """ - - component_id: str - variable_name: str - time_operator: Optional[TimeOperator] - time_aggregator: Optional[TimeAggregator] - scenario_operator: Optional[ScenarioOperator] - - -@dataclass(frozen=True) -class Term: - """ - One term in a linear expression: for example the "10x" par in "10x + 5y + 5" - - Args: - coefficient: the coefficient for that term, for example "10" in "10x" - variable_name: the name of the variable, for example "x" in "10x" - """ - - coefficient: float - component_id: str - variable_name: str - structure: IndexingStructure = field( - default=IndexingStructure(time=True, scenario=True) - ) - time_operator: Optional[TimeOperator] = None - time_aggregator: Optional[TimeAggregator] = None - scenario_operator: Optional[ScenarioOperator] = None - - # TODO: It may be useful to define __add__, __sub__, etc on terms, which should return a linear expression ? - - def is_zero(self) -> bool: - return is_zero(self.coefficient) - - def str_for_coeff(self) -> str: - str_for_coeff = "" - if is_one(self.coefficient): - str_for_coeff = "+" - elif is_minus_one(self.coefficient): - str_for_coeff = "-" - else: - str_for_coeff = "{:+g}".format(self.coefficient) - return str_for_coeff - - def __str__(self) -> str: - # Useful for debugging tests - result = self.str_for_coeff() + str(self.variable_name) - if self.time_operator is not None: - result += f".{str(self.time_operator)}" - if self.time_aggregator is not None: - result += f".{str(self.time_aggregator)}" - if self.scenario_operator is not None: - result += f".{str(self.scenario_operator)}" - return result - - def number_of_instances(self) -> int: - if self.time_aggregator is not None: - return self.time_aggregator.size() - else: - if self.time_operator is not None: - return self.time_operator.size() - else: - return 1 - - -def generate_key(term: Term) -> TermKey: - return TermKey( - term.component_id, - term.variable_name, - term.time_operator, - term.time_aggregator, - term.scenario_operator, - ) - - -def _merge_dicts( - lhs: Dict[TermKey, Term], - rhs: Dict[TermKey, Term], - merge_func: Callable[[Term, Term], Term], - neutral: float, -) -> Dict[TermKey, Term]: - res = {} - for k, v in lhs.items(): - res[k] = merge_func( - v, - rhs.get( - k, - Term( - neutral, - v.component_id, - v.variable_name, - v.structure, - v.time_operator, - v.time_aggregator, - v.scenario_operator, - ), - ), - ) - for k, v in rhs.items(): - if k not in lhs: - res[k] = merge_func( - Term( - neutral, - v.component_id, - v.variable_name, - v.structure, - v.time_operator, - v.time_aggregator, - v.scenario_operator, - ), - v, - ) - return res - - -def _merge_is_possible(lhs: Term, rhs: Term) -> None: - if lhs.component_id != rhs.component_id or lhs.variable_name != rhs.variable_name: - raise ValueError("Cannot merge terms for different variables") - if ( - lhs.time_operator != rhs.time_operator - or lhs.time_aggregator != rhs.time_aggregator - or lhs.scenario_operator != rhs.scenario_operator - ): - raise ValueError("Cannot merge terms with different operators") - if lhs.structure != rhs.structure: - raise ValueError("Cannot merge terms with different structures") - - -def _add_terms(lhs: Term, rhs: Term) -> Term: - _merge_is_possible(lhs, rhs) - return Term( - lhs.coefficient + rhs.coefficient, - lhs.component_id, - lhs.variable_name, - lhs.structure, - lhs.time_operator, - lhs.time_aggregator, - lhs.scenario_operator, - ) - - -def _substract_terms(lhs: Term, rhs: Term) -> Term: - _merge_is_possible(lhs, rhs) - return Term( - lhs.coefficient - rhs.coefficient, - lhs.component_id, - lhs.variable_name, - lhs.structure, - lhs.time_operator, - lhs.time_aggregator, - lhs.scenario_operator, - ) - - -class LinearExpression: - """ - Represents a linear expression with respect to variable names, for example 10x + 5y + 2. - - Operators may be used for construction. - - Args: - terms: the list of variable terms, for example 10x and 5y in "10x + 5y + 2". - constant: the constant term, for example 2 in "10x + 5y + 2" - - Examples: - Operators may be used for construction: - - >>> LinearExpression([], 10) + LinearExpression([Term(10, "x")], 0) - LinearExpression([Term(10, "x")], 10) - """ - - terms: Dict[TermKey, Term] - constant: float - - def __init__( - self, - terms: Optional[Union[Dict[TermKey, Term], List[Term]]] = None, - constant: Optional[float] = None, - ) -> None: - self.constant = 0 - self.terms = {} - - if constant is not None: - # += b - self.constant = constant - if terms is not None: - # Allows to give two different syntax in the constructor: - # - List[Term] is natural - # - Dict[str, Term] is useful when constructing a linear expression from the terms of another expression - if isinstance(terms, dict): - for term_key, term in terms.items(): - if not term.is_zero(): - self.terms[term_key] = term - elif isinstance(terms, list): - for term in terms: - if not term.is_zero(): - self.terms[generate_key(term)] = term - else: - raise TypeError( - f"Terms must be either of type Dict[str, Term] or List[Term], whereas {terms} is of type {type(terms)}" - ) - - def is_zero(self) -> bool: - return len(self.terms) == 0 and is_zero(self.constant) - - def str_for_constant(self) -> str: - if is_zero(self.constant): - return "" - else: - return "{:+g}".format(self.constant) - - def __str__(self) -> str: - # Useful for debugging tests - result = "" - if self.is_zero(): - result += "0" - else: - for term in self.terms.values(): - result += str(term) - - result += self.str_for_constant() - - return result - - def __eq__(self, rhs: object) -> bool: - return ( - isinstance(rhs, LinearExpression) - and is_close_abs(self.constant, rhs.constant, EPS) - and self.terms - == rhs.terms # /!\ There may be float equality comparison in the terms values - ) - - def __iadd__(self, rhs: "LinearExpression") -> "LinearExpression": - if not isinstance(rhs, LinearExpression): - return NotImplemented - self.constant += rhs.constant - aggregated_terms = _merge_dicts(self.terms, rhs.terms, _add_terms, 0) - self.terms = aggregated_terms - self.remove_zeros_from_terms() - return self - - def __add__(self, rhs: "LinearExpression") -> "LinearExpression": - result = LinearExpression() - result += self - result += rhs - return result - - def __isub__(self, rhs: "LinearExpression") -> "LinearExpression": - if not isinstance(rhs, LinearExpression): - return NotImplemented - self.constant -= rhs.constant - aggregated_terms = _merge_dicts(self.terms, rhs.terms, _substract_terms, 0) - self.terms = aggregated_terms - self.remove_zeros_from_terms() - return self - - def __sub__(self, rhs: "LinearExpression") -> "LinearExpression": - result = LinearExpression() - result += self - result -= rhs - return result - - def __neg__(self) -> "LinearExpression": - result = LinearExpression() - result -= self - return result - - def __imul__(self, rhs: "LinearExpression") -> "LinearExpression": - if not isinstance(rhs, LinearExpression): - return NotImplemented - - if self.terms and rhs.terms: - raise ValueError("Cannot multiply two non constant expression") - else: - if self.terms: - left_expr = self - const_expr = rhs - else: - # It is possible that both expr are constant - left_expr = rhs - const_expr = self - if is_close_abs(const_expr.constant, 0, EPS): - return LinearExpression() - elif is_close_abs(const_expr.constant, 1, EPS): - _copy_expression(left_expr, self) - else: - left_expr.constant *= const_expr.constant - for term_key, term in left_expr.terms.items(): - left_expr.terms[term_key] = Term( - term.coefficient * const_expr.constant, - term.component_id, - term.variable_name, - term.structure, - term.time_operator, - term.time_aggregator, - term.scenario_operator, - ) - _copy_expression(left_expr, self) - return self - - def __mul__(self, rhs: "LinearExpression") -> "LinearExpression": - result = LinearExpression() - result += self - result *= rhs - return result - - def __itruediv__(self, rhs: "LinearExpression") -> "LinearExpression": - if not isinstance(rhs, LinearExpression): - return NotImplemented - - if rhs.terms: - raise ValueError("Cannot divide by a non constant expression") - else: - if is_close_abs(rhs.constant, 0, EPS): - raise ZeroDivisionError("Cannot divide expression by zero") - elif is_close_abs(rhs.constant, 1, EPS): - return self - else: - self.constant /= rhs.constant - for term_key, term in self.terms.items(): - self.terms[term_key] = Term( - term.coefficient / rhs.constant, - term.component_id, - term.variable_name, - term.structure, - term.time_operator, - term.time_aggregator, - term.scenario_operator, - ) - return self - - def __truediv__(self, rhs: "LinearExpression") -> "LinearExpression": - result = LinearExpression() - result += self - result /= rhs - - return result - - def remove_zeros_from_terms(self) -> None: - # TODO: Not optimized, checks could be done directly when doing operations on self.linear_term to avoid copies - for term_key, term in self.terms.copy().items(): - if is_close_abs(term.coefficient, 0, EPS): - del self.terms[term_key] - - def is_valid(self) -> bool: - nb_instances = None - for term in self.terms.values(): - term_instances = term.number_of_instances() - if nb_instances is None: - nb_instances = term_instances - else: - if term_instances != nb_instances: - raise ValueError( - "The terms of the linear expression {self} do not have the same number of instances" - ) - return True - - def number_of_instances(self) -> int: - if self.is_valid(): - # All terms have the same number of instances, just pick one - return self.terms[next(iter(self.terms))].number_of_instances() - else: - raise ValueError(f"{self} is not a valid linear expression") - - -def _copy_expression(src: LinearExpression, dst: LinearExpression) -> None: - dst.terms = src.terms - dst.constant = src.constant diff --git a/src/andromede/simulation/linear_expression_resolver.py b/src/andromede/simulation/linear_expression_resolver.py new file mode 100644 index 00000000..7bfbf2ca --- /dev/null +++ b/src/andromede/simulation/linear_expression_resolver.py @@ -0,0 +1,126 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from dataclasses import dataclass +from typing import Dict, List + +import ortools.linear_solver.pywraplp as lp + +from andromede.expression.evaluate_parameters import ( + get_time_ids_from_instances_index, + resolve_coefficient, +) +from andromede.expression.indexing_structure import RowIndex +from andromede.expression.linear_expression import LinearExpression, Term +from andromede.expression.scenario_operator import Expectation +from andromede.expression.time_operator import TimeShift +from andromede.expression.value_provider import ( + TimeScenarioIndex, + TimeScenarioIndices, + ValueProvider, +) + +from .optimization_context import OptimizationContext +from .resolved_linear_expression import ResolvedLinearExpression, ResolvedTerm + + +@dataclass +class LinearExpressionResolver: + context: OptimizationContext + value_provider: ValueProvider + + def resolve( + self, expression: LinearExpression, row_id: RowIndex + ) -> ResolvedLinearExpression: + resolved_terms = [] + for term in expression.terms.values(): + # Here, the value provide is used only to evaluate possible time operator args if the term has one + resolved_variables = self.resolve_variables(term, row_id) + + # TODO: Contrary to the time aggregator that does a sum which is the default behaviour when append resolved terms, expectation performs an averaging, so weights must be included in coefficients. We feel here that we could generalize time and scenario aggregation over variables with more general operators, the following lines are very specific to expectation with same weights over all scenarios + weight: float = 1 + if isinstance(term.scenario_aggregator, Expectation): + weight = 1 / self.value_provider.scenarios() + + for ts_id, lp_variable in resolved_variables.items(): + # TODO: Could we check in which case coeff resolution leads to the same result for each element in the for loop ? When there is only a literal, etc, etc ? + resolved_coeff = resolve_coefficient( + term.coefficient, + self.value_provider, + RowIndex(ts_id.time, ts_id.scenario), + ) + resolved_terms.append( + ResolvedTerm(weight * resolved_coeff, lp_variable) + ) + + resolved_constant = resolve_coefficient( + expression.constant, self.value_provider, row_id + ) + return ResolvedLinearExpression(resolved_terms, resolved_constant) + + def resolve_constant_expr( + self, expression: LinearExpression, row_id: RowIndex + ) -> float: + if not expression.is_constant(): + raise ValueError(f"{str(self)} is not a constant expression") + return resolve_coefficient(expression.constant, self.value_provider, row_id) + + def resolve_variables( + self, term: Term, row_id: RowIndex + ) -> Dict[TimeScenarioIndex, lp.Variable]: + solver_vars = {} + operator_ts_ids = self._row_id_to_term_time_scenario_id(term, row_id) + for time in operator_ts_ids.time_indices: + for scenario in operator_ts_ids.scenario_indices: + solver_vars[ + TimeScenarioIndex(time, scenario) + ] = self.context.get_component_variable( + time, + scenario, + term.component_id, + term.variable_name, + # At term build time, no information on the variable structure is known, we use it now + self.context.network.get_component(term.component_id) + .model.variables[term.variable_name] + .structure, + ) + return solver_vars + + def _row_id_to_term_time_scenario_id( + self, term: Term, row_id: RowIndex + ) -> TimeScenarioIndices: + operator_time_ids = self._compute_operator_time_ids(term, row_id) + + operator_scenario_ids = self._compute_operator_scenario_ids(term, row_id) + return TimeScenarioIndices(operator_time_ids, operator_scenario_ids) + + def _compute_operator_scenario_ids(self, term: Term, row_id: RowIndex) -> List[int]: + if term.scenario_aggregator: + operator_scenario_ids = list(range(self.context.scenarios)) + else: + operator_scenario_ids = [row_id.scenario] + return operator_scenario_ids + + def _compute_operator_time_ids(self, term: Term, row_id: RowIndex) -> List[int]: + if not term.time_operator and not term.time_aggregator: + operator_time_ids = [row_id.time] + elif term.time_operator: + operator_time_ids = get_time_ids_from_instances_index( + term.time_operator.time_ids, self.value_provider, row_id + ) + if isinstance(term.time_operator, TimeShift): + operator_time_ids = [ + row_id.time + time_id for time_id in operator_time_ids + ] + else: # Case time_aggregator but no time_operator i.e. sum over whole block + operator_time_ids = list(range(self.context.block_length())) + return operator_time_ids diff --git a/src/andromede/simulation/linearize.py b/src/andromede/simulation/linearize.py deleted file mode 100644 index 9fe6738a..00000000 --- a/src/andromede/simulation/linearize.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -import dataclasses -from dataclasses import dataclass -from typing import Optional - -import andromede.expression.scenario_operator -import andromede.expression.time_operator -from andromede.expression.evaluate import ValueProvider -from andromede.expression.evaluate_parameters import get_time_ids_from_instances_index -from andromede.expression.expression import ( - ComparisonNode, - ComponentParameterNode, - ComponentVariableNode, - ExpressionNode, - LiteralNode, - ParameterNode, - PortFieldAggregatorNode, - PortFieldNode, - ScenarioOperatorNode, - TimeAggregatorNode, - TimeOperatorNode, - VariableNode, -) -from andromede.expression.indexing import IndexingStructureProvider -from andromede.expression.visitor import ExpressionVisitorOperations, T, visit -from andromede.simulation.linear_expression import LinearExpression, Term, generate_key - - -@dataclass(frozen=True) -class LinearExpressionBuilder(ExpressionVisitorOperations[LinearExpression]): - """ - Reduces a generic expression to a linear expression. - - Parameters should have been evaluated first. - """ - - structure_provider: IndexingStructureProvider - value_provider: Optional[ValueProvider] = None - - def literal(self, node: LiteralNode) -> LinearExpression: - return LinearExpression([], node.value) - - def comparison(self, node: ComparisonNode) -> LinearExpression: - raise ValueError("Linear expression cannot contain a comparison operator.") - - def variable(self, node: VariableNode) -> LinearExpression: - raise ValueError( - "Variables need to be associated with their component ID before linearization." - ) - - def parameter(self, node: ParameterNode) -> LinearExpression: - raise ValueError("Parameters must be evaluated before linearization.") - - def comp_variable(self, node: ComponentVariableNode) -> LinearExpression: - return LinearExpression( - [ - Term( - 1, - node.component_id, - node.name, - self.structure_provider.get_component_variable_structure( - node.component_id, node.name - ), - ) - ], - 0, - ) - - def comp_parameter(self, node: ComponentParameterNode) -> LinearExpression: - raise ValueError("Parameters must be evaluated before linearization.") - - def time_operator(self, node: TimeOperatorNode) -> LinearExpression: - if self.value_provider is None: - raise ValueError( - "A value provider must be specified to linearize a time operator node. This is required in order to evaluate the value of potential parameters used to specified the time ids on which the time operator applies." - ) - - operand_expr = visit(node.operand, self) - time_operator_cls = getattr(andromede.expression.time_operator, node.name) - time_ids = get_time_ids_from_instances_index( - node.instances_index, self.value_provider - ) - - result_terms = {} - for term in operand_expr.terms.values(): - term_with_operator = dataclasses.replace( - term, time_operator=time_operator_cls(time_ids) - ) - result_terms[generate_key(term_with_operator)] = term_with_operator - - # TODO: How can we apply a shift on a parameter ? It seems impossible for now as parameters must already be evaluated... - result_expr = LinearExpression(result_terms, operand_expr.constant) - return result_expr - - def time_aggregator(self, node: TimeAggregatorNode) -> LinearExpression: - # TODO: Very similar to time_operator, may be factorized - operand_expr = visit(node.operand, self) - time_aggregator_cls = getattr(andromede.expression.time_operator, node.name) - result_terms = {} - for term in operand_expr.terms.values(): - term_with_operator = dataclasses.replace( - term, time_aggregator=time_aggregator_cls(node.stay_roll) - ) - result_terms[generate_key(term_with_operator)] = term_with_operator - - result_expr = LinearExpression(result_terms, operand_expr.constant) - return result_expr - - def scenario_operator(self, node: ScenarioOperatorNode) -> LinearExpression: - scenario_operator_cls = getattr( - andromede.expression.scenario_operator, node.name - ) - if scenario_operator_cls.degree() > 1: - raise ValueError( - f"Cannot linearize expression with a non-linear operator: {scenario_operator_cls.__name__}" - ) - - operand_expr = visit(node.operand, self) - result_terms = {} - for term in operand_expr.terms.values(): - term_with_operator = dataclasses.replace( - term, scenario_operator=scenario_operator_cls() - ) - result_terms[generate_key(term_with_operator)] = term_with_operator - - result_expr = LinearExpression(result_terms, operand_expr.constant) - return result_expr - - def port_field(self, node: PortFieldNode) -> LinearExpression: - raise ValueError("Port fields must be replaced before linearization.") - - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> LinearExpression: - raise ValueError( - "Port fields aggregators must be replaced before linearization." - ) - - -def linearize_expression( - expression: ExpressionNode, - structure_provider: IndexingStructureProvider, - value_provider: Optional[ValueProvider] = None, -) -> LinearExpression: - return visit( - expression, LinearExpressionBuilder(structure_provider, value_provider) - ) diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 7542c0e0..755116ff 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -16,418 +16,46 @@ """ import math -from abc import ABC, abstractmethod from dataclasses import dataclass -from enum import Enum -from typing import Dict, Iterable, List, Optional import ortools.linear_solver.pywraplp as lp -from andromede.expression import ( - EvaluationVisitor, - ExpressionNode, - ParameterValueProvider, - ValueProvider, - resolve_parameters, - visit, -) -from andromede.expression.context_adder import add_component_context -from andromede.expression.indexing import IndexingStructureProvider, compute_indexation +from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.port_resolver import PortFieldKey, resolve_port -from andromede.expression.scenario_operator import Expectation -from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum +from andromede.expression.linear_expression import LinearExpression, RowIndex from andromede.model.common import ValueType from andromede.model.constraint import Constraint from andromede.model.model import PortFieldId -from andromede.simulation.linear_expression import LinearExpression, Term -from andromede.simulation.linearize import linearize_expression -from andromede.simulation.strategy import MergedProblemStrategy, ModelSelectionStrategy -from andromede.simulation.time_block import TimeBlock from andromede.study.data import DataBase from andromede.study.network import Component, Network -from andromede.utils import get_or_add - - -@dataclass(eq=True, frozen=True) -class TimestepComponentVariableKey: - """ - Identifies the solver variable for one timestep and one component variable. - """ - - component_id: str - variable_name: str - block_timestep: Optional[int] = None - scenario: Optional[int] = None - - -def _get_parameter_value( - context: "OptimizationContext", - block_timestep: int, - scenario: int, - component_id: str, - name: str, -) -> float: - data = context.database.get_data(component_id, name) - absolute_timestep = context.block_timestep_to_absolute_timestep(block_timestep) - return data.get_value(absolute_timestep, scenario) - - -# TODO: Maybe add the notion of constant parameter in the model -# TODO : And constant over scenarios ? -def _parameter_is_constant_over_time( - component: Component, - name: str, - context: "OptimizationContext", - block_timestep: int, - scenario: int, -) -> bool: - data = context.database.get_data(component.id, name) - return data.get_value(block_timestep, scenario) == IndexingStructure( - time=False, scenario=False - ) - - -class TimestepValueProvider(ABC): - """ - Interface which provides numerical values for individual timesteps. - """ - - @abstractmethod - def get_value(self, block_timestep: int, scenario: int) -> float: - raise NotImplementedError() - - -def _make_value_provider( - context: "OptimizationContext", - block_timestep: int, - scenario: int, - component: Component, -) -> ValueProvider: - """ - Create a value provider which takes its values from - the parameter values as defined in the network data. - - Cannot evaluate expressions which contain variables. - """ - - class Provider(ValueProvider): - def get_component_variable_value(self, component_id: str, name: str) -> float: - raise NotImplementedError( - "Cannot provide variable value at problem build time." - ) - - def get_component_parameter_value(self, component_id: str, name: str) -> float: - return _get_parameter_value( - context, block_timestep, scenario, component_id, name - ) - - def get_variable_value(self, name: str) -> float: - raise NotImplementedError( - "Cannot provide variable value at problem build time." - ) - - def get_parameter_value(self, name: str) -> float: - raise ValueError( - "Parameter must be associated to its component before resolution." - ) - - def parameter_is_constant_over_time(self, name: str) -> bool: - return not component.model.parameters[name].structure.time - - return Provider() - - -@dataclass(frozen=True) -class ExpressionTimestepValueProvider(TimestepValueProvider): - context: "OptimizationContext" - component: Component - expression: ExpressionNode - # OptimizationContext has knowledge of the block, so that get_value only needs block_timestep and scenario to get the correct data value - - def get_value(self, block_timestep: int, scenario: int) -> float: - param_value_provider = _make_value_provider( - self.context, block_timestep, scenario, self.component - ) - visitor = EvaluationVisitor(param_value_provider) - return visit(self.expression, visitor) - - -def _make_parameter_value_provider( - context: "OptimizationContext", - block_timestep: int, - scenario: int, -) -> ParameterValueProvider: - """ - A value provider which takes its values from - the parameter values as defined in the network data. - - Cannot evaluate expressions which contain variables. - """ - - class Provider(ParameterValueProvider): - def get_component_parameter_value(self, component_id: str, name: str) -> float: - return _get_parameter_value( - context, block_timestep, scenario, component_id, name - ) - - def get_parameter_value(self, name: str) -> float: - raise ValueError( - "Parameters should have been associated with their component before resolution." - ) - - return Provider() - - -def _make_data_structure_provider( - network: Network, component: Component -) -> IndexingStructureProvider: - """ - Retrieve information in data structure (parameter and variable) from the model - """ - - class Provider(IndexingStructureProvider): - def get_component_variable_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - return network.get_component(component_id).model.variables[name].structure - - def get_component_parameter_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - return network.get_component(component_id).model.parameters[name].structure - - def get_parameter_structure(self, name: str) -> IndexingStructure: - return component.model.parameters[name].structure - - def get_variable_structure(self, name: str) -> IndexingStructure: - return component.model.variables[name].structure - - return Provider() - - -@dataclass(frozen=True) -class ComponentContext: - """ - Helper class to fill the optimization problem with component-related equations and variables. - """ - - opt_context: "OptimizationContext" - component: Component - - def get_values(self, expression: ExpressionNode) -> TimestepValueProvider: - """ - The returned value provider will evaluate the provided expression. - """ - return ExpressionTimestepValueProvider( - self.opt_context, self.component, expression - ) - - def add_variable( - self, - block_timestep: int, - scenario: int, - model_var_name: str, - variable: lp.Variable, - ) -> None: - self.opt_context.register_component_variable( - block_timestep, scenario, self.component.id, model_var_name, variable - ) - - def get_variable( - self, block_timestep: int, scenario: int, variable_name: str - ) -> lp.Variable: - return self.opt_context.get_component_variable( - block_timestep, - scenario, - self.component.id, - variable_name, - self.component.model.variables[variable_name].structure, - ) - - def linearize_expression( - self, - block_timestep: int, - scenario: int, - expression: ExpressionNode, - ) -> LinearExpression: - parameters_valued_provider = _make_parameter_value_provider( - self.opt_context, block_timestep, scenario - ) - evaluated_expr = resolve_parameters(expression, parameters_valued_provider) - - value_provider = _make_value_provider( - self.opt_context, block_timestep, scenario, self.component - ) - structure_provider = _make_data_structure_provider( - self.opt_context.network, self.component - ) - - return linearize_expression(evaluated_expr, structure_provider, value_provider) - - -class BlockBorderManagement(Enum): - """ - Class to specify the way of handling the time horizon (or time block) border. - - IGNORE_OUT_OF_FRAME: Ignore terms in constraints that lead to out of horizon data - - CYCLE: Consider all timesteps to be specified modulo the horizon length, this is the actual functioning of Antares - """ - - IGNORE_OUT_OF_FRAME = "IGNORE" - CYCLE = "CYCLE" - - -@dataclass -class SolverVariableInfo: - """ - Helper class for constructing the structure file - for Benders solver. It keeps track of the corresponding - column of the variable in the MPS format as well as if it is - present in the objective function or not - """ - - name: str - column_id: int - is_in_objective: bool - - -class OptimizationContext: - """ - Helper class to build the optimization problem. - Maintains some mappings between model and solver objects. - Also provides navigation method in the model (components by node ...). - """ - - def __init__( - self, - network: Network, - database: DataBase, - block: TimeBlock, - scenarios: int, - border_management: BlockBorderManagement, - ): - self._network = network - self._database = database - self._block = block - self._scenarios = scenarios - self._border_management = border_management - self._component_variables: Dict[TimestepComponentVariableKey, lp.Variable] = {} - self._solver_variables: Dict[lp.Variable, SolverVariableInfo] = {} - self._connection_fields_expressions: Dict[ - PortFieldKey, List[ExpressionNode] - ] = {} - - @property - def network(self) -> Network: - return self._network - - @property - def scenarios(self) -> int: - return self._scenarios - - def block_length(self) -> int: - return len(self._block.timesteps) - - @property - def connection_fields_expressions(self) -> Dict[PortFieldKey, List[ExpressionNode]]: - return self._connection_fields_expressions - - # TODO: Need to think about data processing when creating blocks with varying or inequal time steps length (aggregation, sum ?, mean of data ?) - def block_timestep_to_absolute_timestep(self, block_timestep: int) -> int: - return self._block.timesteps[block_timestep] - - @property - def database(self) -> DataBase: - return self._database - - def _manage_border_timesteps(self, timestep: int) -> int: - if self._border_management == BlockBorderManagement.CYCLE: - return timestep % self.block_length() - else: - raise NotImplementedError - - def get_time_indices(self, index_structure: IndexingStructure) -> Iterable[int]: - return range(self.block_length()) if index_structure.time else range(1) - - def get_scenario_indices(self, index_structure: IndexingStructure) -> Iterable[int]: - return range(self.scenarios) if index_structure.scenario else range(1) - - # TODO: API to improve, variable_structure guides which of the indices block_timestep and scenario should be used - def get_component_variable( - self, - block_timestep: int, - scenario: int, - component_id: str, - variable_name: str, - variable_structure: IndexingStructure, - ) -> lp.Variable: - block_timestep = self._manage_border_timesteps(block_timestep) - - # TODO: Improve design, variable_structure defines indexing - if variable_structure.time == False: - block_timestep = 0 - if variable_structure.scenario == False: - scenario = 0 - - return self._component_variables[ - TimestepComponentVariableKey( - component_id, variable_name, block_timestep, scenario - ) - ] - - def get_all_component_variables( - self, - ) -> Dict[TimestepComponentVariableKey, lp.Variable]: - return self._component_variables - - def register_component_variable( - self, - block_timestep: int, - scenario: int, - component_id: str, - variable_name: str, - variable: lp.Variable, - ) -> None: - key = TimestepComponentVariableKey( - component_id, variable_name, block_timestep, scenario - ) - if key not in self._component_variables: - self._solver_variables[variable] = SolverVariableInfo( - variable.name(), len(self._solver_variables), False - ) - self._component_variables[key] = variable - - def get_component_context(self, component: Component) -> ComponentContext: - return ComponentContext(self, component) - - def register_connection_fields_expressions( - self, - component_id: str, - port_name: str, - field_name: str, - expression: ExpressionNode, - ) -> None: - key = PortFieldKey(component_id, PortFieldId(port_name, field_name)) - get_or_add(self._connection_fields_expressions, key, lambda: []).append( - expression - ) +from .linear_expression_resolver import LinearExpressionResolver +from .optimization_context import ( + BlockBorderManagement, + ComponentContext, + OptimizationContext, + make_data_structure_provider, + make_value_provider, +) +from .resolved_linear_expression import ResolvedLinearExpression +from .strategy import MergedProblemStrategy, ModelSelectionStrategy +from .time_block import TimeBlock def _get_indexing( constraint: Constraint, provider: IndexingStructureProvider ) -> IndexingStructure: return ( - compute_indexation(constraint.expression, provider) - or compute_indexation(constraint.lower_bound, provider) - or compute_indexation(constraint.upper_bound, provider) + constraint.expression.compute_indexation(provider) + or constraint.lower_bound.compute_indexation(provider) + or constraint.upper_bound.compute_indexation(provider) ) def _compute_indexing_structure( context: ComponentContext, constraint: Constraint ) -> IndexingStructure: - data_structure_provider = _make_data_structure_provider( + data_structure_provider = make_data_structure_provider( context.opt_context.network, context.component ) constraint_indexing = _get_indexing(constraint, data_structure_provider) @@ -435,19 +63,20 @@ def _compute_indexing_structure( def _instantiate_model_expression( - model_expression: ExpressionNode, + model_expression: LinearExpression, component_id: str, optimization_context: OptimizationContext, -) -> ExpressionNode: +) -> LinearExpression: """ Performs common operations that are necessary on model expressions before their actual use: 1. add component ID for variables and parameters of THIS component 2. replace port fields by their definition """ - with_component = add_component_context(component_id, model_expression) - with_component_and_ports = resolve_port( - with_component, component_id, optimization_context.connection_fields_expressions + # We need to resolve ports before adding component context as binding constraints with ports may involve parameters from the current component + with_ports = model_expression.resolve_port( + component_id, optimization_context.connection_fields_expressions ) + with_component_and_ports = with_ports.add_component_context(component_id) return with_component_and_ports @@ -461,78 +90,56 @@ def _create_constraint( """ constraint_indexing = _compute_indexing_structure(context, constraint) - # Perf: Perform linearization (tree traversing) without timesteps so that we can get the number of instances for the expression (from the time_ids of operators) - linear_expr = context.linearize_expression(0, 0, constraint.expression) - # Will there be cases where instances > 1 ? If not, maybe just a check that get_number_of_instances == 1 is sufficient ? Anyway, the function should be implemented - instances_per_time_step = linear_expr.number_of_instances() + value_provider = make_value_provider(context.opt_context, context.component) + expression_resolver = LinearExpressionResolver(context.opt_context, value_provider) for block_timestep in context.opt_context.get_time_indices(constraint_indexing): for scenario in context.opt_context.get_scenario_indices(constraint_indexing): - linear_expr_at_t = context.linearize_expression( - block_timestep, scenario, constraint.expression + row_id = RowIndex(block_timestep, scenario) + + resolved_expr = expression_resolver.resolve(constraint.expression, row_id) + resolved_lb = expression_resolver.resolve_constant_expr( + constraint.lower_bound, row_id + ) + resolved_ub = expression_resolver.resolve_constant_expr( + constraint.upper_bound, row_id ) - # What happens if there is some time_operator in the bounds ? + + # What happens if there is some time_operator in the bounds ? -> Pb réglé avec le nouveau design ! constraint_data = ConstraintData( name=constraint.name, - lower_bound=context.get_values(constraint.lower_bound).get_value( - block_timestep, scenario - ), - upper_bound=context.get_values(constraint.upper_bound).get_value( - block_timestep, scenario - ), - expression=linear_expr_at_t, - ) - make_constraint( - solver, - context.opt_context, - block_timestep, - scenario, - constraint_data, - instances_per_time_step, + lower_bound=resolved_lb, + upper_bound=resolved_ub, + expression=resolved_expr, ) + make_constraint(solver, row_id, constraint_data) def _create_objective( solver: lp.Solver, opt_context: OptimizationContext, component: Component, - component_context: ComponentContext, - objective_contribution: ExpressionNode, + objective_contribution: LinearExpression, ) -> None: instantiated_expr = _instantiate_model_expression( objective_contribution, component.id, opt_context ) # We have already checked in the model creation that the objective contribution is neither indexed by time nor by scenario - linear_expr = component_context.linearize_expression(0, 0, instantiated_expr) - obj: lp.Objective = solver.Objective() - for term in linear_expr.terms.values(): - # TODO : How to handle the scenario operator in a general manner ? - if isinstance(term.scenario_operator, Expectation): - weight = 1 / opt_context.scenarios - scenario_ids = range(opt_context.scenarios) - else: - weight = 1 - scenario_ids = range(1) - - for scenario in scenario_ids: - solver_vars = _get_solver_vars( - term, - opt_context, - 0, - scenario, - 0, - ) + value_provider = make_value_provider(opt_context, component) + expression_resolver = LinearExpressionResolver(opt_context, value_provider) + resolved_expr = expression_resolver.resolve(instantiated_expr, RowIndex(0, 0)) - for solver_var in solver_vars: - opt_context._solver_variables[solver_var].is_in_objective = True - obj.SetCoefficient( - solver_var, - obj.GetCoefficient(solver_var) + weight * term.coefficient, - ) + obj: lp.Objective = solver.Objective() + for term in resolved_expr.terms: + opt_context._solver_variables[term.variable].is_in_objective = True + obj.SetCoefficient( + term.variable, + obj.GetCoefficient(term.variable) + term.coefficient, + ) # This should have no effect on the optimization - obj.SetOffset(linear_expr.constant + obj.offset()) + obj.SetOffset(resolved_expr.constant + obj.offset()) @dataclass @@ -540,129 +147,35 @@ class ConstraintData: name: str lower_bound: float upper_bound: float - expression: LinearExpression - - -def _get_solver_vars( - term: Term, - context: OptimizationContext, - block_timestep: int, - scenario: int, - instance: int, -) -> List[lp.Variable]: - solver_vars = [] - if isinstance(term.time_aggregator, TimeSum): - if isinstance(term.time_operator, TimeShift): - for time_id in term.time_operator.time_ids: - solver_vars.append( - context.get_component_variable( - block_timestep + time_id, - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - elif isinstance(term.time_operator, TimeEvaluation): - for time_id in term.time_operator.time_ids: - solver_vars.append( - context.get_component_variable( - time_id, - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - else: # time_operator is None, retrieve variable for each time step of the block. What happens if we do x.sum() with x not being indexed by time ? Is there a check that it is a valid expression ? - for time_id in range(context.block_length()): - solver_vars.append( - context.get_component_variable( - block_timestep + time_id, - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - - else: # time_aggregator is None - if isinstance(term.time_operator, TimeShift): - solver_vars.append( - context.get_component_variable( - block_timestep + term.time_operator.time_ids[instance], - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - elif isinstance(term.time_operator, TimeEvaluation): - solver_vars.append( - context.get_component_variable( - term.time_operator.time_ids[instance], - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - else: # time_operator is None - # TODO: horrible tous ces if/else - solver_vars.append( - context.get_component_variable( - block_timestep, - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - return solver_vars + expression: ResolvedLinearExpression def make_constraint( solver: lp.Solver, - context: OptimizationContext, - block_timestep: int, - scenario: int, + row_id: RowIndex, data: ConstraintData, - instances: int, -) -> Dict[str, lp.Constraint]: +) -> lp.Constraint: """ Adds constraint to the solver. """ - solver_constraints = {} - constraint_name = f"{data.name}_t{block_timestep}_s{scenario}" - for instance in range(instances): - if instances > 1: - constraint_name += f"_{instance}" - - solver_constraint: lp.Constraint = solver.Constraint(constraint_name) - constant: float = 0 - for term in data.expression.terms.values(): - solver_vars = _get_solver_vars( - term, - context, - block_timestep, - scenario, - instance, - ) - for solver_var in solver_vars: - coefficient = term.coefficient + solver_constraint.GetCoefficient( - solver_var - ) - solver_constraint.SetCoefficient(solver_var, coefficient) - # TODO: On pourrait aussi faire que l'objet Constraint n'ait pas de terme constant dans son expression et que les constantes soit déjà prises en compte dans les bornes, ça simplifierait le traitement ici - constant += data.expression.constant - solver_constraint.SetBounds( - data.lower_bound - constant, data.upper_bound - constant + constraint_name = f"{data.name}_{str(row_id)}" + + solver_constraint: lp.Constraint = solver.Constraint(constraint_name) + constant: float = 0 + + for term in data.expression.terms: + solver_constraint.SetCoefficient( + term.variable, + term.coefficient + solver_constraint.GetCoefficient(term.variable), ) + constant += data.expression.constant + + solver_constraint.SetBounds( + data.lower_bound - constant, data.upper_bound - constant + ) - # TODO: this dictionary does not make sense, we override the content when there are multiple instances - solver_constraints[constraint_name] = solver_constraint - return solver_constraints + return solver_constraint class OptimizationProblem: @@ -700,8 +213,8 @@ def _register_connection_fields_definitions(self) -> None: ) ) expression_node = port_definition.definition # type: ignore - instantiated_expression = add_component_context( - master_port.component.id, expression_node + instantiated_expression = expression_node.add_component_context( + master_port.component.id ) self.context.register_connection_fields_expressions( component_id=cnx.port1.component.id, @@ -721,6 +234,9 @@ def _create_variables(self) -> None: component_context = self.context.get_component_context(component) model = component.model + value_provider = make_value_provider(self.context, component) + expression_resolver = LinearExpressionResolver(self.context, value_provider) + for model_var in self.strategy.get_variables(model): var_indexing = model_var.structure instantiated_lb_expr = None @@ -750,13 +266,13 @@ def _create_variables(self) -> None: lower_bound = -self.solver.infinity() upper_bound = self.solver.infinity() if instantiated_lb_expr: - lower_bound = component_context.get_values( - instantiated_lb_expr - ).get_value(block_timestep, scenario) + lower_bound = expression_resolver.resolve_constant_expr( + instantiated_lb_expr, RowIndex(block_timestep, scenario) + ) if instantiated_ub_expr: - upper_bound = component_context.get_values( - instantiated_ub_expr - ).get_value(block_timestep, scenario) + upper_bound = expression_resolver.resolve_constant_expr( + instantiated_ub_expr, RowIndex(block_timestep, scenario) + ) scenario_suffix = ( f"_s{scenario}" @@ -817,7 +333,7 @@ def _create_constraints(self) -> None: instantiated_constraint = Constraint( name=f"{component.id}_{constraint.name}", - expression=instantiated_expr, + expression_init=instantiated_expr, lower_bound=instantiated_lb, upper_bound=instantiated_ub, ) @@ -829,7 +345,6 @@ def _create_constraints(self) -> None: def _create_objectives(self) -> None: for component in self.context.network.all_components: - component_context = self.context.get_component_context(component) model = component.model for objective in self.strategy.get_objectives(model): @@ -838,7 +353,6 @@ def _create_objectives(self) -> None: self.solver, self.context, component, - component_context, objective, ) diff --git a/src/andromede/simulation/optimization_context.py b/src/andromede/simulation/optimization_context.py new file mode 100644 index 00000000..789cd456 --- /dev/null +++ b/src/andromede/simulation/optimization_context.py @@ -0,0 +1,352 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Iterable, List, Optional + +import ortools.linear_solver.pywraplp as lp + +from andromede.expression.evaluate_parameters import ValueProvider +from andromede.expression.indexing import IndexingStructureProvider +from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression import ( + LinearExpression, + PortFieldId, + PortFieldKey, +) +from andromede.expression.value_provider import TimeScenarioIndex, TimeScenarioIndices +from andromede.study.data import DataBase +from andromede.study.network import Component, Network +from andromede.utils import get_or_add + +from .time_block import TimeBlock + + +@dataclass(eq=True, frozen=True) +class TimestepComponentVariableKey: + """ + Identifies the solver variable for one timestep and one component variable. + """ + + component_id: str + variable_name: str + block_timestep: Optional[int] = None + scenario: Optional[int] = None + + +@dataclass +class SolverVariableInfo: + """ + Helper class for constructing the structure file + for Benders solver. It keeps track of the corresponding + column of the variable in the MPS format as well as if it is + present in the objective function or not + """ + + name: str + column_id: int + is_in_objective: bool + + +class BlockBorderManagement(Enum): + """ + Class to specify the way of handling the time horizon (or time block) border. + - IGNORE_OUT_OF_FRAME: Ignore terms in constraints that lead to out of horizon data + - CYCLE: Consider all timesteps to be specified modulo the horizon length, this is the actual functioning of Antares + """ + + IGNORE_OUT_OF_FRAME = "IGNORE" + CYCLE = "CYCLE" + + +class OptimizationContext: + """ + Helper class to build the optimization problem. + Maintains some mappings between model and solver objects. + Also provides navigation method in the model (components by node ...). + """ + + def __init__( + self, + network: Network, + database: DataBase, + block: TimeBlock, + scenarios: int, + border_management: BlockBorderManagement, + ): + self._network = network + self._database = database + self._block = block + self._scenarios = scenarios + self._border_management = border_management + self._component_variables: Dict[TimestepComponentVariableKey, lp.Variable] = {} + self._solver_variables: Dict[lp.Variable, SolverVariableInfo] = {} + self._connection_fields_expressions: Dict[ + PortFieldKey, List[LinearExpression] + ] = {} + + @property + def network(self) -> Network: + return self._network + + @property + def scenarios(self) -> int: + return self._scenarios + + def block_length(self) -> int: + return len(self._block.timesteps) + + @property + def connection_fields_expressions( + self, + ) -> Dict[PortFieldKey, List[LinearExpression]]: + return self._connection_fields_expressions + + # TODO: Need to think about data processing when creating blocks with varying or inequal time steps length (aggregation, sum ?, mean of data ?) + def block_timestep_to_absolute_timestep(self, block_timestep: int) -> int: + return self._block.timesteps[block_timestep] + + @property + def database(self) -> DataBase: + return self._database + + def _manage_border_timesteps(self, timestep: int) -> int: + if self._border_management == BlockBorderManagement.CYCLE: + return timestep % self.block_length() + else: + raise NotImplementedError + + def get_time_indices(self, index_structure: IndexingStructure) -> Iterable[int]: + return range(self.block_length()) if index_structure.time else range(1) + + def get_scenario_indices(self, index_structure: IndexingStructure) -> Iterable[int]: + return range(self.scenarios) if index_structure.scenario else range(1) + + # TODO: API to improve, variable_structure guides which of the indices block_timestep and scenario should be used + def get_component_variable( + self, + block_timestep: int, + scenario: int, + component_id: str, + variable_name: str, + variable_structure: IndexingStructure, + ) -> lp.Variable: + block_timestep = self._manage_border_timesteps(block_timestep) + + # TODO: Improve design, variable_structure defines indexing + if variable_structure.time == False: + block_timestep = 0 + if variable_structure.scenario == False: + scenario = 0 + + return self._component_variables[ + TimestepComponentVariableKey( + component_id, variable_name, block_timestep, scenario + ) + ] + + def get_all_component_variables( + self, + ) -> Dict[TimestepComponentVariableKey, lp.Variable]: + return self._component_variables + + def register_component_variable( + self, + block_timestep: int, + scenario: int, + component_id: str, + variable_name: str, + variable: lp.Variable, + ) -> None: + key = TimestepComponentVariableKey( + component_id, variable_name, block_timestep, scenario + ) + if key not in self._component_variables: + self._solver_variables[variable] = SolverVariableInfo( + variable.name(), len(self._solver_variables), False + ) + self._component_variables[key] = variable + + def get_component_context(self, component: Component) -> "ComponentContext": + return ComponentContext(self, component) + + def register_connection_fields_expressions( + self, + component_id: str, + port_name: str, + field_name: str, + expression: LinearExpression, + ) -> None: + key = PortFieldKey(component_id, PortFieldId(port_name, field_name)) + get_or_add(self._connection_fields_expressions, key, lambda: []).append( + expression + ) + + +def _get_parameter_value( + context: OptimizationContext, + block_timestep: int, + scenario: int, + component_id: str, + name: str, +) -> float: + data = context.database.get_data(component_id, name) + absolute_timestep = context.block_timestep_to_absolute_timestep(block_timestep) + return data.get_value(absolute_timestep, scenario) + + +def make_value_provider( + context: "OptimizationContext", + component: Component, +) -> ValueProvider: + """ + Create a value provider which takes its values from + the parameter values as defined in the network data. + + Cannot evaluate expressions which contain variables. + """ + + class Provider(ValueProvider): + def get_component_variable_value( + self, + component_id: str, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError( + "Cannot provide variable value at problem build time." + ) + + def get_component_parameter_value( + self, + component_id: str, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + result = {} + param_index = ( + context.network.get_component(component_id) + .model.parameters[name] + .structure + ) + for block_timestep in time_scenarios_indices.time_indices: + for scenario in time_scenarios_indices.scenario_indices: + result[ + TimeScenarioIndex(block_timestep, scenario) + ] = _get_parameter_value( + context, + _get_data_time_key(block_timestep, param_index), + _get_data_scenario_key(scenario, param_index), + component_id, + name, + ) + return result + + def get_variable_value( + self, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError( + "Cannot provide variable value at problem build time." + ) + + def get_parameter_value( + self, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise ValueError( + "Parameter must be associated to its component before resolution." + ) + + def parameter_is_constant_over_time(self, name: str) -> bool: + return not component.model.parameters[name].structure.time + + @staticmethod + def block_length() -> int: + return context.block_length() + + @staticmethod + def scenarios() -> int: + return context.scenarios + + return Provider() + + +def make_data_structure_provider( + network: Network, component: Component +) -> IndexingStructureProvider: + """ + Retrieve information in data structure (parameter and variable) from the model + """ + + class Provider(IndexingStructureProvider): + def get_component_variable_structure( + self, component_id: str, name: str + ) -> IndexingStructure: + return network.get_component(component_id).model.variables[name].structure + + def get_component_parameter_structure( + self, component_id: str, name: str + ) -> IndexingStructure: + return network.get_component(component_id).model.parameters[name].structure + + def get_parameter_structure(self, name: str) -> IndexingStructure: + return component.model.parameters[name].structure + + def get_variable_structure(self, name: str) -> IndexingStructure: + return component.model.variables[name].structure + + return Provider() + + +@dataclass(frozen=True) +class ComponentContext: + """ + Helper class to fill the optimization problem with component-related equations and variables. + """ + + opt_context: OptimizationContext + component: Component + + def add_variable( + self, + block_timestep: int, + scenario: int, + model_var_name: str, + variable: lp.Variable, + ) -> None: + self.opt_context.register_component_variable( + block_timestep, scenario, self.component.id, model_var_name, variable + ) + + def get_variable( + self, block_timestep: int, scenario: int, variable_name: str + ) -> lp.Variable: + return self.opt_context.get_component_variable( + block_timestep, + scenario, + self.component.id, + variable_name, + self.component.model.variables[variable_name].structure, + ) + + +def _get_data_time_key(block_timestep: int, data_indexing: IndexingStructure) -> int: + return block_timestep if data_indexing.time else 0 + + +def _get_data_scenario_key(scenario: int, data_indexing: IndexingStructure) -> int: + return scenario if data_indexing.scenario else 0 diff --git a/src/andromede/simulation/output_values.py b/src/andromede/simulation/output_values.py index cf94e7c0..d7977358 100644 --- a/src/andromede/simulation/output_values.py +++ b/src/andromede/simulation/output_values.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union, cast -from andromede.simulation.optimization import OptimizationProblem +from .optimization import OptimizationProblem from andromede.study.data import TimeScenarioIndex diff --git a/src/andromede/simulation/resolved_linear_expression.py b/src/andromede/simulation/resolved_linear_expression.py new file mode 100644 index 00000000..acc82d7d --- /dev/null +++ b/src/andromede/simulation/resolved_linear_expression.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +""" +Specific modelling for "resolved" linear expressions, +with only variables and literal coefficients. +""" + +from dataclasses import dataclass, field +from typing import List + +import ortools.linear_solver.pywraplp as lp + + +@dataclass(frozen=True) +class ResolvedTerm: + """ + Represents a term where parameters and variables id have been resolved, in the form of couple (coefficient, variable_id) + """ + + coefficient: float + variable: lp.Variable + + +@dataclass +class ResolvedLinearExpression: + """ + Represents a linear expression where parameters and variables id have been resolved, in the form of couple (coefficient, variable_id) and a constant + """ + + terms: List[ResolvedTerm] = field(default_factory=list) + constant: float = field(default=0) + + def is_constant(self) -> bool: + # Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree... + return not self.terms diff --git a/src/andromede/simulation/strategy.py b/src/andromede/simulation/strategy.py index 75e34c65..288cced2 100644 --- a/src/andromede/simulation/strategy.py +++ b/src/andromede/simulation/strategy.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from typing import Generator, Optional -from andromede.expression import ExpressionNode +from andromede.expression.linear_expression import LinearExpression from andromede.model import Constraint, Model, ProblemContext, Variable @@ -43,7 +43,7 @@ def _keep_from_context(self, context: ProblemContext) -> bool: @abstractmethod def get_objectives( self, model: Model - ) -> Generator[Optional[ExpressionNode], None, None]: + ) -> Generator[Optional[LinearExpression], None, None]: ... @@ -53,7 +53,7 @@ def _keep_from_context(self, context: ProblemContext) -> bool: def get_objectives( self, model: Model - ) -> Generator[Optional[ExpressionNode], None, None]: + ) -> Generator[Optional[LinearExpression], None, None]: yield model.objective_operational_contribution yield model.objective_investment_contribution @@ -66,7 +66,7 @@ def _keep_from_context(self, context: ProblemContext) -> bool: def get_objectives( self, model: Model - ) -> Generator[Optional[ExpressionNode], None, None]: + ) -> Generator[Optional[LinearExpression], None, None]: yield model.objective_investment_contribution @@ -78,5 +78,5 @@ def _keep_from_context(self, context: ProblemContext) -> bool: def get_objectives( self, model: Model - ) -> Generator[Optional[ExpressionNode], None, None]: + ) -> Generator[Optional[LinearExpression], None, None]: yield model.objective_operational_contribution diff --git a/src/andromede/study/data.py b/src/andromede/study/data.py index 50d4acda..7fa7b384 100644 --- a/src/andromede/study/data.py +++ b/src/andromede/study/data.py @@ -17,7 +17,7 @@ import pandas as pd -from andromede.study.network import Network +from .network import Network @dataclass(frozen=True) diff --git a/src/andromede/study/resolve_components.py b/src/andromede/study/resolve_components.py index d6650e78..327d2c45 100644 --- a/src/andromede/study/resolve_components.py +++ b/src/andromede/study/resolve_components.py @@ -13,30 +13,18 @@ from pathlib import Path from typing import Dict, Iterable, List, Optional -import pandas as pd - from andromede.model import Model from andromede.model.library import Library -from andromede.study import ( - Component, + +from .data import ( + AbstractDataStructure, ConstantData, DataBase, - Network, - Node, - PortRef, - PortsConnection, -) -from andromede.study.data import ( - AbstractDataStructure, - TimeScenarioIndex, TimeScenarioSeriesData, load_ts_from_txt, ) -from andromede.study.parsing import ( - InputComponent, - InputComponents, - InputPortConnections, -) +from .network import Component, Network, Node, PortRef, PortsConnection +from .parsing import InputComponent, InputComponents, InputPortConnections @dataclass(frozen=True) diff --git a/tests/functional/test_andromede.py b/tests/functional/test_andromede.py index 8e837a45..16837352 100644 --- a/tests/functional/test_andromede.py +++ b/tests/functional/test_andromede.py @@ -13,28 +13,22 @@ import pandas as pd import pytest -from andromede.expression import literal, param, var +from andromede.expression.expression import literal, param from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, DEMAND_MODEL, GENERATOR_MODEL, - GENERATOR_MODEL_WITH_PMIN, - LINK_MODEL, NODE_BALANCE_MODEL, SHORT_TERM_STORAGE_SIMPLE, SPILLAGE_MODEL, - THERMAL_CLUSTER_MODEL_HD, UNSUPPLIED_ENERGY_MODEL, ) from andromede.model import Model, ModelPort, float_parameter, float_variable, model from andromede.model.model import PortFieldDefinition, PortFieldId -from andromede.simulation import ( - BlockBorderManagement, - OutputValues, - TimeBlock, - build_problem, -) +from andromede.simulation import TimeBlock, build_problem +from andromede.simulation.optimization_context import BlockBorderManagement from andromede.study import ( ConstantData, DataBase, @@ -151,7 +145,7 @@ def test_variable_bound() -> None: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], objective_operational_contribution=(param("cost") * var("generation")) diff --git a/tests/functional/test_andromede_yml.py b/tests/functional/test_andromede_yml.py index 61812b62..c9326d65 100644 --- a/tests/functional/test_andromede_yml.py +++ b/tests/functional/test_andromede_yml.py @@ -1,24 +1,15 @@ import pandas as pd import pytest -from andromede.expression import literal, param, var -from andromede.expression.indexing_structure import IndexingStructure -from andromede.model import Model, ModelPort, float_parameter, float_variable, model from andromede.model.library import Library -from andromede.model.model import PortFieldDefinition, PortFieldId -from andromede.simulation import ( - BlockBorderManagement, - OutputValues, - TimeBlock, - build_problem, -) +from andromede.simulation import OutputValues, TimeBlock, build_problem +from andromede.simulation.optimization_context import BlockBorderManagement from andromede.study import ( ConstantData, DataBase, Network, Node, PortRef, - TimeScenarioIndex, TimeScenarioSeriesData, create_component, ) diff --git a/tests/functional/test_performance.py b/tests/functional/test_performance.py index 1c50af1c..5bf0a0e5 100644 --- a/tests/functional/test_performance.py +++ b/tests/functional/test_performance.py @@ -10,152 +10,87 @@ # # This file is part of the Antares project. -from typing import cast import pytest -from andromede.expression.expression import ExpressionNode, literal, param, var -from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.expression import param +from andromede.expression.indexing_structure import RowIndex +from andromede.expression.linear_expression import literal, var, wrap_in_linear_expr from andromede.libs.standard import ( - BALANCE_PORT_TYPE, DEMAND_MODEL, GENERATOR_MODEL, GENERATOR_MODEL_WITH_STORAGE, NODE_BALANCE_MODEL, ) -from andromede.model import float_parameter, float_variable, model -from andromede.simulation import TimeBlock, build_problem -from andromede.study import ( - ConstantData, - DataBase, - Network, - Node, - PortRef, - create_component, -) +from andromede.simulation.optimization import build_problem +from andromede.simulation.time_block import TimeBlock +from andromede.study.data import ConstantData, DataBase +from andromede.study.network import Network, Node, PortRef, create_component from tests.unittests.test_utils import generate_scalar_matrix_data +from tests.utils import EvaluationContext -def test_large_sum_inside_model_with_loop() -> None: +def test_large_number_of_parameters_sum() -> None: """ Test performance when the problem involves an expression with a high number of terms. - Here the objective function is the sum over nb_terms terms on a for-loop inside the model - This test pass with 476 terms but fails with 477 locally due to recursion depth, - and even less terms are possible with Jenkins... + This test pass with 476 terms but fails with 477 locally due to recursion depth, and even less terms are possible with Jenkins... """ nb_terms = 500 - time_blocks = [TimeBlock(0, [0])] - scenarios = 1 - database = DataBase() - + parameters_value = {} for i in range(1, nb_terms): - database.add_data("simple_cost", f"cost_{i}", ConstantData(1 / i)) + parameters_value[f"cost_{i}"] = 1 / i + # Still the recursion depth error with parameters with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): - SIMPLE_COST_MODEL = model( - id="SIMPLE_COST", - parameters=[ - float_parameter(f"cost_{i}", IndexingStructure(False, False)) - for i in range(1, nb_terms) - ], - objective_operational_contribution=cast( - ExpressionNode, sum(param(f"cost_{i}") for i in range(1, nb_terms)) - ), - ) + expr = sum(wrap_in_linear_expr(param(f"cost_{i}")) for i in range(1, nb_terms)) + expr.evaluate(EvaluationContext(parameters=parameters_value), RowIndex(0, 0)) - # Won't run because last statement will raise the error - network = Network("test") - cost_model = create_component(model=SIMPLE_COST_MODEL, id="simple_cost") - network.add_component(cost_model) - problem = build_problem(network, database, time_blocks[0], scenarios) - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == sum( - [1 / i for i in range(1, nb_terms)] - ) - - -def test_large_sum_outside_model_with_loop() -> None: +def test_large_number_of_identical_parameters_sum() -> None: """ - Test performance when the problem involves an expression with a high number of terms. - Here the objective function is the sum over nb_terms terms on a for-loop outside the model + With identical parameters sum, a simplification is performed online to avoid the recursivity. """ - nb_terms = 10_000 - - time_blocks = [TimeBlock(0, [0])] - scenarios = 1 - database = DataBase() - - obj_coeff = sum([1 / i for i in range(1, nb_terms)]) - - SIMPLE_COST_MODEL = model( - id="SIMPLE_COST", - parameters=[], - objective_operational_contribution=literal(obj_coeff), - ) + nb_terms = 500 - network = Network("test") + parameters_value = {"cost": 1.0} - simple_model = create_component( - model=SIMPLE_COST_MODEL, - id="simple_cost", + # Still the recursion depth error with parameters + # with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): + expr = sum(wrap_in_linear_expr(param("cost")) for _ in range(nb_terms)) + assert ( + expr.evaluate(EvaluationContext(parameters=parameters_value), RowIndex(0, 0)) + == nb_terms ) - network.add_component(simple_model) - problem = build_problem(network, database, time_blocks[0], scenarios) - status = problem.solver.Solve() - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == obj_coeff - - -def test_large_sum_inside_model_with_sum_operator() -> None: +def test_large_number_of_literal_sum() -> None: """ - Test performance when the problem involves an expression with a high number of terms. - Here the objective function is the sum over nb_terms terms with the sum() operator inside the model + Literal sums are computed online to avoid recursivity """ - nb_terms = 10_000 - - scenarios = 1 - time_blocks = [TimeBlock(0, list(range(nb_terms)))] - database = DataBase() + nb_terms = 500 - # Weird values when the "cost" varies over time and we use the sum() operator: - # For testing purposes, will use a const value since the problem seems to come when - # we try to linearize nb_terms variables with nb_terms distinct parameters - # TODO check the sum() operator for time-variable parameters - database.add_data("simple_cost", "cost", ConstantData(3)) - - SIMPLE_COST_MODEL = model( - id="SIMPLE_COST", - parameters=[ - float_parameter("cost", IndexingStructure(False, False)), - ], - variables=[ - float_variable( - "var", - lower_bound=literal(1), - upper_bound=literal(2), - structure=IndexingStructure(True, False), - ), - ], - objective_operational_contribution=(param("cost") * var("var")).sum(), - ) + # # Still the recursion depth error with parameters + # with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): + expr = sum(wrap_in_linear_expr(literal(1)) for _ in range(nb_terms)) + assert expr.evaluate(EvaluationContext(), RowIndex(0, 0)) == nb_terms - network = Network("test") - cost_model = create_component(model=SIMPLE_COST_MODEL, id="simple_cost") - network.add_component(cost_model) +def test_large_number_of_variables_sum() -> None: + """ + Test performance when the problem involves an expression with a high number of terms. No problem when there is a large number of variables as this is derecusified. + """ + nb_terms = 500 - problem = build_problem(network, database, time_blocks[0], scenarios) - status = problem.solver.Solve() + variables_value = {} + for i in range(1, nb_terms): + variables_value[f"cost_{i}"] = 1 / i - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == 3 * nb_terms + expr = sum(var(f"cost_{i}") for i in range(1, nb_terms)) + assert expr.evaluate( + EvaluationContext(variables=variables_value), RowIndex(0, 0) + ) == sum(1 / i for i in range(1, nb_terms)) def test_large_sum_of_port_connections() -> None: @@ -196,14 +131,13 @@ def test_large_sum_of_port_connections() -> None: PortRef(generators[gen_id], "balance_port"), PortRef(node, "balance_port") ) - with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): - problem = build_problem(network, database, time_block, scenarios) + # Raised recursion error with previous implementation + problem = build_problem(network, database, time_block, scenarios) - # Won't run because last statement will raise the error - status = problem.solver.Solve() + status = problem.solver.Solve() - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == 5 * nb_generators + assert status == problem.solver.OPTIMAL + assert problem.solver.Objective().Value() == 5 * nb_generators def test_basic_balance_on_whole_year() -> None: diff --git a/tests/functional/test_xpansion.py b/tests/functional/test_xpansion.py index 76ff45eb..3757b8d2 100644 --- a/tests/functional/test_xpansion.py +++ b/tests/functional/test_xpansion.py @@ -13,15 +13,15 @@ import pandas as pd import pytest -from andromede.expression.expression import literal, param, port_field, var +from andromede.expression.expression import literal, param from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, CONSTANT, DEMAND_MODEL, GENERATOR_MODEL, NODE_BALANCE_MODEL, - NODE_WITH_SPILL_AND_ENS_MODEL, ) from andromede.model import ( Constraint, @@ -38,7 +38,6 @@ MergedProblemStrategy, OutputValues, TimeBlock, - build_benders_decomposed_problem, build_problem, ) from andromede.study import ( @@ -82,12 +81,12 @@ def thermal_candidate() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= var("p_max") + name="Max generation", expression_init=var("generation") <= var("p_max") ) ], objective_operational_contribution=(param("op_cost") * var("generation")) @@ -127,16 +126,17 @@ def discrete_candidate() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= var("p_max") + name="Max generation", expression_init=var("generation") <= var("p_max") ), Constraint( name="Max investment", - expression=var("p_max") == param("p_max_per_unit") * var("nb_units"), + expression_init=var("p_max") + == param("p_max_per_unit") * var("nb_units"), context=INVESTMENT, ), ], @@ -376,10 +376,10 @@ def test_generation_xpansion_two_time_steps_two_scenarios( status = problem.solver.Solve() assert status == problem.solver.OPTIMAL - # assert problem.solver.NumVariables() == 2 * scenarios * horizon + 1 - # assert ( - # problem.solver.NumConstraints() == 3 * scenarios * horizon - # ) # Flow balance, Max generation for each cluster + assert problem.solver.NumVariables() == 2 * scenarios * horizon + 1 + assert ( + problem.solver.NumConstraints() == 3 * scenarios * horizon + ) # Flow balance, Max generation for each cluster assert problem.solver.Objective().Value() == pytest.approx( 490 * 300 + 0.5 * (10 * 300 + 10 * 300 + 40 * 200) diff --git a/tests/integration/test_benders_decomposed.py b/tests/integration/test_benders_decomposed.py index 1aef0492..c9a0db33 100644 --- a/tests/integration/test_benders_decomposed.py +++ b/tests/integration/test_benders_decomposed.py @@ -13,8 +13,9 @@ import pandas as pd import pytest -from andromede.expression.expression import literal, param, var +from andromede.expression.expression import literal, param from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, CONSTANT, @@ -82,12 +83,12 @@ def thermal_candidate() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= var("p_max") + name="Max generation", expression_init=var("generation") <= var("p_max") ) ], objective_operational_contribution=(param("op_cost") * var("generation")) @@ -128,16 +129,17 @@ def discrete_candidate() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= var("p_max") + name="Max generation", expression_init=var("generation") <= var("p_max") ), Constraint( name="Max investment", - expression=var("p_max") == param("p_max_per_unit") * var("nb_units"), + expression_init=var("p_max") + == param("p_max_per_unit") * var("nb_units"), context=INVESTMENT, ), ], diff --git a/tests/models/test_electrolyzer.py b/tests/models/test_electrolyzer.py index e1323070..d6ef3808 100644 --- a/tests/models/test_electrolyzer.py +++ b/tests/models/test_electrolyzer.py @@ -10,8 +10,8 @@ # # This file is part of the Antares project. -from andromede.expression import literal, param, var -from andromede.expression.expression import port_field +from andromede.expression.expression import literal, param +from andromede.expression.linear_expression import port_field, var from andromede.libs.standard import CONSTANT, TIME_AND_SCENARIO_FREE from andromede.model import ( Constraint, @@ -41,7 +41,7 @@ binding_constraints=[ Constraint( name="Balance", - expression=port_field("electrical_port", "flow").sum_connections() + expression_init=port_field("electrical_port", "flow").sum_connections() == literal(0), ) ], @@ -60,7 +60,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("electrical_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], objective_operational_contribution=(param("cost") * var("generation")) @@ -76,7 +76,8 @@ binding_constraints=[ Constraint( name="Balance", - expression=port_field("h2_port", "flow").sum_connections() == literal(0), + expression_init=port_field("h2_port", "flow").sum_connections() + == literal(0), ) ], ) @@ -90,7 +91,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("h2_port", "flow"), - definition=-param("demand"), + definition_init=-param("demand"), ) ], ) @@ -111,17 +112,17 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("electrical_port", "flow"), - definition=-var("electrical_input"), + definition_init=-var("electrical_input"), ), PortFieldDefinition( port_field=PortFieldId("h2_port", "flow"), - definition=var("h2_output"), + definition_init=var("h2_output"), ), ], constraints=[ Constraint( name="Conversion", - expression=var("h2_output") + expression_init=var("h2_output") == var("electrical_input") * param("efficiency"), ) ], diff --git a/tests/models/test_short_term_storage_complex.py b/tests/models/test_short_term_storage_complex.py index 0d23c050..ef28c7e3 100644 --- a/tests/models/test_short_term_storage_complex.py +++ b/tests/models/test_short_term_storage_complex.py @@ -14,7 +14,8 @@ UNSUPPLIED_ENERGY_MODEL, ) from andromede.libs.standard_sc import SHORT_TERM_STORAGE_COMPLEX -from andromede.simulation import BlockBorderManagement, TimeBlock, build_problem +from andromede.simulation import TimeBlock, build_problem +from andromede.simulation.optimization_context import BlockBorderManagement from andromede.study import ( ConstantData, DataBase, diff --git a/tests/unittests/expressions/parsing/test_expression_parsing.py b/tests/unittests/expressions/parsing/test_expression_parsing.py index 9620f4be..81ab09f5 100644 --- a/tests/unittests/expressions/parsing/test_expression_parsing.py +++ b/tests/unittests/expressions/parsing/test_expression_parsing.py @@ -9,18 +9,30 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. -from typing import Set +from typing import Set, Union import pytest -from andromede.expression import ExpressionNode, literal, param, print_expr, var from andromede.expression.equality import expressions_equal -from andromede.expression.expression import ExpressionRange, port_field +from andromede.expression.expression import ( + ExpressionNode, + ExpressionRange, + literal, + param, +) +from andromede.expression.linear_expression import ( + LinearExpression, + StandaloneConstraint, + linear_expressions_equal, + port_field, + var, +) from andromede.expression.parsing.parse_expression import ( AntaresParseException, ModelIdentifiers, parse_expression, ) +from andromede.expression.print import print_expr @pytest.mark.parametrize( @@ -46,22 +58,22 @@ ( {"x"}, {}, - "x[-1..5]", - var("x").eval(ExpressionRange(-literal(1), literal(5))), + "sum(-1..5, x)", + var("x").sum(eval=ExpressionRange(-literal(1), literal(5))), ), ({"x"}, {}, "x[1]", var("x").eval(1)), ({"x"}, {}, "x[t-1]", var("x").shift(-literal(1))), ( {"x"}, {}, - "x[t-1, t+4]", - var("x").shift([-literal(1), literal(4)]), + "sum((t-1, t+4), x)", # TODO: Should raise ValueError: shift always with sum + var("x").sum(shift=[-literal(1), literal(4)]), ), ( {"x"}, {}, "x[t-1+1]", - var("x").shift(-literal(1) + literal(1)), + var("x"), # Simplifications are applied very early in parsing !!!! ), ( {"x"}, @@ -90,26 +102,26 @@ ( {"x"}, {}, - "x[t-1, t, t+4]", - var("x").shift([-literal(1), literal(0), literal(4)]), + "x[t-1, t, t+4]", # TODO: Should raise ValueError: shift always with sum + var("x").sum(shift=[-literal(1), literal(0), literal(4)]), ), ( {"x"}, {}, - "x[t-1..t+5]", - var("x").shift(ExpressionRange(-literal(1), literal(5))), + "x[t-1..t+5]", # TODO: Should raise ValueError: shift always with sum + var("x").sum(shift=ExpressionRange(-literal(1), literal(5))), ), ( {"x"}, {}, - "x[t-1..t]", - var("x").shift(ExpressionRange(-literal(1), literal(0))), + "x[t-1..t]", # TODO: Should raise ValueError: shift always with sum + var("x").sum(shift=ExpressionRange(-literal(1), literal(0))), ), ( {"x"}, {}, - "x[t..t+5]", - var("x").shift(ExpressionRange(literal(0), literal(5))), + "x[t..t+5]", # TODO: Should raise ValueError: shift always with sum + var("x").sum(shift=ExpressionRange(literal(0), literal(5))), ), ({"x"}, {}, "x[t]", var("x")), ({"x"}, {"p"}, "x[t+p]", var("x").shift(param("p"))), @@ -117,7 +129,7 @@ {"x"}, {}, "sum(x[-1..5])", - var("x").eval(ExpressionRange(-literal(1), literal(5))).sum(), + var("x").sum(eval=ExpressionRange(-literal(1), literal(5))), ), ({}, {}, "sum_connections(port.f)", port_field("port", "f").sum_connections()), ( @@ -134,9 +146,9 @@ {"nb_start", "nb_on"}, {"d_min_up"}, "sum(nb_start[-d_min_up + 1 .. 0]) <= nb_on", - var("nb_start") - .eval(ExpressionRange(-param("d_min_up") + 1, literal(0))) - .sum() + var("nb_start").sum( + eval=(ExpressionRange(-param("d_min_up") + 1, literal(0))) + ) <= var("nb_on"), ), ( @@ -151,13 +163,19 @@ def test_parsing_visitor( variables: Set[str], parameters: Set[str], expression_str: str, - expected: ExpressionNode, + expected: Union[ExpressionNode, LinearExpression, StandaloneConstraint], ) -> None: identifiers = ModelIdentifiers(variables, parameters) expr = parse_expression(expression_str, identifiers) print() - print(print_expr(expr)) - assert expressions_equal(expr, expected) + print(f"Expected: \n {str(expected)}") + print(f"Parsed: \n {str(expr)}") + if isinstance(expected, ExpressionNode): + assert expressions_equal(expr, expected) + elif isinstance(expected, LinearExpression): + assert linear_expressions_equal(expr, expected) + elif isinstance(expected, StandaloneConstraint): + assert expected == expr @pytest.mark.parametrize( @@ -168,6 +186,9 @@ def test_parsing_visitor( "x[t+1-t]", "x[2*t]", "x[t 4]", + "x[t..4]", + "x[t+1..t+4]", + "x[1..4]", ], ) def test_parse_cancellation_should_throw(expression_str: str) -> None: diff --git a/tests/unittests/expressions/test_equality.py b/tests/unittests/expressions/test_equality.py index 5d5fd5c2..0d3a74c8 100644 --- a/tests/unittests/expressions/test_equality.py +++ b/tests/unittests/expressions/test_equality.py @@ -10,38 +10,65 @@ # # This file is part of the Antares project. -import math import pytest -from andromede.expression import ExpressionNode, copy_expression, literal, param, var +from andromede.expression.copy import copy_expression from andromede.expression.equality import expressions_equal from andromede.expression.expression import ( - ExpressionRange, + ExpressionNode, + InstancesTimeIndex, + TimeAggregatorName, TimeAggregatorNode, + TimeOperatorName, + TimeOperatorNode, expression_range, + literal, + param, ) -def shifted_x() -> ExpressionNode: - return var("x").shift(expression_range(0, 2)) +def shifted_param() -> ExpressionNode: + return TimeOperatorNode( + param("q"), TimeOperatorName.SHIFT, InstancesTimeIndex(expression_range(0, 2)) + ) @pytest.mark.parametrize( "expr", [ - var("x"), + param("q"), param("p"), - var("x") + 1, - var("x") - 1, - var("x") / 2, - var("x") * 3, - var("x").shift(expression_range(1, 10, 2)).sum(), - var("x").shift(expression_range(1, param("p"))).sum(), - TimeAggregatorNode(shifted_x(), name="TimeSum", stay_roll=True), - TimeAggregatorNode(shifted_x(), name="TimeAggregator", stay_roll=True), - var("x") + 5 <= 2, - var("x").expec(), + param("q") + 1, + param("q") - 1, + param("q") / 2, + param("q") * 3, + TimeAggregatorNode( + TimeOperatorNode( + param("q"), + TimeOperatorName.SHIFT, + InstancesTimeIndex(expression_range(1, 10, 2)), + ), + TimeAggregatorName.TIME_SUM, + stay_roll=True, + ), + TimeAggregatorNode( + TimeOperatorNode( + param("q"), + TimeOperatorName.SHIFT, + InstancesTimeIndex(expression_range(1, param("p"))), + ), + TimeAggregatorName.TIME_SUM, + stay_roll=True, + ), + TimeAggregatorNode( + shifted_param(), name=TimeAggregatorName.TIME_SUM, stay_roll=True + ), + TimeAggregatorNode( + shifted_param(), name=TimeAggregatorName.TIME_SUM, stay_roll=True + ), + param("q") + 5 <= 2, + param("q").expec(), ], ) def test_equals(expr: ExpressionNode) -> None: @@ -52,26 +79,58 @@ def test_equals(expr: ExpressionNode) -> None: @pytest.mark.parametrize( "rhs, lhs", [ - (var("x"), var("y")), + (param("q"), param("y")), (literal(1), literal(2)), - (var("x") + 1, var("x")), - ( - var("x").shift(expression_range(1, param("p"))).sum(), - var("x").shift(expression_range(1, param("q"))).sum(), - ), + (param("q") + 1, param("q")), ( - var("x").shift(expression_range(1, 10, 2)).sum(), - var("x").shift(expression_range(1, 10, 3)).sum(), + TimeAggregatorNode( + TimeOperatorNode( + param("q"), + TimeOperatorName.SHIFT, + InstancesTimeIndex(expression_range(1, param("p"))), + ), + TimeAggregatorName.TIME_SUM, + stay_roll=True, + ), + TimeAggregatorNode( + TimeOperatorNode( + param("q"), + TimeOperatorName.SHIFT, + InstancesTimeIndex(expression_range(1, param("q"))), + ), + TimeAggregatorName.TIME_SUM, + stay_roll=True, + ), ), ( - TimeAggregatorNode(shifted_x(), name="TimeSum", stay_roll=True), - TimeAggregatorNode(shifted_x(), name="TimeSum", stay_roll=False), + TimeAggregatorNode( + TimeOperatorNode( + param("q"), + TimeOperatorName.SHIFT, + InstancesTimeIndex(expression_range(1, 10, 2)), + ), + TimeAggregatorName.TIME_SUM, + stay_roll=True, + ), + TimeAggregatorNode( + TimeOperatorNode( + param("q"), + TimeOperatorName.SHIFT, + InstancesTimeIndex(expression_range(1, 10, 3)), + ), + TimeAggregatorName.TIME_SUM, + stay_roll=True, + ), ), ( - TimeAggregatorNode(shifted_x(), name="TimeSum", stay_roll=True), - TimeAggregatorNode(shifted_x(), name="TimeAggregator", stay_roll=True), + TimeAggregatorNode( + shifted_param(), name=TimeAggregatorName.TIME_SUM, stay_roll=True + ), + TimeAggregatorNode( + shifted_param(), name=TimeAggregatorName.TIME_SUM, stay_roll=False + ), ), - (var("x").expec(), var("y").expec()), + (param("q").expec(), param("y").expec()), ], ) def test_not_equals(lhs: ExpressionNode, rhs: ExpressionNode) -> None: diff --git a/tests/unittests/expressions/test_expressions.py b/tests/unittests/expressions/test_expressions.py index 81bebcd3..699dd1a0 100644 --- a/tests/unittests/expressions/test_expressions.py +++ b/tests/unittests/expressions/test_expressions.py @@ -10,46 +10,46 @@ # # This file is part of the Antares project. +# This test file should contain tests on ExpressionNode objects rather than on LinearExpression objects which are already tested in test_linear_expressions.py ... + +import re from dataclasses import dataclass, field from typing import Dict import pytest -from andromede.expression import ( - AdditionNode, - DivisionNode, - EvaluationContext, - EvaluationVisitor, - ExpressionDegreeVisitor, +from andromede.expression.equality import expressions_equal +from andromede.expression.expression import ( + ComponentParameterNode, ExpressionNode, + ExpressionRange, + InstancesTimeIndex, LiteralNode, ParameterNode, - ParameterValueProvider, - PrinterVisitor, - ValueProvider, - VariableNode, + TimeAggregatorName, + TimeAggregatorNode, + TimeOperatorName, + TimeOperatorNode, + comp_param, literal, param, - resolve_parameters, - sum_expressions, - var, - visit, ) -from andromede.expression.equality import expressions_equal -from andromede.expression.expression import ( - ComponentParameterNode, - ComponentVariableNode, - ExpressionRange, - Instances, - comp_param, +from andromede.expression.indexing import IndexingStructureProvider +from andromede.expression.indexing_structure import IndexingStructure, RowIndex +from andromede.expression.linear_expression import ( + LinearExpression, + StandaloneConstraint, + Term, + TermKey, comp_var, - port_field, + linear_expressions_equal, + sum_expressions, + var, + wrap_in_linear_expr, ) -from andromede.expression.indexing import IndexingStructureProvider, compute_indexation -from andromede.expression.indexing_structure import IndexingStructure -from andromede.model.model import PortFieldId -from andromede.simulation.linear_expression import LinearExpression, Term -from andromede.simulation.linearize import linearize_expression +from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum +from andromede.expression.value_provider import TimeScenarioIndex, TimeScenarioIndices +from tests.utils import EvaluationContext, ValueProvider @dataclass(frozen=True) @@ -72,124 +72,421 @@ class ComponentEvaluationContext(ValueProvider): variables: Dict[ComponentValueKey, float] = field(default_factory=dict) parameters: Dict[ComponentValueKey, float] = field(default_factory=dict) - def get_variable_value(self, name: str) -> float: + def get_variable_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: raise NotImplementedError() - def get_parameter_value(self, name: str) -> float: + def get_parameter_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: raise NotImplementedError() - def get_component_variable_value(self, component_id: str, name: str) -> float: - return self.variables[comp_key(component_id, name)] + def get_component_variable_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + return {TimeScenarioIndex(0, 0): self.variables[comp_key(component_id, name)]} - def get_component_parameter_value(self, component_id: str, name: str) -> float: - return self.parameters[comp_key(component_id, name)] + def get_component_parameter_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + return {TimeScenarioIndex(0, 0): self.parameters[comp_key(component_id, name)]} def parameter_is_constant_over_time(self, name: str) -> bool: raise NotImplementedError() + @staticmethod + def block_length() -> int: + raise NotImplementedError() + + @staticmethod + def scenarios() -> int: + raise NotImplementedError() -def test_comp_parameter() -> None: - add_node = AdditionNode(LiteralNode(1), ComponentVariableNode("comp1", "x")) - expr = DivisionNode(add_node, ComponentParameterNode("comp1", "p")) - assert visit(expr, PrinterVisitor()) == "((1 + comp1.x) / comp1.p)" +# TODO: Redundant with tests in test_linear_expressions_efficient ? +def test_comp_parameter() -> None: + expr1 = LinearExpression([], 1) + LinearExpression([Term(1, "comp1", "x")]) + expr2 = expr1 / LinearExpression(constant=ComponentParameterNode("comp1", "p")) + assert str(expr2) == "(1.0 / comp1.p)x + (1.0 / comp1.p)" context = ComponentEvaluationContext( variables={comp_key("comp1", "x"): 3}, parameters={comp_key("comp1", "p"): 4} ) - assert visit(expr, EvaluationVisitor(context)) == 1 + # Need to specify at which (t, w) to evaluate as the information is not contained anymore within the value provider + assert expr2.evaluate(context, RowIndex(0, 0)) == 1 +# TODO: Find a better name def test_ast() -> None: - add_node = AdditionNode(LiteralNode(1), VariableNode("x")) - expr = DivisionNode(add_node, ParameterNode("p")) + expr1 = LinearExpression([], 1) + LinearExpression([Term(1, "", "x")]) + expr2 = expr1 / LinearExpression(constant=ParameterNode("p")) - assert visit(expr, PrinterVisitor()) == "((1 + x) / p)" + assert str(expr2) == "(1.0 / p)x + (1.0 / p)" context = EvaluationContext(variables={"x": 3}, parameters={"p": 4}) - assert visit(expr, EvaluationVisitor(context)) == 1 + assert expr2.evaluate(context, RowIndex(0, 0)) == 1 def test_operators() -> None: x = var("x") p = param("p") - expr: ExpressionNode = (5 * x + 3) / p - 2 + expr: LinearExpression = (5 * x + 3) / p - 2 - assert visit(expr, PrinterVisitor()) == "((((5.0 * x) + 3.0) / p) - 2.0)" + assert str(expr) == "(5.0 / p)x + ((3.0 / p) - 2.0)" context = EvaluationContext(variables={"x": 3}, parameters={"p": 4}) - assert visit(expr, EvaluationVisitor(context)) == pytest.approx(2.5, 1e-16) - - assert visit(-expr, EvaluationVisitor(context)) == pytest.approx(-2.5, 1e-16) - - -def test_degree() -> None: - x = var("x") - p = param("p") - expr = (5 * x + 3) / p - - assert visit(expr, ExpressionDegreeVisitor()) == 1 + assert expr.evaluate(context, RowIndex(0, 0)) == pytest.approx(2.5, 1e-16) - expr = x * expr - assert visit(expr, ExpressionDegreeVisitor()) == 2 + assert -expr.evaluate(context, RowIndex(0, 0)) == pytest.approx(-2.5, 1e-16) -@pytest.mark.xfail(reason="Degree simplification not implemented") def test_degree_computation_should_take_into_account_simplifications() -> None: x = var("x") expr = x - x - assert visit(expr, ExpressionDegreeVisitor()) == 0 - - expr = LiteralNode(0) * x - assert visit(expr, ExpressionDegreeVisitor()) == 0 - - -def test_parameters_resolution() -> None: - class TestParamProvider(ParameterValueProvider): - def get_component_parameter_value(self, component_id: str, name: str) -> float: - raise NotImplementedError() - - def get_parameter_value(self, name: str) -> float: - return 2 - - x = var("x") - p = param("p") - expr = (5 * x + 3) / p - assert resolve_parameters(expr, TestParamProvider()) == (5 * x + 3) / 2 - + assert expr.is_constant() + + expr = 0 * x + assert expr.is_constant() + assert expr.is_zero() + + +# TODO: Write tests on ExpressionNodes for tree simplification, do the same for multiplication, substraction, etc +@pytest.mark.parametrize( + "e1, e2, expected", + [ + ( + var("x"), + -var("x"), + LinearExpression(), + ), + ( + param("p"), + -param("p"), + LinearExpression(), + ), + ( + var("x"), + -var("y"), + var("x") - var("y"), + ), + ( + comp_var("c1", "x"), + var("x"), + comp_var("c1", "x") + var("x"), + ), + ( + comp_var("c1", "x"), + comp_var("c2", "x"), + comp_var("c1", "x") + comp_var("c2", "x"), + ), + ( + comp_param("c1", "p"), + comp_param("c2", "p"), + comp_param("c1", "p") + comp_param("c2", "p"), + ), + ( + comp_var("c1", "x"), + comp_param("c1", "p"), + comp_var("c1", "x") + comp_param("c1", "p"), + ), + ( + param("p1"), + param("p2"), + param("p1") + param("p2"), + ), + ( + var("x"), + var("x"), + 2 * var("x"), + ), + ( + param("p"), + param("p"), + 2 * param("p"), + ), + ( + literal(4) * param("p"), + param("p"), + 5 * param("p"), + ), + ( + param("p"), + param("p") * param("q"), + (1 + param("q")) + * param("p"), # Equality visitor not able to handle commutativity + ), + ], +) +def test_addition( + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, +) -> None: + assert linear_expressions_equal( + wrap_in_linear_expr(e1) + wrap_in_linear_expr(e2), wrap_in_linear_expr(expected) + ) -def test_linearization() -> None: - x = comp_var("c", "x") - expr = (5 * x + 3) / 2 - provider = StructureProvider() - assert linearize_expression(expr, provider) == LinearExpression( - [Term(2.5, "c", "x")], 1.5 +@pytest.mark.parametrize( + "e1, e2, expected", + [ + ( + var("x"), + -var("x"), + 2 * var("x"), + ), + ( + param("p"), + -param("p"), + 2 * param("p"), + ), + ( + var("x"), + -var("y"), + var("x") + var("y"), + ), + ( + comp_var("c1", "x"), + var("x"), + comp_var("c1", "x") - var("x"), + ), + ( + comp_var("c1", "x"), + comp_var("c2", "x"), + comp_var("c1", "x") - comp_var("c2", "x"), + ), + ( + comp_param("c1", "p"), + comp_param("c2", "p"), + comp_param("c1", "p") - comp_param("c2", "p"), + ), + ( + comp_var("c1", "x"), + comp_param("c1", "p"), + comp_var("c1", "x") - comp_param("c1", "p"), + ), + ( + param("p1"), + param("p2"), + param("p1") - param("p2"), + ), + ( + var("x"), + var("x"), + LinearExpression(), + ), + ( + param("p"), + param("p"), + LinearExpression(), + ), + ( + literal(4) * param("p"), + param("p"), + 3 * param("p"), + ), + ( + param("p"), + param("p") * param("q"), + (1 - param("q")) + * param("p"), # Equality visitor not able to handle commutativity + ), + ], +) +def test_substraction( + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, +) -> None: + assert linear_expressions_equal( + wrap_in_linear_expr(e1) - wrap_in_linear_expr(e2), wrap_in_linear_expr(expected) ) - with pytest.raises(ValueError): - linearize_expression(param("p") * x, provider) +@pytest.mark.parametrize( + "lhs, rhs", + [ + ( + (5 * comp_var("c", "x") + 3) / 2, + LinearExpression([Term(2.5, "c", "x")], 1.5), + ), + ( + param("p") * comp_var("c", "x"), + LinearExpression( + [Term(ParameterNode("p"), "c", "x")], + ), + ), + ( + param("p") * comp_var("c", "x"), + LinearExpression( + [Term(ParameterNode("p"), "c", "x")], + ), + ), + ], +) +def test_linear_expression_equality( + lhs: LinearExpression, rhs: LinearExpression +) -> None: + assert linear_expressions_equal(lhs, rhs) -def test_linearization_of_non_linear_expressions_should_raise_value_error() -> None: - x = var("x") - expr = x.variance() - provider = StructureProvider() - with pytest.raises(ValueError) as exc: - linearize_expression(expr, provider) - assert ( - str(exc.value) - == "Cannot linearize expression with a non-linear operator: Variance" +def test_standalone_constraint() -> None: + cst = StandaloneConstraint( + var("x"), wrap_in_linear_expr(literal(0)), wrap_in_linear_expr(literal(10)) ) + assert str(cst) == "0 <= +x <= + 10" + def test_comparison() -> None: x = var("x") p = param("p") - expr: ExpressionNode = (5 * x + 3) >= p - 2 - assert visit(expr, PrinterVisitor()) == "((5.0 * x) + 3.0) >= (p - 2.0)" + expr_geq = (5 * x + 3) >= p - 2 + expr_leq = (5 * x + 3) <= p - 2 + expr_eq = (5 * x + 3) == p - 2 + + assert str(expr_geq) == "0 <= 5.0x + (3.0 - (p - 2.0)) <= + inf" + assert str(expr_leq) == " + (-inf) <= 5.0x + (3.0 - (p - 2.0)) <= 0" + assert str(expr_eq) == "0 <= 5.0x + (3.0 - (p - 2.0)) <= 0" + + +# TODO: Maybe imagine other use cases, that should be forbidden (composition of operators...) +@pytest.mark.parametrize( + "expr, expec_terms, expec_constant", + [ + ( + (var("x") + var("y") + literal(1)).shift(1), + { + TermKey( + "", + "x", + TimeShift(InstancesTimeIndex(1)), + time_aggregator=TimeSum( + stay_roll=True + ), # The internal representation of shift(1) is sum(shift=1) + scenario_aggregator=None, + ): Term( + LiteralNode(1), + "", + "x", + time_operator=TimeShift( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + ), + TermKey( + "", + "y", + TimeShift( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + scenario_aggregator=None, + ): Term( + LiteralNode(1), + "", + "y", + time_operator=TimeShift(InstancesTimeIndex(1)), + time_aggregator=TimeSum(stay_roll=True), + ), + }, + TimeAggregatorNode( + TimeOperatorNode( + LiteralNode(1), TimeOperatorName.SHIFT, InstancesTimeIndex(1) + ), + TimeAggregatorName.TIME_SUM, + stay_roll=True, + ), # TODO: Could it be simplified online ? + ), + ( + (var("x") + var("y") + literal(1)).eval(1), + { + TermKey( + "", + "x", + TimeEvaluation(InstancesTimeIndex(1)), + time_aggregator=TimeSum( + stay_roll=True + ), # The internal representation of eval(1) is sum(eval=1) + scenario_aggregator=None, + ): Term( + LiteralNode(1), + "", + "x", + time_operator=TimeEvaluation( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + ), + TermKey( + "", + "y", + TimeEvaluation( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + scenario_aggregator=None, + ): Term( + LiteralNode(1), + "", + "y", + time_operator=TimeEvaluation(InstancesTimeIndex(1)), + time_aggregator=TimeSum(stay_roll=True), + ), + }, + TimeAggregatorNode( + TimeOperatorNode( + LiteralNode(1), TimeOperatorName.EVALUATION, InstancesTimeIndex(1) + ), + TimeAggregatorName.TIME_SUM, + stay_roll=True, + ), # TODO: Could it be simplified online ? + ), + ( + (var("x") + var("y") + literal(1)).sum(), + { + TermKey( + "", + "x", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + scenario_aggregator=None, + ): Term( + LiteralNode(1), # Sum is not distributed to coeff + "", + "x", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + ), + TermKey( + "", + "y", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + scenario_aggregator=None, + ): Term( + LiteralNode(1), # Sum is not distributed to coeff + "", + "y", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + ), + }, + TimeAggregatorNode( + LiteralNode(1), TimeAggregatorName.TIME_SUM, stay_roll=False + ), # TODO: Could it be simplified online ? + ), + ], +) +def test_operators_are_correctly_distributed_over_terms( + expr: LinearExpression, + expec_terms: Dict[TermKey, Term], + expec_constant: ExpressionNode, +) -> None: + assert expr.terms == expec_terms + assert expressions_equal(expr.constant, expec_constant) class StructureProvider(IndexingStructureProvider): @@ -210,65 +507,81 @@ def get_variable_structure(self, name: str) -> IndexingStructure: return IndexingStructure(True, True) -def test_shift() -> None: +def test_shift_on_time_step_list_raises_value_error() -> None: x = var("x") - expr = x.shift(ExpressionRange(literal(1), literal(4))) - - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(True, True) - assert expr.instances == Instances.MULTIPLE + with pytest.raises( + ValueError, + match=re.escape( + "The shift operator can only be applied on expressions refering to a single time step. To apply a shifting sum on multiple time indices on an expression x, you should use x.sum(shift=...)" + ), + ): + _ = x.shift(ExpressionRange(1, 4)) -def test_shifting_sum() -> None: +def test_eval_on_time_step_list_raises_value_error() -> None: x = var("x") - expr = x.shift(ExpressionRange(literal(1), literal(4))).sum() - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(True, True) - assert expr.instances == Instances.SIMPLE - - -def test_eval() -> None: - x = var("x") - expr = x.eval(ExpressionRange(literal(1), literal(4))) - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(False, True) - assert expr.instances == Instances.MULTIPLE - - -def test_eval_sum() -> None: - x = var("x") - expr = x.eval(ExpressionRange(literal(1), literal(4))).sum() - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(False, True) - assert expr.instances == Instances.SIMPLE - - -def test_sum_over_whole_block() -> None: - x = var("x") - expr = x.sum() + with pytest.raises( + ValueError, + match=re.escape( + "The eval operator can only be applied on expressions refering to a single time step. To apply a evaluation sum on multiple time indices on an expression x, you should use x.sum(eval=...)" + ), + ): + _ = x.eval(ExpressionRange(1, 4)) + + +# TODO: Shoudl be moved to test_linear_expression_efficient +@pytest.mark.parametrize( + "linear_expr, expected_indexation", + [ + ( + var("x").shift(1), + IndexingStructure(True, True), + ), + ( + var("x").sum(shift=ExpressionRange(1, 4)), + IndexingStructure(True, True), + ), + ( + var("x").eval(1), + IndexingStructure(False, True), + ), + ( + var("x").sum(eval=ExpressionRange(1, 4)), + IndexingStructure(False, True), + ), + ( + var("x").sum(), + IndexingStructure(False, True), + ), + ( + var("x").expec(), + IndexingStructure(True, False), + ), + ( + var("x").sum().expec(), + IndexingStructure(False, False), + ), + ( + var("x").shift(1).expec(), + IndexingStructure(True, False), + ), + ( + var("x").eval(1).expec(), + IndexingStructure(False, False), + ), + ], +) +def test_compute_indexation( + linear_expr: LinearExpression, expected_indexation: IndexingStructure +) -> None: provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(False, True) - assert expr.instances == Instances.SIMPLE + assert linear_expr.compute_indexation(provider) == expected_indexation def test_forbidden_composition_should_raise_value_error() -> None: x = var("x") with pytest.raises(ValueError): - _ = x.shift(ExpressionRange(literal(1), literal(4))) + var("y") - - -def test_expectation() -> None: - x = var("x") - expr = x.expec() - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(True, False) - assert expr.instances == Instances.SIMPLE + _ = x.shift(ExpressionRange(1, 4)) + var("y") def test_indexing_structure_comparison() -> None: @@ -301,13 +614,42 @@ def get_variable_structure(self, name: str) -> IndexingStructure: provider = CustomStructureProvider() - assert compute_indexation(expr, provider) == IndexingStructure(True, True) - - -def test_sum_expressions() -> None: - assert expressions_equal(sum_expressions([]), literal(0)) - assert expressions_equal(sum_expressions([literal(1)]), literal(1)) - assert expressions_equal(sum_expressions([literal(1), var("x")]), 1 + var("x")) - assert expressions_equal( - sum_expressions([literal(1), var("x"), param("p")]), 1 + (var("x") + param("p")) - ) + assert expr.compute_indexation(provider) == IndexingStructure(True, True) + + +@pytest.mark.parametrize( + "sum_expr, expected", + [ + (sum_expressions([]), literal(0)), + (sum_expressions([wrap_in_linear_expr(literal(1))]), literal(1)), + (sum_expressions([wrap_in_linear_expr(literal(1)), var("x")]), 1 + var("x")), + ( + sum_expressions( + [ + wrap_in_linear_expr(literal(1)), + var("x"), + wrap_in_linear_expr(param("p")), + ] + ), + (1 + var("x")) + param("p"), + ), + ], +) +def test_sum_expressions( + sum_expr: LinearExpression, expected: LinearExpression +) -> None: + assert linear_expressions_equal(sum_expr, wrap_in_linear_expr(expected)) + + +@pytest.mark.parametrize( + "expr, unbound", + [ + (literal(float("inf")), True), + (literal(float("-inf")), True), + (literal(-float("inf")), True), + (var("x") + literal(float("-inf")), True), + (var("x") + literal(4), False), + ], +) +def test_is_unbound(expr: LinearExpression, unbound: bool) -> None: + assert wrap_in_linear_expr(expr).is_unbound() == unbound diff --git a/tests/unittests/expressions/test_linear_expressions.py b/tests/unittests/expressions/test_linear_expressions.py index 2564aa8f..b7e8771c 100644 --- a/tests/unittests/expressions/test_linear_expressions.py +++ b/tests/unittests/expressions/test_linear_expressions.py @@ -14,45 +14,19 @@ import pytest +from andromede.expression.expression import expression_range, param +from andromede.expression.linear_expression import ( + LinearExpression, + PortFieldId, + PortFieldTerm, + Term, + _copy_expression, + linear_expressions_equal, + var, + wrap_in_linear_expr, +) from andromede.expression.scenario_operator import Expectation from andromede.expression.time_operator import TimeShift, TimeSum -from andromede.simulation.linear_expression import LinearExpression, Term, TermKey - - -@pytest.mark.parametrize( - "term, expected", - [ - (Term(1, "c", "x"), "+x"), - (Term(-1, "c", "x"), "-x"), - (Term(2.50, "c", "x"), "+2.5x"), - (Term(-3, "c", "x"), "-3x"), - (Term(-3, "c", "x", time_operator=TimeShift([-1])), "-3x.shift([-1])"), - (Term(-3, "c", "x", time_aggregator=TimeSum(True)), "-3x.sum(True)"), - ( - Term( - -3, - "c", - "x", - time_operator=TimeShift([2, 3]), - time_aggregator=TimeSum(False), - ), - "-3x.shift([2, 3]).sum(False)", - ), - (Term(-3, "c", "x", scenario_operator=Expectation()), "-3x.expec()"), - ( - Term( - -3, - "c", - "x", - time_aggregator=TimeSum(True), - scenario_operator=Expectation(), - ), - "-3x.sum(True).expec()", - ), - ], -) -def test_printing_term(term: Term, expected: str) -> None: - assert str(term) == expected @pytest.mark.parametrize( @@ -60,9 +34,9 @@ def test_printing_term(term: Term, expected: str) -> None: [ (0, "x", 0, "0"), (1, "x", 0, "+x"), - (1, "x", 1, "+x+1"), - (3.7, "x", 1, "+3.7x+1"), - (0, "x", 1, "+1"), + (1, "x", 1, "+x + 1.0"), + (3.7, "x", 1, "3.7x + 1.0"), + (0, "x", 1, " + 1.0"), ], ) def test_affine_expression_printing_should_reflect_required_formatting( @@ -72,17 +46,49 @@ def test_affine_expression_printing_should_reflect_required_formatting( assert str(expr) == expec_str +@pytest.mark.parametrize( + "expr", + [ + var("x"), + wrap_in_linear_expr(param("p")), + var("x") + 1, + var("x") - 1, + var("x") / 2, + var("x") * 3, + var("x").sum(shift=expression_range(1, 10, 2)), + var("x").sum(shift=expression_range(1, param("p"))), + var("x").expec(), + ], +) +def test_linear_expressions_equal(expr: LinearExpression) -> None: + copy = LinearExpression() + _copy_expression(expr, copy) + assert linear_expressions_equal(expr, copy) + + @pytest.mark.parametrize( "lhs, rhs", [ - (LinearExpression([], 1) + LinearExpression([], 3), LinearExpression([], 4)), - (LinearExpression([], 4) / LinearExpression([], 2), LinearExpression([], 2)), - (LinearExpression([], 4) * LinearExpression([], 2), LinearExpression([], 8)), - (LinearExpression([], 4) - LinearExpression([], 2), LinearExpression([], 2)), + ( + LinearExpression([], 1) + LinearExpression([], 3), + LinearExpression([], 4), + ), + ( + LinearExpression([], 4) / LinearExpression([], 2), + LinearExpression([], 2), + ), + ( + LinearExpression([], 4) * LinearExpression([], 2), + LinearExpression([], 8), + ), + ( + LinearExpression([], 4) - LinearExpression([], 2), + LinearExpression([], 2), + ), ], ) def test_constant_expressions(lhs: LinearExpression, rhs: LinearExpression) -> None: - assert lhs == rhs + assert linear_expressions_equal(lhs, rhs) @pytest.mark.parametrize( @@ -93,7 +99,7 @@ def test_constant_expressions(lhs: LinearExpression, rhs: LinearExpression) -> N ], ) def test_instantiate_linear_expression_from_dict( - terms_dict: Dict[TermKey, Term], + terms_dict: Dict[str, Term], constant: float, exp_terms: Dict[str, Term], exp_constant: float, @@ -103,6 +109,26 @@ def test_instantiate_linear_expression_from_dict( assert expr.constant == exp_constant +@pytest.mark.parametrize( + "expr, expected", + [ + (LinearExpression(), True), + (LinearExpression([]), True), + (LinearExpression([], 0, {}), True), + (LinearExpression([Term(1, "c", "x")], 0, {}), False), + (LinearExpression([], 1, {}), False), + ( + LinearExpression( + [], 1, {PortFieldId("p", "f"): PortFieldTerm(1, "p", "f")} + ), + False, + ), + ], +) +def test_is_zero(expr: LinearExpression, expected: bool) -> None: + assert expr.is_zero() == expected + + @pytest.mark.parametrize( "e1, e2, expected", [ @@ -123,8 +149,8 @@ def test_instantiate_linear_expression_from_dict( ), ( LinearExpression(), - LinearExpression([Term(10, "c", "x", time_operator=TimeShift([-1]))]), - LinearExpression([Term(10, "c", "x", time_operator=TimeShift([-1]))]), + LinearExpression([Term(10, "c", "x", TimeShift(-1))]), + LinearExpression([Term(10, "c", "x", TimeShift(-1))]), ), ( LinearExpression(), @@ -137,9 +163,12 @@ def test_instantiate_linear_expression_from_dict( ), ( LinearExpression([Term(10, "c", "x")]), - LinearExpression([Term(10, "c", "x", time_operator=TimeShift([-1]))]), + LinearExpression([Term(10, "c", "x", time_operator=TimeShift(-1))]), LinearExpression( - [Term(10, "c", "x"), Term(10, "c", "x", time_operator=TimeShift([-1]))] + [ + Term(10, "c", "x"), + Term(10, "c", "x", time_operator=TimeShift(-1)), + ] ), ), ( @@ -150,8 +179,8 @@ def test_instantiate_linear_expression_from_dict( 10, "c", "x", - time_operator=TimeShift([-1]), - scenario_operator=Expectation(), + time_operator=TimeShift(-1), + scenario_aggregator=Expectation(), ) ] ), @@ -162,8 +191,8 @@ def test_instantiate_linear_expression_from_dict( 10, "c", "x", - time_operator=TimeShift([-1]), - scenario_operator=Expectation(), + time_operator=TimeShift(-1), + scenario_aggregator=Expectation(), ), ] ), @@ -171,15 +200,11 @@ def test_instantiate_linear_expression_from_dict( ], ) def test_addition( - e1: LinearExpression, e2: LinearExpression, expected: LinearExpression + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, ) -> None: - assert e1 + e2 == expected - - -def test_addition_of_linear_expressions_with_different_number_of_instances_should_raise_value_error() -> ( - None -): - pass + assert linear_expressions_equal(e1 + e2, expected) def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_from_terms() -> ( @@ -216,8 +241,8 @@ def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_fr 10, "c", "x", - time_operator=TimeShift([-1]), - scenario_operator=Expectation(), + time_operator=TimeShift(-1), + scenario_aggregator=Expectation(), ) ], 3, @@ -229,8 +254,8 @@ def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_fr 20, "c", "x", - time_operator=TimeShift([-1]), - scenario_operator=Expectation(), + time_operator=TimeShift(-1), + scenario_aggregator=Expectation(), ) ], 6, @@ -239,10 +264,12 @@ def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_fr ], ) def test_multiplication( - e1: LinearExpression, e2: LinearExpression, expected: LinearExpression + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, ) -> None: - assert e1 * e2 == expected - assert e2 * e1 == expected + assert linear_expressions_equal(e1 * e2, expected) + assert linear_expressions_equal(e2 * e1, expected) def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> None: @@ -267,9 +294,9 @@ def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> 10, "c", "x", - time_operator=TimeShift([-1]), + time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 5, @@ -280,9 +307,9 @@ def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> -10, "c", "x", - time_operator=TimeShift([-1]), + time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], -5, @@ -291,7 +318,7 @@ def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> ], ) def test_negation(e1: LinearExpression, expected: LinearExpression) -> None: - assert -e1 == expected + assert linear_expressions_equal(-e1, expected) @pytest.mark.parametrize( @@ -314,8 +341,8 @@ def test_negation(e1: LinearExpression, expected: LinearExpression) -> None: ), ( LinearExpression(), - LinearExpression([Term(10, "c", "x", time_operator=TimeShift([-1]))]), - LinearExpression([Term(-10, "c", "x", time_operator=TimeShift([-1]))]), + LinearExpression([Term(10, "c", "x", time_operator=TimeShift(-1))]), + LinearExpression([Term(-10, "c", "x", time_operator=TimeShift(-1))]), ), ( LinearExpression(), @@ -328,9 +355,12 @@ def test_negation(e1: LinearExpression, expected: LinearExpression) -> None: ), ( LinearExpression([Term(10, "c", "x")]), - LinearExpression([Term(10, "c", "x", time_operator=TimeShift([-1]))]), + LinearExpression([Term(10, "c", "x", time_operator=TimeShift(-1))]), LinearExpression( - [Term(10, "c", "x"), Term(-10, "c", "x", time_operator=TimeShift([-1]))] + [ + Term(10, "c", "x"), + Term(-10, "c", "x", time_operator=TimeShift(-1)), + ] ), ), ( @@ -341,9 +371,9 @@ def test_negation(e1: LinearExpression, expected: LinearExpression) -> None: 10, "c", "x", - time_operator=TimeShift([-1]), + time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ] ), @@ -354,9 +384,9 @@ def test_negation(e1: LinearExpression, expected: LinearExpression) -> None: -10, "c", "x", - time_operator=TimeShift([-1]), + time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ), ] ), @@ -364,9 +394,11 @@ def test_negation(e1: LinearExpression, expected: LinearExpression) -> None: ], ) def test_substraction( - e1: LinearExpression, e2: LinearExpression, expected: LinearExpression + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, ) -> None: - assert e1 - e2 == expected + assert linear_expressions_equal(e1 - e2, expected) @pytest.mark.parametrize( @@ -389,9 +421,9 @@ def test_substraction( 10, "c", "x", - time_operator=TimeShift([-1]), + time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 15, @@ -403,9 +435,9 @@ def test_substraction( 2, "c", "x", - time_operator=TimeShift([-1]), + time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 3, @@ -414,9 +446,11 @@ def test_substraction( ], ) def test_division( - e1: LinearExpression, e2: LinearExpression, expected: LinearExpression + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, ) -> None: - assert e1 / e2 == expected + assert linear_expressions_equal(e1 / e2, expected) def test_division_by_zero_sould_raise_zero_division_error() -> None: @@ -441,6 +475,6 @@ def test_imul_preserve_identity() -> None: e1 = LinearExpression([], 15) e2 = e1 e1 *= LinearExpression([], 2) - assert e1 == LinearExpression([], 30) - assert e2 == e1 + assert linear_expressions_equal(e1, LinearExpression([], 30)) + assert linear_expressions_equal(e2, e1) assert e2 is e1 diff --git a/tests/unittests/expressions/test_port_resolver.py b/tests/unittests/expressions/test_port_resolver.py index ad1370b4..ed30fcd9 100644 --- a/tests/unittests/expressions/test_port_resolver.py +++ b/tests/unittests/expressions/test_port_resolver.py @@ -12,37 +12,57 @@ from typing import Dict, List -from andromede.expression import ExpressionNode, var +import pytest + from andromede.expression.equality import expressions_equal -from andromede.expression.expression import port_field -from andromede.expression.port_resolver import PortFieldKey, resolve_port -from andromede.model.model import PortFieldId +from andromede.expression.linear_expression import ( + LinearExpression, + PortFieldId, + PortFieldKey, + linear_expressions_equal, + port_field, + var, +) -def test_port_field_resolution() -> None: - ports_expressions: Dict[PortFieldKey, List[ExpressionNode]] = {} +@pytest.mark.parametrize( + "port_expr, expected", + [ + (port_field("port", "field") + 2, var("flow") + 2), + (port_field("port", "field") - 2, var("flow") - 2), + (port_field("port", "field") * 2, 2 * var("flow")), + (port_field("port", "field") / 2, var("flow") / 2), + (port_field("port", "field") * 0, LinearExpression()), + ], +) +def test_port_field_resolution( + port_expr: LinearExpression, expected: LinearExpression +) -> None: + ports_expressions: Dict[PortFieldKey, List[LinearExpression]] = {} key = PortFieldKey("com_id", PortFieldId(field_name="field", port_name="port")) expression = var("flow") ports_expressions[key] = [expression] - expression_2 = port_field("port", "field") + 2 + print() + print(port_expr.resolve_port("com_id", ports_expressions)) + print(expected) - assert expressions_equal( - resolve_port(expression_2, "com_id", ports_expressions), var("flow") + 2 + assert linear_expressions_equal( + port_expr.resolve_port("com_id", ports_expressions), expected ) def test_port_field_resolution_sum() -> None: - ports_expressions: Dict[PortFieldKey, List[ExpressionNode]] = {} + ports_expressions: Dict[PortFieldKey, List[LinearExpression]] = {} key = PortFieldKey("com_id", PortFieldId(field_name="field", port_name="port")) ports_expressions[key] = [var("flow1"), var("flow2")] expression_2 = port_field("port", "field").sum_connections() - assert expressions_equal( - resolve_port(expression_2, "com_id", ports_expressions), + assert linear_expressions_equal( + expression_2.resolve_port("com_id", ports_expressions), var("flow1") + var("flow2"), ) diff --git a/tests/unittests/expressions/test_resolve_coefficients.py b/tests/unittests/expressions/test_resolve_coefficients.py new file mode 100644 index 00000000..f5619f0c --- /dev/null +++ b/tests/unittests/expressions/test_resolve_coefficients.py @@ -0,0 +1,334 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +import math +import re +from typing import Dict + +import pytest + +from andromede.expression.evaluate_parameters import resolve_coefficient +from andromede.expression.expression import ( + Comparator, + ComparisonNode, + ExpressionNode, + ExpressionRange, + InstancesTimeIndex, + LiteralNode, + PortFieldAggregatorName, + PortFieldAggregatorNode, + PortFieldNode, + TimeOperatorName, + TimeOperatorNode, + comp_param, + literal, + param, +) +from andromede.expression.indexing_structure import IndexingStructure, RowIndex +from andromede.expression.value_provider import ( + TimeScenarioIndex, + TimeScenarioIndices, + ValueProvider, +) + +p_values = { + TimeScenarioIndex(0, 0): 1.0, + TimeScenarioIndex(1, 0): 2.0, + TimeScenarioIndex(2, 0): 3.0, + TimeScenarioIndex(3, 0): 7.0, + TimeScenarioIndex(0, 1): 4.0, + TimeScenarioIndex(1, 1): 5.0, + TimeScenarioIndex(2, 1): 6.0, + TimeScenarioIndex(3, 1): 8.0, +} + +# A time constant parameter that can be put as TimeShift arg +comp_q_values = { + TimeScenarioIndex(0, 0): 2.0, + TimeScenarioIndex(0, 1): 1.0, +} + + +def _get_data_time_key(block_timestep: int, data_indexing: IndexingStructure) -> int: + return block_timestep if data_indexing.time else 0 + + +def _get_data_scenario_key(scenario: int, data_indexing: IndexingStructure) -> int: + return scenario if data_indexing.scenario else 0 + + +class CustomValueProvider(ValueProvider): + def get_component_variable_value( + self, + component_id: str, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError( + "Cannot provide variable value at problem build time." + ) + + def get_component_parameter_value( + self, + component_id: str, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + result = {} + param_indexing = IndexingStructure(False, True) + for block_timestep in time_scenarios_indices.time_indices: + for scenario in time_scenarios_indices.scenario_indices: + result[TimeScenarioIndex(block_timestep, scenario)] = comp_q_values[ + TimeScenarioIndex( + _get_data_time_key(block_timestep, param_indexing), + _get_data_scenario_key(scenario, param_indexing), + ) + ] + return result + + def get_variable_value( + self, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError( + "Cannot provide variable value at problem build time." + ) + + def get_parameter_value( + self, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + result = {} + param_indexing = IndexingStructure(True, True) + for block_timestep in time_scenarios_indices.time_indices: + for scenario in time_scenarios_indices.scenario_indices: + result[TimeScenarioIndex(block_timestep, scenario)] = p_values[ + TimeScenarioIndex( + _get_data_time_key(block_timestep, param_indexing), + _get_data_scenario_key(scenario, param_indexing), + ) + ] + return result + + def parameter_is_constant_over_time(self, name: str) -> bool: + return True + + @staticmethod + def block_length() -> int: + return 4 + + @staticmethod + def scenarios() -> int: + return 2 + + +@pytest.fixture +def provider() -> CustomValueProvider: + return CustomValueProvider() + + +@pytest.mark.parametrize( + "port_node", + [ + (PortFieldNode("port", "field")), + ( + PortFieldAggregatorNode( + PortFieldNode("port", "field"), PortFieldAggregatorName.PORT_SUM + ) + ), + ], +) +def test_resolve_coefficient_raises_value_error_on_port_field_node( + port_node: ExpressionNode, provider: CustomValueProvider +) -> None: + with pytest.raises( + ValueError, match="Port fields must be resolved before evaluating parameters" + ): + resolve_coefficient(port_node, provider, RowIndex(0, 0)) + + +def test_resolve_coefficient_raises_value_error_on_comparison_node( + provider: CustomValueProvider, +) -> None: + expr = ComparisonNode(LiteralNode(0), param("p"), Comparator.EQUAL) + with pytest.raises(ValueError, match="Cannot evaluate comparison operator."): + resolve_coefficient(expr, provider, RowIndex(0, 0)) + + +@pytest.mark.parametrize( + "expr", + [ + ( + TimeOperatorNode( + param("p"), + TimeOperatorName.SHIFT, + InstancesTimeIndex([LiteralNode(1), LiteralNode(2)]), + ) + ), + ( + TimeOperatorNode( + param("p"), + TimeOperatorName.EVALUATION, + InstancesTimeIndex([LiteralNode(1), LiteralNode(2)]), + ) + ), + ], +) +def test_resolve_coefficient_raises_value_error_on_expressions_that_are_not_aggregated_on_a_single_time_and_scenario( + expr: ExpressionNode, provider: CustomValueProvider +) -> None: + with pytest.raises( + ValueError, match="Evaluation of expression cannot be reduced to a float value" + ): + resolve_coefficient(expr, provider, RowIndex(0, 0)) + + +@pytest.mark.parametrize( + "expr", + [ + (param("p").shift(2)), + (param("p").eval(2)), + param("p").shift(comp_param("c", "q")), + ], +) +def test_resolve_coefficient_on_expression_with_shift_but_without_sum_raises_value_error( + expr: ExpressionNode, + provider: CustomValueProvider, +) -> None: + with pytest.raises( + ValueError, + match=re.escape( + "Expression has a time operator but not time aggregator, maybe you are missing a sum(), necessary even on one element" + ), + ): + resolve_coefficient(expr, provider, RowIndex(0, 0)) + + +@pytest.mark.parametrize( + "expr", + [ + ( + TimeOperatorNode( + param("p"), + TimeOperatorName.EVALUATION, + InstancesTimeIndex([comp_param("c", "q")]), + ) + ), + ( + TimeOperatorNode( + param("p"), + TimeOperatorName.SHIFT, + InstancesTimeIndex([param("q")]), + ) + ), + ], +) +def test_resolve_coefficient_with_no_time_varying_parameter_in_time_operator_argument_raises_value_error( + expr: ExpressionNode, +) -> None: + class TimeVaryingParameterValueProvider(CustomValueProvider): + def parameter_is_constant_over_time(self, name: str) -> bool: + return False + + provider = TimeVaryingParameterValueProvider() + + with pytest.raises( + ValueError, + match="Parameter given in an instance index expression must be constant over time", + ): + resolve_coefficient(expr, provider, RowIndex(0, 0)) + + +@pytest.mark.parametrize( + "expr, row_id, expected", + [ + (param("p"), RowIndex(0, 0), 1.0), + (comp_param("c", "q"), RowIndex(0, 0), 2.0), + (-comp_param("c", "q"), RowIndex(0, 0), -2.0), + (param("p") + comp_param("c", "q"), RowIndex(0, 0), 3.0), + (param("p") - comp_param("c", "q"), RowIndex(0, 0), -1.0), + (param("p") * LiteralNode(2), RowIndex(0, 0), 2.0), + (param("p") / LiteralNode(2), RowIndex(0, 0), 0.5), + ], +) +def test_resolve_coefficient_on_elementary_operations( + expr: ExpressionNode, + row_id: RowIndex, + expected: float, + provider: CustomValueProvider, +) -> None: + assert math.isclose(resolve_coefficient(expr, provider, row_id), expected) + + +@pytest.mark.parametrize( + "expr, row_id, expected", + [ + (param("p").shift(2).sum(), RowIndex(0, 0), 3.0), + (param("p").shift(-1).sum(), RowIndex(2, 1), 5.0), + (literal(0).shift(-1).sum(), RowIndex(0, 0), 0.0), + (param("p").eval(2).sum(), RowIndex(0, 0), 3.0), + (param("p").eval(2).sum(), RowIndex(2, 0), 3.0), + (param("p").shift(ExpressionRange(0, 3)).sum(), RowIndex(0, 0), 13.0), + (param("p").eval(ExpressionRange(1, 2)).sum(), RowIndex(0, 0), 5.0), + (param("p").eval(ExpressionRange(0, 3, 2)).sum(), RowIndex(0, 0), 4.0), + (param("p").shift(comp_param("c", "q")).sum(), RowIndex(1, 0), 7.0), + (param("p").shift(comp_param("c", "q")).sum(), RowIndex(1, 1), 6.0), + (param("p").sum(), RowIndex(0, 0), 13.0), + (param("p").sum(), RowIndex(2, 1), 23.0), + (comp_param("c", "q").sum(), RowIndex(0, 0), 2 * 4.0), + ], +) +def test_resolve_coefficient_on_time_shift_and_sum( + expr: ExpressionNode, + row_id: RowIndex, + expected: float, + provider: CustomValueProvider, +) -> None: + assert math.isclose(resolve_coefficient(expr, provider, row_id), expected) + + +@pytest.mark.parametrize( + "expr, row_id, expected", + [ + (param("p").expec(), RowIndex(0, 0), 2.5), + (param("p").expec(), RowIndex(1, 1), 3.5), + (comp_param("c", "q").expec(), RowIndex(1, 1), 1.5), + ], +) +def test_resolve_coefficient_on_expectation( + expr: ExpressionNode, + row_id: RowIndex, + expected: float, + provider: CustomValueProvider, +) -> None: + assert math.isclose(resolve_coefficient(expr, provider, row_id), expected) + + +@pytest.mark.parametrize( + "expr, row_id, expected", + [ + (param("p").expec().sum(), RowIndex(0, 0), 18.0), + (param("p").sum().expec(), RowIndex(0, 0), 18.0), + (param("p").shift(comp_param("c", "q")).sum().expec(), RowIndex(1, 0), 6.5), + (param("p").expec().shift(comp_param("c", "q")).sum(), RowIndex(1, 0), 7.5), + (param("p").shift(comp_param("c", "q")).expec().sum(), RowIndex(1, 0), 6.5), + ], +) +def test_resolve_coefficient_on_sum_and_expectation( + expr: ExpressionNode, + row_id: RowIndex, + expected: float, + provider: CustomValueProvider, +) -> None: + assert math.isclose(resolve_coefficient(expr, provider, row_id), expected) diff --git a/tests/unittests/expressions/test_term.py b/tests/unittests/expressions/test_term.py new file mode 100644 index 00000000..539ba994 --- /dev/null +++ b/tests/unittests/expressions/test_term.py @@ -0,0 +1,140 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +import pytest + +from andromede.expression.expression import LiteralNode +from andromede.expression.linear_expression import Term +from andromede.expression.scenario_operator import Expectation, Variance +from andromede.expression.time_operator import TimeShift, TimeSum + + +@pytest.mark.parametrize( + "term, expected", + [ + (Term(1, "c", "x"), "+x"), + (Term(-1, "c", "x"), "-x"), + (Term(2.50, "c", "x"), "2.5x"), + (Term(-3, "c", "x"), "-3.0x"), + (Term(-3, "c", "x", time_operator=TimeShift(-1)), "-3.0x.shift(-1)"), + (Term(-3, "c", "x", time_aggregator=TimeSum(True)), "-3.0x.sum(True)"), + ( + Term( + -3, + "c", + "x", + time_operator=TimeShift([2, 3]), + time_aggregator=TimeSum(False), + ), + "-3.0x.shift([2, 3]).sum(False)", + ), + ( + Term(-3, "c", "x", scenario_aggregator=Expectation()), + "-3.0x.expec()", + ), + ( + Term( + -3, + "c", + "x", + time_aggregator=TimeSum(True), + scenario_aggregator=Expectation(), + ), + "-3.0x.sum(True).expec()", + ), + ], +) +def test_printing_term(term: Term, expected: str) -> None: + assert str(term) == expected + + +@pytest.mark.parametrize( + "lhs, rhs, expected", + [ + (Term(1, "c", "x"), Term(1, "c", "x"), True), + (Term(1, "c", "x"), Term(2, "c", "x"), False), + ( + Term(LiteralNode(1), "c", "x"), + Term(LiteralNode(2), "c", "x"), + False, + ), + (Term(-1, "c", "x"), Term(-1, "", "x"), False), + (Term(2.50, "c", "x"), Term(2.50, "c", ""), False), + (Term(-3, "c", "x"), Term(-3, "c", "y"), False), + ( + Term(-3, "c", "x", time_operator=TimeShift(-1)), + Term(-3, "c", "x", time_operator=TimeShift(-1)), + True, + ), + ( + Term(-3, "c", "x", time_operator=TimeShift(-1)), + Term(-3, "c", "x"), + False, + ), + ( + Term(-3, "c", "x", time_aggregator=TimeSum(True)), + Term(-3, "c", "x", time_aggregator=TimeSum(True)), + True, + ), + ( + Term(-3, "c", "x", time_aggregator=TimeSum(True)), + Term(-3, "c", "x", time_operator=TimeShift(-1)), + False, + ), + ( + Term( + -3, + "c", + "x", + time_operator=TimeShift([2, 3]), + time_aggregator=TimeSum(False), + ), + Term( + -3, + "c", + "x", + time_operator=TimeShift([1, 3]), + time_aggregator=TimeSum(False), + ), + False, + ), + ( + Term(-3, "c", "x", scenario_aggregator=Expectation()), + Term(-3, "c", "x", scenario_aggregator=Expectation()), + True, + ), + ( + Term(-3, "c", "x", scenario_aggregator=Expectation()), + Term(-3, "c", "x", scenario_aggregator=Variance()), + False, + ), + ( + Term( + -3, + "c", + "x", + time_aggregator=TimeSum(True), + scenario_aggregator=Expectation(), + ), + Term( + -3, + "c", + "x", + time_aggregator=TimeSum(False), + scenario_aggregator=Expectation(), + ), + False, + ), + ], +) +def test_term_equality(lhs: Term, rhs: Term, expected: bool) -> None: + assert (lhs == rhs) == expected diff --git a/tests/unittests/model/test_model_parsing.py b/tests/unittests/model/test_model_parsing.py index 6512ed98..a79fcec5 100644 --- a/tests/unittests/model/test_model_parsing.py +++ b/tests/unittests/model/test_model_parsing.py @@ -9,13 +9,12 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. -import io from pathlib import Path import pytest -from andromede.expression import literal, param, var -from andromede.expression.expression import port_field +from andromede.expression.expression import literal, param +from andromede.expression.linear_expression import port_field, var from andromede.expression.parsing.parse_expression import AntaresParseException from andromede.libs.standard import CONSTANT from andromede.model import ( @@ -62,7 +61,7 @@ def test_library_parsing(data_dir: Path) -> None: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId(port_name="injection_port", field_name="flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], objective_operational_contribution=(param("cost") * var("generation")) @@ -101,13 +100,13 @@ def test_library_parsing(data_dir: Path) -> None: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId(port_name="injection_port", field_name="flow"), - definition=var("injection") - var("withdrawal"), + definition_init=var("injection") - var("withdrawal"), ) ], constraints=[ Constraint( name="Level equation", - expression=var("level") + expression_init=var("level") - var("level").shift(-literal(1)) - param("efficiency") * var("injection") + var("withdrawal") @@ -163,7 +162,8 @@ def test_library_port_model_ok_parsing(data_dir: Path) -> None: constraints=[ Constraint( name="Level equation", - expression=port_field("injection_port", "flow") == var("withdrawal"), + expression_init=port_field("injection_port", "flow") + == var("withdrawal"), ) ], ) diff --git a/tests/unittests/study/test_components_parsing.py b/tests/unittests/study/test_components_parsing.py index f72c5c2f..f5ad10c6 100644 --- a/tests/unittests/study/test_components_parsing.py +++ b/tests/unittests/study/test_components_parsing.py @@ -5,7 +5,8 @@ from andromede.model.parsing import InputLibrary, parse_yaml_library from andromede.model.resolve_library import resolve_library -from andromede.simulation import BlockBorderManagement, TimeBlock, build_problem +from andromede.simulation import TimeBlock, build_problem +from andromede.simulation.optimization_context import BlockBorderManagement from andromede.study import TimeScenarioIndex, TimeScenarioSeriesData from andromede.study.parsing import InputComponents, parse_yaml_components from andromede.study.resolve_components import ( diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index d41628dc..d30f4509 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -15,8 +15,9 @@ import pandas as pd import pytest -from andromede.expression import param, var +from andromede.expression.expression import param from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, CONSTANT, @@ -80,12 +81,13 @@ def mock_generator_with_fixed_scenario_time_varying_param() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= param("p_max") + name="Max generation", + expression_init=var("generation") <= param("p_max"), ) ], objective_operational_contribution=(param("cost") * var("generation")) @@ -108,12 +110,13 @@ def mock_generator_with_scenario_varying_fixed_time_param() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= param("p_max") + name="Max generation", + expression_init=var("generation") <= param("p_max"), ) ], objective_operational_contribution=(param("cost") * var("generation")) diff --git a/tests/unittests/test_model.py b/tests/unittests/test_model.py index 3917af54..bf5cae24 100644 --- a/tests/unittests/test_model.py +++ b/tests/unittests/test_model.py @@ -10,20 +10,23 @@ # # This file is part of the Antares project. +import re +from typing import Optional, Type + import pytest -from andromede.expression.expression import ( - ExpressionNode, - ExpressionRange, - comp_param, +from andromede.expression.expression import ExpressionRange, comp_param, param +from andromede.expression.linear_expression import ( + LinearExpression, comp_var, + linear_expressions_equal, literal, - param, port_field, var, + wrap_in_linear_expr, ) -from andromede.model import Constraint, float_parameter, float_variable, model -from andromede.model.model import PortFieldDefinition, port_field_def +from andromede.model import Constraint, float_variable, model +from andromede.model.model import port_field_def @pytest.mark.parametrize( @@ -36,8 +39,28 @@ literal(10), "my_constraint", 2 * var("my_var"), - literal(5), + wrap_in_linear_expr(literal(5)), + wrap_in_linear_expr(literal(10)), + ), + ( + "my_constraint", + 2 * var("my_var"), + None, literal(10), + "my_constraint", + 2 * var("my_var"), + wrap_in_linear_expr(literal(-float("inf"))), + wrap_in_linear_expr(literal(10)), + ), + ( + "my_constraint", + 2 * var("my_var"), + literal(5), + None, + "my_constraint", + 2 * var("my_var"), + wrap_in_linear_expr(literal(5)), + wrap_in_linear_expr(literal(float("inf"))), ), ( "my_constraint", @@ -46,8 +69,8 @@ None, "my_constraint", 2 * var("my_var"), - literal(-float("inf")), - literal(float("inf")), + wrap_in_linear_expr(literal(-float("inf"))), + wrap_in_linear_expr(literal(float("inf"))), ), ( "my_constraint", @@ -56,8 +79,8 @@ None, "my_constraint", 2 * var("my_var") - param("p"), - literal(-float("inf")), - literal(0), + wrap_in_linear_expr(literal(-float("inf"))), + wrap_in_linear_expr(literal(0)), ), ( "my_constraint", @@ -66,8 +89,8 @@ None, "my_constraint", 2 * var("my_var") - param("p"), - literal(0), - literal(float("inf")), + wrap_in_linear_expr(literal(0)), + wrap_in_linear_expr(literal(float("inf"))), ), ( "my_constraint", @@ -76,8 +99,8 @@ None, "my_constraint", 2 * var("my_var") - param("p"), - literal(0), - literal(0), + wrap_in_linear_expr(literal(0)), + wrap_in_linear_expr(literal(0)), ), ( "my_constraint", @@ -86,8 +109,8 @@ None, "my_constraint", 2 * var("my_var").expec() - param("p"), - literal(0), - literal(0), + wrap_in_linear_expr(literal(0)), + wrap_in_linear_expr(literal(0)), ), ( "my_constraint", @@ -96,26 +119,34 @@ None, "my_constraint", 2 * var("my_var").shift(-1) - param("p"), - literal(0), - literal(0), + wrap_in_linear_expr(literal(0)), + wrap_in_linear_expr(literal(0)), ), ], ) def test_constraint_instantiation( name: str, - expression: ExpressionNode, - lb: ExpressionNode, - ub: ExpressionNode, + expression: LinearExpression, + lb: Optional[LinearExpression], + ub: Optional[LinearExpression], exp_name: str, - exp_expr: ExpressionNode, - exp_lb: ExpressionNode, - exp_ub: ExpressionNode, + exp_expr: LinearExpression, + exp_lb: LinearExpression, + exp_ub: LinearExpression, ) -> None: - constraint = Constraint(name, expression, lb, ub) + if lb is None and ub is None: + constraint = Constraint(name, expression) + elif lb is None: + constraint = Constraint(name, expression, upper_bound=ub) + elif ub is None: + constraint = Constraint(name, expression, lower_bound=lb) + else: + constraint = Constraint(name, expression, lower_bound=lb, upper_bound=ub) + assert constraint.name == exp_name - assert constraint.expression == exp_expr - assert constraint.lower_bound == exp_lb - assert constraint.upper_bound == exp_ub + assert linear_expressions_equal(constraint.expression, exp_expr) + assert linear_expressions_equal(constraint.lower_bound, exp_lb) + assert linear_expressions_equal(constraint.upper_bound, exp_ub) def test_if_both_comparison_expression_and_bound_given_for_constraint_init_then_it_should_raise_a_value_error() -> ( @@ -134,13 +165,11 @@ def test_if_a_bound_is_not_constant_then_it_should_raise_a_value_error() -> None Constraint("my_constraint", 2 * var("my_var"), var("x")) assert ( str(exc.value) - == "The bounds of a constraint should not contain variables, x was given." + == "The bounds of a constraint should not contain variables, +x was given." ) -def test_writing_p_min_max_constraint_should_represent_all_expected_constraints() -> ( - None -): +def test_writing_p_min_max_constraint_should_not_raise_exception() -> None: """ Aim at representing the following mathematical constraints: For all t, p_min <= p[t] <= p_max * alpha[t] where p_min, p_max are literal paramters and alpha is an input timeseries @@ -160,19 +189,19 @@ def test_writing_p_min_max_constraint_should_represent_all_expected_constraints( assert False, f"Writing p_min and p_max constraints raises an exception: {exc}" -def test_writing_min_up_constraint_should_represent_all_expected_constraints() -> None: +def test_writing_min_up_constraint_should_not_raise_exception() -> None: """ Aim at representing the following mathematical constraints: For all t, for all t' in [t+1, t+d_min_up], off_on[k,t,w] <= on[k,t',w] """ try: - d_min_up = literal(3) + d_min_up = 3 off_on = var("off_on") on = var("on") _ = Constraint( "min_up_time", - off_on <= on.shift(ExpressionRange(literal(1), d_min_up)).sum(), + off_on <= on.sum(shift=ExpressionRange(1, d_min_up)), ) # Later on, the goal is to assert that when this constraint is sent to the solver, it correctly builds: for all t, for all t' in [t+1, t+d_min_up], off_on[k,t,w] <= on[k,t',w] @@ -181,6 +210,7 @@ def test_writing_min_up_constraint_should_represent_all_expected_constraints() - assert False, f"Writing min_up constraints raises an exception: {exc}" +@pytest.mark.skip(reason="Variance not implemented") def test_instantiating_a_model_with_non_linear_scenario_operator_in_the_objective_should_raise_type_error() -> ( None ): @@ -194,25 +224,47 @@ def test_instantiating_a_model_with_non_linear_scenario_operator_in_the_objectiv @pytest.mark.parametrize( - "expression", + "expression, error_type, error_msg", [ - var("x") <= 0, - comp_var("c", "x"), - comp_param("c", "x"), - port_field("p", "f"), - port_field("p", "f").sum_connections(), + ( + var("x") <= 0, + TypeError, + "Unable to wrap + (-inf) <= +x <= 0 into a linear expression", + ), + ( + comp_var("c", "x"), + ValueError, + "Port definition must not contain a variable associated to a component.", + ), + ( + comp_param("c", "x"), + ValueError, + "Port definition must not contain a parameter associated to a component.", + ), + ( + port_field("p", "f"), + ValueError, + "Port definition cannot reference another port field.", + ), + ( + port_field("p", "f").sum_connections(), + ValueError, + "Port definition cannot reference another port field.", + ), ], ) -def test_invalid_port_field_definition_should_raise(expression: ExpressionNode) -> None: - with pytest.raises(ValueError) as exc: +def test_invalid_port_field_definition_should_raise( + expression: LinearExpression, error_type: Type, error_msg: str +) -> None: + with pytest.raises(error_type, match=re.escape(error_msg)): port_field_def(port_name="p", field_name="f", definition=expression) def test_constraint_equals() -> None: # checks in particular that expressions are correctly compared - assert Constraint(name="c", expression=var("x") <= param("p")) == Constraint( - name="c", expression=var("x") <= param("p") + assert Constraint(name="c", expression_init=var("x") <= param("p")) == Constraint( + name="c", expression_init=var("x") <= param("p") ) - assert Constraint(name="c", expression=var("x") <= param("p")) != Constraint( - name="c", expression=var("y") <= param("p") + assert Constraint(name="c", expression_init=var("x") <= param("p")) != Constraint( + name="c", expression_init=var("y") <= param("p") ) diff --git a/tests/unittests/test_output_values.py b/tests/unittests/test_output_values.py index 89b3637e..f3d01c13 100644 --- a/tests/unittests/test_output_values.py +++ b/tests/unittests/test_output_values.py @@ -15,11 +15,8 @@ import ortools.linear_solver.pywraplp as lp from andromede.simulation import OutputValues -from andromede.simulation.optimization import ( - OptimizationContext, - OptimizationProblem, - TimestepComponentVariableKey, -) +from andromede.simulation.optimization import OptimizationContext, OptimizationProblem +from andromede.simulation.optimization_context import TimestepComponentVariableKey def test_component_and_flow_output_object() -> None: diff --git a/tests/unittests/test_port.py b/tests/unittests/test_port.py index 2ff4547f..6f173dee 100644 --- a/tests/unittests/test_port.py +++ b/tests/unittests/test_port.py @@ -12,10 +12,13 @@ import pytest -from andromede.expression import literal -from andromede.expression.expression import port_field +from andromede.expression.expression import literal +from andromede.expression.linear_expression import port_field from andromede.libs.standard import DEMAND_MODEL from andromede.model import Constraint, ModelPort, PortType, model +from andromede.model.constraint import Constraint +from andromede.model.model import ModelPort, model +from andromede.model.port import PortType from andromede.study import Node, PortRef, PortsConnection, create_component @@ -28,7 +31,8 @@ def test_port_type_compatibility_ko() -> None: constraints=[ Constraint( name="Balance", - expression=port_field("balance_port", "flow").sum() == literal(0), + expression_init=port_field("balance_port", "flow").sum_connections() + == literal(0), ) ], ) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..cf5ef16c --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from dataclasses import dataclass, field +from typing import Dict + +from andromede.expression.value_provider import ( + TimeScenarioIndex, + TimeScenarioIndices, + ValueProvider, +) + + +# Used only for tests +@dataclass(frozen=True) +class EvaluationContext(ValueProvider): + """ + Simple value provider relying on dictionaries. + Does not support component variables/parameters. + """ + + variables: Dict[str, float] = field(default_factory=dict) + parameters: Dict[str, float] = field(default_factory=dict) + + def get_variable_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + return {TimeScenarioIndex(0, 0): self.variables[name]} + + def get_parameter_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + return {TimeScenarioIndex(0, 0): self.parameters[name]} + + def get_component_variable_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError() + + def get_component_parameter_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError() + + def parameter_is_constant_over_time(self, name: str) -> bool: + raise NotImplementedError() + + @staticmethod + def block_length() -> int: + raise NotImplementedError() + + @staticmethod + def scenarios() -> int: + raise NotImplementedError()