From 58c12f83b37ed2c9d2d8cc2178176fd4fe2c9891 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Fri, 21 Jun 2024 11:43:45 +0200 Subject: [PATCH 01/51] WIP --- src/andromede/expression/__init__.py | 27 +- src/andromede/expression/equality.py | 57 +- .../expression/expression_efficient.py | 425 +++++++++++++++ .../expression/linear_expression_efficient.py | 495 ++++++++++++++++++ .../expression/parsing/parse_expression.py | 110 ++-- src/andromede/expression/port_resolver.py | 31 +- src/andromede/expression/print.py | 22 +- src/andromede/expression/visitor.py | 48 +- src/andromede/model/model.py | 93 +++- src/andromede/model/resolve_library.py | 5 +- src/andromede/model/variable.py | 19 +- src/andromede/simulation/linear_expression.py | 1 - src/andromede/simulation/optimization.py | 26 +- src/andromede/simulation/strategy.py | 11 +- .../unittests/expressions/test_expressions.py | 1 + .../expressions/test_expressions_efficient.py | 308 +++++++++++ .../test_linear_expressions_efficient.py | 495 ++++++++++++++++++ 17 files changed, 2034 insertions(+), 140 deletions(-) create mode 100644 src/andromede/expression/expression_efficient.py create mode 100644 src/andromede/expression/linear_expression_efficient.py create mode 100644 tests/unittests/expressions/test_expressions_efficient.py create mode 100644 tests/unittests/expressions/test_linear_expressions_efficient.py diff --git a/src/andromede/expression/__init__.py b/src/andromede/expression/__init__.py index 2fe9b94d..07825dee 100644 --- a/src/andromede/expression/__init__.py +++ b/src/andromede/expression/__init__.py @@ -18,22 +18,35 @@ ParameterValueProvider, resolve_parameters, ) -from .expression import ( + +# from .expression import ( +# AdditionNode, +# Comparator, +# ComparisonNode, +# DivisionNode, +# ExpressionNode, +# LiteralNode, +# MultiplicationNode, +# NegationNode, +# ParameterNode, +# SubstractionNode, +# VariableNode, +# literal, +# param, +# sum_expressions, +# var, +# ) +from .expression_efficient import ( AdditionNode, Comparator, ComparisonNode, DivisionNode, - ExpressionNode, + ExpressionNodeEfficient, LiteralNode, MultiplicationNode, NegationNode, ParameterNode, SubstractionNode, - VariableNode, - literal, - param, - sum_expressions, - var, ) from .print import PrinterVisitor, print_expr from .visitor import ExpressionVisitor, visit diff --git a/src/andromede/expression/equality.py b/src/andromede/expression/equality.py index a2deeb27..67eb1e80 100644 --- a/src/andromede/expression/equality.py +++ b/src/andromede/expression/equality.py @@ -14,29 +14,49 @@ from dataclasses import dataclass from typing import Optional -from andromede.expression import ( +from andromede.expression.expression_efficient import ( AdditionNode, + BinaryOperatorNode, ComparisonNode, DivisionNode, - ExpressionNode, + ExpressionNodeEfficient, + ExpressionRange, + InstancesTimeIndex, LiteralNode, MultiplicationNode, NegationNode, ParameterNode, - SubstractionNode, - VariableNode, -) -from andromede.expression.expression import ( - BinaryOperatorNode, - ExpressionRange, - InstancesTimeIndex, PortFieldAggregatorNode, PortFieldNode, ScenarioOperatorNode, + SubstractionNode, TimeAggregatorNode, TimeOperatorNode, ) +# from andromede.expression import ( +# AdditionNode, +# ComparisonNode, +# DivisionNode, +# ExpressionNode, +# LiteralNode, +# MultiplicationNode, +# NegationNode, +# ParameterNode, +# SubstractionNode, +# VariableNode, +# ) +# from andromede.expression.expression import ( +# BinaryOperatorNode, +# ExpressionRange, +# InstancesTimeIndex, +# PortFieldAggregatorNode, +# PortFieldNode, +# ScenarioOperatorNode, +# TimeAggregatorNode, +# TimeOperatorNode, +# ) + @dataclass(frozen=True) class EqualityVisitor: @@ -53,7 +73,9 @@ def __post_init__(self) -> None: f"Relative comparison tolerance must be >= 0, got {self.rel_tol}" ) - def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool: + def visit( + self, left: ExpressionNodeEfficient, right: ExpressionNodeEfficient + ) -> bool: if left.__class__ != right.__class__: return False if isinstance(left, LiteralNode) and isinstance(right, LiteralNode): @@ -72,8 +94,8 @@ def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool: return self.multiplication(left, right) if isinstance(left, ComparisonNode) and isinstance(right, ComparisonNode): return self.comparison(left, right) - if isinstance(left, VariableNode) and isinstance(right, VariableNode): - return self.variable(left, right) + # if isinstance(left, VariableNode) and isinstance(right, VariableNode): + # return self.variable(left, right) if isinstance(left, ParameterNode) and isinstance(right, ParameterNode): return self.parameter(left, right) if isinstance(left, TimeOperatorNode) and isinstance(right, TimeOperatorNode): @@ -124,8 +146,8 @@ def division(self, left: DivisionNode, right: DivisionNode) -> bool: def comparison(self, left: ComparisonNode, right: ComparisonNode) -> bool: return left.comparator == right.comparator and self._visit_operands(left, right) - def variable(self, left: VariableNode, right: VariableNode) -> bool: - return left.name == right.name + # def variable(self, left: VariableNode, right: VariableNode) -> bool: + # return left.name == right.name def parameter(self, left: ParameterNode, right: ParameterNode) -> bool: return left.name == right.name @@ -183,7 +205,10 @@ def port_field_aggregator( def expressions_equal( - left: ExpressionNode, right: ExpressionNode, abs_tol: float = 0, rel_tol: float = 0 + left: ExpressionNodeEfficient, + right: ExpressionNodeEfficient, + abs_tol: float = 0, + rel_tol: float = 0, ) -> bool: """ True if both expression nodes are equal. Literal values may be compared with absolute or relative tolerance. @@ -192,7 +217,7 @@ def expressions_equal( def expressions_equal_if_present( - lhs: Optional[ExpressionNode], rhs: Optional[ExpressionNode] + lhs: Optional[ExpressionNodeEfficient], rhs: Optional[ExpressionNodeEfficient] ) -> bool: if lhs is None and rhs is None: return True diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py new file mode 100644 index 00000000..2a895f10 --- /dev/null +++ b/src/andromede/expression/expression_efficient.py @@ -0,0 +1,425 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +""" +Defines the model for generic expressions. +""" +import enum +import inspect +from dataclasses import dataclass, field +from typing import Any, Callable, List, Optional, Sequence, Union + +import andromede.expression.port_operator +import andromede.expression.scenario_operator +import andromede.expression.time_operator + + +class Instances(enum.Enum): + SIMPLE = "SIMPLE" + MULTIPLE = "MULTIPLE" + + +@dataclass(frozen=True) +class ExpressionNodeEfficient: + """ + Base class for all nodes of the expression AST. + + Operators overloading is provided to help create expressions + programmatically. + + Examples + >>> expr = -var('x') + 5 / param('p') + """ + + instances: Instances = field(init=False, default=Instances.SIMPLE) + + def __neg__(self) -> "ExpressionNodeEfficient": + return NegationNode(self) + + def __add__(self, rhs: Any) -> "ExpressionNodeEfficient": + return _apply_if_node(rhs, lambda x: AdditionNode(self, x)) + + def __radd__(self, lhs: Any) -> "ExpressionNodeEfficient": + return _apply_if_node(lhs, lambda x: AdditionNode(x, self)) + + def __sub__(self, rhs: Any) -> "ExpressionNodeEfficient": + return _apply_if_node(rhs, lambda x: SubstractionNode(self, x)) + + def __rsub__(self, lhs: Any) -> "ExpressionNodeEfficient": + return _apply_if_node(lhs, lambda x: SubstractionNode(x, self)) + + def __mul__(self, rhs: Any) -> "ExpressionNodeEfficient": + return _apply_if_node(rhs, lambda x: MultiplicationNode(self, x)) + + def __rmul__(self, lhs: Any) -> "ExpressionNodeEfficient": + return _apply_if_node(lhs, lambda x: MultiplicationNode(x, self)) + + def __truediv__(self, rhs: Any) -> "ExpressionNodeEfficient": + return _apply_if_node(rhs, lambda x: DivisionNode(self, x)) + + def __rtruediv__(self, lhs: Any) -> "ExpressionNodeEfficient": + return _apply_if_node(lhs, lambda x: DivisionNode(x, self)) + + def __le__(self, rhs: Any) -> "ExpressionNodeEfficient": + return _apply_if_node( + rhs, lambda x: ComparisonNode(self, x, Comparator.LESS_THAN) + ) + + def __ge__(self, rhs: Any) -> "ExpressionNodeEfficient": + return _apply_if_node( + rhs, lambda x: ComparisonNode(self, x, Comparator.GREATER_THAN) + ) + + def __eq__(self, rhs: Any) -> "ExpressionNodeEfficient": # type: ignore + return _apply_if_node(rhs, lambda x: ComparisonNode(self, x, Comparator.EQUAL)) + + def sum(self) -> "ExpressionNodeEfficient": + if isinstance(self, TimeOperatorNode): + return TimeAggregatorNode(self, "TimeSum", stay_roll=True) + else: + return _apply_if_node( + self, lambda x: TimeAggregatorNode(x, "TimeSum", stay_roll=False) + ) + + def sum_connections(self) -> "ExpressionNodeEfficient": + if isinstance(self, PortFieldNode): + return PortFieldAggregatorNode(self, aggregator="PortSum") + raise ValueError( + f"sum_connections() applies only for PortFieldNode, whereas the current node is of type {type(self)}." + ) + + def shift( + self, + expressions: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + ], + ) -> "ExpressionNodeEfficient": + return _apply_if_node( + self, + lambda x: TimeOperatorNode(x, "TimeShift", InstancesTimeIndex(expressions)), + ) + + def eval( + self, + expressions: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + ], + ) -> "ExpressionNodeEfficient": + return _apply_if_node( + self, + lambda x: TimeOperatorNode( + x, "TimeEvaluation", InstancesTimeIndex(expressions) + ), + ) + + def expec(self) -> "ExpressionNodeEfficient": + return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) + + def variance(self) -> "ExpressionNodeEfficient": + return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) + + +def _wrap_in_node(obj: Any) -> ExpressionNodeEfficient: + if isinstance(obj, ExpressionNodeEfficient): + return obj + elif isinstance(obj, float) or isinstance(obj, int): + return LiteralNode(float(obj)) + raise TypeError(f"Unable to wrap {obj} into an expression node") + + +def _apply_if_node( + obj: Any, func: Callable[["ExpressionNodeEfficient"], "ExpressionNodeEfficient"] +) -> "ExpressionNodeEfficient": + if as_node := _wrap_in_node(obj): + return func(as_node) + else: + return NotImplemented + + +@dataclass(frozen=True, eq=False) +class PortFieldNode(ExpressionNodeEfficient): + """ + References a port field. + """ + + port_name: str + field_name: str + + +def port_field(port_name: str, field_name: str) -> PortFieldNode: + return PortFieldNode(port_name, field_name) + + +@dataclass(frozen=True, eq=False) +class ParameterNode(ExpressionNodeEfficient): + name: str + + +@dataclass(frozen=True, eq=False) +class ComponentParameterNode(ExpressionNodeEfficient): + """ + Represents one parameter of one component. + + When building actual equations for a system, + we need to associated each parameter to its + actual component, at some point. + """ + + component_id: str + name: str + + +@dataclass(frozen=True, eq=False) +class LiteralNode(ExpressionNodeEfficient): + value: float + + +@dataclass(frozen=True, eq=False) +class UnaryOperatorNode(ExpressionNodeEfficient): + operand: ExpressionNodeEfficient + + def __post_init__(self) -> None: + object.__setattr__(self, "instances", self.operand.instances) + + +@dataclass(frozen=True, eq=False) +class PortFieldAggregatorNode(UnaryOperatorNode): + aggregator: str + + def __post_init__(self) -> None: + valid_names = [ + cls.__name__ + for _, cls in inspect.getmembers( + andromede.expression.port_operator, inspect.isclass + ) + if issubclass(cls, andromede.expression.port_operator.PortAggregator) + ] + if self.aggregator not in valid_names: + raise NotImplementedError( + f"{self.aggregator} is not a valid port aggregator, valid port aggregators are {valid_names}" + ) + + +@dataclass(frozen=True, eq=False) +class NegationNode(UnaryOperatorNode): + pass + + +@dataclass(frozen=True, eq=False) +class BinaryOperatorNode(ExpressionNodeEfficient): + left: ExpressionNodeEfficient + right: ExpressionNodeEfficient + + def __post_init__(self) -> None: + binary_operator_post_init(self, "apply binary operation with") + + +def binary_operator_post_init(node: BinaryOperatorNode, operation: str) -> None: + if node.left.instances != node.right.instances: + raise ValueError( + f"Cannot {operation} {node.left} and {node.right} as they do not have the same number of instances." + ) + else: + object.__setattr__(node, "instances", node.left.instances) + + +class Comparator(enum.Enum): + LESS_THAN = "LESS_THAN" + EQUAL = "EQUAL" + GREATER_THAN = "GREATER_THAN" + + +@dataclass(frozen=True, eq=False) +class ComparisonNode(BinaryOperatorNode): + comparator: Comparator + + def __post_init__(self) -> None: + binary_operator_post_init(self, "compare") + + +@dataclass(frozen=True, eq=False) +class AdditionNode(BinaryOperatorNode): + def __post_init__(self) -> None: + binary_operator_post_init(self, "add") + + +@dataclass(frozen=True, eq=False) +class SubstractionNode(BinaryOperatorNode): + def __post_init__(self) -> None: + binary_operator_post_init(self, "substract") + + +@dataclass(frozen=True, eq=False) +class MultiplicationNode(BinaryOperatorNode): + def __post_init__(self) -> None: + binary_operator_post_init(self, "multiply") + + +@dataclass(frozen=True, eq=False) +class DivisionNode(BinaryOperatorNode): + def __post_init__(self) -> None: + binary_operator_post_init(self, "divide") + + +@dataclass(frozen=True, eq=False) +class ExpressionRange: + start: ExpressionNodeEfficient + stop: ExpressionNodeEfficient + step: Optional[ExpressionNodeEfficient] = None + + def __post_init__(self) -> None: + for attribute in self.__dict__: + value = getattr(self, attribute) + object.__setattr__( + self, attribute, _wrap_in_node(value) if value is not None else value + ) + + +IntOrExpr = Union[int, ExpressionNodeEfficient] + + +def expression_range( + start: IntOrExpr, stop: IntOrExpr, step: Optional[IntOrExpr] = None +) -> ExpressionRange: + return ExpressionRange( + start=_wrap_in_node(start), + stop=_wrap_in_node(stop), + step=None if step is None else _wrap_in_node(step), + ) + + +@dataclass +class InstancesTimeIndex: + """ + Defines a set of time indices on which a time operator operates. + + In particular, it defines time indices created by the shift operator. + + The actual indices can either be defined as a time range defined by + 2 expression, or as a list of expressions. + """ + + expressions: Union[List[ExpressionNodeEfficient], ExpressionRange] + + def __init__( + self, + expressions: Union[ + int, ExpressionNodeEfficient, List[ExpressionNodeEfficient], ExpressionRange + ], + ) -> None: + if not isinstance( + expressions, (int, ExpressionNodeEfficient, list, ExpressionRange) + ): + raise TypeError( + f"{expressions} must be of type among {{int, ExpressionNodeEfficient, List[ExpressionNodeEfficient], ExpressionRange}}" + ) + if isinstance(expressions, list) and not all( + isinstance(x, ExpressionNodeEfficient) for x in expressions + ): + raise TypeError( + f"All elements of {expressions} must be of type ExpressionNodeEfficient" + ) + + if isinstance(expressions, (int, ExpressionNodeEfficient)): + self.expressions = [_wrap_in_node(expressions)] + else: + self.expressions = expressions + + def is_simple(self) -> bool: + if isinstance(self.expressions, list): + return len(self.expressions) == 1 + else: + # TODO: We could also check that if a range only includes literal nodes, compute the length of the range, if it's one return True. This is more complicated, I do not know if we want to do this + return False + + +@dataclass(frozen=True, eq=False) +class TimeOperatorNode(UnaryOperatorNode): + name: str + instances_index: InstancesTimeIndex + + def __post_init__(self) -> None: + valid_names = [ + cls.__name__ + for _, cls in inspect.getmembers( + andromede.expression.time_operator, inspect.isclass + ) + if issubclass(cls, andromede.expression.time_operator.TimeOperator) + ] + if self.name not in valid_names: + raise ValueError( + f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" + ) + if self.operand.instances == Instances.SIMPLE: + if self.instances_index.is_simple(): + object.__setattr__(self, "instances", Instances.SIMPLE) + else: + object.__setattr__(self, "instances", Instances.MULTIPLE) + else: + raise ValueError( + "Cannot apply time operator on an expression that already represents multiple instances" + ) + + +@dataclass(frozen=True, eq=False) +class TimeAggregatorNode(UnaryOperatorNode): + name: str + stay_roll: bool + + def __post_init__(self) -> None: + valid_names = [ + cls.__name__ + for _, cls in inspect.getmembers( + andromede.expression.time_operator, inspect.isclass + ) + if issubclass(cls, andromede.expression.time_operator.TimeAggregator) + ] + if self.name not in valid_names: + raise ValueError( + f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" + ) + object.__setattr__(self, "instances", Instances.SIMPLE) + + +@dataclass(frozen=True, eq=False) +class ScenarioOperatorNode(UnaryOperatorNode): + name: str + + def __post_init__(self) -> None: + valid_names = [ + cls.__name__ + for _, cls in inspect.getmembers( + andromede.expression.scenario_operator, inspect.isclass + ) + if issubclass(cls, andromede.expression.scenario_operator.ScenarioOperator) + ] + if self.name not in valid_names: + raise ValueError( + f"{self.name} is not a valid scenario operator, valid scenario operators are {valid_names}" + ) + object.__setattr__(self, "instances", Instances.SIMPLE) + + +def sum_expressions( + expressions: Sequence[ExpressionNodeEfficient], +) -> ExpressionNodeEfficient: + if len(expressions) == 0: + return LiteralNode(0) + if len(expressions) == 1: + return expressions[0] + return expressions[0] + sum_expressions(expressions[1:]) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py new file mode 100644 index 00000000..6622572e --- /dev/null +++ b/src/andromede/expression/linear_expression_efficient.py @@ -0,0 +1,495 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +""" +Specific modelling for "instantiated" linear expressions, +with only variables and literal coefficients. +""" +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, TypeVar, Union + +from andromede.expression.equality import expressions_equal +from andromede.expression.evaluate import ValueProvider, evaluate +from andromede.expression.expression_efficient import ( + ComponentParameterNode, + ExpressionNodeEfficient, + LiteralNode, + ParameterNode, +) +from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.print import print_expr +from andromede.expression.scenario_operator import ScenarioOperator +from andromede.expression.time_operator import TimeAggregator, TimeOperator + +T = TypeVar("T") + +EPS = 10 ** (-16) + + +def is_close_abs(value: float, other_value: float, eps: float) -> bool: + return abs(value - other_value) < eps + + +def is_zero(value: ExpressionNodeEfficient) -> bool: + return expressions_equal(value, LiteralNode(0), EPS) + + +def is_one(value: ExpressionNodeEfficient) -> bool: + return expressions_equal(value, LiteralNode(1), EPS) + + +def is_minus_one(value: float) -> bool: + return expressions_equal(value, LiteralNode(-1), EPS) + + +@dataclass(frozen=True) +class TermKeyEfficient: + """ + Utility class to provide key for a term that contains all term information except coefficient + """ + + component_id: str + variable_name: str + time_operator: Optional[TimeOperator] + time_aggregator: Optional[TimeAggregator] + scenario_operator: Optional[ScenarioOperator] + + +@dataclass(frozen=True) +class TermEfficient: + """ + One term in a linear expression: for example the "10x" par in "10x + 5y + 5" + + Args: + coefficient: the coefficient for that term, for example "10" in "10x" + variable_name: the name of the variable, for example "x" in "10x" + """ + + coefficient: ExpressionNodeEfficient + component_id: str + variable_name: str + structure: IndexingStructure = field( + default=IndexingStructure(time=True, scenario=True) + ) + time_operator: Optional[TimeOperator] = None + time_aggregator: Optional[TimeAggregator] = None + scenario_operator: Optional[ScenarioOperator] = None + + # TODO: It may be useful to define __add__, __sub__, etc on terms, which should return a linear expression ? + + def is_zero(self) -> bool: + return is_zero(self.coefficient) + + def str_for_coeff(self) -> str: + str_for_coeff = "" + if is_one(self.coefficient): + str_for_coeff = "+" + elif is_minus_one(self.coefficient): + str_for_coeff = "-" + else: + str_for_coeff = print_expr(self.coefficient) + return str_for_coeff + + def __str__(self) -> str: + # Useful for debugging tests + result = self.str_for_coeff() + str(self.variable_name) + if self.time_operator is not None: + result += f".{str(self.time_operator)}" + if self.time_aggregator is not None: + result += f".{str(self.time_aggregator)}" + if self.scenario_operator is not None: + result += f".{str(self.scenario_operator)}" + return result + + def number_of_instances(self) -> int: + if self.time_aggregator is not None: + return self.time_aggregator.size() + else: + if self.time_operator is not None: + return self.time_operator.size() + else: + return 1 + + def evaluate(self, context: ValueProvider) -> float: + # TODO: Take care of component variables, multiple time scenarios, operators, etc + # Probably very error prone + if self.component_id: + variable_value = context.get_component_variable_value( + self.component_id, self.variable_name + ) + else: + variable_value = context.get_variable_value(self.variable_name) + return evaluate(self.coefficient, context) * variable_value + + +def generate_key(term: TermEfficient) -> TermKeyEfficient: + return TermKeyEfficient( + term.component_id, + term.variable_name, + term.time_operator, + term.time_aggregator, + term.scenario_operator, + ) + + +def _merge_dicts( + lhs: Dict[TermKeyEfficient, TermEfficient], + rhs: Dict[TermKeyEfficient, TermEfficient], + merge_func: Callable[[TermEfficient, TermEfficient], TermEfficient], + neutral: float, +) -> Dict[TermKeyEfficient, TermEfficient]: + res = {} + for k, v in lhs.items(): + res[k] = merge_func( + v, + rhs.get( + k, + TermEfficient( + neutral, + v.component_id, + v.variable_name, + v.structure, + v.time_operator, + v.time_aggregator, + v.scenario_operator, + ), + ), + ) + for k, v in rhs.items(): + if k not in lhs: + res[k] = merge_func( + TermEfficient( + neutral, + v.component_id, + v.variable_name, + v.structure, + v.time_operator, + v.time_aggregator, + v.scenario_operator, + ), + v, + ) + return res + + +def _merge_is_possible(lhs: TermEfficient, rhs: TermEfficient) -> None: + if lhs.component_id != rhs.component_id or lhs.variable_name != rhs.variable_name: + raise ValueError("Cannot merge terms for different variables") + if ( + lhs.time_operator != rhs.time_operator + or lhs.time_aggregator != rhs.time_aggregator + or lhs.scenario_operator != rhs.scenario_operator + ): + raise ValueError("Cannot merge terms with different operators") + if lhs.structure != rhs.structure: + raise ValueError("Cannot merge terms with different structures") + + +def _add_terms(lhs: TermEfficient, rhs: TermEfficient) -> TermEfficient: + _merge_is_possible(lhs, rhs) + return TermEfficient( + lhs.coefficient + rhs.coefficient, + lhs.component_id, + lhs.variable_name, + lhs.structure, + lhs.time_operator, + lhs.time_aggregator, + lhs.scenario_operator, + ) + + +def _substract_terms(lhs: TermEfficient, rhs: TermEfficient) -> TermEfficient: + _merge_is_possible(lhs, rhs) + return TermEfficient( + lhs.coefficient - rhs.coefficient, + lhs.component_id, + lhs.variable_name, + lhs.structure, + lhs.time_operator, + lhs.time_aggregator, + lhs.scenario_operator, + ) + + +class LinearExpressionEfficient: + """ + Represents a linear expression with respect to variable names, for example 10x + 5y + 2. + + Operators may be used for construction. + + Args: + terms: the list of variable terms, for example 10x and 5y in "10x + 5y + 2". + constant: the constant term, for example 2 in "10x + 5y + 2" + + Examples: + Operators may be used for construction: + + >>> LinearExpression([], 10) + LinearExpression([TermEfficient(10, "x")], 0) + LinearExpression([TermEfficient(10, "x")], 10) + """ + + terms: Dict[TermKeyEfficient, TermEfficient] + constant: ExpressionNodeEfficient + + # TODO: We need to check that terms.key is indeed a TermKey and change the tests that this will break + def __init__( + self, + terms: Optional[ + Union[Dict[TermKeyEfficient, TermEfficient], List[TermEfficient]] + ] = None, + constant: Optional[float] = None, + ) -> None: + self.constant = 0 + self.terms = {} + + if constant is not None: + # += b + self.constant = constant + if terms is not None: + # Allows to give two different syntax in the constructor: + # - List[TermEfficient] is natural + # - Dict[str, TermEfficient] is useful when constructing a linear expression from the terms of another expression + if isinstance(terms, dict): + for term_key, term in terms.items(): + if not term.is_zero(): + self.terms[term_key] = term + elif isinstance(terms, list): + for term in terms: + if not term.is_zero(): + self.terms[generate_key(term)] = term + else: + raise TypeError( + f"Terms must be either of type Dict[str, Term] or List[Term], whereas {terms} is of type {type(terms)}" + ) + + def is_zero(self) -> bool: + return len(self.terms) == 0 and is_zero(self.constant) + + def str_for_constant(self) -> str: + if is_zero(self.constant): + return "" + else: + return f" + {print_expr(self.constant)}" + + def __str__(self) -> str: + # Useful for debugging tests + result = "" + if self.is_zero(): + result += "0" + else: + for term in self.terms.values(): + result += str(term) + + result += self.str_for_constant() + + return result + + def __eq__(self, rhs: object) -> bool: + return ( + isinstance(rhs, LinearExpressionEfficient) + and is_close_abs(self.constant, rhs.constant, EPS) + and self.terms + == rhs.terms # /!\ There may be float equality comparison in the terms values + ) + + def __iadd__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": + if not isinstance(rhs, LinearExpressionEfficient): + return NotImplemented + self.constant += rhs.constant + aggregated_terms = _merge_dicts(self.terms, rhs.terms, _add_terms, 0) + self.terms = aggregated_terms + self.remove_zeros_from_terms() + return self + + def __add__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": + result = LinearExpressionEfficient() + result += self + result += rhs + return result + + def __isub__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": + if not isinstance(rhs, LinearExpressionEfficient): + return NotImplemented + self.constant -= rhs.constant + aggregated_terms = _merge_dicts(self.terms, rhs.terms, _substract_terms, 0) + self.terms = aggregated_terms + self.remove_zeros_from_terms() + return self + + def __sub__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": + result = LinearExpressionEfficient() + result += self + result -= rhs + return result + + def __neg__(self) -> "LinearExpressionEfficient": + result = LinearExpressionEfficient() + result -= self + return result + + def __imul__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": + if not isinstance(rhs, LinearExpressionEfficient): + return NotImplemented + + if self.terms and rhs.terms: + raise ValueError("Cannot multiply two non constant expression") + else: + if self.terms: + left_expr = self + const_expr = rhs + else: + # It is possible that both expr are constant + left_expr = rhs + const_expr = self + if expressions_equal(const_expr.constant, LiteralNode(0), EPS): + return LinearExpressionEfficient() + elif expressions_equal(const_expr.constant, LiteralNode(1), EPS): + _copy_expression(left_expr, self) + else: + left_expr.constant *= const_expr.constant + for term_key, term in left_expr.terms.items(): + left_expr.terms[term_key] = TermEfficient( + term.coefficient * const_expr.constant, + term.component_id, + term.variable_name, + term.structure, + term.time_operator, + term.time_aggregator, + term.scenario_operator, + ) + _copy_expression(left_expr, self) + return self + + def __mul__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": + result = LinearExpressionEfficient() + result += self + result *= rhs + return result + + def __itruediv__( + self, rhs: "LinearExpressionEfficient" + ) -> "LinearExpressionEfficient": + if not isinstance(rhs, LinearExpressionEfficient): + return NotImplemented + + if rhs.terms: + raise ValueError("Cannot divide by a non constant expression") + else: + if is_zero(rhs.constant): + raise ZeroDivisionError("Cannot divide expression by zero") + elif is_one(rhs.constant): + return self + else: + self.constant /= rhs.constant + for term_key, term in self.terms.items(): + self.terms[term_key] = TermEfficient( + term.coefficient / rhs.constant, + term.component_id, + term.variable_name, + term.structure, + term.time_operator, + term.time_aggregator, + term.scenario_operator, + ) + return self + + def __truediv__( + self, rhs: "LinearExpressionEfficient" + ) -> "LinearExpressionEfficient": + result = LinearExpressionEfficient() + result += self + result /= rhs + + return result + + def remove_zeros_from_terms(self) -> None: + # TODO: Not optimized, checks could be done directly when doing operations on self.linear_term to avoid copies + for term_key, term in self.terms.copy().items(): + if is_close_abs(term.coefficient, 0, EPS): + del self.terms[term_key] + + def is_valid(self) -> bool: + nb_instances = None + for term in self.terms.values(): + term_instances = term.number_of_instances() + if nb_instances is None: + nb_instances = term_instances + else: + if term_instances != nb_instances: + raise ValueError( + "The terms of the linear expression {self} do not have the same number of instances" + ) + return True + + def number_of_instances(self) -> int: + if self.is_valid(): + # All terms have the same number of instances, just pick one + return self.terms[next(iter(self.terms))].number_of_instances() + else: + raise ValueError(f"{self} is not a valid linear expression") + + def evaluate(self, context: ValueProvider) -> float: + return sum([term.evaluate(context) for term in self.terms.values()]) + evaluate( + self.constant, context + ) + + def is_constant(self) -> bool: + # Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree... + return not self.terms + + +def _copy_expression( + src: LinearExpressionEfficient, dst: LinearExpressionEfficient +) -> None: + dst.terms = src.terms + dst.constant = src.constant + + +def literal(value: float) -> LinearExpressionEfficient: + return LinearExpressionEfficient([], LiteralNode(value)) + + +# TODO : Define shortcuts for "x", is_one etc .... +def var(name: str) -> LinearExpressionEfficient: + return LinearExpressionEfficient( + [ + TermEfficient( + coefficient=LiteralNode(1), component_id="", variable_name=name + ) + ], + LiteralNode(0), + ) + + +def comp_var(component_id: str, name: str) -> LinearExpressionEfficient: + return LinearExpressionEfficient( + [ + TermEfficient( + coefficient=LiteralNode(1), + component_id=component_id, + variable_name=name, + ) + ], + LiteralNode(0), + ) + + +def param(name: str) -> LinearExpressionEfficient: + return LinearExpressionEfficient([], ParameterNode(name)) + + +def comp_param(component_id: str, name: str) -> LinearExpressionEfficient: + return LinearExpressionEfficient([], ComponentParameterNode(component_id, name)) + + +def is_linear(expr: LinearExpressionEfficient) -> bool: + return True diff --git a/src/andromede/expression/parsing/parse_expression.py b/src/andromede/expression/parsing/parse_expression.py index e96a70f1..072419a1 100644 --- a/src/andromede/expression/parsing/parse_expression.py +++ b/src/andromede/expression/parsing/parse_expression.py @@ -15,14 +15,28 @@ from antlr4 import CommonTokenStream, DiagnosticErrorListener, InputStream from antlr4.error.ErrorStrategy import BailErrorStrategy -from andromede.expression import ExpressionNode, literal, param, var +# from andromede.expression import ExpressionNode, literal, param, var from andromede.expression.equality import expressions_equal -from andromede.expression.expression import ( + +# from andromede.expression.expression import ( +# Comparator, +# ComparisonNode, +# ExpressionRange, +# PortFieldNode, +# ) +from andromede.expression.expression_efficient import ( Comparator, ComparisonNode, + ExpressionNodeEfficient, ExpressionRange, PortFieldNode, ) +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + literal, + param, + var, +) from andromede.expression.parsing.antlr.ExprLexer import ExprLexer from andromede.expression.parsing.antlr.ExprParser import ExprParser from andromede.expression.parsing.antlr.ExprVisitor import ExprVisitor @@ -52,19 +66,23 @@ class ExpressionNodeBuilderVisitor(ExprVisitor): identifiers: ModelIdentifiers - def visitFullexpr(self, ctx: ExprParser.FullexprContext) -> ExpressionNode: + def visitFullexpr( + self, ctx: ExprParser.FullexprContext + ) -> LinearExpressionEfficient: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#number. - def visitNumber(self, ctx: ExprParser.NumberContext) -> ExpressionNode: + def visitNumber(self, ctx: ExprParser.NumberContext) -> LinearExpressionEfficient: return literal(float(ctx.NUMBER().getText())) # type: ignore # Visit a parse tree produced by ExprParser#identifier. - def visitIdentifier(self, ctx: ExprParser.IdentifierContext) -> ExpressionNode: + def visitIdentifier( + self, ctx: ExprParser.IdentifierContext + ) -> LinearExpressionEfficient: return self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore # Visit a parse tree produced by ExprParser#division. - def visitMuldiv(self, ctx: ExprParser.MuldivContext) -> ExpressionNode: + def visitMuldiv(self, ctx: ExprParser.MuldivContext) -> LinearExpressionEfficient: left = ctx.expr(0).accept(self) # type: ignore right = ctx.expr(1).accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -75,7 +93,7 @@ def visitMuldiv(self, ctx: ExprParser.MuldivContext) -> ExpressionNode: raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#subtraction. - def visitAddsub(self, ctx: ExprParser.AddsubContext) -> ExpressionNode: + def visitAddsub(self, ctx: ExprParser.AddsubContext) -> LinearExpressionEfficient: left = ctx.expr(0).accept(self) # type: ignore right = ctx.expr(1).accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -86,18 +104,24 @@ def visitAddsub(self, ctx: ExprParser.AddsubContext) -> ExpressionNode: raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#negation. - def visitNegation(self, ctx: ExprParser.NegationContext) -> ExpressionNode: + def visitNegation( + self, ctx: ExprParser.NegationContext + ) -> LinearExpressionEfficient: return -ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#expression. - def visitExpression(self, ctx: ExprParser.ExpressionContext) -> ExpressionNode: + def visitExpression( + self, ctx: ExprParser.ExpressionContext + ) -> LinearExpressionEfficient: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#unsignedAtom. - def visitUnsignedAtom(self, ctx: ExprParser.UnsignedAtomContext) -> ExpressionNode: + def visitUnsignedAtom( + self, ctx: ExprParser.UnsignedAtomContext + ) -> LinearExpressionEfficient: return ctx.atom().accept(self) # type: ignore - def _convert_identifier(self, identifier: str) -> ExpressionNode: + def _convert_identifier(self, identifier: str) -> LinearExpressionEfficient: if self.identifiers.is_variable(identifier): return var(identifier) elif self.identifiers.is_parameter(identifier): @@ -105,14 +129,18 @@ def _convert_identifier(self, identifier: str) -> ExpressionNode: raise ValueError(f"{identifier} is not a valid variable or parameter name.") # Visit a parse tree produced by ExprParser#portField. - def visitPortField(self, ctx: ExprParser.PortFieldContext) -> ExpressionNode: + def visitPortField( + self, ctx: ExprParser.PortFieldContext + ) -> LinearExpressionEfficient: return PortFieldNode( port_name=ctx.IDENTIFIER(0).getText(), # type: ignore field_name=ctx.IDENTIFIER(1).getText(), # type: ignore ) # Visit a parse tree produced by ExprParser#comparison. - def visitComparison(self, ctx: ExprParser.ComparisonContext) -> ExpressionNode: + def visitComparison( + self, ctx: ExprParser.ComparisonContext + ) -> LinearExpressionEfficient: op = ctx.COMPARISON().getText() # type: ignore exp1 = ctx.expr(0).accept(self) # type: ignore exp2 = ctx.expr(1).accept(self) # type: ignore @@ -124,18 +152,24 @@ def visitComparison(self, ctx: ExprParser.ComparisonContext) -> ExpressionNode: return ComparisonNode(exp1, exp2, comp) # Visit a parse tree produced by ExprParser#timeShift. - def visitTimeIndex(self, ctx: ExprParser.TimeIndexContext) -> ExpressionNode: + def visitTimeIndex( + self, ctx: ExprParser.TimeIndexContext + ) -> LinearExpressionEfficient: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore time_shifts = [e.accept(self) for e in ctx.expr()] # type: ignore return shifted_expr.eval(time_shifts) # Visit a parse tree produced by ExprParser#rangeTimeShift. - def visitTimeRange(self, ctx: ExprParser.TimeRangeContext) -> ExpressionNode: + def visitTimeRange( + self, ctx: ExprParser.TimeRangeContext + ) -> LinearExpressionEfficient: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore expressions = [e.accept(self) for e in ctx.expr()] # type: ignore return shifted_expr.eval(ExpressionRange(expressions[0], expressions[1])) - def visitTimeShift(self, ctx: ExprParser.TimeShiftContext) -> ExpressionNode: + def visitTimeShift( + self, ctx: ExprParser.TimeShiftContext + ) -> LinearExpressionEfficient: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore time_shifts = [s.accept(self) for s in ctx.shift()] # type: ignore # specifics for x[t] ... @@ -145,30 +179,34 @@ def visitTimeShift(self, ctx: ExprParser.TimeShiftContext) -> ExpressionNode: def visitTimeShiftRange( self, ctx: ExprParser.TimeShiftRangeContext - ) -> ExpressionNode: + ) -> LinearExpressionEfficient: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore shift1 = ctx.shift1.accept(self) # type: ignore shift2 = ctx.shift2.accept(self) # type: ignore return shifted_expr.shift(ExpressionRange(shift1, shift2)) # Visit a parse tree produced by ExprParser#function. - def visitFunction(self, ctx: ExprParser.FunctionContext) -> ExpressionNode: + def visitFunction( + self, ctx: ExprParser.FunctionContext + ) -> LinearExpressionEfficient: function_name: str = ctx.IDENTIFIER().getText() # type: ignore - operand: ExpressionNode = ctx.expr().accept(self) # type: ignore + operand: LinearExpressionEfficient = ctx.expr().accept(self) # type: ignore fn = _FUNCTIONS.get(function_name, None) if fn is None: raise ValueError(f"Encountered invalid function name {function_name}") return fn(operand) # Visit a parse tree produced by ExprParser#shift. - def visitShift(self, ctx: ExprParser.ShiftContext) -> ExpressionNode: + def visitShift(self, ctx: ExprParser.ShiftContext) -> LinearExpressionEfficient: if ctx.shift_expr() is None: # type: ignore return literal(0) shift = ctx.shift_expr().accept(self) # type: ignore return shift # Visit a parse tree produced by ExprParser#shiftAddsub. - def visitShiftAddsub(self, ctx: ExprParser.ShiftAddsubContext) -> ExpressionNode: + def visitShiftAddsub( + self, ctx: ExprParser.ShiftAddsubContext + ) -> LinearExpressionEfficient: left = ctx.shift_expr().accept(self) # type: ignore right = ctx.right_expr().accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -179,7 +217,9 @@ def visitShiftAddsub(self, ctx: ExprParser.ShiftAddsubContext) -> ExpressionNode raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#shiftMuldiv. - def visitShiftMuldiv(self, ctx: ExprParser.ShiftMuldivContext) -> ExpressionNode: + def visitShiftMuldiv( + self, ctx: ExprParser.ShiftMuldivContext + ) -> LinearExpressionEfficient: left = ctx.shift_expr().accept(self) # type: ignore right = ctx.right_expr().accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -192,14 +232,16 @@ def visitShiftMuldiv(self, ctx: ExprParser.ShiftMuldivContext) -> ExpressionNode # Visit a parse tree produced by ExprParser#signedExpression. def visitSignedExpression( self, ctx: ExprParser.SignedExpressionContext - ) -> ExpressionNode: + ) -> LinearExpressionEfficient: if ctx.op.text == "-": # type: ignore return -ctx.expr().accept(self) # type: ignore else: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#signedAtom. - def visitSignedAtom(self, ctx: ExprParser.SignedAtomContext) -> ExpressionNode: + def visitSignedAtom( + self, ctx: ExprParser.SignedAtomContext + ) -> LinearExpressionEfficient: if ctx.op.text == "-": # type: ignore return -ctx.atom().accept(self) # type: ignore else: @@ -208,11 +250,13 @@ def visitSignedAtom(self, ctx: ExprParser.SignedAtomContext) -> ExpressionNode: # Visit a parse tree produced by ExprParser#rightExpression. def visitRightExpression( self, ctx: ExprParser.RightExpressionContext - ) -> ExpressionNode: + ) -> LinearExpressionEfficient: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#rightMuldiv. - def visitRightMuldiv(self, ctx: ExprParser.RightMuldivContext) -> ExpressionNode: + def visitRightMuldiv( + self, ctx: ExprParser.RightMuldivContext + ) -> LinearExpressionEfficient: left = ctx.right_expr(0).accept(self) # type: ignore right = ctx.right_expr(1).accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -223,14 +267,16 @@ def visitRightMuldiv(self, ctx: ExprParser.RightMuldivContext) -> ExpressionNode raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#rightAtom. - def visitRightAtom(self, ctx: ExprParser.RightAtomContext) -> ExpressionNode: + def visitRightAtom( + self, ctx: ExprParser.RightAtomContext + ) -> LinearExpressionEfficient: return ctx.atom().accept(self) # type: ignore _FUNCTIONS = { - "sum": ExpressionNode.sum, - "sum_connections": ExpressionNode.sum_connections, - "expec": ExpressionNode.expec, + "sum": ExpressionNodeEfficient.sum, + "sum_connections": ExpressionNodeEfficient.sum_connections, + "expec": ExpressionNodeEfficient.expec, } @@ -238,7 +284,9 @@ class AntaresParseException(Exception): pass -def parse_expression(expression: str, identifiers: ModelIdentifiers) -> ExpressionNode: +def parse_expression( + expression: str, identifiers: ModelIdentifiers +) -> LinearExpressionEfficient: """ Parses a string expression to create the corresponding AST representation. """ diff --git a/src/andromede/expression/port_resolver.py b/src/andromede/expression/port_resolver.py index 6f333408..a6728e27 100644 --- a/src/andromede/expression/port_resolver.py +++ b/src/andromede/expression/port_resolver.py @@ -14,14 +14,21 @@ from dataclasses import dataclass from typing import Dict, List -from andromede.expression import CopyVisitor, sum_expressions, visit -from andromede.expression.expression import ( - AdditionNode, - ExpressionNode, - LiteralNode, +from andromede.expression import CopyVisitor, visit + +# from andromede.expression.expression import ( +# AdditionNode, +# ExpressionNode, +# LiteralNode, +# PortFieldAggregatorNode, +# PortFieldNode, +# ) +from andromede.expression.expression_efficient import ( PortFieldAggregatorNode, PortFieldNode, + sum_expressions, ) +from andromede.expression.linear_expression_efficient import LinearExpressionEfficient from andromede.model.model import PortFieldId @@ -43,9 +50,9 @@ class PortResolver(CopyVisitor): """ component_id: str - ports_expressions: Dict[PortFieldKey, List[ExpressionNode]] + ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]] - def port_field(self, node: PortFieldNode) -> ExpressionNode: + def port_field(self, node: PortFieldNode) -> LinearExpressionEfficient: expressions = self.ports_expressions[ PortFieldKey( self.component_id, PortFieldId(node.port_name, node.field_name) @@ -58,7 +65,9 @@ def port_field(self, node: PortFieldNode) -> ExpressionNode: else: return expressions[0] - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode: + def port_field_aggregator( + self, node: PortFieldAggregatorNode + ) -> LinearExpressionEfficient: if node.aggregator != "PortSum": raise NotImplementedError("Only PortSum is supported.") port_field_node = node.operand @@ -76,8 +85,8 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode def resolve_port( - expression: ExpressionNode, + expression: LinearExpressionEfficient, component_id: str, - ports_expressions: Dict[PortFieldKey, List[ExpressionNode]], -) -> ExpressionNode: + ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]], +) -> LinearExpressionEfficient: return visit(expression, PortResolver(component_id, ports_expressions)) diff --git a/src/andromede/expression/print.py b/src/andromede/expression/print.py index c01ae76f..86ef42d0 100644 --- a/src/andromede/expression/print.py +++ b/src/andromede/expression/print.py @@ -22,7 +22,22 @@ ) from andromede.expression.visitor import T -from .expression import ( +# from .expression import ( +# AdditionNode, +# Comparator, +# ComparisonNode, +# DivisionNode, +# LiteralNode, +# MultiplicationNode, +# NegationNode, +# ParameterNode, +# ScenarioOperatorNode, +# SubstractionNode, +# TimeAggregatorNode, +# TimeOperatorNode, +# VariableNode, +# ) +from .expression_efficient import ( AdditionNode, Comparator, ComparisonNode, @@ -35,7 +50,6 @@ SubstractionNode, TimeAggregatorNode, TimeOperatorNode, - VariableNode, ) from .visitor import ExpressionVisitor, visit @@ -86,8 +100,8 @@ def comparison(self, node: ComparisonNode) -> str: right_value = visit(node.right, self) return f"{left_value} {op} {right_value}" - def variable(self, node: VariableNode) -> str: - return node.name + # def variable(self, node: VariableNode) -> str: + # return node.name def parameter(self, node: ParameterNode) -> str: return node.name diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index 25bbfb02..29e95cee 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -13,17 +13,34 @@ """ Defines abstract base class for visitors of expressions. """ -import typing from abc import ABC, abstractmethod from typing import Generic, Protocol, TypeVar -from andromede.expression.expression import ( +# from andromede.expression.expression import ( +# AdditionNode, +# ComparisonNode, +# ComponentParameterNode, +# ComponentVariableNode, +# DivisionNode, +# ExpressionNode, +# LiteralNode, +# MultiplicationNode, +# NegationNode, +# ParameterNode, +# PortFieldAggregatorNode, +# PortFieldNode, +# ScenarioOperatorNode, +# SubstractionNode, +# TimeAggregatorNode, +# TimeOperatorNode, +# VariableNode, +# ) +from andromede.expression.expression_efficient import ( AdditionNode, ComparisonNode, ComponentParameterNode, - ComponentVariableNode, DivisionNode, - ExpressionNode, + ExpressionNodeEfficient, LiteralNode, MultiplicationNode, NegationNode, @@ -34,7 +51,6 @@ SubstractionNode, TimeAggregatorNode, TimeOperatorNode, - VariableNode, ) T = TypeVar("T") @@ -77,9 +93,9 @@ def division(self, node: DivisionNode) -> T: def comparison(self, node: ComparisonNode) -> T: ... - @abstractmethod - def variable(self, node: VariableNode) -> T: - ... + # @abstractmethod + # def variable(self, node: VariableNode) -> T: + # ... @abstractmethod def parameter(self, node: ParameterNode) -> T: @@ -89,9 +105,9 @@ def parameter(self, node: ParameterNode) -> T: def comp_parameter(self, node: ComponentParameterNode) -> T: ... - @abstractmethod - def comp_variable(self, node: ComponentVariableNode) -> T: - ... + # @abstractmethod + # def comp_variable(self, node: ComponentVariableNode) -> T: + # ... @abstractmethod def time_operator(self, node: TimeOperatorNode) -> T: @@ -114,7 +130,7 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T: ... -def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T: +def visit(root: ExpressionNodeEfficient, visitor: ExpressionVisitor[T]) -> T: """ Utility method to dispatch calls to the right method of a visitor. """ @@ -122,14 +138,14 @@ def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T: return visitor.literal(root) elif isinstance(root, NegationNode): return visitor.negation(root) - elif isinstance(root, VariableNode): - return visitor.variable(root) + # elif isinstance(root, VariableNode): + # return visitor.variable(root) elif isinstance(root, ParameterNode): return visitor.parameter(root) elif isinstance(root, ComponentParameterNode): return visitor.comp_parameter(root) - elif isinstance(root, ComponentVariableNode): - return visitor.comp_variable(root) + # elif isinstance(root, ComponentVariableNode): + # return visitor.comp_variable(root) elif isinstance(root, AdditionNode): return visitor.addition(root) elif isinstance(root, MultiplicationNode): diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index 3997f43d..a293e78f 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -19,38 +19,77 @@ from dataclasses import dataclass, field from typing import Dict, Iterable, Optional -from andromede.expression import ( +# from andromede.expression.expression import ( +# BinaryOperatorNode, +# ComponentParameterNode, +# ComponentVariableNode, +# PortFieldAggregatorNode, +# PortFieldNode, +# ScenarioOperatorNode, +# TimeAggregatorNode, +# TimeOperatorNode, +# ) +from andromede.expression.expression_efficient import ( AdditionNode, + BinaryOperatorNode, ComparisonNode, + ComponentParameterNode, DivisionNode, - ExpressionNode, - ExpressionVisitor, LiteralNode, MultiplicationNode, NegationNode, ParameterNode, - SubstractionNode, - VariableNode, -) -from andromede.expression.degree import is_linear -from andromede.expression.expression import ( - BinaryOperatorNode, - ComponentParameterNode, - ComponentVariableNode, PortFieldAggregatorNode, PortFieldNode, ScenarioOperatorNode, + SubstractionNode, TimeAggregatorNode, TimeOperatorNode, ) from andromede.expression.indexing import IndexingStructureProvider, compute_indexation from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.visitor import T, visit +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + is_linear, +) +from andromede.expression.visitor import ExpressionVisitor, visit from andromede.model.constraint import Constraint from andromede.model.parameter import Parameter from andromede.model.port import PortType from andromede.model.variable import Variable +# from andromede.expression import ( +# AdditionNode, +# ComparisonNode, +# DivisionNode, +# ExpressionNode, +# ExpressionVisitor, +# LiteralNode, +# MultiplicationNode, +# NegationNode, +# ParameterNode, +# SubstractionNode, +# VariableNode, +# ) +# from andromede.expression.expression_efficient import ( +# AdditionNode, +# BinaryOperatorNode, +# ComparisonNode, +# ComponentParameterNode, +# DivisionNode, +# ExpressionNodeEfficient, +# LiteralNode, +# MultiplicationNode, +# NegationNode, +# ParameterNode, +# PortFieldAggregatorNode, +# PortFieldNode, +# ScenarioOperatorNode, +# SubstractionNode, +# TimeAggregatorNode, +# TimeOperatorNode, +# ) + # TODO: Introduce bool_variable ? def _make_structure_provider(model: "Model") -> IndexingStructureProvider: @@ -79,7 +118,7 @@ def get_component_variable_structure( def _is_objective_contribution_valid( - model: "Model", objective_contribution: ExpressionNode + model: "Model", objective_contribution: LinearExpressionEfficient ) -> bool: if not is_linear(objective_contribution): raise ValueError("Objective contribution must be a linear expression.") @@ -121,14 +160,14 @@ class PortFieldDefinition: """ port_field: PortFieldId - definition: ExpressionNode + definition: LinearExpressionEfficient def __post_init__(self) -> None: _validate_port_field_expression(self) def port_field_def( - port_name: str, field_name: str, definition: ExpressionNode + port_name: str, field_name: str, definition: LinearExpressionEfficient ) -> PortFieldDefinition: return PortFieldDefinition(PortFieldId(port_name, field_name), definition) @@ -146,8 +185,8 @@ class Model: inter_block_dyn: bool = False parameters: Dict[str, Parameter] = field(default_factory=dict) variables: Dict[str, Variable] = field(default_factory=dict) - objective_operational_contribution: Optional[ExpressionNode] = None - objective_investment_contribution: Optional[ExpressionNode] = None + objective_operational_contribution: Optional[LinearExpressionEfficient] = None + objective_investment_contribution: Optional[LinearExpressionEfficient] = None ports: Dict[str, ModelPort] = field(default_factory=dict) # key = port name port_fields_definitions: Dict[PortFieldId, PortFieldDefinition] = field( default_factory=dict @@ -190,8 +229,8 @@ def model( binding_constraints: Optional[Iterable[Constraint]] = None, parameters: Optional[Iterable[Parameter]] = None, variables: Optional[Iterable[Variable]] = None, - objective_operational_contribution: Optional[ExpressionNode] = None, - objective_investment_contribution: Optional[ExpressionNode] = None, + objective_operational_contribution: Optional[LinearExpressionEfficient] = None, + objective_investment_contribution: Optional[LinearExpressionEfficient] = None, inter_block_dyn: bool = False, ports: Optional[Iterable[ModelPort]] = None, port_fields_definitions: Optional[Iterable[PortFieldDefinition]] = None, @@ -258,8 +297,8 @@ def division(self, node: DivisionNode) -> None: def comparison(self, node: ComparisonNode) -> None: raise ValueError("Port definition cannot contain a comparison operator.") - def variable(self, node: VariableNode) -> None: - pass + # def variable(self, node: VariableNode) -> None: + # pass def parameter(self, node: ParameterNode) -> None: pass @@ -269,10 +308,10 @@ def comp_parameter(self, node: ComponentParameterNode) -> None: "Port definition must not contain a parameter associated to a component." ) - def comp_variable(self, node: ComponentVariableNode) -> None: - raise ValueError( - "Port definition must not contain a variable associated to a component." - ) + # def comp_variable(self, node: ComponentVariableNode) -> None: + # raise ValueError( + # "Port definition must not contain a variable associated to a component." + # ) def time_operator(self, node: TimeOperatorNode) -> None: visit(node.operand, self) @@ -291,4 +330,6 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> None: def _validate_port_field_expression(definition: PortFieldDefinition) -> None: - visit(definition.definition, _PortFieldExpressionChecker()) + for term in definition.definition.terms.values(): + visit(term.coefficient, _PortFieldExpressionChecker()) + visit(definition.definition.constant, _PortFieldExpressionChecker()) diff --git a/src/andromede/model/resolve_library.py b/src/andromede/model/resolve_library.py index 546acdbd..ef117283 100644 --- a/src/andromede/model/resolve_library.py +++ b/src/andromede/model/resolve_library.py @@ -11,8 +11,9 @@ # This file is part of the Antares project. from typing import Dict, List, Optional -from andromede.expression import ExpressionNode +# from andromede.expression import ExpressionNode from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression_efficient import LinearExpressionEfficient from andromede.expression.parsing.parse_expression import ( ModelIdentifiers, parse_expression, @@ -123,7 +124,7 @@ def _to_parameter(param: InputParameter) -> Parameter: def _to_expression_if_present( expr: Optional[str], identifiers: ModelIdentifiers -) -> Optional[ExpressionNode]: +) -> Optional[LinearExpressionEfficient]: if not expr: return None return parse_expression(expr, identifiers) diff --git a/src/andromede/model/variable.py b/src/andromede/model/variable.py index b8eb5407..daaa3a64 100644 --- a/src/andromede/model/variable.py +++ b/src/andromede/model/variable.py @@ -13,13 +13,12 @@ from dataclasses import dataclass from typing import Any, Optional -from andromede.expression import ExpressionNode -from andromede.expression.degree import is_constant from andromede.expression.equality import ( expressions_equal, expressions_equal_if_present, ) from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression_efficient import LinearExpressionEfficient from andromede.model.common import ProblemContext, ValueType @@ -31,15 +30,15 @@ class Variable: name: str data_type: ValueType - lower_bound: Optional[ExpressionNode] - upper_bound: Optional[ExpressionNode] + lower_bound: Optional[LinearExpressionEfficient] + upper_bound: Optional[LinearExpressionEfficient] structure: IndexingStructure context: ProblemContext def __post_init__(self) -> None: - if self.lower_bound and not is_constant(self.lower_bound): + if self.lower_bound and not self.lower_bound.is_constant(): raise ValueError("Lower bounds of variables must be constant") - if self.upper_bound and not is_constant(self.upper_bound): + if self.upper_bound and not self.upper_bound.is_constant(): raise ValueError("Upper bounds of variables must be constant") def __eq__(self, other: Any) -> bool: @@ -56,8 +55,8 @@ def __eq__(self, other: Any) -> bool: def int_variable( name: str, - lower_bound: Optional[ExpressionNode] = None, - upper_bound: Optional[ExpressionNode] = None, + lower_bound: Optional[LinearExpressionEfficient] = None, + upper_bound: Optional[LinearExpressionEfficient] = None, structure: IndexingStructure = IndexingStructure(True, True), context: ProblemContext = ProblemContext.OPERATIONAL, ) -> Variable: @@ -68,8 +67,8 @@ def int_variable( def float_variable( name: str, - lower_bound: Optional[ExpressionNode] = None, - upper_bound: Optional[ExpressionNode] = None, + lower_bound: Optional[LinearExpressionEfficient] = None, + upper_bound: Optional[LinearExpressionEfficient] = None, structure: IndexingStructure = IndexingStructure(True, True), context: ProblemContext = ProblemContext.OPERATIONAL, ) -> Variable: diff --git a/src/andromede/simulation/linear_expression.py b/src/andromede/simulation/linear_expression.py index 344a0988..8167420d 100644 --- a/src/andromede/simulation/linear_expression.py +++ b/src/andromede/simulation/linear_expression.py @@ -45,7 +45,6 @@ def is_minus_one(value: float) -> bool: @dataclass(frozen=True) class TermKey: - """ Utility class to provide key for a term that contains all term information except coefficient """ diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 4ec9cb63..ce45b388 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -22,9 +22,8 @@ import ortools.linear_solver.pywraplp as lp -from andromede.expression import ( +from andromede.expression import ( # ExpressionNode, EvaluationVisitor, - ExpressionNode, ParameterValueProvider, ValueProvider, resolve_parameters, @@ -33,6 +32,7 @@ from andromede.expression.context_adder import add_component_context from andromede.expression.indexing import IndexingStructureProvider, compute_indexation from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression_efficient import LinearExpressionEfficient from andromede.expression.port_resolver import PortFieldKey, resolve_port from andromede.expression.scenario_operator import Expectation from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum @@ -141,7 +141,7 @@ def parameter_is_constant_over_time(self, name: str) -> bool: class ExpressionTimestepValueProvider(TimestepValueProvider): context: "OptimizationContext" component: Component - expression: ExpressionNode + expression: LinearExpressionEfficient # OptimizationContext has knowledge of the block, so that get_value only needs block_timestep and scenario to get the correct data value @@ -215,7 +215,9 @@ class ComponentContext: opt_context: "OptimizationContext" component: Component - def get_values(self, expression: ExpressionNode) -> TimestepValueProvider: + def get_values( + self, expression: LinearExpressionEfficient + ) -> TimestepValueProvider: """ The returned value provider will evaluate the provided expression. """ @@ -249,7 +251,7 @@ def linearize_expression( self, block_timestep: int, scenario: int, - expression: ExpressionNode, + expression: LinearExpressionEfficient, ) -> LinearExpression: parameters_valued_provider = _make_parameter_value_provider( self.opt_context, block_timestep, scenario @@ -314,7 +316,7 @@ def __init__( self._component_variables: Dict[TimestepComponentVariableKey, lp.Variable] = {} self._solver_variables: Dict[lp.Variable, SolverVariableInfo] = {} self._connection_fields_expressions: Dict[ - PortFieldKey, List[ExpressionNode] + PortFieldKey, List[LinearExpressionEfficient] ] = {} @property @@ -329,7 +331,9 @@ def block_length(self) -> int: return len(self._block.timesteps) @property - def connection_fields_expressions(self) -> Dict[PortFieldKey, List[ExpressionNode]]: + def connection_fields_expressions( + self, + ) -> Dict[PortFieldKey, List[LinearExpressionEfficient]]: return self._connection_fields_expressions # TODO: Need to think about data processing when creating blocks with varying or inequal time steps length (aggregation, sum ?, mean of data ?) @@ -405,7 +409,7 @@ def register_connection_fields_expressions( component_id: str, port_name: str, field_name: str, - expression: ExpressionNode, + expression: LinearExpressionEfficient, ) -> None: key = PortFieldKey(component_id, PortFieldId(port_name, field_name)) get_or_add(self._connection_fields_expressions, key, lambda: []).append( @@ -434,10 +438,10 @@ def _compute_indexing_structure( def _instantiate_model_expression( - model_expression: ExpressionNode, + model_expression: LinearExpressionEfficient, component_id: str, optimization_context: OptimizationContext, -) -> ExpressionNode: +) -> LinearExpressionEfficient: """ Performs common operations that are necessary on model expressions before their actual use: 1. add component ID for variables and parameters of THIS component @@ -496,7 +500,7 @@ def _create_objective( opt_context: OptimizationContext, component: Component, component_context: ComponentContext, - objective_contribution: ExpressionNode, + objective_contribution: LinearExpressionEfficient, ) -> None: instantiated_expr = _instantiate_model_expression( objective_contribution, component.id, opt_context diff --git a/src/andromede/simulation/strategy.py b/src/andromede/simulation/strategy.py index 75e34c65..cb30f96a 100644 --- a/src/andromede/simulation/strategy.py +++ b/src/andromede/simulation/strategy.py @@ -13,7 +13,8 @@ from abc import ABC, abstractmethod from typing import Generator, Optional -from andromede.expression import ExpressionNode +# from andromede.expression import ExpressionNode +from andromede.expression.linear_expression_efficient import LinearExpressionEfficient from andromede.model import Constraint, Model, ProblemContext, Variable @@ -43,7 +44,7 @@ def _keep_from_context(self, context: ProblemContext) -> bool: @abstractmethod def get_objectives( self, model: Model - ) -> Generator[Optional[ExpressionNode], None, None]: + ) -> Generator[Optional[LinearExpressionEfficient], None, None]: ... @@ -53,7 +54,7 @@ def _keep_from_context(self, context: ProblemContext) -> bool: def get_objectives( self, model: Model - ) -> Generator[Optional[ExpressionNode], None, None]: + ) -> Generator[Optional[LinearExpressionEfficient], None, None]: yield model.objective_operational_contribution yield model.objective_investment_contribution @@ -66,7 +67,7 @@ def _keep_from_context(self, context: ProblemContext) -> bool: def get_objectives( self, model: Model - ) -> Generator[Optional[ExpressionNode], None, None]: + ) -> Generator[Optional[LinearExpressionEfficient], None, None]: yield model.objective_investment_contribution @@ -78,5 +79,5 @@ def _keep_from_context(self, context: ProblemContext) -> bool: def get_objectives( self, model: Model - ) -> Generator[Optional[ExpressionNode], None, None]: + ) -> Generator[Optional[LinearExpressionEfficient], None, None]: yield model.objective_operational_contribution diff --git a/tests/unittests/expressions/test_expressions.py b/tests/unittests/expressions/test_expressions.py index b990605d..03f3a332 100644 --- a/tests/unittests/expressions/test_expressions.py +++ b/tests/unittests/expressions/test_expressions.py @@ -13,6 +13,7 @@ from dataclasses import dataclass, field from typing import Dict +import pandas as pd import pytest from andromede.expression import ( diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py new file mode 100644 index 00000000..f5a89c7a --- /dev/null +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -0,0 +1,308 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from dataclasses import dataclass, field +from typing import Dict + +import pytest + +from andromede.expression.evaluate import EvaluationContext, ValueProvider +from andromede.expression.evaluate_parameters import ParameterValueProvider +from andromede.expression.expression_efficient import ( + ComponentParameterNode, + ExpressionRange, + ParameterNode, +) +from andromede.expression.indexing import IndexingStructureProvider +from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + TermEfficient, + param, + var, +) +from andromede.simulation.linearize import linearize_expression + + +@dataclass(frozen=True) +class ComponentValueKey: + component_id: str + variable_name: str + + +def comp_key(component_id: str, variable_name: str) -> ComponentValueKey: + return ComponentValueKey(component_id, variable_name) + + +@dataclass(frozen=True) +class ComponentEvaluationContext(ValueProvider): + """ + Simple value provider relying on dictionaries. + Does not support component variables/parameters. + """ + + variables: Dict[ComponentValueKey, float] = field(default_factory=dict) + parameters: Dict[ComponentValueKey, float] = field(default_factory=dict) + + def get_variable_value(self, name: str) -> float: + raise NotImplementedError() + + def get_parameter_value(self, name: str) -> float: + raise NotImplementedError() + + def get_component_variable_value(self, component_id: str, name: str) -> float: + return self.variables[comp_key(component_id, name)] + + def get_component_parameter_value(self, component_id: str, name: str) -> float: + return self.parameters[comp_key(component_id, name)] + + def parameter_is_constant_over_time(self, name: str) -> bool: + raise NotImplementedError() + + +# TODO: Redundant with add tests in test_linear_expressions_efficient ? +def test_comp_parameter() -> None: + expr1 = LinearExpressionEfficient([], 1) + LinearExpressionEfficient( + [TermEfficient(1, "comp1", "x")] + ) + expr2 = expr1 / LinearExpressionEfficient( + constant=ComponentParameterNode("comp1", "p") + ) + + assert str(expr2) == "(1.0 / comp1.p)x + (1.0 / comp1.p)" + context = ComponentEvaluationContext( + variables={comp_key("comp1", "x"): 3}, parameters={comp_key("comp1", "p"): 4} + ) + assert expr2.evaluate(context) == 1 + + +# TODO: Find a better name +def test_ast() -> None: + expr1 = LinearExpressionEfficient([], 1) + LinearExpressionEfficient( + [TermEfficient(1, "", "x")] + ) + expr2 = expr1 / LinearExpressionEfficient(constant=ParameterNode("p")) + + assert str(expr2) == "(1.0 / p)x + (1.0 / p)" + + context = EvaluationContext(variables={"x": 3}, parameters={"p": 4}) + assert expr2.evaluate(context) == 1 + + +def test_operators() -> None: + x = var("x") + p = param("p") + expr: LinearExpressionEfficient = (5 * x + 3) / p - 2 + + assert str(expr) == "((((5.0 * x) + 3.0) / p) - 2.0)" + + context = EvaluationContext(variables={"x": 3}, parameters={"p": 4}) + assert expr.evaluate(context) == pytest.approx(2.5, 1e-16) + + assert expr.evaluate(context) == pytest.approx(-2.5, 1e-16) + + +def test_degree() -> None: + x = var("x") + p = param("p") + expr = (5 * x + 3) / p + + assert expr.compute_degree() == 1 + + # TODO: Should this be allowed ? If so, how should we represent is ? + expr = x * expr + assert expr.compute_degree() == 2 + + +@pytest.mark.xfail(reason="Degree simplification not implemented") +def test_degree_computation_should_take_into_account_simplifications() -> None: + x = var("x") + expr = x - x + assert expr.compute_degree() == 0 + + expr = 0 * x + assert expr.compute_degree() == 0 + assert expr.is_zero() + + +def test_parameters_resolution() -> None: + class TestParamProvider(ParameterValueProvider): + def get_component_parameter_value(self, component_id: str, name: str) -> float: + raise NotImplementedError() + + def get_parameter_value(self, name: str) -> float: + return 2 + + x = var("x") + p = param("p") + expr = (5 * x + 3) / p + # TODO: We do not want this in the API, but rather expr.get(t, w) + assert expr.resolve_parameters(TestParamProvider()) == (5 * x + 3) / 2 + + +# No real equivalent in the "efficient" formalism +def test_linearization() -> None: + x = comp_var("c", "x") + expr = (5 * x + 3) / 2 + provider = StructureProvider() + + assert expr == LinearExpressionEfficient([TermEfficient(2.5, "c", "x")], 1.5) + + # Does not raise error !!!! + assert param("p") * x == LinearExpressionEfficient( + [TermEfficient(ParameterNode("p"), "c", "x")], 1.5 + ) + + +# TODO: What is the equivalent of this test ? +def test_linearization_of_non_linear_expressions_should_raise_value_error() -> None: + x = var("x") + expr = x.variance() + + provider = StructureProvider() + with pytest.raises(ValueError) as exc: + linearize_expression(expr, provider) + assert ( + str(exc.value) + == "Cannot linearize expression with a non-linear operator: Variance" + ) + + +def test_comparison() -> None: + x = var("x") + p = param("p") + expr: Constraint = ( + 5 * x + 3 + ) >= p - 2 ## Overloading operator to return a constraint object ! + + assert str(expr) == "((5.0 * x) + 3.0) >= (p - 2.0)" + + +class StructureProvider(IndexingStructureProvider): + def get_component_variable_structure( + self, component_id: str, name: str + ) -> IndexingStructure: + return IndexingStructure(True, True) + + def get_component_parameter_structure( + self, component_id: str, name: str + ) -> IndexingStructure: + return IndexingStructure(True, True) + + def get_parameter_structure(self, name: str) -> IndexingStructure: + return IndexingStructure(True, True) + + def get_variable_structure(self, name: str) -> IndexingStructure: + return IndexingStructure(True, True) + + +def test_shift() -> None: + x = var("x") + expr = x.shift(ExpressionRange(1, 4)) + + provider = StructureProvider() + + assert expr.compute_indexation(provider) == IndexingStructure(True, True) + assert expr.instances == Instances.MULTIPLE + + +def test_shifting_sum() -> None: + x = var("x") + expr = x.shift(ExpressionRange(1, 4)).sum() + provider = StructureProvider() + + assert expr.compute_indexation(provider) == IndexingStructure(True, True) + assert expr.instances == Instances.SIMPLE + + +def test_eval() -> None: + x = var("x") + expr = x.eval(ExpressionRange(1, 4)) + provider = StructureProvider() + + assert expr.compute_indexation(provider) == IndexingStructure(False, True) + assert expr.instances == Instances.MULTIPLE + + +def test_eval_sum() -> None: + x = var("x") + expr = x.eval(ExpressionRange(1, 4)).sum() + provider = StructureProvider() + + assert expr.compute_indexation(provider) == IndexingStructure(False, True) + assert expr.instances == Instances.SIMPLE + + +def test_sum_over_whole_block() -> None: + x = var("x") + expr = x.sum() + provider = StructureProvider() + + assert expr.compute_indexation(provider) == IndexingStructure(False, True) + assert expr.instances == Instances.SIMPLE + + +def test_forbidden_composition_should_raise_value_error() -> None: + x = var("x") + with pytest.raises(ValueError): + _ = x.shift(ExpressionRange(1, 4)) + var("y") + + +def test_expectation() -> None: + x = var("x") + expr = x.expec() + provider = StructureProvider() + + assert expr.compute_indexation(provider) == IndexingStructure(True, False) + assert expr.instances == Instances.SIMPLE + + +def test_indexing_structure_comparison() -> None: + free = IndexingStructure(True, True) + constant = IndexingStructure(False, False) + assert free | constant == IndexingStructure(True, True) + + +def test_multiplication_of_differently_indexed_terms() -> None: + x = var("x") + p = param("p") + expr = p * x + + class CustomStructureProvider(IndexingStructureProvider): + def get_component_variable_structure( + self, component_id: str, name: str + ) -> IndexingStructure: + raise NotImplementedError() + + def get_component_parameter_structure( + self, component_id: str, name: str + ) -> IndexingStructure: + raise NotImplementedError() + + def get_parameter_structure(self, name: str) -> IndexingStructure: + return IndexingStructure(False, False) + + def get_variable_structure(self, name: str) -> IndexingStructure: + return IndexingStructure(True, True) + + provider = CustomStructureProvider() + + assert expr.compute_indexation(provider) == IndexingStructure(True, True) + + +def test_sum_expressions() -> None: + assert sum_expressions([]) == literal(0) + assert sum_expressions([literal(1)]) == literal(1) + assert sum_expressions([literal(1), var("x")]) == 1 + var("x") + assert sum_expressions([literal(1), var("x"), param("p")]) == 1 + ( + var("x") + param("p") + ) diff --git a/tests/unittests/expressions/test_linear_expressions_efficient.py b/tests/unittests/expressions/test_linear_expressions_efficient.py new file mode 100644 index 00000000..4da8b486 --- /dev/null +++ b/tests/unittests/expressions/test_linear_expressions_efficient.py @@ -0,0 +1,495 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from typing import Dict + +import pytest + +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + TermEfficient, +) +from andromede.expression.scenario_operator import Expectation +from andromede.expression.time_operator import TimeShift, TimeSum + + +@pytest.mark.parametrize( + "term, expected", + [ + (TermEfficient(1, "c", "x"), "+x"), + (TermEfficient(-1, "c", "x"), "-x"), + (TermEfficient(2.50, "c", "x"), "+2.5x"), + (TermEfficient(-3, "c", "x"), "-3x"), + (TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), "-3x.shift([-1])"), + (TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), "-3x.sum(True)"), + ( + TermEfficient( + -3, + "c", + "x", + time_operator=TimeShift([2, 3]), + time_aggregator=TimeSum(False), + ), + "-3x.shift([2, 3]).sum(False)", + ), + (TermEfficient(-3, "c", "x", scenario_operator=Expectation()), "-3x.expec()"), + ( + TermEfficient( + -3, + "c", + "x", + time_aggregator=TimeSum(True), + scenario_operator=Expectation(), + ), + "-3x.sum(True).expec()", + ), + ], +) +def test_printing_term(term: TermEfficient, expected: str) -> None: + assert str(term) == expected + + +@pytest.mark.parametrize( + "coeff, var_name, constant, expec_str", + [ + (0, "x", 0, "0"), + (1, "x", 0, "+x"), + (1, "x", 1, "+x+1"), + (3.7, "x", 1, "+3.7x+1"), + (0, "x", 1, "+1"), + ], +) +def test_affine_expression_printing_should_reflect_required_formatting( + coeff: float, var_name: str, constant: float, expec_str: str +) -> None: + expr = LinearExpressionEfficient([TermEfficient(coeff, "c", var_name)], constant) + assert str(expr) == expec_str + + +@pytest.mark.parametrize( + "lhs, rhs", + [ + ( + LinearExpressionEfficient([], 1) + LinearExpressionEfficient([], 3), + LinearExpressionEfficient([], 4), + ), + ( + LinearExpressionEfficient([], 4) / LinearExpressionEfficient([], 2), + LinearExpressionEfficient([], 2), + ), + ( + LinearExpressionEfficient([], 4) * LinearExpressionEfficient([], 2), + LinearExpressionEfficient([], 8), + ), + ( + LinearExpressionEfficient([], 4) - LinearExpressionEfficient([], 2), + LinearExpressionEfficient([], 2), + ), + ], +) +def test_constant_expressions( + lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient +) -> None: + assert lhs == rhs + + +@pytest.mark.parametrize( + "terms_dict, constant, exp_terms, exp_constant", + [ + ({"x": TermEfficient(0, "c", "x")}, 1, {}, 1), + ({"x": TermEfficient(1, "c", "x")}, 1, {"x": TermEfficient(1, "c", "x")}, 1), + ], +) +def test_instantiate_linear_expression_from_dict( + terms_dict: Dict[str, TermEfficient], + constant: float, + exp_terms: Dict[str, TermEfficient], + exp_constant: float, +) -> None: + expr = LinearExpressionEfficient(terms_dict, constant) + assert expr.terms == exp_terms + assert expr.constant == exp_constant + + +@pytest.mark.parametrize( + "e1, e2, expected", + [ + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 1), + LinearExpressionEfficient([TermEfficient(5, "c", "x")], 2), + LinearExpressionEfficient([TermEfficient(15, "c", "x")], 3), + ), + ( + LinearExpressionEfficient([TermEfficient(10, "c1", "x")], 1), + LinearExpressionEfficient([TermEfficient(5, "c2", "x")], 2), + LinearExpressionEfficient( + [TermEfficient(10, "c1", "x"), TermEfficient(5, "c2", "x")], 3 + ), + ), + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 0), + LinearExpressionEfficient([TermEfficient(5, "c", "y")], 0), + LinearExpressionEfficient( + [TermEfficient(10, "c", "x"), TermEfficient(5, "c", "y")], 0 + ), + ), + ( + LinearExpressionEfficient(), + LinearExpressionEfficient([TermEfficient(10, "c", "x", TimeShift(-1))]), + LinearExpressionEfficient([TermEfficient(10, "c", "x", TimeShift(-1))]), + ), + ( + LinearExpressionEfficient(), + LinearExpressionEfficient( + [TermEfficient(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] + ), + LinearExpressionEfficient( + [TermEfficient(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] + ), + ), + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")]), + LinearExpressionEfficient( + [TermEfficient(10, "c", "x", time_operator=TimeShift(-1))] + ), + LinearExpressionEfficient( + [ + TermEfficient(10, "c", "x"), + TermEfficient(10, "c", "x", time_operator=TimeShift(-1)), + ] + ), + ), + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")]), + LinearExpressionEfficient( + [ + TermEfficient( + 10, + "c", + "x", + time_operator=TimeShift(-1), + scenario_operator=Expectation(), + ) + ] + ), + LinearExpressionEfficient( + [ + TermEfficient(10, "c", "x"), + TermEfficient( + 10, + "c", + "x", + time_operator=TimeShift(-1), + scenario_operator=Expectation(), + ), + ] + ), + ), + ], +) +def test_addition( + e1: LinearExpressionEfficient, + e2: LinearExpressionEfficient, + expected: LinearExpressionEfficient, +) -> None: + assert e1 + e2 == expected + + +def test_addition_of_linear_expressions_with_different_number_of_instances_should_raise_value_error() -> ( + None +): + pass + + +def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_from_terms() -> ( + None +): + e1 = LinearExpressionEfficient([TermEfficient(10, "c", "x")], 1) + e2 = LinearExpressionEfficient([TermEfficient(10, "c", "x")], 2) + e3 = e2 - e1 + assert e3.terms == {} + + +@pytest.mark.parametrize( + "e1, e2, expected", + [ + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 3), + LinearExpressionEfficient([], 2), + LinearExpressionEfficient([TermEfficient(20, "c", "x")], 6), + ), + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 3), + LinearExpressionEfficient([], 1), + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 3), + ), + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 3), + LinearExpressionEfficient(), + LinearExpressionEfficient(), + ), + ( + LinearExpressionEfficient( + [ + TermEfficient( + 10, + "c", + "x", + time_operator=TimeShift(-1), + scenario_operator=Expectation(), + ) + ], + 3, + ), + LinearExpressionEfficient([], 2), + LinearExpressionEfficient( + [ + TermEfficient( + 20, + "c", + "x", + time_operator=TimeShift(-1), + scenario_operator=Expectation(), + ) + ], + 6, + ), + ), + ], +) +def test_multiplication( + e1: LinearExpressionEfficient, + e2: LinearExpressionEfficient, + expected: LinearExpressionEfficient, +) -> None: + assert e1 * e2 == expected + assert e2 * e1 == expected + + +def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> None: + e1 = LinearExpressionEfficient([TermEfficient(10, "c", "x")], 0) + e2 = LinearExpressionEfficient([TermEfficient(5, "c", "x")], 0) + with pytest.raises(ValueError) as exc: + _ = e1 * e2 + assert str(exc.value) == "Cannot multiply two non constant expression" + + +@pytest.mark.parametrize( + "e1, expected", + [ + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 5), + LinearExpressionEfficient([TermEfficient(-10, "c", "x")], -5), + ), + ( + LinearExpressionEfficient( + [ + TermEfficient( + 10, + "c", + "x", + time_operator=TimeShift(-1), + time_aggregator=TimeSum(False), + scenario_operator=Expectation(), + ) + ], + 5, + ), + LinearExpressionEfficient( + [ + TermEfficient( + -10, + "c", + "x", + time_operator=TimeShift(-1), + time_aggregator=TimeSum(False), + scenario_operator=Expectation(), + ) + ], + -5, + ), + ), + ], +) +def test_negation( + e1: LinearExpressionEfficient, expected: LinearExpressionEfficient +) -> None: + assert -e1 == expected + + +@pytest.mark.parametrize( + "e1, e2, expected", + [ + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 1), + LinearExpressionEfficient([TermEfficient(5, "c", "x")], 2), + LinearExpressionEfficient([TermEfficient(5, "c", "x")], -1), + ), + ( + LinearExpressionEfficient([TermEfficient(10, "c1", "x")], 1), + LinearExpressionEfficient([TermEfficient(5, "c2", "x")], 2), + LinearExpressionEfficient( + [TermEfficient(10, "c1", "x"), TermEfficient(-5, "c2", "x")], -1 + ), + ), + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 0), + LinearExpressionEfficient([TermEfficient(5, "c", "y")], 0), + LinearExpressionEfficient( + [TermEfficient(10, "c", "x"), TermEfficient(-5, "c", "y")], 0 + ), + ), + ( + LinearExpressionEfficient(), + LinearExpressionEfficient( + [TermEfficient(10, "c", "x", time_operator=TimeShift(-1))] + ), + LinearExpressionEfficient( + [TermEfficient(-10, "c", "x", time_operator=TimeShift(-1))] + ), + ), + ( + LinearExpressionEfficient(), + LinearExpressionEfficient( + [TermEfficient(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] + ), + LinearExpressionEfficient( + [TermEfficient(-10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] + ), + ), + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")]), + LinearExpressionEfficient( + [TermEfficient(10, "c", "x", time_operator=TimeShift(-1))] + ), + LinearExpressionEfficient( + [ + TermEfficient(10, "c", "x"), + TermEfficient(-10, "c", "x", time_operator=TimeShift(-1)), + ] + ), + ), + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")]), + LinearExpressionEfficient( + [ + TermEfficient( + 10, + "c", + "x", + time_operator=TimeShift(-1), + time_aggregator=TimeSum(False), + scenario_operator=Expectation(), + ) + ] + ), + LinearExpressionEfficient( + [ + TermEfficient(10, "c", "x"), + TermEfficient( + -10, + "c", + "x", + time_operator=TimeShift(-1), + time_aggregator=TimeSum(False), + scenario_operator=Expectation(), + ), + ] + ), + ), + ], +) +def test_substraction( + e1: LinearExpressionEfficient, + e2: LinearExpressionEfficient, + expected: LinearExpressionEfficient, +) -> None: + assert e1 - e2 == expected + + +@pytest.mark.parametrize( + "e1, e2, expected", + [ + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 15), + LinearExpressionEfficient([], 5), + LinearExpressionEfficient([TermEfficient(2, "c", "x")], 3), + ), + ( + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 15), + LinearExpressionEfficient([], 1), + LinearExpressionEfficient([TermEfficient(10, "c", "x")], 15), + ), + ( + LinearExpressionEfficient( + [ + TermEfficient( + 10, + "c", + "x", + time_operator=TimeShift(-1), + time_aggregator=TimeSum(False), + scenario_operator=Expectation(), + ) + ], + 15, + ), + LinearExpressionEfficient([], 5), + LinearExpressionEfficient( + [ + TermEfficient( + 2, + "c", + "x", + time_operator=TimeShift(-1), + time_aggregator=TimeSum(False), + scenario_operator=Expectation(), + ) + ], + 3, + ), + ), + ], +) +def test_division( + e1: LinearExpressionEfficient, + e2: LinearExpressionEfficient, + expected: LinearExpressionEfficient, +) -> None: + assert e1 / e2 == expected + + +def test_division_by_zero_sould_raise_zero_division_error() -> None: + e1 = LinearExpressionEfficient([TermEfficient(10, "c", "x")], 15) + e2 = LinearExpressionEfficient() + with pytest.raises(ZeroDivisionError) as exc: + _ = e1 / e2 + assert str(exc.value) == "Cannot divide expression by zero" + + +def test_division_by_non_constant_expr_sould_raise_value_error() -> None: + e1 = LinearExpressionEfficient([TermEfficient(10, "c", "x")], 15) + e2 = LinearExpressionEfficient() + with pytest.raises(ValueError) as exc: + _ = e2 / e1 + assert str(exc.value) == "Cannot divide by a non constant expression" + + +def test_imul_preserve_identity() -> None: + # technical test to check the behaviour of reassigning "self" in imul operator: + # it did not preserve identity, which could lead to weird behaviour + e1 = LinearExpressionEfficient([], 15) + e2 = e1 + e1 *= LinearExpressionEfficient([], 2) + assert e1 == LinearExpressionEfficient([], 30) + assert e2 == e1 + assert e2 is e1 From 0f285c6513b753efd7b4e0af7a6ffcf266f8c161 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Fri, 21 Jun 2024 17:52:07 +0200 Subject: [PATCH 02/51] Operator simplification on construction --- .../expression/expression_efficient.py | 97 +++++++++++++++++-- .../expression/linear_expression_efficient.py | 66 ++++++++----- src/andromede/expression/visitor.py | 6 +- 3 files changed, 134 insertions(+), 35 deletions(-) diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index 2a895f10..fa6399a5 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -16,6 +16,7 @@ import enum import inspect from dataclasses import dataclass, field +import math from typing import Any, Callable, List, Optional, Sequence, Union import andromede.expression.port_operator @@ -43,31 +44,31 @@ class ExpressionNodeEfficient: instances: Instances = field(init=False, default=Instances.SIMPLE) def __neg__(self) -> "ExpressionNodeEfficient": - return NegationNode(self) + return _negate_node(self) def __add__(self, rhs: Any) -> "ExpressionNodeEfficient": - return _apply_if_node(rhs, lambda x: AdditionNode(self, x)) + return _apply_if_node(rhs, lambda x: _add_node(self, x)) def __radd__(self, lhs: Any) -> "ExpressionNodeEfficient": - return _apply_if_node(lhs, lambda x: AdditionNode(x, self)) + return _apply_if_node(lhs, lambda x: _add_node(x, self)) def __sub__(self, rhs: Any) -> "ExpressionNodeEfficient": - return _apply_if_node(rhs, lambda x: SubstractionNode(self, x)) + return _apply_if_node(rhs, lambda x: _substract_node(self, x)) def __rsub__(self, lhs: Any) -> "ExpressionNodeEfficient": - return _apply_if_node(lhs, lambda x: SubstractionNode(x, self)) + return _apply_if_node(lhs, lambda x: _substract_node(x, self)) def __mul__(self, rhs: Any) -> "ExpressionNodeEfficient": - return _apply_if_node(rhs, lambda x: MultiplicationNode(self, x)) + return _apply_if_node(rhs, lambda x: _multiply_node(self, x)) def __rmul__(self, lhs: Any) -> "ExpressionNodeEfficient": - return _apply_if_node(lhs, lambda x: MultiplicationNode(x, self)) + return _apply_if_node(lhs, lambda x: _multiply_node(x, self)) def __truediv__(self, rhs: Any) -> "ExpressionNodeEfficient": - return _apply_if_node(rhs, lambda x: DivisionNode(self, x)) + return _apply_if_node(rhs, lambda x: _divide_node(self, x)) def __rtruediv__(self, lhs: Any) -> "ExpressionNodeEfficient": - return _apply_if_node(lhs, lambda x: DivisionNode(x, self)) + return _apply_if_node(lhs, lambda x: _divide_node(x, self)) def __le__(self, rhs: Any) -> "ExpressionNodeEfficient": return _apply_if_node( @@ -150,6 +151,84 @@ def _apply_if_node( else: return NotImplemented +def is_zero(node: ExpressionNodeEfficient) -> bool: + # Could we use expressions equal + # TODO: Change hard coded 1e-16 abs_tol + return isinstance(node, LiteralNode) and math.isclose(node.value, 0, abs_tol=1e-16) + +def is_one(node: ExpressionNodeEfficient) -> bool: + # Could we use expressions equal + return isinstance(node, LiteralNode) and math.isclose(node.value, 1) + +def is_minus_one(node: ExpressionNodeEfficient) -> bool: + # Could we use expressions equal + return isinstance(node, LiteralNode) and math.isclose(node.value, -1) + +def _negate_node(node: ExpressionNodeEfficient): + if isinstance(node, LiteralNode): + return LiteralNode(-node.value) + else: + return NegationNode(node) + +def _add_node( + lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient +) -> ExpressionNodeEfficient: + if is_zero(lhs): + return rhs + if is_zero(rhs): + return lhs + if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): + return LiteralNode(lhs.value + rhs.value) + # TODO : Si noeuds gauche droite même clé de param -> 2 * param + else: + return AdditionNode(lhs, rhs) + + +def _substract_node( + lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient +) -> ExpressionNodeEfficient: + if is_zero(lhs): + return -rhs + if is_zero(rhs): + return lhs + if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): + return LiteralNode(lhs.value - rhs.value) + # TODO : Si noeuds gauche droite même clé de param -> 0 + else: + return SubstractionNode(lhs, rhs) + + +def _multiply_node( + lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient +) -> ExpressionNodeEfficient: + if is_one(lhs): + return rhs + if is_one(rhs): + return lhs + if is_minus_one(lhs): + return -rhs + if is_minus_one(rhs): + return -lhs + if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): + return LiteralNode(lhs.value * rhs.value) + else: + return MultiplicationNode(lhs, rhs) + + +def _divide_node( + lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient +) -> ExpressionNodeEfficient: + if is_one(rhs): + return lhs + if is_minus_one(rhs): + return -lhs + if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): + # This could raise division by 0 error + return LiteralNode(lhs.value / rhs.value) + + else: + return DivisionNode(lhs, rhs) + @dataclass(frozen=True, eq=False) class PortFieldNode(ExpressionNodeEfficient): diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 6622572e..f4a455ef 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -15,7 +15,7 @@ with only variables and literal coefficients. """ from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import ValueProvider, evaluate @@ -35,8 +35,8 @@ EPS = 10 ** (-16) -def is_close_abs(value: float, other_value: float, eps: float) -> bool: - return abs(value - other_value) < eps +# def is_close_abs(value: float, other_value: float, eps: float) -> bool: +# return abs(value - other_value) < eps def is_zero(value: ExpressionNodeEfficient) -> bool: @@ -296,36 +296,37 @@ def __str__(self) -> str: def __eq__(self, rhs: object) -> bool: return ( isinstance(rhs, LinearExpressionEfficient) - and is_close_abs(self.constant, rhs.constant, EPS) + and expressions_equal(self.constant, rhs.constant) and self.terms == rhs.terms # /!\ There may be float equality comparison in the terms values ) - def __iadd__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": - if not isinstance(rhs, LinearExpressionEfficient): - return NotImplemented + def __iadd__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": + rhs = _wrap_in_linear_expr(rhs) self.constant += rhs.constant aggregated_terms = _merge_dicts(self.terms, rhs.terms, _add_terms, 0) self.terms = aggregated_terms self.remove_zeros_from_terms() return self - def __add__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": + def __add__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": result = LinearExpressionEfficient() result += self result += rhs return result - def __isub__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": - if not isinstance(rhs, LinearExpressionEfficient): - return NotImplemented + def __radd__(self, rhs: int) -> "LinearExpressionEfficient": + return self.__add__(rhs) + + def __isub__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": + rhs = _wrap_in_linear_expr(rhs) self.constant -= rhs.constant aggregated_terms = _merge_dicts(self.terms, rhs.terms, _substract_terms, 0) self.terms = aggregated_terms self.remove_zeros_from_terms() return self - def __sub__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": + def __sub__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": result = LinearExpressionEfficient() result += self result -= rhs @@ -336,9 +337,8 @@ def __neg__(self) -> "LinearExpressionEfficient": result -= self return result - def __imul__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": - if not isinstance(rhs, LinearExpressionEfficient): - return NotImplemented + def __imul__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": + rhs = _wrap_in_linear_expr(rhs) if self.terms and rhs.terms: raise ValueError("Cannot multiply two non constant expression") @@ -350,9 +350,9 @@ def __imul__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficie # It is possible that both expr are constant left_expr = rhs const_expr = self - if expressions_equal(const_expr.constant, LiteralNode(0), EPS): + if is_zero(const_expr.constant): return LinearExpressionEfficient() - elif expressions_equal(const_expr.constant, LiteralNode(1), EPS): + elif is_one(const_expr.constant): _copy_expression(left_expr, self) else: left_expr.constant *= const_expr.constant @@ -369,17 +369,19 @@ def __imul__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficie _copy_expression(left_expr, self) return self - def __mul__(self, rhs: "LinearExpressionEfficient") -> "LinearExpressionEfficient": + def __mul__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": result = LinearExpressionEfficient() result += self result *= rhs return result + def __rmul__(self, rhs: int) -> "LinearExpressionEfficient": + return self.__mul__(rhs) + def __itruediv__( - self, rhs: "LinearExpressionEfficient" + self, rhs: Union["LinearExpressionEfficient", int, float] ) -> "LinearExpressionEfficient": - if not isinstance(rhs, LinearExpressionEfficient): - return NotImplemented + rhs = _wrap_in_linear_expr(rhs) if rhs.terms: raise ValueError("Cannot divide by a non constant expression") @@ -403,18 +405,21 @@ def __itruediv__( return self def __truediv__( - self, rhs: "LinearExpressionEfficient" + self, rhs: Union["LinearExpressionEfficient", int, float] ) -> "LinearExpressionEfficient": result = LinearExpressionEfficient() result += self result /= rhs return result + + def __rtruediv__(self, rhs: Union[int, float]) -> "LinearExpressionEfficient": + return self.__truediv__(rhs) def remove_zeros_from_terms(self) -> None: # TODO: Not optimized, checks could be done directly when doing operations on self.linear_term to avoid copies for term_key, term in self.terms.copy().items(): - if is_close_abs(term.coefficient, 0, EPS): + if is_zero(term.coefficient): del self.terms[term_key] def is_valid(self) -> bool: @@ -446,6 +451,21 @@ def is_constant(self) -> bool: # Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree... return not self.terms +def _wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient: + if isinstance(obj, LinearExpressionEfficient): + return obj + elif isinstance(obj, float) or isinstance(obj, int): + return LinearExpressionEfficient([], LiteralNode(float(obj))) + raise TypeError(f"Unable to wrap {obj} into a linear expression") + +def _apply_if_node( + obj: Any, func: Callable[[LinearExpressionEfficient], LinearExpressionEfficient] +) -> LinearExpressionEfficient: + if as_linear_expr := _wrap_in_linear_expr(obj): + return func(as_linear_expr) + else: + return NotImplemented + def _copy_expression( src: LinearExpressionEfficient, dst: LinearExpressionEfficient diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index 29e95cee..54b5f16f 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -213,9 +213,9 @@ def addition(self, node: AdditionNode) -> T_op: return left_value + right_value def substraction(self, node: SubstractionNode) -> T_op: - left_value = visit(node.left, self) - right_value = visit(node.right, self) - return left_value - right_value + left_value = visit(node.left, self) + right_value = visit(node.right, self) + return left_value - right_value def multiplication(self, node: MultiplicationNode) -> T_op: left_value = visit(node.left, self) From ef9c8b6ec5992b29df25aa4bfb5ed66f2d44e55c Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 25 Jun 2024 14:55:44 +0200 Subject: [PATCH 03/51] Improve common operations on linear expr --- src/andromede/expression/equality.py | 6 + .../expression/expression_efficient.py | 40 ++- .../expression/linear_expression_efficient.py | 80 +++--- .../expressions/test_expressions_efficient.py | 227 ++++++++++++++---- .../test_linear_expressions_efficient.py | 36 --- .../expressions/test_term_efficient.py | 137 +++++++++++ 6 files changed, 402 insertions(+), 124 deletions(-) create mode 100644 tests/unittests/expressions/test_term_efficient.py diff --git a/src/andromede/expression/equality.py b/src/andromede/expression/equality.py index 67eb1e80..c558f346 100644 --- a/src/andromede/expression/equality.py +++ b/src/andromede/expression/equality.py @@ -18,6 +18,7 @@ AdditionNode, BinaryOperatorNode, ComparisonNode, + ComponentParameterNode, DivisionNode, ExpressionNodeEfficient, ExpressionRange, @@ -98,6 +99,8 @@ def visit( # return self.variable(left, right) if isinstance(left, ParameterNode) and isinstance(right, ParameterNode): return self.parameter(left, right) + if isinstance(left, ComponentParameterNode) and isinstance(right, ComponentParameterNode): + return self.comp_parameter(left, right) if isinstance(left, TimeOperatorNode) and isinstance(right, TimeOperatorNode): return self.time_operator(left, right) if isinstance(left, TimeAggregatorNode) and isinstance( @@ -151,6 +154,9 @@ def comparison(self, left: ComparisonNode, right: ComparisonNode) -> bool: def parameter(self, left: ParameterNode, right: ParameterNode) -> bool: return left.name == right.name + + def comp_parameter(self, left: ComponentParameterNode, right: ComponentParameterNode) -> bool: + return left.component_id == right.component_id and left.name == right.name def expression_range(self, left: ExpressionRange, right: ExpressionRange) -> bool: if not self.visit(left.start, right.start): diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index fa6399a5..426e9cd1 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -15,14 +15,16 @@ """ import enum import inspect -from dataclasses import dataclass, field import math +from dataclasses import dataclass, field from typing import Any, Callable, List, Optional, Sequence, Union import andromede.expression.port_operator import andromede.expression.scenario_operator import andromede.expression.time_operator +EPS = 10 ** (-16) + class Instances(enum.Enum): SIMPLE = "SIMPLE" @@ -135,7 +137,7 @@ def variance(self) -> "ExpressionNodeEfficient": return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) -def _wrap_in_node(obj: Any) -> ExpressionNodeEfficient: +def wrap_in_node(obj: Any) -> ExpressionNodeEfficient: if isinstance(obj, ExpressionNodeEfficient): return obj elif isinstance(obj, float) or isinstance(obj, int): @@ -146,30 +148,34 @@ def _wrap_in_node(obj: Any) -> ExpressionNodeEfficient: def _apply_if_node( obj: Any, func: Callable[["ExpressionNodeEfficient"], "ExpressionNodeEfficient"] ) -> "ExpressionNodeEfficient": - if as_node := _wrap_in_node(obj): + if as_node := wrap_in_node(obj): return func(as_node) else: return NotImplemented + def is_zero(node: ExpressionNodeEfficient) -> bool: - # Could we use expressions equal - # TODO: Change hard coded 1e-16 abs_tol - return isinstance(node, LiteralNode) and math.isclose(node.value, 0, abs_tol=1e-16) + # Faster implementation than expressions equal for this particular cases + return isinstance(node, LiteralNode) and math.isclose(node.value, 0, abs_tol=EPS) + def is_one(node: ExpressionNodeEfficient) -> bool: - # Could we use expressions equal + # Faster implementation than expressions equal for this particular cases return isinstance(node, LiteralNode) and math.isclose(node.value, 1) + def is_minus_one(node: ExpressionNodeEfficient) -> bool: - # Could we use expressions equal + # Faster implementation than expressions equal for this particular cases return isinstance(node, LiteralNode) and math.isclose(node.value, -1) + def _negate_node(node: ExpressionNodeEfficient): if isinstance(node, LiteralNode): return LiteralNode(-node.value) else: return NegationNode(node) + def _add_node( lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient ) -> ExpressionNodeEfficient: @@ -177,6 +183,9 @@ def _add_node( return rhs if is_zero(rhs): return lhs + # TODO: How can we use the equality visitor here (simple import -> circular import), copy code here ? + # if expressions_equal(lhs, -rhs): + # return LiteralNode(0) if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): return LiteralNode(lhs.value + rhs.value) # TODO : Si noeuds gauche droite même clé de param -> 2 * param @@ -191,6 +200,9 @@ def _substract_node( return -rhs if is_zero(rhs): return lhs + # TODO: How can we use the equality visitor here (simple import -> circular import), copy code here ? + # if expressions_equal(lhs, rhs): + # return LiteralNode(0) if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): return LiteralNode(lhs.value - rhs.value) # TODO : Si noeuds gauche droite même clé de param -> 0 @@ -201,6 +213,8 @@ def _substract_node( def _multiply_node( lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient ) -> ExpressionNodeEfficient: + if is_zero(lhs) or is_zero(rhs): + return LiteralNode(0) if is_one(lhs): return rhs if is_one(rhs): @@ -365,7 +379,7 @@ def __post_init__(self) -> None: for attribute in self.__dict__: value = getattr(self, attribute) object.__setattr__( - self, attribute, _wrap_in_node(value) if value is not None else value + self, attribute, wrap_in_node(value) if value is not None else value ) @@ -376,9 +390,9 @@ def expression_range( start: IntOrExpr, stop: IntOrExpr, step: Optional[IntOrExpr] = None ) -> ExpressionRange: return ExpressionRange( - start=_wrap_in_node(start), - stop=_wrap_in_node(stop), - step=None if step is None else _wrap_in_node(step), + start=wrap_in_node(start), + stop=wrap_in_node(stop), + step=None if step is None else wrap_in_node(step), ) @@ -415,7 +429,7 @@ def __init__( ) if isinstance(expressions, (int, ExpressionNodeEfficient)): - self.expressions = [_wrap_in_node(expressions)] + self.expressions = [wrap_in_node(expressions)] else: self.expressions = expressions diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index f4a455ef..455cffe1 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -24,6 +24,10 @@ ExpressionNodeEfficient, LiteralNode, ParameterNode, + is_minus_one, + is_one, + is_zero, + wrap_in_node, ) from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.print import print_expr @@ -32,24 +36,6 @@ T = TypeVar("T") -EPS = 10 ** (-16) - - -# def is_close_abs(value: float, other_value: float, eps: float) -> bool: -# return abs(value - other_value) < eps - - -def is_zero(value: ExpressionNodeEfficient) -> bool: - return expressions_equal(value, LiteralNode(0), EPS) - - -def is_one(value: ExpressionNodeEfficient) -> bool: - return expressions_equal(value, LiteralNode(1), EPS) - - -def is_minus_one(value: float) -> bool: - return expressions_equal(value, LiteralNode(-1), EPS) - @dataclass(frozen=True) class TermKeyEfficient: @@ -86,6 +72,21 @@ class TermEfficient: # TODO: It may be useful to define __add__, __sub__, etc on terms, which should return a linear expression ? + def __post_init__(self) -> None: + object.__setattr__(self, "coefficient", wrap_in_node(self.coefficient)) + + def __eq__(self, other: "TermEfficient") -> bool: + return ( + isinstance(other, TermEfficient) + and expressions_equal(self.coefficient, other.coefficient) + and self.component_id == other.component_id + and self.variable_name == other.variable_name + and self.structure == other.structure + and self.time_operator == other.time_operator + and self.time_aggregator == other.time_aggregator + and self.scenario_operator == other.scenario_operator + ) + def is_zero(self) -> bool: return is_zero(self.coefficient) @@ -246,14 +247,15 @@ def __init__( terms: Optional[ Union[Dict[TermKeyEfficient, TermEfficient], List[TermEfficient]] ] = None, - constant: Optional[float] = None, + constant: Optional[Union[float, ExpressionNodeEfficient]] = None, ) -> None: - self.constant = 0 - self.terms = {} - if constant is not None: - # += b - self.constant = constant + if constant is None: + self.constant = LiteralNode(0) + else: + self.constant = wrap_in_node(constant) + + self.terms = {} if terms is not None: # Allows to give two different syntax in the constructor: # - List[TermEfficient] is natural @@ -298,10 +300,12 @@ def __eq__(self, rhs: object) -> bool: isinstance(rhs, LinearExpressionEfficient) and expressions_equal(self.constant, rhs.constant) and self.terms - == rhs.terms # /!\ There may be float equality comparison in the terms values + == rhs.terms ) - def __iadd__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": + def __iadd__( + self, rhs: Union["LinearExpressionEfficient", int, float] + ) -> "LinearExpressionEfficient": rhs = _wrap_in_linear_expr(rhs) self.constant += rhs.constant aggregated_terms = _merge_dicts(self.terms, rhs.terms, _add_terms, 0) @@ -309,7 +313,9 @@ def __iadd__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "Line self.remove_zeros_from_terms() return self - def __add__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": + def __add__( + self, rhs: Union["LinearExpressionEfficient", int, float] + ) -> "LinearExpressionEfficient": result = LinearExpressionEfficient() result += self result += rhs @@ -318,7 +324,9 @@ def __add__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "Linea def __radd__(self, rhs: int) -> "LinearExpressionEfficient": return self.__add__(rhs) - def __isub__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": + def __isub__( + self, rhs: Union["LinearExpressionEfficient", int, float] + ) -> "LinearExpressionEfficient": rhs = _wrap_in_linear_expr(rhs) self.constant -= rhs.constant aggregated_terms = _merge_dicts(self.terms, rhs.terms, _substract_terms, 0) @@ -326,7 +334,9 @@ def __isub__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "Line self.remove_zeros_from_terms() return self - def __sub__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": + def __sub__( + self, rhs: Union["LinearExpressionEfficient", int, float] + ) -> "LinearExpressionEfficient": result = LinearExpressionEfficient() result += self result -= rhs @@ -337,7 +347,9 @@ def __neg__(self) -> "LinearExpressionEfficient": result -= self return result - def __imul__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": + def __imul__( + self, rhs: Union["LinearExpressionEfficient", int, float] + ) -> "LinearExpressionEfficient": rhs = _wrap_in_linear_expr(rhs) if self.terms and rhs.terms: @@ -369,7 +381,9 @@ def __imul__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "Line _copy_expression(left_expr, self) return self - def __mul__(self, rhs: Union["LinearExpressionEfficient", int, float]) -> "LinearExpressionEfficient": + def __mul__( + self, rhs: Union["LinearExpressionEfficient", int, float] + ) -> "LinearExpressionEfficient": result = LinearExpressionEfficient() result += self result *= rhs @@ -412,7 +426,7 @@ def __truediv__( result /= rhs return result - + def __rtruediv__(self, rhs: Union[int, float]) -> "LinearExpressionEfficient": return self.__truediv__(rhs) @@ -451,6 +465,7 @@ def is_constant(self) -> bool: # Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree... return not self.terms + def _wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient: if isinstance(obj, LinearExpressionEfficient): return obj @@ -458,6 +473,7 @@ def _wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient: return LinearExpressionEfficient([], LiteralNode(float(obj))) raise TypeError(f"Unable to wrap {obj} into a linear expression") + def _apply_if_node( obj: Any, func: Callable[[LinearExpressionEfficient], LinearExpressionEfficient] ) -> LinearExpressionEfficient: diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index f5a89c7a..a90b2bfb 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -27,9 +27,12 @@ from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, TermEfficient, + comp_param, + comp_var, param, var, ) +from andromede.model.constraint import Constraint from andromede.simulation.linearize import linearize_expression @@ -103,12 +106,12 @@ def test_operators() -> None: p = param("p") expr: LinearExpressionEfficient = (5 * x + 3) / p - 2 - assert str(expr) == "((((5.0 * x) + 3.0) / p) - 2.0)" + assert str(expr) == "(5.0 / p)x + ((3.0 / p) - 2.0)" context = EvaluationContext(variables={"x": 3}, parameters={"p": 4}) assert expr.evaluate(context) == pytest.approx(2.5, 1e-16) - assert expr.evaluate(context) == pytest.approx(-2.5, 1e-16) + assert -expr.evaluate(context) == pytest.approx(-2.5, 1e-16) def test_degree() -> None: @@ -123,58 +126,196 @@ def test_degree() -> None: assert expr.compute_degree() == 2 -@pytest.mark.xfail(reason="Degree simplification not implemented") def test_degree_computation_should_take_into_account_simplifications() -> None: x = var("x") expr = x - x - assert expr.compute_degree() == 0 + assert expr.is_constant() expr = 0 * x - assert expr.compute_degree() == 0 + assert expr.is_constant() assert expr.is_zero() -def test_parameters_resolution() -> None: - class TestParamProvider(ParameterValueProvider): - def get_component_parameter_value(self, component_id: str, name: str) -> float: - raise NotImplementedError() - - def get_parameter_value(self, name: str) -> float: - return 2 - - x = var("x") - p = param("p") - expr = (5 * x + 3) / p - # TODO: We do not want this in the API, but rather expr.get(t, w) - assert expr.resolve_parameters(TestParamProvider()) == (5 * x + 3) / 2 - - -# No real equivalent in the "efficient" formalism -def test_linearization() -> None: - x = comp_var("c", "x") - expr = (5 * x + 3) / 2 - provider = StructureProvider() - - assert expr == LinearExpressionEfficient([TermEfficient(2.5, "c", "x")], 1.5) - - # Does not raise error !!!! - assert param("p") * x == LinearExpressionEfficient( - [TermEfficient(ParameterNode("p"), "c", "x")], 1.5 - ) +# def test_parameters_resolution() -> None: +# class TestParamProvider(ParameterValueProvider): +# def get_component_parameter_value(self, component_id: str, name: str) -> float: +# raise NotImplementedError() + +# def get_parameter_value(self, name: str) -> float: +# return 2 + +# x = var("x") +# p = param("p") +# expr = (5 * x + 3) / p +# # TODO: We do not want this in the API, but rather expr.get(t, w) +# assert expr.resolve_parameters(TestParamProvider()) == (5 * x + 3) / 2 + + +# TODO: Write tests on ExpressionEfficientNodes for tree simplification, do the same for multiplication, substraction, etc +@pytest.mark.parametrize( + "e1, e2, expected", + [ + ( + var("x"), + -var("x"), + LinearExpressionEfficient(), + ), + ( + param("p"), + -param("p"), + LinearExpressionEfficient(), + ), + ( + var("x"), + -var("y"), + var("x") - var("y"), + ), + ( + comp_var("c1", "x"), + var("x"), + comp_var("c1", "x") + var("x"), + ), + ( + comp_var("c1", "x"), + comp_var("c2", "x"), + comp_var("c1", "x") + comp_var("c2", "x"), + ), + ( + comp_param("c1", "p"), + comp_param("c2", "p"), + comp_param("c1", "p") + comp_param("c2", "p"), + ), + ( + comp_var("c1", "x"), + comp_param("c1", "p"), + comp_var("c1", "x") + comp_param("c1", "p"), + ), + ( + param("p1"), + param("p2"), + param("p1") + param("p2"), + ), + ( + var("x"), + var("x"), + 2 * var("x"), + ), + ( + param("p"), + param("p"), + 2 * param("p"), + ), + ], +) +def test_addition( + e1: LinearExpressionEfficient, + e2: LinearExpressionEfficient, + expected: LinearExpressionEfficient, +) -> None: + assert e1 + e2 == expected + +@pytest.mark.parametrize( + "e1, e2, expected", + [ + ( + var("x"), + -var("x"), + 2 * var("x"), + ), + ( + param("p"), + -param("p"), + 2 * param("p"), + ), + ( + var("x"), + -var("y"), + var("x") + var("y"), + ), + ( + comp_var("c1", "x"), + var("x"), + comp_var("c1", "x") - var("x"), + ), + ( + comp_var("c1", "x"), + comp_var("c2", "x"), + comp_var("c1", "x") - comp_var("c2", "x"), + ), + ( + comp_param("c1", "p"), + comp_param("c2", "p"), + comp_param("c1", "p") - comp_param("c2", "p"), + ), + ( + comp_var("c1", "x"), + comp_param("c1", "p"), + comp_var("c1", "x") - comp_param("c1", "p"), + ), + ( + param("p1"), + param("p2"), + param("p1") - param("p2"), + ), + ( + var("x"), + var("x"), + LinearExpressionEfficient(), + ), + ( + param("p"), + param("p"), + LinearExpressionEfficient(), + ), + ], +) +def test_substraction( + e1: LinearExpressionEfficient, + e2: LinearExpressionEfficient, + expected: LinearExpressionEfficient, +) -> None: + assert e1 - e2 == expected + + +@pytest.mark.parametrize( + "lhs, rhs", + [ + ( + (5 * comp_var("c", "x") + 3) / 2, + LinearExpressionEfficient([TermEfficient(2.5, "c", "x")], 1.5), + ), + ( + param("p") * comp_var("c", "x"), + LinearExpressionEfficient( + [TermEfficient(ParameterNode("p"), "c", "x")], + ), + ), + ( + param("p") * comp_var("c", "x"), + LinearExpressionEfficient( + [TermEfficient(ParameterNode("p"), "c", "x")], + ), + ), + ], +) +def test_linear_expression_equality( + lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient +) -> None: + assert lhs == rhs # TODO: What is the equivalent of this test ? -def test_linearization_of_non_linear_expressions_should_raise_value_error() -> None: - x = var("x") - expr = x.variance() - - provider = StructureProvider() - with pytest.raises(ValueError) as exc: - linearize_expression(expr, provider) - assert ( - str(exc.value) - == "Cannot linearize expression with a non-linear operator: Variance" - ) +# def test_linearization_of_non_linear_expressions_should_raise_value_error() -> None: +# x = var("x") +# expr = x.variance() + +# provider = StructureProvider() +# with pytest.raises(ValueError) as exc: +# linearize_expression(expr, provider) +# assert ( +# str(exc.value) +# == "Cannot linearize expression with a non-linear operator: Variance" +# ) def test_comparison() -> None: diff --git a/tests/unittests/expressions/test_linear_expressions_efficient.py b/tests/unittests/expressions/test_linear_expressions_efficient.py index 4da8b486..df698f7f 100644 --- a/tests/unittests/expressions/test_linear_expressions_efficient.py +++ b/tests/unittests/expressions/test_linear_expressions_efficient.py @@ -22,42 +22,6 @@ from andromede.expression.time_operator import TimeShift, TimeSum -@pytest.mark.parametrize( - "term, expected", - [ - (TermEfficient(1, "c", "x"), "+x"), - (TermEfficient(-1, "c", "x"), "-x"), - (TermEfficient(2.50, "c", "x"), "+2.5x"), - (TermEfficient(-3, "c", "x"), "-3x"), - (TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), "-3x.shift([-1])"), - (TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), "-3x.sum(True)"), - ( - TermEfficient( - -3, - "c", - "x", - time_operator=TimeShift([2, 3]), - time_aggregator=TimeSum(False), - ), - "-3x.shift([2, 3]).sum(False)", - ), - (TermEfficient(-3, "c", "x", scenario_operator=Expectation()), "-3x.expec()"), - ( - TermEfficient( - -3, - "c", - "x", - time_aggregator=TimeSum(True), - scenario_operator=Expectation(), - ), - "-3x.sum(True).expec()", - ), - ], -) -def test_printing_term(term: TermEfficient, expected: str) -> None: - assert str(term) == expected - - @pytest.mark.parametrize( "coeff, var_name, constant, expec_str", [ diff --git a/tests/unittests/expressions/test_term_efficient.py b/tests/unittests/expressions/test_term_efficient.py new file mode 100644 index 00000000..aaae5b9f --- /dev/null +++ b/tests/unittests/expressions/test_term_efficient.py @@ -0,0 +1,137 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +import pytest + +from andromede.expression.expression_efficient import LiteralNode +from andromede.expression.linear_expression_efficient import TermEfficient +from andromede.expression.scenario_operator import Expectation, Variance +from andromede.expression.time_operator import TimeShift, TimeSum + + +@pytest.mark.parametrize( + "term, expected", + [ + (TermEfficient(1, "c", "x"), "+x"), + (TermEfficient(-1, "c", "x"), "-x"), + (TermEfficient(2.50, "c", "x"), "+2.5x"), + (TermEfficient(-3, "c", "x"), "-3x"), + (TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), "-3x.shift([-1])"), + (TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), "-3x.sum(True)"), + ( + TermEfficient( + -3, + "c", + "x", + time_operator=TimeShift([2, 3]), + time_aggregator=TimeSum(False), + ), + "-3x.shift([2, 3]).sum(False)", + ), + (TermEfficient(-3, "c", "x", scenario_operator=Expectation()), "-3x.expec()"), + ( + TermEfficient( + -3, + "c", + "x", + time_aggregator=TimeSum(True), + scenario_operator=Expectation(), + ), + "-3x.sum(True).expec()", + ), + ], +) +def test_printing_term(term: TermEfficient, expected: str) -> None: + assert str(term) == expected + + +@pytest.mark.parametrize( + "lhs, rhs, expected", + [ + (TermEfficient(1, "c", "x"), TermEfficient(1, "c", "x"), True), + (TermEfficient(1, "c", "x"), TermEfficient(2, "c", "x"), False), + ( + TermEfficient(LiteralNode(1), "c", "x"), + TermEfficient(LiteralNode(2), "c", "x"), + False, + ), + (TermEfficient(-1, "c", "x"), TermEfficient(-1, "", "x"), False), + (TermEfficient(2.50, "c", "x"), TermEfficient(2.50, "c", ""), False), + (TermEfficient(-3, "c", "x"), TermEfficient(-3, "c", "y"), False), + ( + TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), + TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), + True, + ), + ( + TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), + TermEfficient(-3, "c", "x"), + False, + ), + ( + TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), + TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), + True, + ), + ( + TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), + TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), + False, + ), + ( + TermEfficient( + -3, + "c", + "x", + time_operator=TimeShift([2, 3]), + time_aggregator=TimeSum(False), + ), + TermEfficient( + -3, + "c", + "x", + time_operator=TimeShift([1, 3]), + time_aggregator=TimeSum(False), + ), + False, + ), + ( + TermEfficient(-3, "c", "x", scenario_operator=Expectation()), + TermEfficient(-3, "c", "x", scenario_operator=Expectation()), + True, + ), + ( + TermEfficient(-3, "c", "x", scenario_operator=Expectation()), + TermEfficient(-3, "c", "x", scenario_operator=Variance()), + False, + ), + ( + TermEfficient( + -3, + "c", + "x", + time_aggregator=TimeSum(True), + scenario_operator=Expectation(), + ), + TermEfficient( + -3, + "c", + "x", + time_aggregator=TimeSum(False), + scenario_operator=Expectation(), + ), + False, + ), + ], +) +def test_term_equality(lhs: TermEfficient, rhs: TermEfficient, expected: bool) -> None: + assert (lhs == rhs) == expected From 4f0542c76bedfb8ead992aad2e68a52d1e95631f Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 25 Jun 2024 17:27:02 +0200 Subject: [PATCH 04/51] Handle constraints --- .../expression/linear_expression_efficient.py | 58 ++++++++++++++++++- src/andromede/model/constraint.py | 51 ++++++++-------- .../expressions/test_expressions_efficient.py | 16 +++-- 3 files changed, 92 insertions(+), 33 deletions(-) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 455cffe1..9905d193 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -295,12 +295,28 @@ def __str__(self) -> str: return result + def __le__(self, rhs: Any) -> "StandaloneConstraint": + return StandaloneConstraint( + expression=self - rhs, + lower_bound=literal(-float("inf")), + upper_bound=literal(0), + ) + + def __ge__(self, rhs: Any) -> "ExpressionNodeEfficient": + return StandaloneConstraint( + expression=self - rhs, + lower_bound=literal(0), + upper_bound=literal(float("inf")), + ) + + # def __eq__(self, rhs: Any) -> "ExpressionNodeEfficient": # type: ignore + # return _apply_if_node(rhs, lambda x: ComparisonNode(self, x, Comparator.EQUAL)) + def __eq__(self, rhs: object) -> bool: return ( isinstance(rhs, LinearExpressionEfficient) and expressions_equal(self.constant, rhs.constant) - and self.terms - == rhs.terms + and self.terms == rhs.terms ) def __iadd__( @@ -466,6 +482,44 @@ def is_constant(self) -> bool: return not self.terms +@dataclass +class StandaloneConstraint: + """ + A standalone constraint, with rugid initialization. + """ + + expression: LinearExpressionEfficient + lower_bound: LinearExpressionEfficient + upper_bound: LinearExpressionEfficient + + def __init__( + self, + expression: LinearExpressionEfficient, + lower_bound: LinearExpressionEfficient, + upper_bound: LinearExpressionEfficient, + ) -> None: + + for bound in [lower_bound, upper_bound]: + if bound is not None and not bound.is_constant(): + raise ValueError( + f"The bounds of a constraint should not contain variables, {print_expr(bound)} was given." + ) + + self.expression = expression + if lower_bound is not None: + self.lower_bound = lower_bound + else: + self.lower_bound = literal(-float("inf")) + + if upper_bound is not None: + self.upper_bound = upper_bound + else: + self.upper_bound = literal(float("inf")) + + def __str__(self) -> str: + return f"{str(self.lower_bound)} <= {str(self.expression)} <= {str(self.upper_bound)}" + + def _wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient: if isinstance(obj, LinearExpressionEfficient): return obj diff --git a/src/andromede/model/constraint.py b/src/andromede/model/constraint.py index 9852046e..b8ef3d54 100644 --- a/src/andromede/model/constraint.py +++ b/src/andromede/model/constraint.py @@ -10,17 +10,24 @@ # # This file is part of the Antares project. from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Optional, Union from andromede.expression.degree import is_constant from andromede.expression.equality import ( expressions_equal, expressions_equal_if_present, ) -from andromede.expression.expression import ( - Comparator, - ComparisonNode, - ExpressionNode, + +# from andromede.expression.expression import ( +# Comparator, +# ComparisonNode, +# ExpressionNode, +# literal, +# ) +from andromede.expression.expression_efficient import Comparator, ComparisonNode +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + StandaloneConstraint, literal, ) from andromede.expression.print import print_expr @@ -36,42 +43,31 @@ class Constraint: """ name: str - expression: ExpressionNode - lower_bound: ExpressionNode - upper_bound: ExpressionNode + expression: LinearExpressionEfficient + lower_bound: LinearExpressionEfficient + upper_bound: LinearExpressionEfficient context: ProblemContext def __init__( self, name: str, - expression: ExpressionNode, - lower_bound: Optional[ExpressionNode] = None, - upper_bound: Optional[ExpressionNode] = None, + expression: Union[LinearExpressionEfficient, StandaloneConstraint], + lower_bound: Optional[LinearExpressionEfficient] = None, + upper_bound: Optional[LinearExpressionEfficient] = None, context: ProblemContext = ProblemContext.OPERATIONAL, ) -> None: self.name = name self.context = context - if isinstance(expression, ComparisonNode): + if isinstance(expression, StandaloneConstraint): if lower_bound is not None or upper_bound is not None: raise ValueError( "Both comparison between two expressions and a bound are specfied, set either only a comparison between expressions or a single linear expression with bounds." ) - merged_expr = expression.left - expression.right - self.expression = merged_expr - - if expression.comparator == Comparator.LESS_THAN: - # lhs - rhs <= 0 - self.upper_bound = literal(0) - self.lower_bound = literal(-float("inf")) - elif expression.comparator == Comparator.GREATER_THAN: - # lhs - rhs >= 0 - self.lower_bound = literal(0) - self.upper_bound = literal(float("inf")) - else: # lhs - rhs == 0 - self.lower_bound = literal(0) - self.upper_bound = literal(0) + self.expression = expression.expression + self.lower_bound = expression.lower_bound + self.upper_bound = expression.upper_bound else: for bound in [lower_bound, upper_bound]: if bound is not None and not is_constant(bound): @@ -99,3 +95,6 @@ def __eq__(self, other: Any) -> bool: and expressions_equal_if_present(self.lower_bound, other.lower_bound) and expressions_equal_if_present(self.upper_bound, other.upper_bound) ) + + def __str__(self) -> str: + return f"{str(self.lower_bound)} <= {str(self.expression)} <= {str(self.upper_bound)}" diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index a90b2bfb..213c6d5f 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -26,9 +26,11 @@ from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, + StandaloneConstraint, TermEfficient, comp_param, comp_var, + literal, param, var, ) @@ -214,6 +216,7 @@ def test_addition( ) -> None: assert e1 + e2 == expected + @pytest.mark.parametrize( "e1, e2, expected", [ @@ -318,14 +321,17 @@ def test_linear_expression_equality( # ) +def test_standalone_constraint() -> None: + cst = StandaloneConstraint(var("x"), literal(0), literal(10)) + + assert str(cst) == "0 <= +x <= + 10" + + def test_comparison() -> None: x = var("x") p = param("p") - expr: Constraint = ( - 5 * x + 3 - ) >= p - 2 ## Overloading operator to return a constraint object ! - - assert str(expr) == "((5.0 * x) + 3.0) >= (p - 2.0)" + expr = (5 * x + 3) >= p - 2 + assert str(expr) == "0 <= 5.0x + (3.0 - (p - 2.0)) <= + inf" class StructureProvider(IndexingStructureProvider): From 946c3667dc714a5eabe4ffe4bbd0a8049f545966 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 26 Jun 2024 16:12:38 +0200 Subject: [PATCH 05/51] Equality comparator for linear expression --- .../expression/linear_expression_efficient.py | 32 +++++++++++++------ .../expressions/test_expressions_efficient.py | 10 ++++-- .../test_linear_expressions_efficient.py | 25 ++++++++------- 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 9905d193..fcb0ee28 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -280,7 +280,13 @@ def str_for_constant(self) -> str: if is_zero(self.constant): return "" else: - return f" + {print_expr(self.constant)}" + const_str = print_expr(self.constant) + if const_str.startswith("+"): + return f" + {const_str[1:]}" + elif const_str.startswith("-"): + return f" + ({const_str})" + else: + return f" + {print_expr(self.constant)}" def __str__(self) -> str: # Useful for debugging tests @@ -309,14 +315,11 @@ def __ge__(self, rhs: Any) -> "ExpressionNodeEfficient": upper_bound=literal(float("inf")), ) - # def __eq__(self, rhs: Any) -> "ExpressionNodeEfficient": # type: ignore - # return _apply_if_node(rhs, lambda x: ComparisonNode(self, x, Comparator.EQUAL)) - - def __eq__(self, rhs: object) -> bool: - return ( - isinstance(rhs, LinearExpressionEfficient) - and expressions_equal(self.constant, rhs.constant) - and self.terms == rhs.terms + def __eq__(self, rhs: Any) -> "ExpressionNodeEfficient": # type: ignore + return StandaloneConstraint( + expression=self - rhs, + lower_bound=literal(0), + upper_bound=literal(0), ) def __iadd__( @@ -482,6 +485,17 @@ def is_constant(self) -> bool: return not self.terms +def linear_expressions_equal( + lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient +) -> bool: + return ( + isinstance(lhs, LinearExpressionEfficient) + and isinstance(rhs, LinearExpressionEfficient) + and expressions_equal(lhs.constant, rhs.constant) + and lhs.terms == rhs.terms + ) + + @dataclass class StandaloneConstraint: """ diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index 213c6d5f..72f87980 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -330,8 +330,14 @@ def test_standalone_constraint() -> None: def test_comparison() -> None: x = var("x") p = param("p") - expr = (5 * x + 3) >= p - 2 - assert str(expr) == "0 <= 5.0x + (3.0 - (p - 2.0)) <= + inf" + + expr_geq = (5 * x + 3) >= p - 2 + expr_leq = (5 * x + 3) <= p - 2 + expr_eq = (5 * x + 3) == p - 2 + + assert str(expr_geq) == "0 <= 5.0x + (3.0 - (p - 2.0)) <= + inf" + assert str(expr_leq) == " + (-inf) <= 5.0x + (3.0 - (p - 2.0)) <= 0" + assert str(expr_eq) == "0 <= 5.0x + (3.0 - (p - 2.0)) <= 0" class StructureProvider(IndexingStructureProvider): diff --git a/tests/unittests/expressions/test_linear_expressions_efficient.py b/tests/unittests/expressions/test_linear_expressions_efficient.py index df698f7f..f894dd6c 100644 --- a/tests/unittests/expressions/test_linear_expressions_efficient.py +++ b/tests/unittests/expressions/test_linear_expressions_efficient.py @@ -17,6 +17,7 @@ from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, TermEfficient, + linear_expressions_equal, ) from andromede.expression.scenario_operator import Expectation from andromede.expression.time_operator import TimeShift, TimeSum @@ -27,9 +28,9 @@ [ (0, "x", 0, "0"), (1, "x", 0, "+x"), - (1, "x", 1, "+x+1"), - (3.7, "x", 1, "+3.7x+1"), - (0, "x", 1, "+1"), + (1, "x", 1, "+x + 1.0"), + (3.7, "x", 1, "3.7x + 1.0"), + (0, "x", 1, " + 1.0"), ], ) def test_affine_expression_printing_should_reflect_required_formatting( @@ -63,7 +64,7 @@ def test_affine_expression_printing_should_reflect_required_formatting( def test_constant_expressions( lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient ) -> None: - assert lhs == rhs + assert linear_expressions_equal(lhs, rhs) @pytest.mark.parametrize( @@ -165,7 +166,7 @@ def test_addition( e2: LinearExpressionEfficient, expected: LinearExpressionEfficient, ) -> None: - assert e1 + e2 == expected + assert linear_expressions_equal(e1 + e2, expected) def test_addition_of_linear_expressions_with_different_number_of_instances_should_raise_value_error() -> ( @@ -235,8 +236,8 @@ def test_multiplication( e2: LinearExpressionEfficient, expected: LinearExpressionEfficient, ) -> None: - assert e1 * e2 == expected - assert e2 * e1 == expected + assert linear_expressions_equal(e1 * e2, expected) + assert linear_expressions_equal(e2 * e1, expected) def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> None: @@ -287,7 +288,7 @@ def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> def test_negation( e1: LinearExpressionEfficient, expected: LinearExpressionEfficient ) -> None: - assert -e1 == expected + assert linear_expressions_equal(-e1, expected) @pytest.mark.parametrize( @@ -377,7 +378,7 @@ def test_substraction( e2: LinearExpressionEfficient, expected: LinearExpressionEfficient, ) -> None: - assert e1 - e2 == expected + assert linear_expressions_equal(e1 - e2, expected) @pytest.mark.parametrize( @@ -429,7 +430,7 @@ def test_division( e2: LinearExpressionEfficient, expected: LinearExpressionEfficient, ) -> None: - assert e1 / e2 == expected + assert linear_expressions_equal(e1 / e2, expected) def test_division_by_zero_sould_raise_zero_division_error() -> None: @@ -454,6 +455,6 @@ def test_imul_preserve_identity() -> None: e1 = LinearExpressionEfficient([], 15) e2 = e1 e1 *= LinearExpressionEfficient([], 2) - assert e1 == LinearExpressionEfficient([], 30) - assert e2 == e1 + assert linear_expressions_equal(e1, LinearExpressionEfficient([], 30)) + assert linear_expressions_equal(e2, e1) assert e2 is e1 From e1ebb29332d2393c4320e822dbc6a16c835417f9 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 26 Jun 2024 19:08:05 +0200 Subject: [PATCH 06/51] WIP --- .../expression/expression_efficient.py | 6 +- src/andromede/expression/indexing.py | 16 +- .../expression/linear_expression_efficient.py | 181 +++++++++++++++++- src/andromede/expression/time_operator.py | 16 +- 4 files changed, 192 insertions(+), 27 deletions(-) diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index 426e9cd1..c6828b15 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -396,7 +396,7 @@ def expression_range( ) -@dataclass +@dataclass(frozen=True) class InstancesTimeIndex: """ Defines a set of time indices on which a time operator operates. @@ -429,9 +429,9 @@ def __init__( ) if isinstance(expressions, (int, ExpressionNodeEfficient)): - self.expressions = [wrap_in_node(expressions)] + object.__setattr__(self, "expressions", [wrap_in_node(expressions)]) else: - self.expressions = expressions + object.__setattr__(self, "expressions", expressions) def is_simple(self) -> bool: if isinstance(self.expressions, list): diff --git a/src/andromede/expression/indexing.py b/src/andromede/expression/indexing.py index 11051dd5..102f4c45 100644 --- a/src/andromede/expression/indexing.py +++ b/src/andromede/expression/indexing.py @@ -89,20 +89,20 @@ def division(self, node: DivisionNode) -> IndexingStructure: def comparison(self, node: ComparisonNode) -> IndexingStructure: return visit(node.left, self) | visit(node.right, self) - def variable(self, node: VariableNode) -> IndexingStructure: - time = self.context.get_variable_structure(node.name).time == True - scenario = self.context.get_variable_structure(node.name).scenario == True - return IndexingStructure(time, scenario) + # def variable(self, node: VariableNode) -> IndexingStructure: + # time = self.context.get_variable_structure(node.name).time == True + # scenario = self.context.get_variable_structure(node.name).scenario == True + # return IndexingStructure(time, scenario) def parameter(self, node: ParameterNode) -> IndexingStructure: time = self.context.get_parameter_structure(node.name).time == True scenario = self.context.get_parameter_structure(node.name).scenario == True return IndexingStructure(time, scenario) - def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure: - return self.context.get_component_variable_structure( - node.component_id, node.name - ) + # def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure: + # return self.context.get_component_variable_structure( + # node.component_id, node.name + # ) def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure: return self.context.get_component_parameter_structure( diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index fcb0ee28..8f6f04ff 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -14,6 +14,7 @@ Specific modelling for "instantiated" linear expressions, with only variables and literal coefficients. """ +import dataclasses from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, TypeVar, Union @@ -22,17 +23,26 @@ from andromede.expression.expression_efficient import ( ComponentParameterNode, ExpressionNodeEfficient, + ExpressionRange, + Instances, + InstancesTimeIndex, LiteralNode, ParameterNode, + TimeOperatorNode, is_minus_one, is_one, is_zero, wrap_in_node, ) +from andromede.expression.indexing import ( + IndexingStructureProvider, + TimeScenarioIndexingVisitor, + compute_indexation, +) from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.print import print_expr from andromede.expression.scenario_operator import ScenarioOperator -from andromede.expression.time_operator import TimeAggregator, TimeOperator +from andromede.expression.time_operator import TimeAggregator, TimeOperator, TimeShift T = TypeVar("T") @@ -70,11 +80,23 @@ class TermEfficient: time_aggregator: Optional[TimeAggregator] = None scenario_operator: Optional[ScenarioOperator] = None - # TODO: It may be useful to define __add__, __sub__, etc on terms, which should return a linear expression ? + # TODO: Try to remove this + instances: Instances = field(init=False, default=Instances.SIMPLE) def __post_init__(self) -> None: object.__setattr__(self, "coefficient", wrap_in_node(self.coefficient)) + if self.time_operator is not None and self.time_aggregator is None: + + # TODO: Make a fuinction in time operator class + time_op_instances = Instances.SIMPLE if self.time_operator.time_ids.is_simple() else Instances.MULTIPLE + + if self.coefficient.instances != time_op_instances: + raise ValueError( + f"Cannot build term with coefficient {self.coefficient} and operator {self.time_operator} as they do not have the same number of instances." + ) + self.instances = self.coefficient.instances + def __eq__(self, other: "TermEfficient") -> bool: return ( isinstance(other, TermEfficient) @@ -131,6 +153,58 @@ def evaluate(self, context: ValueProvider) -> float: variable_value = context.get_variable_value(self.variable_name) return evaluate(self.coefficient, context) * variable_value + def compute_indexation( + self, provider: IndexingStructureProvider + ) -> IndexingStructure: + + # TODO: Improve this if/else structure + if self.component_id: + time = ( + provider.get_component_variable_structure(self.variable_name).time + == True + ) + scenario = ( + provider.get_component_variable_structure(self.variable_name).scenario + == True + ) + else: + time = provider.get_variable_structure(self.variable_name).time == True + scenario = ( + provider.get_variable_structure(self.variable_name).scenario == True + ) + return IndexingStructure(time, scenario) + + def shift( + self, + expressions: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + ], + ) -> "TermEfficient": + """ + Time shift of term + """ + # The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a shift operator on a whole expression, rather than just on the variables of an expression + + # Example : (param("p") * var("x")).shift(1) + # Previous behavior : p[t]x[t-1] + # New behavior : p[t-1]x[t-1] + + if self.time_operator is not None: + raise ValueError( + f"Composition of time operators {self.time_operator} and {TimeShift(InstancesTimeIndex(expressions))} is not allowed" + ) + + return dataclasses.replace( + self, + coefficient=TimeOperatorNode( + self.coefficient, "TimeShift", InstancesTimeIndex(expressions) + ), + time_operator=TimeShift(InstancesTimeIndex(expressions)), + ) + def generate_key(term: TermEfficient) -> TermKeyEfficient: return TermKeyEfficient( @@ -240,6 +314,8 @@ class LinearExpressionEfficient: terms: Dict[TermKeyEfficient, TermEfficient] constant: ExpressionNodeEfficient + # TODO: Probably not necessary, for now we replicate old implementation functioning + instances: Instances # TODO: We need to check that terms.key is indeed a TermKey and change the tests that this will break def __init__( @@ -272,6 +348,9 @@ def __init__( raise TypeError( f"Terms must be either of type Dict[str, Term] or List[Term], whereas {terms} is of type {type(terms)}" ) + + def _compute_instances(self): + def is_zero(self) -> bool: return len(self.terms) == 0 and is_zero(self.constant) @@ -484,6 +563,90 @@ def is_constant(self) -> bool: # Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree... return not self.terms + def compute_indexation( + self, provider: IndexingStructureProvider + ) -> IndexingStructure: + + indexing = compute_indexation(self.constant, provider) + for term in self.terms.values(): + indexing = indexing | term.compute_indexation(provider) + + return indexing + + # def sum(self) -> "ExpressionNode": + # if isinstance(self, TimeOperatorNode): + # return TimeAggregatorNode(self, "TimeSum", stay_roll=True) + # else: + # return _apply_if_node( + # self, lambda x: TimeAggregatorNode(x, "TimeSum", stay_roll=False) + # ) + + # def sum_connections(self) -> "ExpressionNode": + # if isinstance(self, PortFieldNode): + # return PortFieldAggregatorNode(self, aggregator="PortSum") + # raise ValueError( + # f"sum_connections() applies only for PortFieldNode, whereas the current node is of type {type(self)}." + # ) + + def shift( + self, + expressions: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + ], + ) -> "LinearExpressionEfficient": + """ + Time shift of variables + + Examples: + >>> x.shift([1, 2, 4]) represents the vector of variables (x[t+1], x[t+2], x[t+4]) + + No variables allowed in shift argument, but parameter trees are ok + + It is assumed that the shift operator is linear and distributes to all terms and to the constant of the linear expression on which it is applied. + + Examples: + >>> (param("a") * var("x") + param("b")).shift([1, 2, 4]) represents the vector of variables (a[t+1]x[t+1] + b[t+1], a[t+2]x[t+2] + b[t+2], a[t+4]x[t+4] + b[t+4]) + """ + + # The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a shift operator on a whole expression, rather than just on the variables of an expression + + # Example : (param("p") * var("x")).shift(1) + # Previous behavior : p[t]x[t-1] + # New behavior : p[t-1]x[t-1] + + result_terms = {} + for term in self.terms.values(): + term_with_operator = term.shift(expressions) + result_terms[generate_key(term_with_operator)] = term_with_operator + + result_constant = TimeOperatorNode( + self.constant, "TimeShift", InstancesTimeIndex(expressions) + ) + result_expr = LinearExpressionEfficient(result_terms, result_constant) + return result_expr + + # def eval( + # self, + # expressions: Union[ + # int, "ExpressionNode", List["ExpressionNode"], "ExpressionRange" + # ], + # ) -> "ExpressionNode": + # return _apply_if_node( + # self, + # lambda x: TimeOperatorNode( + # x, "TimeEvaluation", InstancesTimeIndex(expressions) + # ), + # ) + + # def expec(self) -> "ExpressionNode": + # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) + + # def variance(self) -> "ExpressionNode": + # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) + def linear_expressions_equal( lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient @@ -542,13 +705,13 @@ def _wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient: raise TypeError(f"Unable to wrap {obj} into a linear expression") -def _apply_if_node( - obj: Any, func: Callable[[LinearExpressionEfficient], LinearExpressionEfficient] -) -> LinearExpressionEfficient: - if as_linear_expr := _wrap_in_linear_expr(obj): - return func(as_linear_expr) - else: - return NotImplemented +# def _apply_if_node( +# obj: Any, func: Callable[[LinearExpressionEfficient], LinearExpressionEfficient] +# ) -> LinearExpressionEfficient: +# if as_linear_expr := _wrap_in_linear_expr(obj): +# return func(as_linear_expr) +# else: +# return NotImplemented def _copy_expression( diff --git a/src/andromede/expression/time_operator.py b/src/andromede/expression/time_operator.py index 63059528..4d4fc676 100644 --- a/src/andromede/expression/time_operator.py +++ b/src/andromede/expression/time_operator.py @@ -18,6 +18,8 @@ from dataclasses import dataclass from typing import Any, List, Tuple +from andromede.expression.expression_efficient import InstancesTimeIndex + @dataclass(frozen=True) class TimeOperator(ABC): @@ -27,21 +29,21 @@ class TimeOperator(ABC): - is_rolling: bool, if true, this means that the time_ids are to be understood relatively to the current timestep of the context AND that the represented expression will have to be instanciated for all timesteps. Otherwise, the time_ids are "absolute" times and the expression only has to be instantiated once. """ - time_ids: List[int] + time_ids: InstancesTimeIndex @classmethod @abstractmethod def rolling(cls) -> bool: raise NotImplementedError - def __post_init__(self) -> None: - if isinstance(self.time_ids, int): - object.__setattr__(self, "time_ids", [self.time_ids]) - elif isinstance(self.time_ids, range): - object.__setattr__(self, "time_ids", list(self.time_ids)) + # def __post_init__(self) -> None: + # if isinstance(self.time_ids, int): + # object.__setattr__(self, "time_ids", [self.time_ids]) + # elif isinstance(self.time_ids, range): + # object.__setattr__(self, "time_ids", list(self.time_ids)) def key(self) -> Tuple[int, ...]: - return tuple(self.time_ids) + return self.time_ids def size(self) -> int: return len(self.time_ids) From 10c206678ef780dea6f611044fa1665039e0e236 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Thu, 27 Jun 2024 17:28:28 +0200 Subject: [PATCH 07/51] Sum shift --- .../expression/linear_expression_efficient.py | 189 +++++++++++------- .../expressions/test_expressions_efficient.py | 16 +- 2 files changed, 128 insertions(+), 77 deletions(-) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 8f6f04ff..a4c43a74 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -28,6 +28,7 @@ InstancesTimeIndex, LiteralNode, ParameterNode, + TimeAggregatorNode, TimeOperatorNode, is_minus_one, is_one, @@ -80,23 +81,9 @@ class TermEfficient: time_aggregator: Optional[TimeAggregator] = None scenario_operator: Optional[ScenarioOperator] = None - # TODO: Try to remove this - instances: Instances = field(init=False, default=Instances.SIMPLE) - def __post_init__(self) -> None: object.__setattr__(self, "coefficient", wrap_in_node(self.coefficient)) - if self.time_operator is not None and self.time_aggregator is None: - - # TODO: Make a fuinction in time operator class - time_op_instances = Instances.SIMPLE if self.time_operator.time_ids.is_simple() else Instances.MULTIPLE - - if self.coefficient.instances != time_op_instances: - raise ValueError( - f"Cannot build term with coefficient {self.coefficient} and operator {self.time_operator} as they do not have the same number of instances." - ) - self.instances = self.coefficient.instances - def __eq__(self, other: "TermEfficient") -> bool: return ( isinstance(other, TermEfficient) @@ -133,14 +120,14 @@ def __str__(self) -> str: result += f".{str(self.scenario_operator)}" return result - def number_of_instances(self) -> int: - if self.time_aggregator is not None: - return self.time_aggregator.size() - else: - if self.time_operator is not None: - return self.time_operator.size() - else: - return 1 + # def number_of_instances(self) -> int: + # if self.time_aggregator is not None: + # return self.time_aggregator.size() + # else: + # if self.time_operator is not None: + # return self.time_operator.size() + # else: + # return 1 def evaluate(self, context: ValueProvider) -> float: # TODO: Take care of component variables, multiple time scenarios, operators, etc @@ -314,8 +301,6 @@ class LinearExpressionEfficient: terms: Dict[TermKeyEfficient, TermEfficient] constant: ExpressionNodeEfficient - # TODO: Probably not necessary, for now we replicate old implementation functioning - instances: Instances # TODO: We need to check that terms.key is indeed a TermKey and change the tests that this will break def __init__( @@ -348,9 +333,6 @@ def __init__( raise TypeError( f"Terms must be either of type Dict[str, Term] or List[Term], whereas {terms} is of type {type(terms)}" ) - - def _compute_instances(self): - def is_zero(self) -> bool: return len(self.terms) == 0 and is_zero(self.constant) @@ -534,25 +516,25 @@ def remove_zeros_from_terms(self) -> None: if is_zero(term.coefficient): del self.terms[term_key] - def is_valid(self) -> bool: - nb_instances = None - for term in self.terms.values(): - term_instances = term.number_of_instances() - if nb_instances is None: - nb_instances = term_instances - else: - if term_instances != nb_instances: - raise ValueError( - "The terms of the linear expression {self} do not have the same number of instances" - ) - return True - - def number_of_instances(self) -> int: - if self.is_valid(): - # All terms have the same number of instances, just pick one - return self.terms[next(iter(self.terms))].number_of_instances() - else: - raise ValueError(f"{self} is not a valid linear expression") + # def is_valid(self) -> bool: + # nb_instances = None + # for term in self.terms.values(): + # term_instances = term.number_of_instances() + # if nb_instances is None: + # nb_instances = term_instances + # else: + # if term_instances != nb_instances: + # raise ValueError( + # "The terms of the linear expression {self} do not have the same number of instances" + # ) + # return True + + # def number_of_instances(self) -> int: + # if self.is_valid(): + # # All terms have the same number of instances, just pick one + # return self.terms[next(iter(self.terms))].number_of_instances() + # else: + # raise ValueError(f"{self} is not a valid linear expression") def evaluate(self, context: ValueProvider) -> float: return sum([term.evaluate(context) for term in self.terms.values()]) + evaluate( @@ -573,13 +555,88 @@ def compute_indexation( return indexing - # def sum(self) -> "ExpressionNode": - # if isinstance(self, TimeOperatorNode): - # return TimeAggregatorNode(self, "TimeSum", stay_roll=True) - # else: - # return _apply_if_node( - # self, lambda x: TimeAggregatorNode(x, "TimeSum", stay_roll=False) - # ) + def sum( + self, + shift: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + None, + ] = None, + eval: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + None, + ] = None, + ) -> "LinearExpressionEfficient": + """ + Examples: + >>> x.sum(shift=[1, 2, 4]) represents x[t+1] + x[t+2] + x[t+4] + + No variables allowed in shift argument, but parameter trees are ok + + It is assumed that the shift operator is linear and distributes to all terms and to the constant of the linear expression on which it is applied. + + Examples: + >>> (param("a") * var("x") + param("b")).sum(shift=[1, 2, 4]) represents a[t+1]x[t+1] + b[t+1] + a[t+2]x[t+2] + b[t+2] + a[t+4]x[t+4] + b[t+4] + """ + + # if isinstance(self, TimeOperatorNode): + # return TimeAggregatorNode(self, "TimeSum", stay_roll=True) + # else: + # return _apply_if_node( + # self, lambda x: TimeAggregatorNode(x, "TimeSum", stay_roll=False) + # ) + + if shift is not None and eval is not None: + raise ValueError("Only shift or eval arguments should specified, not both.") + + if shift is not None: + result_terms = {} + for term in self.terms.values(): + term_with_operator = term.sum(shift=shift) + result_terms[generate_key(term_with_operator)] = term_with_operator + + result_constant = TimeAggregatorNode( + TimeOperatorNode(self.constant, "TimeShift", InstancesTimeIndex(shift)), + "TimeSum", + stay_roll=True, + ) + result_expr = LinearExpressionEfficient(result_terms, result_constant) + return result_expr + + if eval is not None: + result_terms = {} + for term in self.terms.values(): + term_with_operator = term.sum(eval=eval) + result_terms[generate_key(term_with_operator)] = term_with_operator + + result_constant = TimeAggregatorNode( + TimeOperatorNode( + self.constant, "TimeEvaluation", InstancesTimeIndex(eval) + ), + "TimeSum", + stay_roll=False, + ) + result_expr = LinearExpressionEfficient(result_terms, result_constant) + return result_expr + + else: # x.sum() -> Sum over all time block + result_terms = {} + for term in self.terms.values(): + term_with_operator = term.sum() + result_terms[generate_key(term_with_operator)] = term_with_operator + + result_constant = TimeAggregatorNode( + self.constant, + "TimeSum", + stay_roll=False, + ) + result_expr = LinearExpressionEfficient(result_terms, result_constant) + return result_expr # def sum_connections(self) -> "ExpressionNode": # if isinstance(self, PortFieldNode): @@ -598,17 +655,12 @@ def shift( ], ) -> "LinearExpressionEfficient": """ - Time shift of variables + Shorthand for shift on a single time step - Examples: - >>> x.shift([1, 2, 4]) represents the vector of variables (x[t+1], x[t+2], x[t+4]) + To refer to x[t-1], it is more natural to write x.shift(-1) than x.sum(shift=-1). - No variables allowed in shift argument, but parameter trees are ok + This function provides the shorthand x.sum(shift=expr), valid only in the case when expr refers to a single time step. - It is assumed that the shift operator is linear and distributes to all terms and to the constant of the linear expression on which it is applied. - - Examples: - >>> (param("a") * var("x") + param("b")).shift([1, 2, 4]) represents the vector of variables (a[t+1]x[t+1] + b[t+1], a[t+2]x[t+2] + b[t+2], a[t+4]x[t+4] + b[t+4]) """ # The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a shift operator on a whole expression, rather than just on the variables of an expression @@ -617,16 +669,13 @@ def shift( # Previous behavior : p[t]x[t-1] # New behavior : p[t-1]x[t-1] - result_terms = {} - for term in self.terms.values(): - term_with_operator = term.shift(expressions) - result_terms[generate_key(term_with_operator)] = term_with_operator + if not InstancesTimeIndex(expressions).is_simple(): + raise ValueError( + "The shift operator can only be applied on expressions refering to a single time step. To apply a shifting sum on multiple time indices on an expression x, you should use x.sum(shift=...)" + ) - result_constant = TimeOperatorNode( - self.constant, "TimeShift", InstancesTimeIndex(expressions) - ) - result_expr = LinearExpressionEfficient(result_terms, result_constant) - return result_expr + else: + return self.sum(shift=expressions) # def eval( # self, diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index 72f87980..f9b9d527 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -358,23 +358,25 @@ def get_variable_structure(self, name: str) -> IndexingStructure: return IndexingStructure(True, True) -def test_shift() -> None: +def test_shift_on_time_step_list_raises_value_error() -> None: x = var("x") - expr = x.shift(ExpressionRange(1, 4)) + with pytest.raises(ValueError): + _ = x.shift(ExpressionRange(1, 4)) +def test_shift_on_single_time_step() -> None: + x = var("x") + expr = x.shift(1) + provider = StructureProvider() - assert expr.compute_indexation(provider) == IndexingStructure(True, True) - assert expr.instances == Instances.MULTIPLE def test_shifting_sum() -> None: x = var("x") - expr = x.shift(ExpressionRange(1, 4)).sum() + expr = x.sum(shift=ExpressionRange(1, 4)) + provider = StructureProvider() - assert expr.compute_indexation(provider) == IndexingStructure(True, True) - assert expr.instances == Instances.SIMPLE def test_eval() -> None: From d13a64904a82ab81325812b7c9722093cf353905 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Thu, 11 Jul 2024 16:54:43 +0200 Subject: [PATCH 08/51] WIP --- .../expression/linear_expression_efficient.py | 161 +++++++++--------- 1 file changed, 83 insertions(+), 78 deletions(-) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index a4c43a74..72b4067c 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -24,7 +24,6 @@ ComponentParameterNode, ExpressionNodeEfficient, ExpressionRange, - Instances, InstancesTimeIndex, LiteralNode, ParameterNode, @@ -35,11 +34,7 @@ is_zero, wrap_in_node, ) -from andromede.expression.indexing import ( - IndexingStructureProvider, - TimeScenarioIndexingVisitor, - compute_indexation, -) +from andromede.expression.indexing import IndexingStructureProvider, compute_indexation from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.print import print_expr from andromede.expression.scenario_operator import ScenarioOperator @@ -120,15 +115,6 @@ def __str__(self) -> str: result += f".{str(self.scenario_operator)}" return result - # def number_of_instances(self) -> int: - # if self.time_aggregator is not None: - # return self.time_aggregator.size() - # else: - # if self.time_operator is not None: - # return self.time_operator.size() - # else: - # return 1 - def evaluate(self, context: ValueProvider) -> float: # TODO: Take care of component variables, multiple time scenarios, operators, etc # Probably very error prone @@ -161,6 +147,44 @@ def compute_indexation( ) return IndexingStructure(time, scenario) + def sum( + self, + shift: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + None, + ] = None, + eval: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + None, + ] = None, + ) -> "TermEfficient": + + if shift is not None and eval is not None: + raise ValueError("Only shift or eval arguments should specified, not both.") + + if shift is not None: + return dataclasses.replace( + self, + coefficient=TimeOperatorNode( + self.coefficient, "TimeShift", InstancesTimeIndex(shift) + ), + time_operator=TimeShift(InstancesTimeIndex(shift)), + ) + sum_args = {"shift": shift} + stay_roll = True + elif eval is not None: + sum_args = {"eval": eval} + stay_roll = True + else: # x.sum() -> Sum over all time block + sum_args = {} + stay_roll = False + def shift( self, expressions: Union[ @@ -186,9 +210,13 @@ def shift( return dataclasses.replace( self, - coefficient=TimeOperatorNode( - self.coefficient, "TimeShift", InstancesTimeIndex(expressions) + coefficient=TimeAggregatorNode( + TimeOperatorNode( + self.constant, "TimeShift", InstancesTimeIndex(sum_args.values()[0]) ), + "TimeSum", + stay_roll=stay_roll, + ), time_operator=TimeShift(InstancesTimeIndex(expressions)), ) @@ -369,14 +397,14 @@ def __le__(self, rhs: Any) -> "StandaloneConstraint": upper_bound=literal(0), ) - def __ge__(self, rhs: Any) -> "ExpressionNodeEfficient": + def __ge__(self, rhs: Any) -> "StandaloneConstraint": return StandaloneConstraint( expression=self - rhs, lower_bound=literal(0), upper_bound=literal(float("inf")), ) - def __eq__(self, rhs: Any) -> "ExpressionNodeEfficient": # type: ignore + def __eq__(self, rhs: Any) -> "StandaloneConstraint": # type: ignore return StandaloneConstraint( expression=self - rhs, lower_bound=literal(0), @@ -516,26 +544,6 @@ def remove_zeros_from_terms(self) -> None: if is_zero(term.coefficient): del self.terms[term_key] - # def is_valid(self) -> bool: - # nb_instances = None - # for term in self.terms.values(): - # term_instances = term.number_of_instances() - # if nb_instances is None: - # nb_instances = term_instances - # else: - # if term_instances != nb_instances: - # raise ValueError( - # "The terms of the linear expression {self} do not have the same number of instances" - # ) - # return True - - # def number_of_instances(self) -> int: - # if self.is_valid(): - # # All terms have the same number of instances, just pick one - # return self.terms[next(iter(self.terms))].number_of_instances() - # else: - # raise ValueError(f"{self} is not a valid linear expression") - def evaluate(self, context: ValueProvider) -> float: return sum([term.evaluate(context) for term in self.terms.values()]) + evaluate( self.constant, context @@ -595,48 +603,45 @@ def sum( raise ValueError("Only shift or eval arguments should specified, not both.") if shift is not None: - result_terms = {} - for term in self.terms.values(): - term_with_operator = term.sum(shift=shift) - result_terms[generate_key(term_with_operator)] = term_with_operator - - result_constant = TimeAggregatorNode( - TimeOperatorNode(self.constant, "TimeShift", InstancesTimeIndex(shift)), - "TimeSum", - stay_roll=True, - ) - result_expr = LinearExpressionEfficient(result_terms, result_constant) - return result_expr - - if eval is not None: - result_terms = {} - for term in self.terms.values(): - term_with_operator = term.sum(eval=eval) - result_terms[generate_key(term_with_operator)] = term_with_operator + sum_args = {"shift": shift} + stay_roll = True + elif eval is not None: + sum_args = {"eval": eval} + stay_roll = True + else: # x.sum() -> Sum over all time block + sum_args = {} + stay_roll = False - result_constant = TimeAggregatorNode( - TimeOperatorNode( - self.constant, "TimeEvaluation", InstancesTimeIndex(eval) - ), - "TimeSum", - stay_roll=False, - ) - result_expr = LinearExpressionEfficient(result_terms, result_constant) - return result_expr + return self._apply_operator(sum_args, stay_roll) - else: # x.sum() -> Sum over all time block - result_terms = {} - for term in self.terms.values(): - term_with_operator = term.sum() - result_terms[generate_key(term_with_operator)] = term_with_operator + def _apply_operator( + self, + sum_args: Dict[ + str, + Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + None, + ], + ], + stay_roll: bool, + ): + result_terms = {} + for term in self.terms.values(): + term_with_operator = term.sum(**sum_args) + result_terms[generate_key(term_with_operator)] = term_with_operator - result_constant = TimeAggregatorNode( - self.constant, - "TimeSum", - stay_roll=False, - ) - result_expr = LinearExpressionEfficient(result_terms, result_constant) - return result_expr + result_constant = TimeAggregatorNode( + TimeOperatorNode( + self.constant, "TimeShift", InstancesTimeIndex(sum_args.values()[0]) + ), + "TimeSum", + stay_roll=stay_roll, + ) + result_expr = LinearExpressionEfficient(result_terms, result_constant) + return result_expr # def sum_connections(self) -> "ExpressionNode": # if isinstance(self, PortFieldNode): From 75a588125e015688323b4b6b8be9465c3c8f05b5 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Fri, 12 Jul 2024 17:53:13 +0200 Subject: [PATCH 09/51] Shift and eval implementation in progress --- .../expression/linear_expression_efficient.py | 159 +++++++++++------- .../functional/test_performance_efficient.py | 51 ++++++ .../expressions/test_expressions_efficient.py | 24 ++- 3 files changed, 172 insertions(+), 62 deletions(-) create mode 100644 tests/functional/test_performance_efficient.py diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 72b4067c..8deff23a 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -38,7 +38,13 @@ from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.print import print_expr from andromede.expression.scenario_operator import ScenarioOperator -from andromede.expression.time_operator import TimeAggregator, TimeOperator, TimeShift +from andromede.expression.time_operator import ( + TimeAggregator, + TimeEvaluation, + TimeOperator, + TimeShift, + TimeSum, +) T = TypeVar("T") @@ -129,7 +135,6 @@ def evaluate(self, context: ValueProvider) -> float: def compute_indexation( self, provider: IndexingStructureProvider ) -> IndexingStructure: - # TODO: Improve this if/else structure if self.component_id: time = ( @@ -164,26 +169,32 @@ def sum( None, ] = None, ) -> "TermEfficient": - if shift is not None and eval is not None: raise ValueError("Only shift or eval arguments should specified, not both.") + # The shift or eval operators distribute over the coefficients whereas the sum only applies to the whole as (param("a") * var("x")).shift([1,5]) represents: a[t+1]x[t+1] + ... + a[t+5]x[t+5] + # And (param("a") * var("x")).eval([1,5]) represents: a[1]x[1] + ... + a[5]x[5] + if shift is not None: return dataclasses.replace( - self, - coefficient=TimeOperatorNode( - self.coefficient, "TimeShift", InstancesTimeIndex(shift) - ), - time_operator=TimeShift(InstancesTimeIndex(shift)), - ) - sum_args = {"shift": shift} - stay_roll = True + self, + coefficient=TimeOperatorNode( + self.coefficient, "TimeShift", InstancesTimeIndex(shift) + ), + time_operator=TimeShift(InstancesTimeIndex(shift)), + time_aggregator=TimeSum(stay_roll=True), + ) elif eval is not None: - sum_args = {"eval": eval} - stay_roll = True + return dataclasses.replace( + self, + coefficient=TimeOperatorNode( + self.coefficient, "TimeEvaluation", InstancesTimeIndex(eval) + ), + time_operator=TimeEvaluation(InstancesTimeIndex(eval)), + time_aggregator=TimeSum(stay_roll=True), + ) else: # x.sum() -> Sum over all time block - sum_args = {} - stay_roll = False + return dataclasses.replace(self, time_aggregator=TimeSum(stay_roll=False)) def shift( self, @@ -193,32 +204,61 @@ def shift( List["ExpressionNodeEfficient"], "ExpressionRange", ], - ) -> "TermEfficient": + ) -> "LinearExpressionEfficient": """ - Time shift of term + Shorthand for shift on a single time step + + To refer to x[t-1], it is more natural to write x.shift(-1) than x.sum(shift=-1). + + This function provides the shorthand x.sum(shift=expr), valid only in the case when expr refers to a single time step. + """ + # The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a shift operator on a whole expression, rather than just on the variables of an expression # Example : (param("p") * var("x")).shift(1) # Previous behavior : p[t]x[t-1] # New behavior : p[t-1]x[t-1] - if self.time_operator is not None: + if not InstancesTimeIndex(expressions).is_simple(): raise ValueError( - f"Composition of time operators {self.time_operator} and {TimeShift(InstancesTimeIndex(expressions))} is not allowed" + "The shift operator can only be applied on expressions refering to a single time step. To apply a shifting sum on multiple time indices on an expression x, you should use x.sum(shift=...)" ) - return dataclasses.replace( - self, - coefficient=TimeAggregatorNode( - TimeOperatorNode( - self.constant, "TimeShift", InstancesTimeIndex(sum_args.values()[0]) - ), - "TimeSum", - stay_roll=stay_roll, - ), - time_operator=TimeShift(InstancesTimeIndex(expressions)), - ) + else: + return self.sum(shift=expressions) + + def eval( + self, + expressions: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + ], + ) -> "LinearExpressionEfficient": + """ + Shorthand for eval on a single time step + + To refer to x[1], it is more natural to write x.eval(1) than x.sum(eval=1). + + This function provides the shorthand x.sum(eval=expr), valid only in the case when expr refers to a single time step. + + """ + + # The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a eval operator on a whole expression, rather than just on the variables of an expression + + # Example : (param("p") * var("x")).eval(1) + # Previous behavior : p[t]x[1] + # New behavior : p[1]x[1] + + if not InstancesTimeIndex(expressions).is_simple(): + raise ValueError( + "The eval operator can only be applied on expressions refering to a single time step. To apply a evaluating sum on multiple time indices on an expression x, you should use x.sum(eval=...)" + ) + + else: + return self.sum(eval=expressions) def generate_key(term: TermEfficient) -> TermKeyEfficient: @@ -338,7 +378,6 @@ def __init__( ] = None, constant: Optional[Union[float, ExpressionNodeEfficient]] = None, ) -> None: - if constant is None: self.constant = LiteralNode(0) else: @@ -556,7 +595,6 @@ def is_constant(self) -> bool: def compute_indexation( self, provider: IndexingStructureProvider ) -> IndexingStructure: - indexing = compute_indexation(self.constant, provider) for term in self.terms.values(): indexing = indexing | term.compute_indexation(provider) @@ -592,13 +630,6 @@ def sum( >>> (param("a") * var("x") + param("b")).sum(shift=[1, 2, 4]) represents a[t+1]x[t+1] + b[t+1] + a[t+2]x[t+2] + b[t+2] + a[t+4]x[t+4] + b[t+4] """ - # if isinstance(self, TimeOperatorNode): - # return TimeAggregatorNode(self, "TimeSum", stay_roll=True) - # else: - # return _apply_if_node( - # self, lambda x: TimeAggregatorNode(x, "TimeSum", stay_roll=False) - # ) - if shift is not None and eval is not None: raise ValueError("Only shift or eval arguments should specified, not both.") @@ -635,7 +666,11 @@ def _apply_operator( result_constant = TimeAggregatorNode( TimeOperatorNode( - self.constant, "TimeShift", InstancesTimeIndex(sum_args.values()[0]) + self.constant, + "TimeShift", + InstancesTimeIndex( + sum_args.popitem()[1] + ), # Dangerous as it modifies sum_args ? ), "TimeSum", stay_roll=stay_roll, @@ -668,12 +703,6 @@ def shift( """ - # The behavior is richer/different than the previous implementation (with linear expr as trees) as we can now apply a shift operator on a whole expression, rather than just on the variables of an expression - - # Example : (param("p") * var("x")).shift(1) - # Previous behavior : p[t]x[t-1] - # New behavior : p[t-1]x[t-1] - if not InstancesTimeIndex(expressions).is_simple(): raise ValueError( "The shift operator can only be applied on expressions refering to a single time step. To apply a shifting sum on multiple time indices on an expression x, you should use x.sum(shift=...)" @@ -682,18 +711,31 @@ def shift( else: return self.sum(shift=expressions) - # def eval( - # self, - # expressions: Union[ - # int, "ExpressionNode", List["ExpressionNode"], "ExpressionRange" - # ], - # ) -> "ExpressionNode": - # return _apply_if_node( - # self, - # lambda x: TimeOperatorNode( - # x, "TimeEvaluation", InstancesTimeIndex(expressions) - # ), - # ) + def eval( + self, + expressions: Union[ + int, + "ExpressionNodeEfficient", + List["ExpressionNodeEfficient"], + "ExpressionRange", + ], + ) -> "LinearExpressionEfficient": + """ + Shorthand for eval on a single time step + + To refer to x[1], it is more natural to write x.eval(1) than x.sum(eval=1). + + This function provides the shorthand x.sum(eval=expr), valid only in the case when expr refers to a single time step. + + """ + + if not InstancesTimeIndex(expressions).is_simple(): + raise ValueError( + "The eval operator can only be applied on expressions refering to a single time step. To apply a evaluation sum on multiple time indices on an expression x, you should use x.sum(eval=...)" + ) + + else: + return self.sum(eval=expressions) # def expec(self) -> "ExpressionNode": # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) @@ -729,7 +771,6 @@ def __init__( lower_bound: LinearExpressionEfficient, upper_bound: LinearExpressionEfficient, ) -> None: - for bound in [lower_bound, upper_bound]: if bound is not None and not bound.is_constant(): raise ValueError( diff --git a/tests/functional/test_performance_efficient.py b/tests/functional/test_performance_efficient.py new file mode 100644 index 00000000..03cb0534 --- /dev/null +++ b/tests/functional/test_performance_efficient.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + + +import pytest + +from andromede.expression.evaluate import EvaluationContext +from andromede.expression.linear_expression_efficient import param, var + + +def test_large_number_of_parameters_sum() -> None: + """ + Test performance when the problem involves an expression with a high number of terms. + + This test pass with 476 terms but fails with 477 locally due to recursion depth, and even less terms are possible with Jenkins... + """ + nb_terms = 500 + + parameters_value = {} + for i in range(1, nb_terms): + parameters_value[f"cost_{i}"] = 1 / i + + # Still the recursion depth error with parameters + with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): + expr = sum(param(f"cost_{i}") for i in range(1, nb_terms)) + expr.evaluate(EvaluationContext(parameters=parameters_value)) + + +def test_large_number_of_variables_sum() -> None: + """ + Test performance when the problem involves an expression with a high number of terms. No problem when there is a large number of variables as this is derecusified. + """ + nb_terms = 500 + + variables_value = {} + for i in range(1, nb_terms): + variables_value[f"cost_{i}"] = 1 / i + + expr = sum(var(f"cost_{i}") for i in range(1, nb_terms)) + assert expr.evaluate(EvaluationContext(variables=variables_value)) == sum( + 1 / i for i in range(1, nb_terms) + ) diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index f9b9d527..58841e5e 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -10,6 +10,7 @@ # # This file is part of the Antares project. +import re from dataclasses import dataclass, field from typing import Dict @@ -360,13 +361,30 @@ def get_variable_structure(self, name: str) -> IndexingStructure: def test_shift_on_time_step_list_raises_value_error() -> None: x = var("x") - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=re.escape( + "The shift operator can only be applied on expressions refering to a single time step. To apply a shifting sum on multiple time indices on an expression x, you should use x.sum(shift=...)" + ), + ): _ = x.shift(ExpressionRange(1, 4)) + +def test_eval_on_time_step_list_raises_value_error() -> None: + x = var("x") + with pytest.raises( + ValueError, + match=re.escape( + "The eval operator can only be applied on expressions refering to a single time step. To apply a evaluation sum on multiple time indices on an expression x, you should use x.sum(eval=...)" + ), + ): + _ = x.eval(ExpressionRange(1, 4)) + + def test_shift_on_single_time_step() -> None: x = var("x") expr = x.shift(1) - + provider = StructureProvider() assert expr.compute_indexation(provider) == IndexingStructure(True, True) @@ -374,7 +392,7 @@ def test_shift_on_single_time_step() -> None: def test_shifting_sum() -> None: x = var("x") expr = x.sum(shift=ExpressionRange(1, 4)) - + provider = StructureProvider() assert expr.compute_indexation(provider) == IndexingStructure(True, True) From 2860242aef63f65a87796513407e5d6ad5c67d83 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Fri, 12 Jul 2024 18:36:02 +0200 Subject: [PATCH 10/51] TimeShift hashing --- src/andromede/expression/expression_efficient.py | 7 +++++++ src/andromede/expression/time_operator.py | 6 ------ tests/unittests/expressions/test_expressions_efficient.py | 5 +---- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index c6828b15..8356c225 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -433,6 +433,13 @@ def __init__( else: object.__setattr__(self, "expressions", expressions) + def __hash__(self) -> int: + # Maybe if/else not needed and always using the tuple works ? + if isinstance(self.expressions, list): + return hash(tuple(self.expressions)) + else: + return hash(self.expressions) + def is_simple(self) -> bool: if isinstance(self.expressions, list): return len(self.expressions) == 1 diff --git a/src/andromede/expression/time_operator.py b/src/andromede/expression/time_operator.py index 4d4fc676..fa45b599 100644 --- a/src/andromede/expression/time_operator.py +++ b/src/andromede/expression/time_operator.py @@ -36,12 +36,6 @@ class TimeOperator(ABC): def rolling(cls) -> bool: raise NotImplementedError - # def __post_init__(self) -> None: - # if isinstance(self.time_ids, int): - # object.__setattr__(self, "time_ids", [self.time_ids]) - # elif isinstance(self.time_ids, range): - # object.__setattr__(self, "time_ids", list(self.time_ids)) - def key(self) -> Tuple[int, ...]: return self.time_ids diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index 58841e5e..c4cdaa91 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -399,11 +399,10 @@ def test_shifting_sum() -> None: def test_eval() -> None: x = var("x") - expr = x.eval(ExpressionRange(1, 4)) + expr = x.eval(1) provider = StructureProvider() assert expr.compute_indexation(provider) == IndexingStructure(False, True) - assert expr.instances == Instances.MULTIPLE def test_eval_sum() -> None: @@ -412,7 +411,6 @@ def test_eval_sum() -> None: provider = StructureProvider() assert expr.compute_indexation(provider) == IndexingStructure(False, True) - assert expr.instances == Instances.SIMPLE def test_sum_over_whole_block() -> None: @@ -421,7 +419,6 @@ def test_sum_over_whole_block() -> None: provider = StructureProvider() assert expr.compute_indexation(provider) == IndexingStructure(False, True) - assert expr.instances == Instances.SIMPLE def test_forbidden_composition_should_raise_value_error() -> None: From b640b22f8f092d2fd4119b990e653b0b02b8be01 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Mon, 15 Jul 2024 18:46:14 +0200 Subject: [PATCH 11/51] Implement shift, eval and time sum of linear expressions --- .../expression/expression_efficient.py | 76 +++--- .../expression/linear_expression_efficient.py | 128 ++++++--- .../expressions/test_expressions_efficient.py | 242 ++++++++++++++---- 3 files changed, 325 insertions(+), 121 deletions(-) diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index 8356c225..a8f76ac9 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -26,9 +26,9 @@ EPS = 10 ** (-16) -class Instances(enum.Enum): - SIMPLE = "SIMPLE" - MULTIPLE = "MULTIPLE" +# class Instances(enum.Enum): +# SIMPLE = "SIMPLE" +# MULTIPLE = "MULTIPLE" @dataclass(frozen=True) @@ -43,7 +43,7 @@ class ExpressionNodeEfficient: >>> expr = -var('x') + 5 / param('p') """ - instances: Instances = field(init=False, default=Instances.SIMPLE) + # instances: Instances = field(init=False, default=Instances.SIMPLE) def __neg__(self) -> "ExpressionNodeEfficient": return _negate_node(self) @@ -286,8 +286,8 @@ class LiteralNode(ExpressionNodeEfficient): class UnaryOperatorNode(ExpressionNodeEfficient): operand: ExpressionNodeEfficient - def __post_init__(self) -> None: - object.__setattr__(self, "instances", self.operand.instances) + # def __post_init__(self) -> None: + # object.__setattr__(self, "instances", self.operand.instances) @dataclass(frozen=True, eq=False) @@ -318,17 +318,17 @@ class BinaryOperatorNode(ExpressionNodeEfficient): left: ExpressionNodeEfficient right: ExpressionNodeEfficient - def __post_init__(self) -> None: - binary_operator_post_init(self, "apply binary operation with") + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "apply binary operation with") -def binary_operator_post_init(node: BinaryOperatorNode, operation: str) -> None: - if node.left.instances != node.right.instances: - raise ValueError( - f"Cannot {operation} {node.left} and {node.right} as they do not have the same number of instances." - ) - else: - object.__setattr__(node, "instances", node.left.instances) +# def binary_operator_post_init(node: BinaryOperatorNode, operation: str) -> None: +# if node.left.instances != node.right.instances: +# raise ValueError( +# f"Cannot {operation} {node.left} and {node.right} as they do not have the same number of instances." +# ) +# else: +# object.__setattr__(node, "instances", node.left.instances) class Comparator(enum.Enum): @@ -341,32 +341,36 @@ class Comparator(enum.Enum): class ComparisonNode(BinaryOperatorNode): comparator: Comparator - def __post_init__(self) -> None: - binary_operator_post_init(self, "compare") + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "compare") @dataclass(frozen=True, eq=False) class AdditionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "add") + pass + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "add") @dataclass(frozen=True, eq=False) class SubstractionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "substract") + pass + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "substract") @dataclass(frozen=True, eq=False) class MultiplicationNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "multiply") + pass + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "multiply") @dataclass(frozen=True, eq=False) class DivisionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "divide") + pass + # def __post_init__(self) -> None: + # binary_operator_post_init(self, "divide") @dataclass(frozen=True, eq=False) @@ -465,15 +469,15 @@ def __post_init__(self) -> None: raise ValueError( f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" ) - if self.operand.instances == Instances.SIMPLE: - if self.instances_index.is_simple(): - object.__setattr__(self, "instances", Instances.SIMPLE) - else: - object.__setattr__(self, "instances", Instances.MULTIPLE) - else: - raise ValueError( - "Cannot apply time operator on an expression that already represents multiple instances" - ) + # if self.operand.instances == Instances.SIMPLE: + # if self.instances_index.is_simple(): + # object.__setattr__(self, "instances", Instances.SIMPLE) + # else: + # object.__setattr__(self, "instances", Instances.MULTIPLE) + # else: + # raise ValueError( + # "Cannot apply time operator on an expression that already represents multiple instances" + # ) @dataclass(frozen=True, eq=False) @@ -493,7 +497,7 @@ def __post_init__(self) -> None: raise ValueError( f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" ) - object.__setattr__(self, "instances", Instances.SIMPLE) + # object.__setattr__(self, "instances", Instances.SIMPLE) @dataclass(frozen=True, eq=False) @@ -512,7 +516,7 @@ def __post_init__(self) -> None: raise ValueError( f"{self.name} is not a valid scenario operator, valid scenario operators are {valid_names}" ) - object.__setattr__(self, "instances", Instances.SIMPLE) + # object.__setattr__(self, "instances", Instances.SIMPLE) def sum_expressions( diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 8deff23a..f7e10267 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -27,6 +27,7 @@ InstancesTimeIndex, LiteralNode, ParameterNode, + ScenarioOperatorNode, TimeAggregatorNode, TimeOperatorNode, is_minus_one, @@ -37,7 +38,7 @@ from andromede.expression.indexing import IndexingStructureProvider, compute_indexation from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.print import print_expr -from andromede.expression.scenario_operator import ScenarioOperator +from andromede.expression.scenario_operator import Expectation, ScenarioOperator from andromede.expression.time_operator import ( TimeAggregator, TimeEvaluation, @@ -135,22 +136,39 @@ def evaluate(self, context: ValueProvider) -> float: def compute_indexation( self, provider: IndexingStructureProvider ) -> IndexingStructure: - # TODO: Improve this if/else structure - if self.component_id: - time = ( - provider.get_component_variable_structure(self.variable_name).time - == True - ) - scenario = ( - provider.get_component_variable_structure(self.variable_name).scenario - == True - ) + + return IndexingStructure( + self._compute_time_indexing(provider), + self._compute_scenario_indexing(provider), + ) + + def _compute_time_indexing(self, provider: IndexingStructureProvider) -> bool: + if (self.time_aggregator and not self.time_aggregator.stay_roll) or ( + self.time_operator and not self.time_operator.rolling() + ): + time = False else: - time = provider.get_variable_structure(self.variable_name).time == True - scenario = ( - provider.get_variable_structure(self.variable_name).scenario == True - ) - return IndexingStructure(time, scenario) + if self.component_id: + time = provider.get_component_variable_structure( + self.component_id, self.variable_name + ).time + else: + time = provider.get_variable_structure(self.variable_name).time + return time + + def _compute_scenario_indexing(self, provider: IndexingStructureProvider) -> bool: + if self.scenario_operator: + scenario = False + else: + # TODO: Improve this if/else structure, probably simplify IndexingStructureProvider + if self.component_id: + scenario = provider.get_component_variable_structure( + self.component_id, self.variable_name + ).scenario + + else: + scenario = provider.get_variable_structure(self.variable_name).scenario + return scenario def sum( self, @@ -204,7 +222,7 @@ def shift( List["ExpressionNodeEfficient"], "ExpressionRange", ], - ) -> "LinearExpressionEfficient": + ) -> "TermEfficient": """ Shorthand for shift on a single time step @@ -236,7 +254,7 @@ def eval( List["ExpressionNodeEfficient"], "ExpressionRange", ], - ) -> "LinearExpressionEfficient": + ) -> "TermEfficient": """ Shorthand for eval on a single time step @@ -260,6 +278,10 @@ def eval( else: return self.sum(eval=expressions) + def expec(self) -> "TermEfficient": + # TODO: Do we need checks, in case a scenario operator is already specified ? + return dataclasses.replace(self, scenario_operator=Expectation()) + def generate_key(term: TermEfficient) -> TermKeyEfficient: return TermKeyEfficient( @@ -595,7 +617,12 @@ def is_constant(self) -> bool: def compute_indexation( self, provider: IndexingStructureProvider ) -> IndexingStructure: - indexing = compute_indexation(self.constant, provider) + """ + Computes the (time, scenario) indexing of a linear expression. + + Time and scenario indexation is driven by the indexation of variables in the expression. If a single term is indexed by time (resp. scenario), then the linear expression is indexed by time (resp. scenario). + """ + indexing = IndexingStructure(False, False) for term in self.terms.values(): indexing = indexing | term.compute_indexation(provider) @@ -635,15 +662,40 @@ def sum( if shift is not None: sum_args = {"shift": shift} - stay_roll = True + + result_constant = TimeAggregatorNode( + TimeOperatorNode( + self.constant, + "TimeShift", + InstancesTimeIndex(shift), + ), + "TimeSum", + stay_roll=True, + ) elif eval is not None: sum_args = {"eval": eval} - stay_roll = True + + result_constant = TimeAggregatorNode( + TimeOperatorNode( + self.constant, + "TimeEvaluation", + InstancesTimeIndex(eval), + ), + "TimeSum", + stay_roll=True, + ) else: # x.sum() -> Sum over all time block sum_args = {} - stay_roll = False - return self._apply_operator(sum_args, stay_roll) + result_constant = TimeAggregatorNode( + self.constant, + "TimeSum", + stay_roll=False, + ) + + return LinearExpressionEfficient( + self._apply_operator(sum_args), result_constant + ) def _apply_operator( self, @@ -657,26 +709,13 @@ def _apply_operator( None, ], ], - stay_roll: bool, ): result_terms = {} for term in self.terms.values(): term_with_operator = term.sum(**sum_args) result_terms[generate_key(term_with_operator)] = term_with_operator - result_constant = TimeAggregatorNode( - TimeOperatorNode( - self.constant, - "TimeShift", - InstancesTimeIndex( - sum_args.popitem()[1] - ), # Dangerous as it modifies sum_args ? - ), - "TimeSum", - stay_roll=stay_roll, - ) - result_expr = LinearExpressionEfficient(result_terms, result_constant) - return result_expr + return result_terms # def sum_connections(self) -> "ExpressionNode": # if isinstance(self, PortFieldNode): @@ -737,8 +776,19 @@ def eval( else: return self.sum(eval=expressions) - # def expec(self) -> "ExpressionNode": - # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) + def expec(self) -> "LinearExpressionEfficient": + """ + Expectation of linear expression. As the operator is linear, it distributes over all terms and the constant + """ + + result_terms = {} + for term in self.terms.values(): + term_with_operator = term.expec() + result_terms[generate_key(term_with_operator)] = term_with_operator + + result_constant = ScenarioOperatorNode(self.constant, "Expectation") + result_expr = LinearExpressionEfficient(result_terms, result_constant) + return result_expr # def variance(self) -> "ExpressionNode": # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index c4cdaa91..8705d342 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -16,12 +16,18 @@ import pytest +from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import EvaluationContext, ValueProvider from andromede.expression.evaluate_parameters import ParameterValueProvider from andromede.expression.expression_efficient import ( ComponentParameterNode, + ExpressionNodeEfficient, ExpressionRange, + InstancesTimeIndex, + LiteralNode, ParameterNode, + TimeAggregatorNode, + TimeOperatorNode, ) from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure @@ -29,12 +35,14 @@ LinearExpressionEfficient, StandaloneConstraint, TermEfficient, + TermKeyEfficient, comp_param, comp_var, literal, param, var, ) +from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum from andromede.model.constraint import Constraint from andromede.simulation.linearize import linearize_expression @@ -341,6 +349,149 @@ def test_comparison() -> None: assert str(expr_eq) == "0 <= 5.0x + (3.0 - (p - 2.0)) <= 0" +# TODO: Maybe imagine other use cases, that should be forbidden (composition of operators...) +@pytest.mark.parametrize( + "expr, expec_terms, expec_constant", + [ + ( + (var("x") + var("y") + literal(1)).shift(1), + { + TermKeyEfficient( + "", + "x", + TimeShift(InstancesTimeIndex(1)), + time_aggregator=TimeSum( + stay_roll=True + ), # The internal representation of shift(1) is sum(shift=1) + scenario_operator=None, + ): TermEfficient( + TimeOperatorNode( + LiteralNode(1), "TimeShift", InstancesTimeIndex(1) + ), + "", + "x", + time_operator=TimeShift( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + ), + TermKeyEfficient( + "", + "y", + TimeShift( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + scenario_operator=None, + ): TermEfficient( + TimeOperatorNode( + LiteralNode(1), "TimeShift", InstancesTimeIndex(1) + ), + "", + "y", + time_operator=TimeShift(InstancesTimeIndex(1)), + time_aggregator=TimeSum(stay_roll=True), + ), + }, + TimeAggregatorNode( + TimeOperatorNode(LiteralNode(1), "TimeShift", InstancesTimeIndex(1)), + "TimeSum", + stay_roll=True, + ), # TODO: Could it be simplified online ? + ), + ( + (var("x") + var("y") + literal(1)).eval(1), + { + TermKeyEfficient( + "", + "x", + TimeEvaluation(InstancesTimeIndex(1)), + time_aggregator=TimeSum( + stay_roll=True + ), # The internal representation of eval(1) is sum(eval=1) + scenario_operator=None, + ): TermEfficient( + TimeOperatorNode( + LiteralNode(1), "TimeEvaluation", InstancesTimeIndex(1) + ), + "", + "x", + time_operator=TimeEvaluation( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + ), + TermKeyEfficient( + "", + "y", + TimeEvaluation( + InstancesTimeIndex(1), + ), + time_aggregator=TimeSum(stay_roll=True), + scenario_operator=None, + ): TermEfficient( + TimeOperatorNode( + LiteralNode(1), "TimeEvaluation", InstancesTimeIndex(1) + ), + "", + "y", + time_operator=TimeEvaluation(InstancesTimeIndex(1)), + time_aggregator=TimeSum(stay_roll=True), + ), + }, + TimeAggregatorNode( + TimeOperatorNode( + LiteralNode(1), "TimeEvaluation", InstancesTimeIndex(1) + ), + "TimeSum", + stay_roll=True, + ), # TODO: Could it be simplified online ? + ), + ( + (var("x") + var("y") + literal(1)).sum(), + { + TermKeyEfficient( + "", + "x", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + scenario_operator=None, + ): TermEfficient( + LiteralNode(1), # Sum is not distributed to coeff + "", + "x", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + ), + TermKeyEfficient( + "", + "y", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + scenario_operator=None, + ): TermEfficient( + LiteralNode(1), # Sum is not distributed to coeff + "", + "y", + time_operator=None, + time_aggregator=TimeSum(stay_roll=False), + ), + }, + TimeAggregatorNode( + LiteralNode(1), "TimeSum", stay_roll=False + ), # TODO: Could it be simplified online ? + ), + ], +) +def test_operators_are_correctly_distributed_over_terms( + expr: LinearExpressionEfficient, + expec_terms: Dict[TermKeyEfficient, TermEfficient], + expec_constant: ExpressionNodeEfficient, +) -> None: + assert expr.terms == expec_terms + assert expressions_equal(expr.constant, expec_constant) + + class StructureProvider(IndexingStructureProvider): def get_component_variable_structure( self, component_id: str, name: str @@ -381,44 +532,52 @@ def test_eval_on_time_step_list_raises_value_error() -> None: _ = x.eval(ExpressionRange(1, 4)) -def test_shift_on_single_time_step() -> None: - x = var("x") - expr = x.shift(1) - - provider = StructureProvider() - assert expr.compute_indexation(provider) == IndexingStructure(True, True) - - -def test_shifting_sum() -> None: - x = var("x") - expr = x.sum(shift=ExpressionRange(1, 4)) - - provider = StructureProvider() - assert expr.compute_indexation(provider) == IndexingStructure(True, True) - - -def test_eval() -> None: - x = var("x") - expr = x.eval(1) - provider = StructureProvider() - - assert expr.compute_indexation(provider) == IndexingStructure(False, True) - - -def test_eval_sum() -> None: - x = var("x") - expr = x.eval(ExpressionRange(1, 4)).sum() - provider = StructureProvider() - - assert expr.compute_indexation(provider) == IndexingStructure(False, True) - - -def test_sum_over_whole_block() -> None: - x = var("x") - expr = x.sum() +@pytest.mark.parametrize( + "linear_expr, expected_indexation", + [ + ( + var("x").shift(1), + IndexingStructure(True, True), + ), + ( + var("x").sum(shift=ExpressionRange(1, 4)), + IndexingStructure(True, True), + ), + ( + var("x").eval(1), + IndexingStructure(False, True), + ), + ( + var("x").sum(eval=ExpressionRange(1, 4)), + IndexingStructure(False, True), + ), + ( + var("x").sum(), + IndexingStructure(False, True), + ), + ( + var("x").expec(), + IndexingStructure(True, False), + ), + ( + var("x").sum().expec(), + IndexingStructure(False, False), + ), + ( + var("x").shift(1).expec(), + IndexingStructure(True, False), + ), + ( + var("x").eval(1).expec(), + IndexingStructure(False, False), + ), + ], +) +def test_compute_indexation( + linear_expr: LinearExpressionEfficient, expected_indexation: IndexingStructure +) -> None: provider = StructureProvider() - - assert expr.compute_indexation(provider) == IndexingStructure(False, True) + assert linear_expr.compute_indexation(provider) == expected_indexation def test_forbidden_composition_should_raise_value_error() -> None: @@ -427,15 +586,6 @@ def test_forbidden_composition_should_raise_value_error() -> None: _ = x.shift(ExpressionRange(1, 4)) + var("y") -def test_expectation() -> None: - x = var("x") - expr = x.expec() - provider = StructureProvider() - - assert expr.compute_indexation(provider) == IndexingStructure(True, False) - assert expr.instances == Instances.SIMPLE - - def test_indexing_structure_comparison() -> None: free = IndexingStructure(True, True) constant = IndexingStructure(False, False) From 5f2823fdb264b4f0a6715a56acf79d3ae4a01d40 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Mon, 15 Jul 2024 20:07:27 +0200 Subject: [PATCH 12/51] Test sum of linear expressions --- .../expressions/test_expressions_efficient.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index 8705d342..c79a5b1c 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -38,6 +38,7 @@ TermKeyEfficient, comp_param, comp_var, + linear_expressions_equal, literal, param, var, @@ -620,9 +621,10 @@ def get_variable_structure(self, name: str) -> IndexingStructure: def test_sum_expressions() -> None: - assert sum_expressions([]) == literal(0) - assert sum_expressions([literal(1)]) == literal(1) - assert sum_expressions([literal(1), var("x")]) == 1 + var("x") - assert sum_expressions([literal(1), var("x"), param("p")]) == 1 + ( - var("x") + param("p") + + + assert linear_expressions_equal(sum([literal(1)]), literal(1)) + assert linear_expressions_equal(sum([literal(1), var("x")]), 1 + var("x")) + assert linear_expressions_equal( + sum([literal(1), var("x"), param("p")]), (1 + var("x")) + param("p") ) From e90254e5f7bed240187a7784c6b144a31f1a96e5 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 16 Jul 2024 09:14:35 +0200 Subject: [PATCH 13/51] Fix printing term tests --- .../expressions/test_expressions_efficient.py | 18 +++++++++--------- .../expressions/test_linear_expressions.py | 2 +- .../expressions/test_term_efficient.py | 14 +++++++------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index c79a5b1c..26f4ed0e 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -126,16 +126,16 @@ def test_operators() -> None: assert -expr.evaluate(context) == pytest.approx(-2.5, 1e-16) -def test_degree() -> None: - x = var("x") - p = param("p") - expr = (5 * x + 3) / p +# def test_degree() -> None: +# x = var("x") +# p = param("p") +# expr = (5 * x + 3) / p - assert expr.compute_degree() == 1 +# assert expr.compute_degree() == 1 - # TODO: Should this be allowed ? If so, how should we represent is ? - expr = x * expr - assert expr.compute_degree() == 2 +# # TODO: Should this be allowed ? If so, how should we represent is ? +# expr = x * expr +# assert expr.compute_degree() == 2 def test_degree_computation_should_take_into_account_simplifications() -> None: @@ -622,7 +622,7 @@ def get_variable_structure(self, name: str) -> IndexingStructure: def test_sum_expressions() -> None: - + # TODO: Sum of an empty list ? How to return a null LinearExpression object if the list is supposed to contain LinearExpression objects ? assert linear_expressions_equal(sum([literal(1)]), literal(1)) assert linear_expressions_equal(sum([literal(1), var("x")]), 1 + var("x")) assert linear_expressions_equal( diff --git a/tests/unittests/expressions/test_linear_expressions.py b/tests/unittests/expressions/test_linear_expressions.py index 54a11e95..d9225b95 100644 --- a/tests/unittests/expressions/test_linear_expressions.py +++ b/tests/unittests/expressions/test_linear_expressions.py @@ -26,7 +26,7 @@ (Term(-1, "c", "x"), "-x"), (Term(2.50, "c", "x"), "+2.5x"), (Term(-3, "c", "x"), "-3x"), - (Term(-3, "c", "x", time_operator=TimeShift(-1)), "-3x.shift([-1])"), + (Term(-3, "c", "x", time_operator=TimeShift(-1)), "-3x.shift(-1)"), (Term(-3, "c", "x", time_aggregator=TimeSum(True)), "-3x.sum(True)"), ( Term( diff --git a/tests/unittests/expressions/test_term_efficient.py b/tests/unittests/expressions/test_term_efficient.py index aaae5b9f..d30f9310 100644 --- a/tests/unittests/expressions/test_term_efficient.py +++ b/tests/unittests/expressions/test_term_efficient.py @@ -23,10 +23,10 @@ [ (TermEfficient(1, "c", "x"), "+x"), (TermEfficient(-1, "c", "x"), "-x"), - (TermEfficient(2.50, "c", "x"), "+2.5x"), - (TermEfficient(-3, "c", "x"), "-3x"), - (TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), "-3x.shift([-1])"), - (TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), "-3x.sum(True)"), + (TermEfficient(2.50, "c", "x"), "2.5x"), + (TermEfficient(-3, "c", "x"), "-3.0x"), + (TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), "-3.0x.shift(-1)"), + (TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), "-3.0x.sum(True)"), ( TermEfficient( -3, @@ -35,9 +35,9 @@ time_operator=TimeShift([2, 3]), time_aggregator=TimeSum(False), ), - "-3x.shift([2, 3]).sum(False)", + "-3.0x.shift([2, 3]).sum(False)", ), - (TermEfficient(-3, "c", "x", scenario_operator=Expectation()), "-3x.expec()"), + (TermEfficient(-3, "c", "x", scenario_operator=Expectation()), "-3.0x.expec()"), ( TermEfficient( -3, @@ -46,7 +46,7 @@ time_aggregator=TimeSum(True), scenario_operator=Expectation(), ), - "-3x.sum(True).expec()", + "-3.0x.sum(True).expec()", ), ], ) From c5214ebdef4b543020753e525e83a4eb069ae39c Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 16 Jul 2024 18:43:24 +0200 Subject: [PATCH 14/51] More online simplifications, start handling constraints and ports --- .../expression/expression_efficient.py | 340 ++++++++++++++---- .../expression/linear_expression_efficient.py | 70 ++-- src/andromede/expression/port_operator.py | 4 +- src/andromede/model/constraint.py | 56 ++- src/andromede/model/model.py | 3 + .../functional/test_performance_efficient.py | 28 +- .../expressions/test_expressions_efficient.py | 57 ++- tests/unittests/test_model.py | 91 +++-- tests/unittests/test_port.py | 2 +- 9 files changed, 488 insertions(+), 163 deletions(-) diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index a8f76ac9..4d26331f 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -16,8 +16,8 @@ import enum import inspect import math -from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Sequence, Union +from dataclasses import dataclass +from typing import Any, Callable, List, Optional, Union import andromede.expression.port_operator import andromede.expression.scenario_operator @@ -26,11 +26,6 @@ EPS = 10 ** (-16) -# class Instances(enum.Enum): -# SIMPLE = "SIMPLE" -# MULTIPLE = "MULTIPLE" - - @dataclass(frozen=True) class ExpressionNodeEfficient: """ @@ -43,8 +38,6 @@ class ExpressionNodeEfficient: >>> expr = -var('x') + 5 / param('p') """ - # instances: Instances = field(init=False, default=Instances.SIMPLE) - def __neg__(self) -> "ExpressionNodeEfficient": return _negate_node(self) @@ -172,6 +165,8 @@ def is_minus_one(node: ExpressionNodeEfficient) -> bool: def _negate_node(node: ExpressionNodeEfficient): if isinstance(node, LiteralNode): return LiteralNode(-node.value) + elif isinstance(node, NegationNode): + return node.operand else: return NegationNode(node) @@ -184,15 +179,60 @@ def _add_node( if is_zero(rhs): return lhs # TODO: How can we use the equality visitor here (simple import -> circular import), copy code here ? - # if expressions_equal(lhs, -rhs): - # return LiteralNode(0) + if expressions_equal(lhs, -rhs): + return LiteralNode(0) if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): return LiteralNode(lhs.value + rhs.value) - # TODO : Si noeuds gauche droite même clé de param -> 2 * param + if _are_parameter_nodes_equal(lhs, rhs): + return MultiplicationNode(LiteralNode(2), lhs) + if (lhs_is_param := isinstance(lhs, ParameterNode)) or ( + rhs_is_param := isinstance(rhs, ParameterNode) + ): + if lhs_is_param: + param_node = lhs + other = rhs + elif rhs_is_param: + param_node = rhs + other = lhs + + if isinstance(other, MultiplicationNode): + if _are_parameter_nodes_equal(param_node, other.left): + return MultiplicationNode( + _add_node(LiteralNode(1), other.right), param_node + ) + elif _are_parameter_nodes_equal(param_node, other.right): + return MultiplicationNode( + _add_node(LiteralNode(1), other.left), param_node + ) + + if isinstance(lhs, MultiplicationNode) and isinstance(rhs, MultiplicationNode): + if _are_parameter_nodes_equal(lhs.left, rhs.left): + return MultiplicationNode(_add_node(lhs.right, rhs.right), lhs.left) + elif _are_parameter_nodes_equal(lhs.left, rhs.right): + return MultiplicationNode(_add_node(lhs.right, rhs.left), lhs.left) + elif _are_parameter_nodes_equal(lhs.right, rhs.left): + return MultiplicationNode(_add_node(lhs.left, rhs.right), lhs.right) + elif _are_parameter_nodes_equal(lhs.right, rhs.right): + return MultiplicationNode(_add_node(lhs.left, rhs.left), lhs.right) else: return AdditionNode(lhs, rhs) +# Better if we could use equality visitor +def _are_parameter_nodes_equal( + lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient +) -> bool: + return ( + isinstance(lhs, ParameterNode) + and isinstance(rhs, ParameterNode) + and lhs.name == rhs.name + ) + + +# def _is_parameter_multiplication(node: ExpressionNodeEfficient, name: str): +# return isinstance(node, MultiplicationNode) and ((isinstance(node.left, ParameterNode) and node.left.name == name) or + + def _substract_node( lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient ) -> ExpressionNodeEfficient: @@ -201,11 +241,51 @@ def _substract_node( if is_zero(rhs): return lhs # TODO: How can we use the equality visitor here (simple import -> circular import), copy code here ? - # if expressions_equal(lhs, rhs): - # return LiteralNode(0) + if expressions_equal(lhs, rhs): + return LiteralNode(0) if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): return LiteralNode(lhs.value - rhs.value) - # TODO : Si noeuds gauche droite même clé de param -> 0 + if _are_parameter_nodes_equal(lhs, -rhs): + return MultiplicationNode(LiteralNode(2), lhs) + if (lhs_is_param := isinstance(lhs, ParameterNode)) or ( + rhs_is_param := isinstance(rhs, ParameterNode) + ): + if lhs_is_param: + param_node = lhs + other = rhs + elif rhs_is_param: + param_node = rhs + other = lhs + + if isinstance(other, MultiplicationNode): + if _are_parameter_nodes_equal(param_node, other.left): + if lhs_is_param: + return MultiplicationNode( + _substract_node(LiteralNode(1), other.right), param_node + ) + elif rhs_is_param: + return MultiplicationNode( + _substract_node(other.right, LiteralNode(1)), param_node + ) + elif _are_parameter_nodes_equal(param_node, other.right): + if lhs_is_param: + return MultiplicationNode( + _substract_node(LiteralNode(1), other.left), param_node + ) + elif rhs_is_param: + return MultiplicationNode( + _substract_node(other.left, LiteralNode(1)), param_node + ) + + if isinstance(lhs, MultiplicationNode) and isinstance(rhs, MultiplicationNode): + if _are_parameter_nodes_equal(lhs.left, rhs.left): + return MultiplicationNode(_substract_node(lhs.right, rhs.right), lhs.left) + elif _are_parameter_nodes_equal(lhs.left, rhs.right): + return MultiplicationNode(_substract_node(lhs.right, rhs.left), lhs.left) + elif _are_parameter_nodes_equal(lhs.right, rhs.left): + return MultiplicationNode(_substract_node(lhs.left, rhs.right), lhs.right) + elif _are_parameter_nodes_equal(lhs.right, rhs.right): + return MultiplicationNode(_substract_node(lhs.left, rhs.left), lhs.right) else: return SubstractionNode(lhs, rhs) @@ -254,10 +334,6 @@ class PortFieldNode(ExpressionNodeEfficient): field_name: str -def port_field(port_name: str, field_name: str) -> PortFieldNode: - return PortFieldNode(port_name, field_name) - - @dataclass(frozen=True, eq=False) class ParameterNode(ExpressionNodeEfficient): name: str @@ -282,13 +358,14 @@ class LiteralNode(ExpressionNodeEfficient): value: float +def is_unbound(expr: ExpressionNodeEfficient) -> bool: + return isinstance(expr, LiteralNode) and (abs(expr.value) == float("inf")) + + @dataclass(frozen=True, eq=False) class UnaryOperatorNode(ExpressionNodeEfficient): operand: ExpressionNodeEfficient - # def __post_init__(self) -> None: - # object.__setattr__(self, "instances", self.operand.instances) - @dataclass(frozen=True, eq=False) class PortFieldAggregatorNode(UnaryOperatorNode): @@ -318,18 +395,6 @@ class BinaryOperatorNode(ExpressionNodeEfficient): left: ExpressionNodeEfficient right: ExpressionNodeEfficient - # def __post_init__(self) -> None: - # binary_operator_post_init(self, "apply binary operation with") - - -# def binary_operator_post_init(node: BinaryOperatorNode, operation: str) -> None: -# if node.left.instances != node.right.instances: -# raise ValueError( -# f"Cannot {operation} {node.left} and {node.right} as they do not have the same number of instances." -# ) -# else: -# object.__setattr__(node, "instances", node.left.instances) - class Comparator(enum.Enum): LESS_THAN = "LESS_THAN" @@ -341,36 +406,25 @@ class Comparator(enum.Enum): class ComparisonNode(BinaryOperatorNode): comparator: Comparator - # def __post_init__(self) -> None: - # binary_operator_post_init(self, "compare") - @dataclass(frozen=True, eq=False) class AdditionNode(BinaryOperatorNode): pass - # def __post_init__(self) -> None: - # binary_operator_post_init(self, "add") @dataclass(frozen=True, eq=False) class SubstractionNode(BinaryOperatorNode): pass - # def __post_init__(self) -> None: - # binary_operator_post_init(self, "substract") @dataclass(frozen=True, eq=False) class MultiplicationNode(BinaryOperatorNode): pass - # def __post_init__(self) -> None: - # binary_operator_post_init(self, "multiply") @dataclass(frozen=True, eq=False) class DivisionNode(BinaryOperatorNode): pass - # def __post_init__(self) -> None: - # binary_operator_post_init(self, "divide") @dataclass(frozen=True, eq=False) @@ -469,15 +523,6 @@ def __post_init__(self) -> None: raise ValueError( f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" ) - # if self.operand.instances == Instances.SIMPLE: - # if self.instances_index.is_simple(): - # object.__setattr__(self, "instances", Instances.SIMPLE) - # else: - # object.__setattr__(self, "instances", Instances.MULTIPLE) - # else: - # raise ValueError( - # "Cannot apply time operator on an expression that already represents multiple instances" - # ) @dataclass(frozen=True, eq=False) @@ -497,7 +542,6 @@ def __post_init__(self) -> None: raise ValueError( f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" ) - # object.__setattr__(self, "instances", Instances.SIMPLE) @dataclass(frozen=True, eq=False) @@ -516,14 +560,182 @@ def __post_init__(self) -> None: raise ValueError( f"{self.name} is not a valid scenario operator, valid scenario operators are {valid_names}" ) - # object.__setattr__(self, "instances", Instances.SIMPLE) -def sum_expressions( - expressions: Sequence[ExpressionNodeEfficient], -) -> ExpressionNodeEfficient: - if len(expressions) == 0: - return LiteralNode(0) - if len(expressions) == 1: - return expressions[0] - return expressions[0] + sum_expressions(expressions[1:]) +@dataclass(frozen=True) +class EqualityVisitor: + abs_tol: float = 0 + rel_tol: float = 0 + + def __post_init__(self) -> None: + if self.abs_tol < 0: + raise ValueError( + f"Absolute comparison tolerance must be >= 0, got {self.abs_tol}" + ) + if self.rel_tol < 0: + raise ValueError( + f"Relative comparison tolerance must be >= 0, got {self.rel_tol}" + ) + + def visit( + self, left: ExpressionNodeEfficient, right: ExpressionNodeEfficient + ) -> bool: + if left.__class__ != right.__class__: + return False + if isinstance(left, LiteralNode) and isinstance(right, LiteralNode): + return self.literal(left, right) + if isinstance(left, NegationNode) and isinstance(right, NegationNode): + return self.negation(left, right) + if isinstance(left, AdditionNode) and isinstance(right, AdditionNode): + return self.addition(left, right) + if isinstance(left, SubstractionNode) and isinstance(right, SubstractionNode): + return self.substraction(left, right) + if isinstance(left, DivisionNode) and isinstance(right, DivisionNode): + return self.division(left, right) + if isinstance(left, MultiplicationNode) and isinstance( + right, MultiplicationNode + ): + return self.multiplication(left, right) + if isinstance(left, ComparisonNode) and isinstance(right, ComparisonNode): + return self.comparison(left, right) + # if isinstance(left, VariableNode) and isinstance(right, VariableNode): + # return self.variable(left, right) + if isinstance(left, ParameterNode) and isinstance(right, ParameterNode): + return self.parameter(left, right) + if isinstance(left, ComponentParameterNode) and isinstance( + right, ComponentParameterNode + ): + return self.comp_parameter(left, right) + if isinstance(left, TimeOperatorNode) and isinstance(right, TimeOperatorNode): + return self.time_operator(left, right) + if isinstance(left, TimeAggregatorNode) and isinstance( + right, TimeAggregatorNode + ): + return self.time_aggregator(left, right) + if isinstance(left, ScenarioOperatorNode) and isinstance( + right, ScenarioOperatorNode + ): + return self.scenario_operator(left, right) + if isinstance(left, PortFieldNode) and isinstance(right, PortFieldNode): + return self.port_field(left, right) + if isinstance(left, PortFieldAggregatorNode) and isinstance( + right, PortFieldAggregatorNode + ): + return self.port_field_aggregator(left, right) + raise NotImplementedError(f"Equality not implemented for {left.__class__}") + + def literal(self, left: LiteralNode, right: LiteralNode) -> bool: + return math.isclose( + left.value, right.value, abs_tol=self.abs_tol, rel_tol=self.rel_tol + ) + + def _visit_operands( + self, left: BinaryOperatorNode, right: BinaryOperatorNode + ) -> bool: + return self.visit(left.left, right.left) and self.visit(left.right, right.right) + + def negation(self, left: NegationNode, right: NegationNode) -> bool: + return self.visit(left.operand, right.operand) + + def addition(self, left: AdditionNode, right: AdditionNode) -> bool: + # TODO: Commutativty ??? Cannot detect that a+b == b+a + return self._visit_operands(left, right) + + def substraction(self, left: SubstractionNode, right: SubstractionNode) -> bool: + return self._visit_operands(left, right) + + def multiplication( + self, left: MultiplicationNode, right: MultiplicationNode + ) -> bool: + return self._visit_operands(left, right) + + def division(self, left: DivisionNode, right: DivisionNode) -> bool: + return self._visit_operands(left, right) + + def comparison(self, left: ComparisonNode, right: ComparisonNode) -> bool: + return left.comparator == right.comparator and self._visit_operands(left, right) + + # def variable(self, left: VariableNode, right: VariableNode) -> bool: + # return left.name == right.name + + def parameter(self, left: ParameterNode, right: ParameterNode) -> bool: + return left.name == right.name + + def comp_parameter( + self, left: ComponentParameterNode, right: ComponentParameterNode + ) -> bool: + return left.component_id == right.component_id and left.name == right.name + + def expression_range(self, left: ExpressionRange, right: ExpressionRange) -> bool: + if not self.visit(left.start, right.start): + return False + if not self.visit(left.stop, right.stop): + return False + if left.step is not None and right.step is not None: + return self.visit(left.step, right.step) + return left.step is None and right.step is None + + def instances_index(self, lhs: InstancesTimeIndex, rhs: InstancesTimeIndex) -> bool: + if isinstance(lhs.expressions, ExpressionRange) and isinstance( + rhs.expressions, ExpressionRange + ): + return self.expression_range(lhs.expressions, rhs.expressions) + if isinstance(lhs.expressions, list) and isinstance(rhs.expressions, list): + return len(lhs.expressions) == len(rhs.expressions) and all( + self.visit(l, r) for l, r in zip(lhs.expressions, rhs.expressions) + ) + return False + + def time_operator(self, left: TimeOperatorNode, right: TimeOperatorNode) -> bool: + return ( + left.name == right.name + and self.instances_index(left.instances_index, right.instances_index) + and self.visit(left.operand, right.operand) + ) + + def time_aggregator( + self, left: TimeAggregatorNode, right: TimeAggregatorNode + ) -> bool: + return ( + left.name == right.name + and left.stay_roll == right.stay_roll + and self.visit(left.operand, right.operand) + ) + + def scenario_operator( + self, left: ScenarioOperatorNode, right: ScenarioOperatorNode + ) -> bool: + return left.name == right.name and self.visit(left.operand, right.operand) + + def port_field(self, left: PortFieldNode, right: PortFieldNode) -> bool: + return left.port_name == right.port_name and left.field_name == right.field_name + + def port_field_aggregator( + self, left: PortFieldAggregatorNode, right: PortFieldAggregatorNode + ) -> bool: + return left.aggregator == right.aggregator and self.visit( + left.operand, right.operand + ) + + +def expressions_equal( + left: ExpressionNodeEfficient, + right: ExpressionNodeEfficient, + abs_tol: float = 0, + rel_tol: float = 0, +) -> bool: + """ + True if both expression nodes are equal. Literal values may be compared with absolute or relative tolerance. + """ + return EqualityVisitor(abs_tol, rel_tol).visit(left, right) + + +def expressions_equal_if_present( + lhs: Optional[ExpressionNodeEfficient], rhs: Optional[ExpressionNodeEfficient] +) -> bool: + if lhs is None and rhs is None: + return True + elif lhs is None or rhs is None: + return False + else: + return expressions_equal(lhs, rhs) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index f7e10267..c9f0210b 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -16,7 +16,7 @@ """ import dataclasses from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import ValueProvider, evaluate @@ -32,11 +32,13 @@ TimeOperatorNode, is_minus_one, is_one, + is_unbound, is_zero, wrap_in_node, ) from andromede.expression.indexing import IndexingStructureProvider, compute_indexation from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.port_operator import PortAggregator, PortSum from andromede.expression.print import print_expr from andromede.expression.scenario_operator import Expectation, ScenarioOperator from andromede.expression.time_operator import ( @@ -511,6 +513,9 @@ def __sub__( result -= rhs return result + def __rsub__(self, rhs: int) -> "LinearExpressionEfficient": + return -self + rhs + def __neg__(self) -> "LinearExpressionEfficient": result = LinearExpressionEfficient() result -= self @@ -614,6 +619,9 @@ def is_constant(self) -> bool: # Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree... return not self.terms + def is_unbound(self) -> bool: + return is_unbound(self.constant) + def compute_indexation( self, provider: IndexingStructureProvider ) -> IndexingStructure: @@ -709,7 +717,7 @@ def _apply_operator( None, ], ], - ): + ) -> Dict[TermKeyEfficient, TermEfficient]: result_terms = {} for term in self.terms.values(): term_with_operator = term.sum(**sum_args) @@ -805,39 +813,35 @@ def linear_expressions_equal( ) +# TODO: Is this function useful ? Could we just rely on the sum operator overloading ? Only the case with an empty list may make the function useful +def sum_expressions( + expressions: Sequence[LinearExpressionEfficient], +) -> LinearExpressionEfficient: + if len(expressions) == 0: + return literal(0) + else: + return sum(expressions) + + @dataclass class StandaloneConstraint: """ - A standalone constraint, with rugid initialization. + A standalone constraint, with rigid initialization. """ expression: LinearExpressionEfficient lower_bound: LinearExpressionEfficient upper_bound: LinearExpressionEfficient - def __init__( + def __post_init__( self, - expression: LinearExpressionEfficient, - lower_bound: LinearExpressionEfficient, - upper_bound: LinearExpressionEfficient, ) -> None: - for bound in [lower_bound, upper_bound]: - if bound is not None and not bound.is_constant(): + for bound in [self.lower_bound, self.upper_bound]: + if not bound.is_constant(): raise ValueError( f"The bounds of a constraint should not contain variables, {print_expr(bound)} was given." ) - self.expression = expression - if lower_bound is not None: - self.lower_bound = lower_bound - else: - self.lower_bound = literal(-float("inf")) - - if upper_bound is not None: - self.upper_bound = upper_bound - else: - self.upper_bound = literal(float("inf")) - def __str__(self) -> str: return f"{str(self.lower_bound)} <= {str(self.expression)} <= {str(self.upper_bound)}" @@ -905,3 +909,29 @@ def comp_param(component_id: str, name: str) -> LinearExpressionEfficient: def is_linear(expr: LinearExpressionEfficient) -> bool: return True + + +@dataclass(frozen=True) +class PortFieldTerm: + port_name: str + field_name: str + aggregator: Optional[PortAggregator] = None + + def __str__(self) -> str: + result = f"{self.port_name}.{self.field_name}" + if self.aggregator is not None: + result += f".{str(self.aggregator)}" + return result + + def sum_connections(self) -> "PortFieldTerm": + if self.aggregator is not None: + raise ValueError(f"Port field {str(self)} already has a port aggregator") + return dataclasses.replace(self, aggregator=PortSum()) + + +class PortFieldExpr: + terms: List[PortFieldTerm] + + +def port_field(port_name: str, field_name: str) -> PortFieldTerm: + return PortFieldTerm(port_name, field_name) diff --git a/src/andromede/expression/port_operator.py b/src/andromede/expression/port_operator.py index 875d4f32..c15245a9 100644 --- a/src/andromede/expression/port_operator.py +++ b/src/andromede/expression/port_operator.py @@ -30,4 +30,6 @@ class PortAggregator: @dataclass(frozen=True) class PortSum(PortAggregator): - pass + + def __str__(self): + return "PortSum" diff --git a/src/andromede/model/constraint.py b/src/andromede/model/constraint.py index b8ef3d54..3c055e77 100644 --- a/src/andromede/model/constraint.py +++ b/src/andromede/model/constraint.py @@ -9,7 +9,7 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Optional, Union from andromede.expression.degree import is_constant @@ -24,10 +24,11 @@ # ExpressionNode, # literal, # ) -from andromede.expression.expression_efficient import Comparator, ComparisonNode +from andromede.expression.expression_efficient import Comparator, ComparisonNode, is_unbound from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, StandaloneConstraint, + linear_expressions_equal, literal, ) from andromede.expression.print import print_expr @@ -44,56 +45,39 @@ class Constraint: name: str expression: LinearExpressionEfficient - lower_bound: LinearExpressionEfficient - upper_bound: LinearExpressionEfficient - context: ProblemContext + lower_bound: LinearExpressionEfficient = field(default=literal(-float("inf"))) + upper_bound: LinearExpressionEfficient = field(default=literal(float("inf"))) + context: ProblemContext = field(default=ProblemContext.OPERATIONAL) - def __init__( + def __post_init__( self, - name: str, - expression: Union[LinearExpressionEfficient, StandaloneConstraint], - lower_bound: Optional[LinearExpressionEfficient] = None, - upper_bound: Optional[LinearExpressionEfficient] = None, - context: ProblemContext = ProblemContext.OPERATIONAL, ) -> None: - self.name = name - self.context = context - - if isinstance(expression, StandaloneConstraint): - if lower_bound is not None or upper_bound is not None: + if isinstance(self.expression, StandaloneConstraint): + # Case where constraint is initialized with something like Constraint(var("x") <= var("y")) + if not self.lower_bound.is_unbound() or not self.upper_bound.is_unbound(): raise ValueError( "Both comparison between two expressions and a bound are specfied, set either only a comparison between expressions or a single linear expression with bounds." ) - self.expression = expression.expression - self.lower_bound = expression.lower_bound - self.upper_bound = expression.upper_bound + self.lower_bound = self.expression.lower_bound + self.upper_bound = self.expression.upper_bound + self.expression = self.expression.expression + else: - for bound in [lower_bound, upper_bound]: - if bound is not None and not is_constant(bound): + for bound in [self.lower_bound, self.upper_bound]: + if not bound.is_constant(): raise ValueError( - f"The bounds of a constraint should not contain variables, {print_expr(bound)} was given." + f"The bounds of a constraint should not contain variables, {str(bound)} was given." ) - self.expression = expression - if lower_bound is not None: - self.lower_bound = lower_bound - else: - self.lower_bound = literal(-float("inf")) - - if upper_bound is not None: - self.upper_bound = upper_bound - else: - self.upper_bound = literal(float("inf")) - def __eq__(self, other: Any) -> bool: if not isinstance(other, Constraint): return False return ( self.name == other.name - and expressions_equal(self.expression, other.expression) - and expressions_equal_if_present(self.lower_bound, other.lower_bound) - and expressions_equal_if_present(self.upper_bound, other.upper_bound) + and linear_expressions_equal(self.expression, other.expression) + and linear_expressions_equal(self.lower_bound, other.lower_bound) + and linear_expressions_equal(self.upper_bound, other.upper_bound) ) def __str__(self) -> str: diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index a293e78f..a09d3ff9 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -330,6 +330,9 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> None: def _validate_port_field_expression(definition: PortFieldDefinition) -> None: + if not isinstance(definition.definition, LinearExpressionEfficient): + raise TypeError(f"Port field definition should be a LinearExpression, not a {type(definition.definition)}") + for term in definition.definition.terms.values(): visit(term.coefficient, _PortFieldExpressionChecker()) visit(definition.definition.constant, _PortFieldExpressionChecker()) diff --git a/tests/functional/test_performance_efficient.py b/tests/functional/test_performance_efficient.py index 03cb0534..43d1987f 100644 --- a/tests/functional/test_performance_efficient.py +++ b/tests/functional/test_performance_efficient.py @@ -14,7 +14,7 @@ import pytest from andromede.expression.evaluate import EvaluationContext -from andromede.expression.linear_expression_efficient import param, var +from andromede.expression.linear_expression_efficient import literal, param, var def test_large_number_of_parameters_sum() -> None: @@ -35,6 +35,32 @@ def test_large_number_of_parameters_sum() -> None: expr.evaluate(EvaluationContext(parameters=parameters_value)) +def test_large_number_of_identical_parameters_sum() -> None: + """ + With identical parameters sum, a simplification is performed online to avoid the recursivity. + """ + nb_terms = 500 + + parameters_value = {"cost": 1.0} + + # Still the recursion depth error with parameters + # with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): + expr = sum(param("cost") for _ in range(nb_terms)) + assert expr.evaluate(EvaluationContext(parameters=parameters_value)) == nb_terms + + +def test_large_number_of_literal_sum() -> None: + """ + Literal sums are computed online to avoid recursivity + """ + nb_terms = 500 + + # # Still the recursion depth error with parameters + # with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): + expr = sum(literal(1) for _ in range(nb_terms)) + assert expr.evaluate(EvaluationContext()) == nb_terms + + def test_large_number_of_variables_sum() -> None: """ Test performance when the problem involves an expression with a high number of terms. No problem when there is a large number of variables as this is derecusified. diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index 26f4ed0e..182786d4 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -18,7 +18,6 @@ from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import EvaluationContext, ValueProvider -from andromede.expression.evaluate_parameters import ParameterValueProvider from andromede.expression.expression_efficient import ( ComponentParameterNode, ExpressionNodeEfficient, @@ -41,11 +40,10 @@ linear_expressions_equal, literal, param, + sum_expressions, var, ) from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum -from andromede.model.constraint import Constraint -from andromede.simulation.linearize import linearize_expression @dataclass(frozen=True) @@ -217,6 +215,16 @@ def test_degree_computation_should_take_into_account_simplifications() -> None: param("p"), 2 * param("p"), ), + ( + literal(4) * param("p"), + param("p"), + 5 * param("p"), + ), + ( + param("p"), + param("p") * param("q"), + (1 + param("q")) * param("p"), # Equality visitor not able to handle commutativity + ), ], ) def test_addition( @@ -224,7 +232,7 @@ def test_addition( e2: LinearExpressionEfficient, expected: LinearExpressionEfficient, ) -> None: - assert e1 + e2 == expected + assert linear_expressions_equal(e1 + e2, expected) @pytest.mark.parametrize( @@ -280,6 +288,16 @@ def test_addition( param("p"), LinearExpressionEfficient(), ), + ( + literal(4) * param("p"), + param("p"), + 3 * param("p"), + ), + ( + param("p"), + param("p") * param("q"), + param("p") * (1 - param("q")), # Equality visitor not able to handle commutativity + ), ], ) def test_substraction( @@ -287,7 +305,10 @@ def test_substraction( e2: LinearExpressionEfficient, expected: LinearExpressionEfficient, ) -> None: - assert e1 - e2 == expected + print() + print(e1 - e2) + print(expected) + assert linear_expressions_equal(e1 - e2, expected) @pytest.mark.parametrize( @@ -314,7 +335,7 @@ def test_substraction( def test_linear_expression_equality( lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient ) -> None: - assert lhs == rhs + assert linear_expressions_equal(lhs, rhs) # TODO: What is the equivalent of this test ? @@ -622,9 +643,25 @@ def get_variable_structure(self, name: str) -> IndexingStructure: def test_sum_expressions() -> None: - # TODO: Sum of an empty list ? How to return a null LinearExpression object if the list is supposed to contain LinearExpression objects ? - assert linear_expressions_equal(sum([literal(1)]), literal(1)) - assert linear_expressions_equal(sum([literal(1), var("x")]), 1 + var("x")) + assert linear_expressions_equal(sum_expressions([]), literal(0)) + assert linear_expressions_equal(sum_expressions([literal(1)]), literal(1)) assert linear_expressions_equal( - sum([literal(1), var("x"), param("p")]), (1 + var("x")) + param("p") + sum_expressions([literal(1), var("x")]), 1 + var("x") ) + assert linear_expressions_equal( + sum_expressions([literal(1), var("x"), param("p")]), (1 + var("x")) + param("p") + ) + + +@pytest.mark.parametrize( + "expr, unbound", + [ + (literal(float("inf")), True), + (literal(float("-inf")), True), + (literal(-float("inf")), True), + (var("x") + literal(float("-inf")), True), + (var("x") + literal(4), False), + ], +) +def test_is_unbound(expr: LinearExpressionEfficient, unbound: bool) -> None: + assert expr.is_unbound() == unbound diff --git a/tests/unittests/test_model.py b/tests/unittests/test_model.py index c99f7abb..c5a1da16 100644 --- a/tests/unittests/test_model.py +++ b/tests/unittests/test_model.py @@ -10,16 +10,18 @@ # # This file is part of the Antares project. +from typing import Optional, Type + import pytest -from andromede.expression.expression import ( - ExpressionNode, - ExpressionRange, +from andromede.expression.expression_efficient import ExpressionRange, port_field +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, comp_param, comp_var, + linear_expressions_equal, literal, param, - port_field, var, ) from andromede.model import Constraint, float_parameter, float_variable, model @@ -39,6 +41,26 @@ literal(5), literal(10), ), + ( + "my_constraint", + 2 * var("my_var"), + None, + literal(10), + "my_constraint", + 2 * var("my_var"), + literal(-float("inf")), + literal(10), + ), + ( + "my_constraint", + 2 * var("my_var"), + literal(5), + None, + "my_constraint", + 2 * var("my_var"), + literal(5), + literal(float("inf")), + ), ( "my_constraint", 2 * var("my_var"), @@ -103,19 +125,27 @@ ) def test_constraint_instantiation( name: str, - expression: ExpressionNode, - lb: ExpressionNode, - ub: ExpressionNode, + expression: LinearExpressionEfficient, + lb: Optional[LinearExpressionEfficient], + ub: Optional[LinearExpressionEfficient], exp_name: str, - exp_expr: ExpressionNode, - exp_lb: ExpressionNode, - exp_ub: ExpressionNode, + exp_expr: LinearExpressionEfficient, + exp_lb: LinearExpressionEfficient, + exp_ub: LinearExpressionEfficient, ) -> None: - constraint = Constraint(name, expression, lb, ub) + if lb is None and ub is None: + constraint = Constraint(name, expression) + elif lb is None: + constraint = Constraint(name, expression, upper_bound=ub) + elif ub is None: + constraint = Constraint(name, expression, lower_bound=lb) + else: + constraint = Constraint(name, expression, lower_bound=lb, upper_bound=ub) + assert constraint.name == exp_name - assert constraint.expression == exp_expr - assert constraint.lower_bound == exp_lb - assert constraint.upper_bound == exp_ub + assert linear_expressions_equal(constraint.expression, exp_expr) + assert linear_expressions_equal(constraint.lower_bound, exp_lb) + assert linear_expressions_equal(constraint.upper_bound, exp_ub) def test_if_both_comparison_expression_and_bound_given_for_constraint_init_then_it_should_raise_a_value_error() -> ( @@ -134,13 +164,11 @@ def test_if_a_bound_is_not_constant_then_it_should_raise_a_value_error() -> None Constraint("my_constraint", 2 * var("my_var"), var("x")) assert ( str(exc.value) - == "The bounds of a constraint should not contain variables, x was given." + == "The bounds of a constraint should not contain variables, +x was given." ) -def test_writing_p_min_max_constraint_should_represent_all_expected_constraints() -> ( - None -): +def test_writing_p_min_max_constraint_should_not_raise_exception() -> None: """ Aim at representing the following mathematical constraints: For all t, p_min <= p[t] <= p_max * alpha[t] where p_min, p_max are literal paramters and alpha is an input timeseries @@ -160,19 +188,19 @@ def test_writing_p_min_max_constraint_should_represent_all_expected_constraints( assert False, f"Writing p_min and p_max constraints raises an exception: {exc}" -def test_writing_min_up_constraint_should_represent_all_expected_constraints() -> None: +def test_writing_min_up_constraint_should_not_raise_exception() -> None: """ Aim at representing the following mathematical constraints: For all t, for all t' in [t+1, t+d_min_up], off_on[k,t,w] <= on[k,t',w] """ try: - d_min_up = literal(3) + d_min_up = 3 off_on = var("off_on") on = var("on") _ = Constraint( "min_up_time", - off_on <= on.shift(ExpressionRange(literal(1), d_min_up)).sum(), + off_on <= on.sum(shift=ExpressionRange(1, d_min_up)), ) # Later on, the goal is to assert that when this constraint is sent to the solver, it correctly builds: for all t, for all t' in [t+1, t+d_min_up], off_on[k,t,w] <= on[k,t',w] @@ -181,6 +209,7 @@ def test_writing_min_up_constraint_should_represent_all_expected_constraints() - assert False, f"Writing min_up constraints raises an exception: {exc}" +@pytest.mark.skip(reason="Variance not implemented") def test_instantiating_a_model_with_non_linear_scenario_operator_in_the_objective_should_raise_type_error() -> ( None ): @@ -194,21 +223,23 @@ def test_instantiating_a_model_with_non_linear_scenario_operator_in_the_objectiv @pytest.mark.parametrize( - "expression", + "expression, error_type", [ - var("x") <= 0, - comp_var("c", "x"), - comp_param("c", "x"), - port_field("p", "f"), - port_field("p", "f").sum_connections(), + (var("x") <= 0, TypeError), + (comp_var("c", "x"), ValueError), + (comp_param("c", "x"), ValueError), + (port_field("p", "f"), ValueError), + (port_field("p", "f").sum_connections(), ValueError) ], ) -def test_invalid_port_field_definition_should_raise(expression: ExpressionNode) -> None: - with pytest.raises(ValueError) as exc: +def test_invalid_port_field_definition_should_raise( + expression: LinearExpressionEfficient, error_type: Type +) -> None: + with pytest.raises(error_type): port_field_def(port_name="p", field_name="f", definition=expression) -def test_constraint_equals(): +def test_constraint_equals() -> None: # checks in particular that expressions are correctly compared assert Constraint(name="c", expression=var("x") <= param("p")) == Constraint( name="c", expression=var("x") <= param("p") diff --git a/tests/unittests/test_port.py b/tests/unittests/test_port.py index 2ff4547f..8b471a31 100644 --- a/tests/unittests/test_port.py +++ b/tests/unittests/test_port.py @@ -12,8 +12,8 @@ import pytest -from andromede.expression import literal from andromede.expression.expression import port_field +from andromede.expression.linear_expression_efficient import literal from andromede.libs.standard import DEMAND_MODEL from andromede.model import Constraint, ModelPort, PortType, model from andromede.study import Node, PortRef, PortsConnection, create_component From 9ca614ab4a876a8d6f2abe27d4c2827fd48efd0d Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Mon, 22 Jul 2024 18:43:59 +0200 Subject: [PATCH 15/51] Make param and literal return expression node to later handle ExpressionRange --- .../expression/expression_efficient.py | 16 +- .../expression/linear_expression_efficient.py | 145 ++++++++++-------- src/andromede/expression/print.py | 36 +---- src/andromede/libs/standard.py | 29 ++-- src/andromede/model/common.py | 8 + src/andromede/model/constraint.py | 31 ++-- src/andromede/model/model.py | 38 +++-- src/andromede/model/variable.py | 40 +++-- .../expressions/test_expressions_efficient.py | 61 +++++--- tests/unittests/test_port.py | 10 +- 10 files changed, 241 insertions(+), 173 deletions(-) diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index 4d26331f..ffd95795 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -135,7 +135,8 @@ def wrap_in_node(obj: Any) -> ExpressionNodeEfficient: return obj elif isinstance(obj, float) or isinstance(obj, int): return LiteralNode(float(obj)) - raise TypeError(f"Unable to wrap {obj} into an expression node") + # Do not raise excpetion so that we can return NotImplemented in _apply_if_node + # raise TypeError(f"Unable to wrap {obj} into an expression node") def _apply_if_node( @@ -353,11 +354,23 @@ class ComponentParameterNode(ExpressionNodeEfficient): name: str +def param(name: str) -> ExpressionNodeEfficient: + return ParameterNode(name) + + +def comp_param(component_id: str, name: str) -> ExpressionNodeEfficient: + return ComponentParameterNode(component_id, name) + + @dataclass(frozen=True, eq=False) class LiteralNode(ExpressionNodeEfficient): value: float +def literal(value: float) -> ExpressionNodeEfficient: + return LiteralNode(value) + + def is_unbound(expr: ExpressionNodeEfficient) -> bool: return isinstance(expr, LiteralNode) and (abs(expr.value) == float("inf")) @@ -429,6 +442,7 @@ class DivisionNode(BinaryOperatorNode): @dataclass(frozen=True, eq=False) class ExpressionRange: + start: ExpressionNodeEfficient stop: ExpressionNodeEfficient step: Optional[ExpressionNodeEfficient] = None diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index c9f0210b..00b5f2a6 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -21,12 +21,10 @@ from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import ValueProvider, evaluate from andromede.expression.expression_efficient import ( - ComponentParameterNode, ExpressionNodeEfficient, ExpressionRange, InstancesTimeIndex, LiteralNode, - ParameterNode, ScenarioOperatorNode, TimeAggregatorNode, TimeOperatorNode, @@ -34,9 +32,10 @@ is_one, is_unbound, is_zero, + literal, wrap_in_node, ) -from andromede.expression.indexing import IndexingStructureProvider, compute_indexation +from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.port_operator import PortAggregator, PortSum from andromede.expression.print import print_expr @@ -88,7 +87,7 @@ class TermEfficient: def __post_init__(self) -> None: object.__setattr__(self, "coefficient", wrap_in_node(self.coefficient)) - def __eq__(self, other: "TermEfficient") -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, TermEfficient) and expressions_equal(self.coefficient, other.coefficient) @@ -374,6 +373,32 @@ def _substract_terms(lhs: TermEfficient, rhs: TermEfficient) -> TermEfficient: ) +# TODO: Try to use PortField Id which is exactly the same ? +@dataclass(frozen=True) +class PortFieldKey: + port_name: str + field_name: str + + +@dataclass(frozen=True) +class PortFieldTerm: + coefficient: ExpressionNodeEfficient + port_name: str + field_name: str + aggregator: Optional[PortAggregator] = None + + def __str__(self) -> str: + result = f"{self.port_name}.{self.field_name}" + if self.aggregator is not None: + result += f".{str(self.aggregator)}" + return result + + def sum_connections(self) -> "LinearExpressionEfficient": + if self.aggregator is not None: + raise ValueError(f"Port field {str(self)} already has a port aggregator") + return dataclasses.replace(self, aggregator=PortSum()) + + class LinearExpressionEfficient: """ Represents a linear expression with respect to variable names, for example 10x + 5y + 2. @@ -393,6 +418,7 @@ class LinearExpressionEfficient: terms: Dict[TermKeyEfficient, TermEfficient] constant: ExpressionNodeEfficient + port_field_terms: Dict[PortFieldKey, PortFieldTerm] # TODO: We need to check that terms.key is indeed a TermKey and change the tests that this will break def __init__( @@ -401,6 +427,9 @@ def __init__( Union[Dict[TermKeyEfficient, TermEfficient], List[TermEfficient]] ] = None, constant: Optional[Union[float, ExpressionNodeEfficient]] = None, + port_field_terms: Optional[ + Union[Dict[PortFieldKey, PortFieldTerm], List[PortFieldTerm]] + ] = None, ) -> None: if constant is None: self.constant = LiteralNode(0) @@ -422,10 +451,28 @@ def __init__( self.terms[generate_key(term)] = term else: raise TypeError( - f"Terms must be either of type Dict[str, Term] or List[Term], whereas {terms} is of type {type(terms)}" + f"Terms must be either of type Dict[TermKeyEfficient, Term] or List[Term], whereas {terms} is of type {type(terms)}" + ) + + self.port_field_terms = {} + if port_field_terms is not None: + if isinstance(port_field_terms, dict): + for port_field_term_key, port_field_term in port_field_terms.items(): + self.port_field_terms[port_field_term_key] = port_field_term + elif isinstance(port_field_terms, list): + for port_field_term in port_field_terms: + self.port_field_terms[ + PortFieldKey( + port_field_term.port_name, port_field_term.field_name + ) + ] = port_field_term + else: + raise TypeError( + f"Port field terms must be either of type Dict[PortFieldKey, PortFieldTerm] or List[PortFieldTerm], whereas {port_field_terms} is of type {type(port_field_terms)}" ) def is_zero(self) -> bool: + # TODO : Contribution of portfield ? return len(self.terms) == 0 and is_zero(self.constant) def str_for_constant(self) -> str: @@ -456,28 +503,28 @@ def __str__(self) -> str: def __le__(self, rhs: Any) -> "StandaloneConstraint": return StandaloneConstraint( expression=self - rhs, - lower_bound=literal(-float("inf")), - upper_bound=literal(0), + lower_bound=wrap_in_linear_expr(literal(-float("inf"))), + upper_bound=wrap_in_linear_expr(literal(0)), ) def __ge__(self, rhs: Any) -> "StandaloneConstraint": return StandaloneConstraint( expression=self - rhs, - lower_bound=literal(0), - upper_bound=literal(float("inf")), + lower_bound=wrap_in_linear_expr(literal(0)), + upper_bound=wrap_in_linear_expr(literal(float("inf"))), ) def __eq__(self, rhs: Any) -> "StandaloneConstraint": # type: ignore return StandaloneConstraint( expression=self - rhs, - lower_bound=literal(0), - upper_bound=literal(0), + lower_bound=wrap_in_linear_expr(literal(0)), + upper_bound=wrap_in_linear_expr(literal(0)), ) def __iadd__( self, rhs: Union["LinearExpressionEfficient", int, float] ) -> "LinearExpressionEfficient": - rhs = _wrap_in_linear_expr(rhs) + rhs = wrap_in_linear_expr(rhs) self.constant += rhs.constant aggregated_terms = _merge_dicts(self.terms, rhs.terms, _add_terms, 0) self.terms = aggregated_terms @@ -498,7 +545,7 @@ def __radd__(self, rhs: int) -> "LinearExpressionEfficient": def __isub__( self, rhs: Union["LinearExpressionEfficient", int, float] ) -> "LinearExpressionEfficient": - rhs = _wrap_in_linear_expr(rhs) + rhs = wrap_in_linear_expr(rhs) self.constant -= rhs.constant aggregated_terms = _merge_dicts(self.terms, rhs.terms, _substract_terms, 0) self.terms = aggregated_terms @@ -524,7 +571,7 @@ def __neg__(self) -> "LinearExpressionEfficient": def __imul__( self, rhs: Union["LinearExpressionEfficient", int, float] ) -> "LinearExpressionEfficient": - rhs = _wrap_in_linear_expr(rhs) + rhs = wrap_in_linear_expr(rhs) if self.terms and rhs.terms: raise ValueError("Cannot multiply two non constant expression") @@ -569,7 +616,7 @@ def __rmul__(self, rhs: int) -> "LinearExpressionEfficient": def __itruediv__( self, rhs: Union["LinearExpressionEfficient", int, float] ) -> "LinearExpressionEfficient": - rhs = _wrap_in_linear_expr(rhs) + rhs = wrap_in_linear_expr(rhs) if rhs.terms: raise ValueError("Cannot divide by a non constant expression") @@ -801,6 +848,16 @@ def expec(self) -> "LinearExpressionEfficient": # def variance(self) -> "ExpressionNode": # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) + def sum_connections(self) -> "LinearExpressionEfficient": + if not self.is_zero(): + raise ValueError( + "sum_connections only after an expression created with port_field" + ) + port_field_terms = {} + for port_field_key, port_field_value in self.port_field_terms.items(): + port_field_terms[port_field_key] = port_field_value.sum_connections() + return LinearExpressionEfficient(port_field_terms=port_field_terms) + def linear_expressions_equal( lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient @@ -818,7 +875,7 @@ def sum_expressions( expressions: Sequence[LinearExpressionEfficient], ) -> LinearExpressionEfficient: if len(expressions) == 0: - return literal(0) + return wrap_in_linear_expr(literal(0)) else: return sum(expressions) @@ -846,21 +903,21 @@ def __str__(self) -> str: return f"{str(self.lower_bound)} <= {str(self.expression)} <= {str(self.upper_bound)}" -def _wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient: +def wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient: if isinstance(obj, LinearExpressionEfficient): return obj elif isinstance(obj, float) or isinstance(obj, int): return LinearExpressionEfficient([], LiteralNode(float(obj))) + elif isinstance(obj, ExpressionNodeEfficient): + return LinearExpressionEfficient([], obj) raise TypeError(f"Unable to wrap {obj} into a linear expression") -# def _apply_if_node( -# obj: Any, func: Callable[[LinearExpressionEfficient], LinearExpressionEfficient] -# ) -> LinearExpressionEfficient: -# if as_linear_expr := _wrap_in_linear_expr(obj): -# return func(as_linear_expr) -# else: -# return NotImplemented +def wrap_in_linear_expr_if_present(obj: Any) -> Union[None, LinearExpressionEfficient]: + if obj is None: + return None + else: + return wrap_in_linear_expr(obj) def _copy_expression( @@ -870,10 +927,6 @@ def _copy_expression( dst.constant = src.constant -def literal(value: float) -> LinearExpressionEfficient: - return LinearExpressionEfficient([], LiteralNode(value)) - - # TODO : Define shortcuts for "x", is_one etc .... def var(name: str) -> LinearExpressionEfficient: return LinearExpressionEfficient( @@ -899,39 +952,11 @@ def comp_var(component_id: str, name: str) -> LinearExpressionEfficient: ) -def param(name: str) -> LinearExpressionEfficient: - return LinearExpressionEfficient([], ParameterNode(name)) - - -def comp_param(component_id: str, name: str) -> LinearExpressionEfficient: - return LinearExpressionEfficient([], ComponentParameterNode(component_id, name)) +def port_field(port_name: str, field_name: str) -> LinearExpressionEfficient: + return LinearExpressionEfficient( + port_field_terms=[PortFieldTerm(literal(1), port_name, field_name)] + ) def is_linear(expr: LinearExpressionEfficient) -> bool: return True - - -@dataclass(frozen=True) -class PortFieldTerm: - port_name: str - field_name: str - aggregator: Optional[PortAggregator] = None - - def __str__(self) -> str: - result = f"{self.port_name}.{self.field_name}" - if self.aggregator is not None: - result += f".{str(self.aggregator)}" - return result - - def sum_connections(self) -> "PortFieldTerm": - if self.aggregator is not None: - raise ValueError(f"Port field {str(self)} already has a port aggregator") - return dataclasses.replace(self, aggregator=PortSum()) - - -class PortFieldExpr: - terms: List[PortFieldTerm] - - -def port_field(port_name: str, field_name: str) -> PortFieldTerm: - return PortFieldTerm(port_name, field_name) diff --git a/src/andromede/expression/print.py b/src/andromede/expression/print.py index 86ef42d0..8fd8c53a 100644 --- a/src/andromede/expression/print.py +++ b/src/andromede/expression/print.py @@ -13,39 +13,19 @@ from dataclasses import dataclass from typing import Dict -from andromede.expression.expression import ( - ComponentParameterNode, - ComponentVariableNode, - ExpressionNode, - PortFieldAggregatorNode, - PortFieldNode, -) -from andromede.expression.visitor import T - -# from .expression import ( -# AdditionNode, -# Comparator, -# ComparisonNode, -# DivisionNode, -# LiteralNode, -# MultiplicationNode, -# NegationNode, -# ParameterNode, -# ScenarioOperatorNode, -# SubstractionNode, -# TimeAggregatorNode, -# TimeOperatorNode, -# VariableNode, -# ) from .expression_efficient import ( AdditionNode, Comparator, ComparisonNode, + ComponentParameterNode, DivisionNode, + ExpressionNodeEfficient, LiteralNode, MultiplicationNode, NegationNode, ParameterNode, + PortFieldAggregatorNode, + PortFieldNode, ScenarioOperatorNode, SubstractionNode, TimeAggregatorNode, @@ -100,15 +80,9 @@ def comparison(self, node: ComparisonNode) -> str: right_value = visit(node.right, self) return f"{left_value} {op} {right_value}" - # def variable(self, node: VariableNode) -> str: - # return node.name - def parameter(self, node: ParameterNode) -> str: return node.name - def comp_variable(self, node: ComponentVariableNode) -> str: - return f"{node.component_id}.{node.name}" - def comp_parameter(self, node: ComponentParameterNode) -> str: return f"{node.component_id}.{node.name}" @@ -129,5 +103,5 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> str: return f"({visit(node.operand, self)}.{node.aggregator})" -def print_expr(expression: ExpressionNode) -> str: +def print_expr(expression: ExpressionNodeEfficient) -> str: return visit(expression, PrinterVisitor()) diff --git a/src/andromede/libs/standard.py b/src/andromede/libs/standard.py index 9a6df24a..e5cdaf1f 100644 --- a/src/andromede/libs/standard.py +++ b/src/andromede/libs/standard.py @@ -13,9 +13,10 @@ """ The standard module contains the definition of standard models. """ -from andromede.expression import literal, param, var -from andromede.expression.expression import ExpressionRange, port_field + +from andromede.expression.expression_efficient import ExpressionRange, literal, param from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression_efficient import port_field, var from andromede.model.constraint import Constraint from andromede.model.model import ModelPort, PortFieldDefinition, PortFieldId, model from andromede.model.parameter import float_parameter, int_parameter @@ -254,16 +255,16 @@ ), Constraint( "Min up time", - var("nb_start") - .shift(ExpressionRange(-param("d_min_up") + 1, literal(0))) - .sum() + var("nb_start").sum( + shift=ExpressionRange(-param("d_min_up") + 1, literal(0)) + ) <= var("nb_on"), ), Constraint( "Min down time", - var("nb_stop") - .shift(ExpressionRange(-param("d_min_down") + 1, literal(0))) - .sum() + var("nb_stop").sum( + shift=ExpressionRange(-param("d_min_down") + 1, literal(0)) + ) <= param("nb_units_max").shift(-param("d_min_down")) - var("nb_on"), ), # It also works by writing ExpressionRange(-param("d_min_down") + 1, 0) as ExpressionRange's __post_init__ wraps integers to literal nodes. However, MyPy does not seem to infer that ExpressionRange's attributes are necessarily of ExpressionNode type and raises an error if the arguments in the constructor are integer (whereas it runs correctly), this why we specify it here with literal(0) instead of 0. @@ -331,16 +332,16 @@ ), Constraint( "Min up time", - var("nb_start") - .shift(ExpressionRange(-param("d_min_up") + 1, literal(0))) - .sum() + var("nb_start").sum( + shift=ExpressionRange(-param("d_min_up") + 1, literal(0)) + ) <= var("nb_on"), ), Constraint( "Min down time", - var("nb_stop") - .shift(ExpressionRange(-param("d_min_down") + 1, literal(0))) - .sum() + var("nb_stop").sum( + shift=ExpressionRange(-param("d_min_down") + 1, literal(0)) + ) <= param("nb_units_max").shift(-param("d_min_down")) - var("nb_on"), ), ], diff --git a/src/andromede/model/common.py b/src/andromede/model/common.py index 07ebdcb4..359951ce 100644 --- a/src/andromede/model/common.py +++ b/src/andromede/model/common.py @@ -14,6 +14,14 @@ Module for common classes used in models. """ from enum import Enum +from typing import Union + +from andromede.expression.expression_efficient import ExpressionNodeEfficient +from andromede.expression.linear_expression_efficient import LinearExpressionEfficient + +ValueOrExprNodeOrLinearExpr = Union[ + int, float, ExpressionNodeEfficient, LinearExpressionEfficient +] class ValueType(Enum): diff --git a/src/andromede/model/constraint.py b/src/andromede/model/constraint.py index 3c055e77..f1cb27b8 100644 --- a/src/andromede/model/constraint.py +++ b/src/andromede/model/constraint.py @@ -10,28 +10,15 @@ # # This file is part of the Antares project. from dataclasses import dataclass, field -from typing import Any, Optional, Union +from typing import Any -from andromede.expression.degree import is_constant -from andromede.expression.equality import ( - expressions_equal, - expressions_equal_if_present, -) - -# from andromede.expression.expression import ( -# Comparator, -# ComparisonNode, -# ExpressionNode, -# literal, -# ) -from andromede.expression.expression_efficient import Comparator, ComparisonNode, is_unbound +from andromede.expression.expression_efficient import literal from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, StandaloneConstraint, linear_expressions_equal, - literal, + wrap_in_linear_expr, ) -from andromede.expression.print import print_expr from andromede.model.common import ProblemContext @@ -45,13 +32,20 @@ class Constraint: name: str expression: LinearExpressionEfficient - lower_bound: LinearExpressionEfficient = field(default=literal(-float("inf"))) - upper_bound: LinearExpressionEfficient = field(default=literal(float("inf"))) + lower_bound: LinearExpressionEfficient = field( + default=wrap_in_linear_expr(literal(-float("inf"))) + ) + upper_bound: LinearExpressionEfficient = field( + default=wrap_in_linear_expr(literal(float("inf"))) + ) context: ProblemContext = field(default=ProblemContext.OPERATIONAL) def __post_init__( self, ) -> None: + self.lower_bound = wrap_in_linear_expr(self.lower_bound) + self.upper_bound = wrap_in_linear_expr(self.upper_bound) + if isinstance(self.expression, StandaloneConstraint): # Case where constraint is initialized with something like Constraint(var("x") <= var("y")) if not self.lower_bound.is_unbound() or not self.upper_bound.is_unbound(): @@ -64,6 +58,7 @@ def __post_init__( self.expression = self.expression.expression else: + self.expression = wrap_in_linear_expr(self.expression) for bound in [self.lower_bound, self.upper_bound]: if not bound.is_constant(): raise ValueError( diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index a09d3ff9..62f79a65 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -51,8 +51,11 @@ from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, is_linear, + wrap_in_linear_expr, + wrap_in_linear_expr_if_present, ) from andromede.expression.visitor import ExpressionVisitor, visit +from andromede.model.common import ValueOrExprNodeOrLinearExpr from andromede.model.constraint import Constraint from andromede.model.parameter import Parameter from andromede.model.port import PortType @@ -124,8 +127,8 @@ def _is_objective_contribution_valid( raise ValueError("Objective contribution must be a linear expression.") data_structure_provider = _make_structure_provider(model) - objective_structure = compute_indexation( - objective_contribution, data_structure_provider + objective_structure = objective_contribution.compute_indexation( + data_structure_provider ) if objective_structure != IndexingStructure(time=False, scenario=False): @@ -163,6 +166,7 @@ class PortFieldDefinition: definition: LinearExpressionEfficient def __post_init__(self) -> None: + object.__setattr__(self, "definition", wrap_in_linear_expr(self.definition)) _validate_port_field_expression(self) @@ -229,8 +233,8 @@ def model( binding_constraints: Optional[Iterable[Constraint]] = None, parameters: Optional[Iterable[Parameter]] = None, variables: Optional[Iterable[Variable]] = None, - objective_operational_contribution: Optional[LinearExpressionEfficient] = None, - objective_investment_contribution: Optional[LinearExpressionEfficient] = None, + objective_operational_contribution: Optional[ValueOrExprNodeOrLinearExpr] = None, + objective_investment_contribution: Optional[ValueOrExprNodeOrLinearExpr] = None, inter_block_dyn: bool = False, ports: Optional[Iterable[ModelPort]] = None, port_fields_definitions: Optional[Iterable[PortFieldDefinition]] = None, @@ -251,18 +255,24 @@ def model( return Model( id=id, constraints={c.name: c for c in constraints} if constraints else {}, - binding_constraints={c.name: c for c in binding_constraints} - if binding_constraints - else {}, + binding_constraints=( + {c.name: c for c in binding_constraints} if binding_constraints else {} + ), parameters={p.name: p for p in parameters} if parameters else {}, variables={v.name: v for v in variables} if variables else {}, - objective_operational_contribution=objective_operational_contribution, - objective_investment_contribution=objective_investment_contribution, + objective_operational_contribution=wrap_in_linear_expr_if_present( + objective_operational_contribution + ), + objective_investment_contribution=wrap_in_linear_expr_if_present( + objective_investment_contribution + ), inter_block_dyn=inter_block_dyn, ports=existing_port_names, - port_fields_definitions={d.port_field: d for d in port_fields_definitions} - if port_fields_definitions - else {}, + port_fields_definitions=( + {d.port_field: d for d in port_fields_definitions} + if port_fields_definitions + else {} + ), ) @@ -331,7 +341,9 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> None: def _validate_port_field_expression(definition: PortFieldDefinition) -> None: if not isinstance(definition.definition, LinearExpressionEfficient): - raise TypeError(f"Port field definition should be a LinearExpression, not a {type(definition.definition)}") + raise TypeError( + f"Port field definition should be a LinearExpression, not a {type(definition.definition)}" + ) for term in definition.definition.terms.values(): visit(term.coefficient, _PortFieldExpressionChecker()) diff --git a/src/andromede/model/variable.py b/src/andromede/model/variable.py index daaa3a64..c8d26280 100644 --- a/src/andromede/model/variable.py +++ b/src/andromede/model/variable.py @@ -13,13 +13,17 @@ from dataclasses import dataclass from typing import Any, Optional -from andromede.expression.equality import ( - expressions_equal, - expressions_equal_if_present, -) +from andromede.expression.equality import expressions_equal_if_present from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import LinearExpressionEfficient -from andromede.model.common import ProblemContext, ValueType +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + wrap_in_linear_expr_if_present, +) +from andromede.model.common import ( + ProblemContext, + ValueOrExprNodeOrLinearExpr, + ValueType, +) @dataclass @@ -55,21 +59,33 @@ def __eq__(self, other: Any) -> bool: def int_variable( name: str, - lower_bound: Optional[LinearExpressionEfficient] = None, - upper_bound: Optional[LinearExpressionEfficient] = None, + lower_bound: Optional[ValueOrExprNodeOrLinearExpr] = None, + upper_bound: Optional[ValueOrExprNodeOrLinearExpr] = None, structure: IndexingStructure = IndexingStructure(True, True), context: ProblemContext = ProblemContext.OPERATIONAL, ) -> Variable: return Variable( - name, ValueType.INTEGER, lower_bound, upper_bound, structure, context + name, + ValueType.INTEGER, + wrap_in_linear_expr_if_present(lower_bound), + wrap_in_linear_expr_if_present(upper_bound), + structure, + context, ) def float_variable( name: str, - lower_bound: Optional[LinearExpressionEfficient] = None, - upper_bound: Optional[LinearExpressionEfficient] = None, + lower_bound: Optional[ValueOrExprNodeOrLinearExpr] = None, + upper_bound: Optional[ValueOrExprNodeOrLinearExpr] = None, structure: IndexingStructure = IndexingStructure(True, True), context: ProblemContext = ProblemContext.OPERATIONAL, ) -> Variable: - return Variable(name, ValueType.FLOAT, lower_bound, upper_bound, structure, context) + return Variable( + name, + ValueType.FLOAT, + wrap_in_linear_expr_if_present(lower_bound), + wrap_in_linear_expr_if_present(upper_bound), + structure, + context, + ) diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index 182786d4..cdf99703 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -27,6 +27,9 @@ ParameterNode, TimeAggregatorNode, TimeOperatorNode, + comp_param, + literal, + param, ) from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure @@ -35,13 +38,11 @@ StandaloneConstraint, TermEfficient, TermKeyEfficient, - comp_param, comp_var, linear_expressions_equal, - literal, - param, sum_expressions, var, + wrap_in_linear_expr, ) from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum @@ -223,7 +224,8 @@ def test_degree_computation_should_take_into_account_simplifications() -> None: ( param("p"), param("p") * param("q"), - (1 + param("q")) * param("p"), # Equality visitor not able to handle commutativity + (1 + param("q")) + * param("p"), # Equality visitor not able to handle commutativity ), ], ) @@ -232,7 +234,9 @@ def test_addition( e2: LinearExpressionEfficient, expected: LinearExpressionEfficient, ) -> None: - assert linear_expressions_equal(e1 + e2, expected) + assert linear_expressions_equal( + wrap_in_linear_expr(e1) + wrap_in_linear_expr(e2), wrap_in_linear_expr(expected) + ) @pytest.mark.parametrize( @@ -296,7 +300,8 @@ def test_addition( ( param("p"), param("p") * param("q"), - param("p") * (1 - param("q")), # Equality visitor not able to handle commutativity + (1 - param("q")) + * param("p"), # Equality visitor not able to handle commutativity ), ], ) @@ -305,10 +310,9 @@ def test_substraction( e2: LinearExpressionEfficient, expected: LinearExpressionEfficient, ) -> None: - print() - print(e1 - e2) - print(expected) - assert linear_expressions_equal(e1 - e2, expected) + assert linear_expressions_equal( + wrap_in_linear_expr(e1) - wrap_in_linear_expr(e2), wrap_in_linear_expr(expected) + ) @pytest.mark.parametrize( @@ -353,7 +357,9 @@ def test_linear_expression_equality( def test_standalone_constraint() -> None: - cst = StandaloneConstraint(var("x"), literal(0), literal(10)) + cst = StandaloneConstraint( + var("x"), wrap_in_linear_expr(literal(0)), wrap_in_linear_expr(literal(10)) + ) assert str(cst) == "0 <= +x <= + 10" @@ -641,16 +647,29 @@ def get_variable_structure(self, name: str) -> IndexingStructure: assert expr.compute_indexation(provider) == IndexingStructure(True, True) -def test_sum_expressions() -> None: +@pytest.mark.parametrize( + "sum_expr, expected", + [ + (sum_expressions([]), literal(0)), + (sum_expressions([wrap_in_linear_expr(literal(1))]), literal(1)), + (sum_expressions([wrap_in_linear_expr(literal(1)), var("x")]), 1 + var("x")), + ( + sum_expressions( + [ + wrap_in_linear_expr(literal(1)), + var("x"), + wrap_in_linear_expr(param("p")), + ] + ), + (1 + var("x")) + param("p"), + ), + ], +) +def test_sum_expressions( + sum_expr: LinearExpressionEfficient, expected: LinearExpressionEfficient +) -> None: - assert linear_expressions_equal(sum_expressions([]), literal(0)) - assert linear_expressions_equal(sum_expressions([literal(1)]), literal(1)) - assert linear_expressions_equal( - sum_expressions([literal(1), var("x")]), 1 + var("x") - ) - assert linear_expressions_equal( - sum_expressions([literal(1), var("x"), param("p")]), (1 + var("x")) + param("p") - ) + assert linear_expressions_equal(sum_expr, wrap_in_linear_expr(expected)) @pytest.mark.parametrize( @@ -664,4 +683,4 @@ def test_sum_expressions() -> None: ], ) def test_is_unbound(expr: LinearExpressionEfficient, unbound: bool) -> None: - assert expr.is_unbound() == unbound + assert wrap_in_linear_expr(expr).is_unbound() == unbound diff --git a/tests/unittests/test_port.py b/tests/unittests/test_port.py index 8b471a31..3e443b76 100644 --- a/tests/unittests/test_port.py +++ b/tests/unittests/test_port.py @@ -12,10 +12,13 @@ import pytest -from andromede.expression.expression import port_field -from andromede.expression.linear_expression_efficient import literal +from andromede.expression.expression_efficient import literal +from andromede.expression.linear_expression_efficient import port_field from andromede.libs.standard import DEMAND_MODEL from andromede.model import Constraint, ModelPort, PortType, model +from andromede.model.constraint import Constraint +from andromede.model.model import ModelPort, model +from andromede.model.port import PortType from andromede.study import Node, PortRef, PortsConnection, create_component @@ -28,7 +31,8 @@ def test_port_type_compatibility_ko() -> None: constraints=[ Constraint( name="Balance", - expression=port_field("balance_port", "flow").sum() == literal(0), + expression=port_field("balance_port", "flow").sum_connections() + == literal(0), ) ], ) From 3d03d86a81ba6cb33d59972a146b5bea437c9917 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Mon, 22 Jul 2024 18:59:05 +0200 Subject: [PATCH 16/51] Fix model test --- tests/unittests/test_model.py | 48 +++++++++++++++++------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/unittests/test_model.py b/tests/unittests/test_model.py index c5a1da16..8ae4664c 100644 --- a/tests/unittests/test_model.py +++ b/tests/unittests/test_model.py @@ -14,18 +14,18 @@ import pytest -from andromede.expression.expression_efficient import ExpressionRange, port_field +from andromede.expression.expression_efficient import ExpressionRange, comp_param, param from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, - comp_param, comp_var, linear_expressions_equal, literal, - param, + port_field, var, + wrap_in_linear_expr, ) -from andromede.model import Constraint, float_parameter, float_variable, model -from andromede.model.model import PortFieldDefinition, port_field_def +from andromede.model import Constraint, float_variable, model +from andromede.model.model import port_field_def @pytest.mark.parametrize( @@ -38,8 +38,8 @@ literal(10), "my_constraint", 2 * var("my_var"), - literal(5), - literal(10), + wrap_in_linear_expr(literal(5)), + wrap_in_linear_expr(literal(10)), ), ( "my_constraint", @@ -48,8 +48,8 @@ literal(10), "my_constraint", 2 * var("my_var"), - literal(-float("inf")), - literal(10), + wrap_in_linear_expr(literal(-float("inf"))), + wrap_in_linear_expr(literal(10)), ), ( "my_constraint", @@ -58,8 +58,8 @@ None, "my_constraint", 2 * var("my_var"), - literal(5), - literal(float("inf")), + wrap_in_linear_expr(literal(5)), + wrap_in_linear_expr(literal(float("inf"))), ), ( "my_constraint", @@ -68,8 +68,8 @@ None, "my_constraint", 2 * var("my_var"), - literal(-float("inf")), - literal(float("inf")), + wrap_in_linear_expr(literal(-float("inf"))), + wrap_in_linear_expr(literal(float("inf"))), ), ( "my_constraint", @@ -78,8 +78,8 @@ None, "my_constraint", 2 * var("my_var") - param("p"), - literal(-float("inf")), - literal(0), + wrap_in_linear_expr(literal(-float("inf"))), + wrap_in_linear_expr(literal(0)), ), ( "my_constraint", @@ -88,8 +88,8 @@ None, "my_constraint", 2 * var("my_var") - param("p"), - literal(0), - literal(float("inf")), + wrap_in_linear_expr(literal(0)), + wrap_in_linear_expr(literal(float("inf"))), ), ( "my_constraint", @@ -98,8 +98,8 @@ None, "my_constraint", 2 * var("my_var") - param("p"), - literal(0), - literal(0), + wrap_in_linear_expr(literal(0)), + wrap_in_linear_expr(literal(0)), ), ( "my_constraint", @@ -108,8 +108,8 @@ None, "my_constraint", 2 * var("my_var").expec() - param("p"), - literal(0), - literal(0), + wrap_in_linear_expr(literal(0)), + wrap_in_linear_expr(literal(0)), ), ( "my_constraint", @@ -118,8 +118,8 @@ None, "my_constraint", 2 * var("my_var").shift(-1) - param("p"), - literal(0), - literal(0), + wrap_in_linear_expr(literal(0)), + wrap_in_linear_expr(literal(0)), ), ], ) @@ -229,7 +229,7 @@ def test_instantiating_a_model_with_non_linear_scenario_operator_in_the_objectiv (comp_var("c", "x"), ValueError), (comp_param("c", "x"), ValueError), (port_field("p", "f"), ValueError), - (port_field("p", "f").sum_connections(), ValueError) + (port_field("p", "f").sum_connections(), ValueError), ], ) def test_invalid_port_field_definition_should_raise( From 32514959e6261fac6aab17c6f120b388fab7d923 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 23 Jul 2024 11:55:03 +0200 Subject: [PATCH 17/51] Fix test imports and port definition validation --- .../expression/parsing/parse_expression.py | 10 +---- src/andromede/expression/port_resolver.py | 12 ++---- src/andromede/libs/standard_sc.py | 6 ++- src/andromede/model/model.py | 25 ++++++++++++- tests/functional/test_andromede.py | 7 +--- tests/functional/test_andromede_yml.py | 5 --- .../functional/test_performance_efficient.py | 3 +- tests/functional/test_xpansion.py | 5 +-- tests/integration/test_benders_decomposed.py | 3 +- tests/models/test_electrolyzer.py | 3 +- .../parsing/test_expression_parsing.py | 6 ++- tests/unittests/expressions/test_equality.py | 11 +++--- tests/unittests/model/test_model_parsing.py | 4 +- tests/unittests/test_data.py | 3 +- tests/unittests/test_model.py | 37 +++++++++++++++---- 15 files changed, 85 insertions(+), 55 deletions(-) diff --git a/src/andromede/expression/parsing/parse_expression.py b/src/andromede/expression/parsing/parse_expression.py index 072419a1..b7b79704 100644 --- a/src/andromede/expression/parsing/parse_expression.py +++ b/src/andromede/expression/parsing/parse_expression.py @@ -15,26 +15,18 @@ from antlr4 import CommonTokenStream, DiagnosticErrorListener, InputStream from antlr4.error.ErrorStrategy import BailErrorStrategy -# from andromede.expression import ExpressionNode, literal, param, var from andromede.expression.equality import expressions_equal - -# from andromede.expression.expression import ( -# Comparator, -# ComparisonNode, -# ExpressionRange, -# PortFieldNode, -# ) from andromede.expression.expression_efficient import ( Comparator, ComparisonNode, ExpressionNodeEfficient, ExpressionRange, PortFieldNode, + param, ) from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, literal, - param, var, ) from andromede.expression.parsing.antlr.ExprLexer import ExprLexer diff --git a/src/andromede/expression/port_resolver.py b/src/andromede/expression/port_resolver.py index a6728e27..02755488 100644 --- a/src/andromede/expression/port_resolver.py +++ b/src/andromede/expression/port_resolver.py @@ -15,20 +15,14 @@ from typing import Dict, List from andromede.expression import CopyVisitor, visit - -# from andromede.expression.expression import ( -# AdditionNode, -# ExpressionNode, -# LiteralNode, -# PortFieldAggregatorNode, -# PortFieldNode, -# ) from andromede.expression.expression_efficient import ( PortFieldAggregatorNode, PortFieldNode, +) +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, sum_expressions, ) -from andromede.expression.linear_expression_efficient import LinearExpressionEfficient from andromede.model.model import PortFieldId diff --git a/src/andromede/libs/standard_sc.py b/src/andromede/libs/standard_sc.py index 25200638..7864adf8 100644 --- a/src/andromede/libs/standard_sc.py +++ b/src/andromede/libs/standard_sc.py @@ -10,9 +10,11 @@ # # This file is part of the Antares project. -from andromede.expression import literal, param, var + from andromede.expression.expression import port_field -from andromede.libs.standard import BALANCE_PORT_TYPE, CONSTANT, TIME_AND_SCENARIO_FREE +from andromede.expression.expression_efficient import literal, param +from andromede.expression.linear_expression_efficient import var +from andromede.libs.standard import BALANCE_PORT_TYPE, CONSTANT from andromede.model import ( Constraint, ModelPort, diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index 62f79a65..6a856e05 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -279,7 +279,7 @@ def model( class _PortFieldExpressionChecker(ExpressionVisitor[None]): """ Visits the whole expression to check there is no: - comparison, other port field, component-associated parametrs or variables... + comparison, other port field, component-associated parameters or variables... """ def literal(self, node: LiteralNode) -> None: @@ -340,11 +340,34 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> None: def _validate_port_field_expression(definition: PortFieldDefinition) -> None: + """ + Check there is no: + comparison, other port field, component-associated parameters or variables... + """ + _check_port_field_expression_type(definition) + _check_no_reference_to_other_port_field(definition) + _check_no_component_associated_variable_or_parameter(definition) + + +def _check_port_field_expression_type(definition: PortFieldDefinition) -> None: if not isinstance(definition.definition, LinearExpressionEfficient): raise TypeError( f"Port field definition should be a LinearExpression, not a {type(definition.definition)}" ) + +def _check_no_reference_to_other_port_field(definition: PortFieldDefinition) -> None: + if definition.definition.port_field_terms: + raise ValueError("Port definition cannot reference another port field.") + + +def _check_no_component_associated_variable_or_parameter( + definition: PortFieldDefinition, +) -> None: for term in definition.definition.terms.values(): + if term.component_id: + raise ValueError( + "Port definition must not contain a variable associated to a component." + ) visit(term.coefficient, _PortFieldExpressionChecker()) visit(definition.definition.constant, _PortFieldExpressionChecker()) diff --git a/tests/functional/test_andromede.py b/tests/functional/test_andromede.py index 6e87b01a..52db4c2b 100644 --- a/tests/functional/test_andromede.py +++ b/tests/functional/test_andromede.py @@ -13,25 +13,22 @@ import pandas as pd import pytest -from andromede.expression import literal, param, var +from andromede.expression.expression_efficient import literal, param from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression_efficient import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, DEMAND_MODEL, GENERATOR_MODEL, - GENERATOR_MODEL_WITH_PMIN, - LINK_MODEL, NODE_BALANCE_MODEL, SHORT_TERM_STORAGE_SIMPLE, SPILLAGE_MODEL, - THERMAL_CLUSTER_MODEL_HD, UNSUPPLIED_ENERGY_MODEL, ) from andromede.model import Model, ModelPort, float_parameter, float_variable, model from andromede.model.model import PortFieldDefinition, PortFieldId from andromede.simulation import ( BlockBorderManagement, - OutputValues, TimeBlock, build_problem, ) diff --git a/tests/functional/test_andromede_yml.py b/tests/functional/test_andromede_yml.py index c959f0f4..07a2d1dd 100644 --- a/tests/functional/test_andromede_yml.py +++ b/tests/functional/test_andromede_yml.py @@ -1,10 +1,6 @@ import pandas as pd import pytest -from andromede.expression import literal, param, var -from andromede.expression.indexing_structure import IndexingStructure -from andromede.model import Model, ModelPort, float_parameter, float_variable, model -from andromede.model.model import PortFieldDefinition, PortFieldId from andromede.simulation import ( BlockBorderManagement, OutputValues, @@ -17,7 +13,6 @@ Network, Node, PortRef, - TimeScenarioIndex, TimeScenarioSeriesData, create_component, ) diff --git a/tests/functional/test_performance_efficient.py b/tests/functional/test_performance_efficient.py index 43d1987f..412f092c 100644 --- a/tests/functional/test_performance_efficient.py +++ b/tests/functional/test_performance_efficient.py @@ -14,7 +14,8 @@ import pytest from andromede.expression.evaluate import EvaluationContext -from andromede.expression.linear_expression_efficient import literal, param, var +from andromede.expression.expression_efficient import param +from andromede.expression.linear_expression_efficient import literal, var def test_large_number_of_parameters_sum() -> None: diff --git a/tests/functional/test_xpansion.py b/tests/functional/test_xpansion.py index cba44bca..4fbba8fe 100644 --- a/tests/functional/test_xpansion.py +++ b/tests/functional/test_xpansion.py @@ -13,15 +13,15 @@ import pandas as pd import pytest -from andromede.expression.expression import literal, param, port_field, var +from andromede.expression.expression_efficient import literal, param from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression_efficient import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, CONSTANT, DEMAND_MODEL, GENERATOR_MODEL, NODE_BALANCE_MODEL, - NODE_WITH_SPILL_AND_ENS_MODEL, ) from andromede.model import ( Constraint, @@ -38,7 +38,6 @@ MergedProblemStrategy, OutputValues, TimeBlock, - build_benders_decomposed_problem, build_problem, ) from andromede.study import ( diff --git a/tests/integration/test_benders_decomposed.py b/tests/integration/test_benders_decomposed.py index 38cc06b0..a80ea3b4 100644 --- a/tests/integration/test_benders_decomposed.py +++ b/tests/integration/test_benders_decomposed.py @@ -12,8 +12,9 @@ import pytest -from andromede.expression.expression import literal, param, var +from andromede.expression.expression_efficient import literal, param from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression_efficient import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, CONSTANT, diff --git a/tests/models/test_electrolyzer.py b/tests/models/test_electrolyzer.py index f6ea1d7c..afea65fc 100644 --- a/tests/models/test_electrolyzer.py +++ b/tests/models/test_electrolyzer.py @@ -10,8 +10,9 @@ # # This file is part of the Antares project. -from andromede.expression import literal, param, var from andromede.expression.expression import port_field +from andromede.expression.expression_efficient import literal, param +from andromede.expression.linear_expression_efficient import var from andromede.libs.standard import CONSTANT, TIME_AND_SCENARIO_FREE from andromede.model import ( Constraint, diff --git a/tests/unittests/expressions/parsing/test_expression_parsing.py b/tests/unittests/expressions/parsing/test_expression_parsing.py index 13fb04a3..eb31ab3f 100644 --- a/tests/unittests/expressions/parsing/test_expression_parsing.py +++ b/tests/unittests/expressions/parsing/test_expression_parsing.py @@ -13,14 +13,16 @@ import pytest -from andromede.expression import ExpressionNode, literal, param, print_expr, var from andromede.expression.equality import expressions_equal from andromede.expression.expression import ExpressionRange, port_field +from andromede.expression.expression_efficient import literal, param +from andromede.expression.linear_expression_efficient import LinearExpressionEfficient, var from andromede.expression.parsing.parse_expression import ( AntaresParseException, ModelIdentifiers, parse_expression, ) +from andromede.expression.print import print_expr @pytest.mark.parametrize( @@ -151,7 +153,7 @@ def test_parsing_visitor( variables: Set[str], parameters: Set[str], expression_str: str, - expected: ExpressionNode, + expected: LinearExpressionEfficient, ): identifiers = ModelIdentifiers(variables, parameters) expr = parse_expression(expression_str, identifiers) diff --git a/tests/unittests/expressions/test_equality.py b/tests/unittests/expressions/test_equality.py index c042fdde..8b9f0360 100644 --- a/tests/unittests/expressions/test_equality.py +++ b/tests/unittests/expressions/test_equality.py @@ -10,17 +10,18 @@ # # This file is part of the Antares project. -import math import pytest -from andromede.expression import ExpressionNode, copy_expression, literal, param, var + +from andromede.expression.copy import copy_expression from andromede.expression.equality import expressions_equal from andromede.expression.expression import ( - ExpressionRange, TimeAggregatorNode, expression_range, ) +from andromede.expression.expression_efficient import literal, param +from andromede.expression.linear_expression_efficient import LinearExpressionEfficient, var def shifted_x(): @@ -44,7 +45,7 @@ def shifted_x(): var("x").expec(), ], ) -def test_equals(expr: ExpressionNode) -> None: +def test_equals(expr: LinearExpressionEfficient) -> None: copy = copy_expression(expr) assert expressions_equal(expr, copy) @@ -74,7 +75,7 @@ def test_equals(expr: ExpressionNode) -> None: (var("x").expec(), var("y").expec()), ], ) -def test_not_equals(lhs: ExpressionNode, rhs: ExpressionNode) -> None: +def test_not_equals(lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient) -> None: assert not expressions_equal(lhs, rhs) diff --git a/tests/unittests/model/test_model_parsing.py b/tests/unittests/model/test_model_parsing.py index a9234839..931deaee 100644 --- a/tests/unittests/model/test_model_parsing.py +++ b/tests/unittests/model/test_model_parsing.py @@ -9,13 +9,13 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. -import io from pathlib import Path import pytest -from andromede.expression import literal, param, var from andromede.expression.expression import port_field +from andromede.expression.expression_efficient import literal, param +from andromede.expression.linear_expression_efficient import var from andromede.expression.parsing.parse_expression import AntaresParseException from andromede.libs.standard import CONSTANT from andromede.model import ( diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 02f9d979..c0e014b4 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -15,8 +15,9 @@ import pandas as pd import pytest -from andromede.expression import param, var +from andromede.expression.expression_efficient import param from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression_efficient import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, CONSTANT, diff --git a/tests/unittests/test_model.py b/tests/unittests/test_model.py index 8ae4664c..43da1f8c 100644 --- a/tests/unittests/test_model.py +++ b/tests/unittests/test_model.py @@ -10,6 +10,7 @@ # # This file is part of the Antares project. +import re from typing import Optional, Type import pytest @@ -223,19 +224,39 @@ def test_instantiating_a_model_with_non_linear_scenario_operator_in_the_objectiv @pytest.mark.parametrize( - "expression, error_type", + "expression, error_type, error_msg", [ - (var("x") <= 0, TypeError), - (comp_var("c", "x"), ValueError), - (comp_param("c", "x"), ValueError), - (port_field("p", "f"), ValueError), - (port_field("p", "f").sum_connections(), ValueError), + ( + var("x") <= 0, + TypeError, + "Unable to wrap + (-inf) <= +x <= 0 into a linear expression", + ), + ( + comp_var("c", "x"), + ValueError, + "Port definition must not contain a variable associated to a component.", + ), + ( + comp_param("c", "x"), + ValueError, + "Port definition must not contain a parameter associated to a component.", + ), + ( + port_field("p", "f"), + ValueError, + "Port definition cannot reference another port field.", + ), + ( + port_field("p", "f").sum_connections(), + ValueError, + "Port definition cannot reference another port field.", + ), ], ) def test_invalid_port_field_definition_should_raise( - expression: LinearExpressionEfficient, error_type: Type + expression: LinearExpressionEfficient, error_type: Type, error_msg: str ) -> None: - with pytest.raises(error_type): + with pytest.raises(error_type, match=re.escape(error_msg)): port_field_def(port_name="p", field_name="f", definition=expression) From 6f85db070e03dc2eb15e01ed9b76443ad4b77801 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 23 Jul 2024 18:24:29 +0200 Subject: [PATCH 18/51] Design problem between time operator / expression node --- .../expression/linear_expression_efficient.py | 257 +++++++++++------- src/andromede/expression/port_resolver.py | 95 +++---- src/andromede/expression/time_operator.py | 2 +- .../functional/test_performance_efficient.py | 12 +- tests/unittests/expressions/test_equality.py | 114 +++++--- .../test_linear_expressions_efficient.py | 56 +++- .../expressions/test_port_resolver.py | 47 +++- 7 files changed, 371 insertions(+), 212 deletions(-) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 00b5f2a6..29c8924c 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -16,7 +16,17 @@ """ import dataclasses from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + TypeVar, + Union, + overload, +) from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import ValueProvider, evaluate @@ -48,8 +58,6 @@ TimeSum, ) -T = TypeVar("T") - @dataclass(frozen=True) class TermKeyEfficient: @@ -294,47 +302,86 @@ def generate_key(term: TermEfficient) -> TermKeyEfficient: ) +@dataclass(frozen=True) +class PortFieldId: + port_name: str + field_name: str + + +@dataclass(eq=True, frozen=True) +class PortFieldKey: + """ + Identifies the expression node for one component and one port variable. + """ + + component_id: str + port_variable_id: PortFieldId + + +@dataclass(frozen=True) +class PortFieldTerm: + coefficient: ExpressionNodeEfficient + port_name: str + field_name: str + aggregator: Optional[PortAggregator] = None + + def __str__(self) -> str: + result = f"{self.port_name}.{self.field_name}" + if self.aggregator is not None: + result += f".{str(self.aggregator)}" + return result + + def sum_connections(self) -> "LinearExpressionEfficient": + if self.aggregator is not None: + raise ValueError(f"Port field {str(self)} already has a port aggregator") + return dataclasses.replace(self, aggregator=PortSum()) + + +T_val = TypeVar("T_val", bound=Union[TermEfficient, PortFieldTerm]) + + +@overload def _merge_dicts( lhs: Dict[TermKeyEfficient, TermEfficient], rhs: Dict[TermKeyEfficient, TermEfficient], merge_func: Callable[[TermEfficient, TermEfficient], TermEfficient], neutral: float, -) -> Dict[TermKeyEfficient, TermEfficient]: +) -> Dict[TermKeyEfficient, TermEfficient]: ... + + +@overload +def _merge_dicts( + lhs: Dict[PortFieldId, PortFieldTerm], + rhs: Dict[PortFieldId, PortFieldTerm], + merge_func: Callable[[PortFieldTerm, PortFieldTerm], PortFieldTerm], + neutral: float, +) -> Dict[PortFieldId, PortFieldTerm]: ... + + +def _get_neutral_term(term: T_val, neutral: float) -> T_val: + return dataclasses.replace(term, coefficient=neutral) + + +def _merge_dicts(lhs, rhs, merge_func, neutral): res = {} for k, v in lhs.items(): - res[k] = merge_func( - v, - rhs.get( - k, - TermEfficient( - neutral, - v.component_id, - v.variable_name, - v.structure, - v.time_operator, - v.time_aggregator, - v.scenario_operator, - ), - ), - ) + res[k] = merge_func(v, rhs.get(k, _get_neutral_term(v, neutral))) for k, v in rhs.items(): if k not in lhs: - res[k] = merge_func( - TermEfficient( - neutral, - v.component_id, - v.variable_name, - v.structure, - v.time_operator, - v.time_aggregator, - v.scenario_operator, - ), - v, - ) + res[k] = merge_func(_get_neutral_term(v, neutral), v) return res -def _merge_is_possible(lhs: TermEfficient, rhs: TermEfficient) -> None: +def _merge_is_possible(lhs: T_val, rhs: T_val) -> None: + if isinstance(lhs, TermEfficient) and isinstance(rhs, TermEfficient): + _merge_term_is_possible(lhs, rhs) + elif isinstance(lhs, PortFieldTerm) and isinstance(rhs, PortFieldTerm): + _merge_port_terms_is_possible(lhs, rhs) + else: + raise TypeError("Cannot merge terms of different types") + + +def _merge_term_is_possible(lhs: TermEfficient, rhs: TermEfficient) -> None: if lhs.component_id != rhs.component_id or lhs.variable_name != rhs.variable_name: raise ValueError("Cannot merge terms for different variables") if ( @@ -347,56 +394,21 @@ def _merge_is_possible(lhs: TermEfficient, rhs: TermEfficient) -> None: raise ValueError("Cannot merge terms with different structures") -def _add_terms(lhs: TermEfficient, rhs: TermEfficient) -> TermEfficient: - _merge_is_possible(lhs, rhs) - return TermEfficient( - lhs.coefficient + rhs.coefficient, - lhs.component_id, - lhs.variable_name, - lhs.structure, - lhs.time_operator, - lhs.time_aggregator, - lhs.scenario_operator, - ) +def _merge_port_terms_is_possible(lhs: PortFieldTerm, rhs: PortFieldTerm) -> None: + if lhs.port_name != rhs.port_name or lhs.field_name != rhs.field_name: + raise ValueError("Cannot merge terms for different ports") + if lhs.aggregator != rhs.aggregator: + raise ValueError("Cannot merge port terms with different aggregators") -def _substract_terms(lhs: TermEfficient, rhs: TermEfficient) -> TermEfficient: +def _add_terms(lhs: T_val, rhs: T_val) -> T_val: _merge_is_possible(lhs, rhs) - return TermEfficient( - lhs.coefficient - rhs.coefficient, - lhs.component_id, - lhs.variable_name, - lhs.structure, - lhs.time_operator, - lhs.time_aggregator, - lhs.scenario_operator, - ) - - -# TODO: Try to use PortField Id which is exactly the same ? -@dataclass(frozen=True) -class PortFieldKey: - port_name: str - field_name: str - - -@dataclass(frozen=True) -class PortFieldTerm: - coefficient: ExpressionNodeEfficient - port_name: str - field_name: str - aggregator: Optional[PortAggregator] = None + return dataclasses.replace(lhs, coefficient=lhs.coefficient + rhs.coefficient) - def __str__(self) -> str: - result = f"{self.port_name}.{self.field_name}" - if self.aggregator is not None: - result += f".{str(self.aggregator)}" - return result - def sum_connections(self) -> "LinearExpressionEfficient": - if self.aggregator is not None: - raise ValueError(f"Port field {str(self)} already has a port aggregator") - return dataclasses.replace(self, aggregator=PortSum()) +def _substract_terms(lhs: T_val, rhs: T_val) -> T_val: + _merge_is_possible(lhs, rhs) + return dataclasses.replace(lhs, coefficient=lhs.coefficient - rhs.coefficient) class LinearExpressionEfficient: @@ -418,7 +430,7 @@ class LinearExpressionEfficient: terms: Dict[TermKeyEfficient, TermEfficient] constant: ExpressionNodeEfficient - port_field_terms: Dict[PortFieldKey, PortFieldTerm] + port_field_terms: Dict[PortFieldId, PortFieldTerm] # TODO: We need to check that terms.key is indeed a TermKey and change the tests that this will break def __init__( @@ -428,7 +440,7 @@ def __init__( ] = None, constant: Optional[Union[float, ExpressionNodeEfficient]] = None, port_field_terms: Optional[ - Union[Dict[PortFieldKey, PortFieldTerm], List[PortFieldTerm]] + Union[Dict[PortFieldId, PortFieldTerm], List[PortFieldTerm]] ] = None, ) -> None: if constant is None: @@ -462,7 +474,7 @@ def __init__( elif isinstance(port_field_terms, list): for port_field_term in port_field_terms: self.port_field_terms[ - PortFieldKey( + PortFieldId( port_field_term.port_name, port_field_term.field_name ) ] = port_field_term @@ -526,8 +538,15 @@ def __iadd__( ) -> "LinearExpressionEfficient": rhs = wrap_in_linear_expr(rhs) self.constant += rhs.constant + aggregated_terms = _merge_dicts(self.terms, rhs.terms, _add_terms, 0) self.terms = aggregated_terms + + aggregated_port_terms = _merge_dicts( + self.port_field_terms, rhs.port_field_terms, _add_terms, 0 + ) + self.port_field_terms = aggregated_port_terms + self.remove_zeros_from_terms() return self @@ -547,8 +566,15 @@ def __isub__( ) -> "LinearExpressionEfficient": rhs = wrap_in_linear_expr(rhs) self.constant -= rhs.constant + aggregated_terms = _merge_dicts(self.terms, rhs.terms, _substract_terms, 0) self.terms = aggregated_terms + + aggregated_port_terms = _merge_dicts( + self.port_field_terms, rhs.port_field_terms, _substract_terms, 0 + ) + self.port_field_terms = aggregated_port_terms + self.remove_zeros_from_terms() return self @@ -573,14 +599,13 @@ def __imul__( ) -> "LinearExpressionEfficient": rhs = wrap_in_linear_expr(rhs) - if self.terms and rhs.terms: + if not (self.is_constant() or rhs.is_constant()): raise ValueError("Cannot multiply two non constant expression") else: - if self.terms: + if rhs.is_constant(): left_expr = self const_expr = rhs - else: - # It is possible that both expr are constant + else: # self is constant left_expr = rhs const_expr = self if is_zero(const_expr.constant): @@ -590,14 +615,13 @@ def __imul__( else: left_expr.constant *= const_expr.constant for term_key, term in left_expr.terms.items(): - left_expr.terms[term_key] = TermEfficient( - term.coefficient * const_expr.constant, - term.component_id, - term.variable_name, - term.structure, - term.time_operator, - term.time_aggregator, - term.scenario_operator, + left_expr.terms[term_key] = dataclasses.replace( + term, coefficient=term.coefficient * const_expr.constant + ) + for port_term_key, port_term in left_expr.port_field_terms.items(): + left_expr.port_field_terms[port_term_key] = dataclasses.replace( + port_term, + coefficient=port_term.coefficient * const_expr.constant, ) _copy_expression(left_expr, self) return self @@ -618,7 +642,7 @@ def __itruediv__( ) -> "LinearExpressionEfficient": rhs = wrap_in_linear_expr(rhs) - if rhs.terms: + if not rhs.is_constant(): raise ValueError("Cannot divide by a non constant expression") else: if is_zero(rhs.constant): @@ -628,14 +652,12 @@ def __itruediv__( else: self.constant /= rhs.constant for term_key, term in self.terms.items(): - self.terms[term_key] = TermEfficient( - term.coefficient / rhs.constant, - term.component_id, - term.variable_name, - term.structure, - term.time_operator, - term.time_aggregator, - term.scenario_operator, + self.terms[term_key] = dataclasses.replace( + term, coefficient=term.coefficient / rhs.constant + ) + for port_term_key, port_term in self.port_field_terms.items(): + self.port_field_terms[port_term_key] = dataclasses.replace( + port_term, coefficient=port_term.coefficient / rhs.constant ) return self @@ -656,6 +678,9 @@ def remove_zeros_from_terms(self) -> None: for term_key, term in self.terms.copy().items(): if is_zero(term.coefficient): del self.terms[term_key] + for port_term_key, port_term in self.port_field_terms.copy().items(): + if is_zero(port_term.coefficient): + del self.port_field_terms[port_term_key] def evaluate(self, context: ValueProvider) -> float: return sum([term.evaluate(context) for term in self.terms.values()]) + evaluate( @@ -664,7 +689,7 @@ def evaluate(self, context: ValueProvider) -> float: def is_constant(self) -> bool: # Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree... - return not self.terms + return not self.terms and not self.port_field_terms def is_unbound(self) -> bool: return is_unbound(self.constant) @@ -858,6 +883,34 @@ def sum_connections(self) -> "LinearExpressionEfficient": port_field_terms[port_field_key] = port_field_value.sum_connections() return LinearExpressionEfficient(port_field_terms=port_field_terms) + def resolve_port( + self, + component_id: str, + ports_expressions: Dict[PortFieldKey, List["LinearExpressionEfficient"]], + ) -> "LinearExpressionEfficient": + port_expr = LinearExpressionEfficient() + for port_term in self.port_field_terms.values(): + expressions = ports_expressions.get( + PortFieldKey( + component_id, + PortFieldId(port_term.port_name, port_term.field_name), + ), + [], + ) + if port_term.aggregator is None: + if len(expressions) != 1: + raise ValueError( + f"Invalid number of expression for port : {port_term.port_name}" + ) + else: + if port_term.aggregator != PortSum(): + raise NotImplementedError("Only PortSum is supported.") + + port_expr += sum_expressions( + [port_term.coefficient * expression for expression in expressions] + ) + return self + port_expr + def linear_expressions_equal( lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient diff --git a/src/andromede/expression/port_resolver.py b/src/andromede/expression/port_resolver.py index 02755488..806626f6 100644 --- a/src/andromede/expression/port_resolver.py +++ b/src/andromede/expression/port_resolver.py @@ -10,19 +10,8 @@ # # This file is part of the Antares project. -from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List -from andromede.expression import CopyVisitor, visit -from andromede.expression.expression_efficient import ( - PortFieldAggregatorNode, - PortFieldNode, -) -from andromede.expression.linear_expression_efficient import ( - LinearExpressionEfficient, - sum_expressions, -) from andromede.model.model import PortFieldId @@ -36,51 +25,51 @@ class PortFieldKey: port_variable_id: PortFieldId -@dataclass(frozen=True) -class PortResolver(CopyVisitor): - """ - Duplicates the AST with replacement of port field nodes by - their corresponding expression. - """ +# @dataclass(frozen=True) +# class PortResolver(CopyVisitor): +# """ +# Duplicates the AST with replacement of port field nodes by +# their corresponding expression. +# """ - component_id: str - ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]] +# component_id: str +# ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]] - def port_field(self, node: PortFieldNode) -> LinearExpressionEfficient: - expressions = self.ports_expressions[ - PortFieldKey( - self.component_id, PortFieldId(node.port_name, node.field_name) - ) - ] - if len(expressions) != 1: - raise ValueError( - f"Invalid number of expression for port : {node.port_name}" - ) - else: - return expressions[0] +# def port_field(self, node: PortFieldNode) -> LinearExpressionEfficient: +# expressions = self.ports_expressions[ +# PortFieldKey( +# self.component_id, PortFieldId(node.port_name, node.field_name) +# ) +# ] +# if len(expressions) != 1: +# raise ValueError( +# f"Invalid number of expression for port : {node.port_name}" +# ) +# else: +# return expressions[0] - def port_field_aggregator( - self, node: PortFieldAggregatorNode - ) -> LinearExpressionEfficient: - if node.aggregator != "PortSum": - raise NotImplementedError("Only PortSum is supported.") - port_field_node = node.operand - if not isinstance(port_field_node, PortFieldNode): - raise ValueError(f"Should be a portFieldNode : {port_field_node}") +# def port_field_aggregator( +# self, node: PortFieldAggregatorNode +# ) -> LinearExpressionEfficient: +# if node.aggregator != "PortSum": +# raise NotImplementedError("Only PortSum is supported.") +# port_field_node = node.operand +# if not isinstance(port_field_node, PortFieldNode): +# raise ValueError(f"Should be a portFieldNode : {port_field_node}") - expressions = self.ports_expressions.get( - PortFieldKey( - self.component_id, - PortFieldId(port_field_node.port_name, port_field_node.field_name), - ), - [], - ) - return sum_expressions(expressions) +# expressions = self.ports_expressions.get( +# PortFieldKey( +# self.component_id, +# PortFieldId(port_field_node.port_name, port_field_node.field_name), +# ), +# [], +# ) +# return sum_expressions(expressions) -def resolve_port( - expression: LinearExpressionEfficient, - component_id: str, - ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]], -) -> LinearExpressionEfficient: - return visit(expression, PortResolver(component_id, ports_expressions)) +# def resolve_port( +# expression: LinearExpressionEfficient, +# component_id: str, +# ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]], +# ) -> LinearExpressionEfficient: +# return visit(expression, PortResolver(component_id, ports_expressions)) diff --git a/src/andromede/expression/time_operator.py b/src/andromede/expression/time_operator.py index fa45b599..3b8e7bce 100644 --- a/src/andromede/expression/time_operator.py +++ b/src/andromede/expression/time_operator.py @@ -40,7 +40,7 @@ def key(self) -> Tuple[int, ...]: return self.time_ids def size(self) -> int: - return len(self.time_ids) + return len(self.time_ids.expressions) @dataclass(frozen=True) diff --git a/tests/functional/test_performance_efficient.py b/tests/functional/test_performance_efficient.py index 412f092c..6b0a1b4a 100644 --- a/tests/functional/test_performance_efficient.py +++ b/tests/functional/test_performance_efficient.py @@ -15,7 +15,11 @@ from andromede.expression.evaluate import EvaluationContext from andromede.expression.expression_efficient import param -from andromede.expression.linear_expression_efficient import literal, var +from andromede.expression.linear_expression_efficient import ( + literal, + var, + wrap_in_linear_expr, +) def test_large_number_of_parameters_sum() -> None: @@ -32,7 +36,7 @@ def test_large_number_of_parameters_sum() -> None: # Still the recursion depth error with parameters with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): - expr = sum(param(f"cost_{i}") for i in range(1, nb_terms)) + expr = sum(wrap_in_linear_expr(param(f"cost_{i}")) for i in range(1, nb_terms)) expr.evaluate(EvaluationContext(parameters=parameters_value)) @@ -46,7 +50,7 @@ def test_large_number_of_identical_parameters_sum() -> None: # Still the recursion depth error with parameters # with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): - expr = sum(param("cost") for _ in range(nb_terms)) + expr = sum(wrap_in_linear_expr(param("cost")) for _ in range(nb_terms)) assert expr.evaluate(EvaluationContext(parameters=parameters_value)) == nb_terms @@ -58,7 +62,7 @@ def test_large_number_of_literal_sum() -> None: # # Still the recursion depth error with parameters # with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): - expr = sum(literal(1) for _ in range(nb_terms)) + expr = sum(wrap_in_linear_expr(literal(1)) for _ in range(nb_terms)) assert expr.evaluate(EvaluationContext()) == nb_terms diff --git a/tests/unittests/expressions/test_equality.py b/tests/unittests/expressions/test_equality.py index 8b9f0360..1b738999 100644 --- a/tests/unittests/expressions/test_equality.py +++ b/tests/unittests/expressions/test_equality.py @@ -13,39 +13,57 @@ import pytest - from andromede.expression.copy import copy_expression from andromede.expression.equality import expressions_equal -from andromede.expression.expression import ( +from andromede.expression.expression_efficient import ( + ExpressionNodeEfficient, + InstancesTimeIndex, TimeAggregatorNode, + TimeOperatorNode, expression_range, + literal, + param, ) -from andromede.expression.expression_efficient import literal, param -from andromede.expression.linear_expression_efficient import LinearExpressionEfficient, var -def shifted_x(): - return var("x").shift(expression_range(0, 2)) +def shifted_param() -> ExpressionNodeEfficient: + return TimeOperatorNode( + param("q"), "TimeShift", InstancesTimeIndex(expression_range(0, 2)) + ) @pytest.mark.parametrize( "expr", [ - var("x"), + param("q"), param("p"), - var("x") + 1, - var("x") - 1, - var("x") / 2, - var("x") * 3, - var("x").shift(expression_range(1, 10, 2)).sum(), - var("x").shift(expression_range(1, param("p"))).sum(), - TimeAggregatorNode(shifted_x(), name="TimeSum", stay_roll=True), - TimeAggregatorNode(shifted_x(), name="TimeAggregator", stay_roll=True), - var("x") + 5 <= 2, - var("x").expec(), + param("q") + 1, + param("q") - 1, + param("q") / 2, + param("q") * 3, + TimeAggregatorNode( + TimeOperatorNode( + param("q"), "TimeShift", InstancesTimeIndex(expression_range(1, 10, 2)) + ), + "TimeSum", + stay_roll=True, + ), + TimeAggregatorNode( + TimeOperatorNode( + param("q"), + "TimeShift", + InstancesTimeIndex(expression_range(1, param("p"))), + ), + "TimeSum", + stay_roll=True, + ), + TimeAggregatorNode(shifted_param(), name="TimeSum", stay_roll=True), + TimeAggregatorNode(shifted_param(), name="TimeAggregator", stay_roll=True), + param("q") + 5 <= 2, + param("q").expec(), ], ) -def test_equals(expr: LinearExpressionEfficient) -> None: +def test_equals(expr: ExpressionNodeEfficient) -> None: copy = copy_expression(expr) assert expressions_equal(expr, copy) @@ -53,33 +71,67 @@ def test_equals(expr: LinearExpressionEfficient) -> None: @pytest.mark.parametrize( "rhs, lhs", [ - (var("x"), var("y")), + (param("q"), param("y")), (literal(1), literal(2)), - (var("x") + 1, var("x")), + (param("q") + 1, param("q")), ( - var("x").shift(expression_range(1, param("p"))).sum(), - var("x").shift(expression_range(1, param("q"))).sum(), + TimeAggregatorNode( + TimeOperatorNode( + param("q"), + "TimeShift", + InstancesTimeIndex(expression_range(1, param("p"))), + ), + "TimeSum", + stay_roll=True, + ), + TimeAggregatorNode( + TimeOperatorNode( + param("q"), + "TimeShift", + InstancesTimeIndex(expression_range(1, param("q"))), + ), + "TimeSum", + stay_roll=True, + ), ), ( - var("x").shift(expression_range(1, 10, 2)).sum(), - var("x").shift(expression_range(1, 10, 3)).sum(), + TimeAggregatorNode( + TimeOperatorNode( + param("q"), + "TimeShift", + InstancesTimeIndex(expression_range(1, 10, 2)), + ), + "TimeSum", + stay_roll=True, + ), + TimeAggregatorNode( + TimeOperatorNode( + param("q"), + "TimeShift", + InstancesTimeIndex(expression_range(1, 10, 3)), + ), + "TimeSum", + stay_roll=True, + ), ), ( - TimeAggregatorNode(shifted_x(), name="TimeSum", stay_roll=True), - TimeAggregatorNode(shifted_x(), name="TimeSum", stay_roll=False), + TimeAggregatorNode(shifted_param(), name="TimeSum", stay_roll=True), + TimeAggregatorNode(shifted_param(), name="TimeSum", stay_roll=False), ), ( - TimeAggregatorNode(shifted_x(), name="TimeSum", stay_roll=True), - TimeAggregatorNode(shifted_x(), name="TimeAggregator", stay_roll=True), + TimeAggregatorNode(shifted_param(), name="TimeSum", stay_roll=True), + TimeAggregatorNode(shifted_param(), name="TimeAggregator", stay_roll=True), ), - (var("x").expec(), var("y").expec()), + (param("q").expec(), param("y").expec()), ], ) -def test_not_equals(lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient) -> None: +def test_not_equals( + lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient +) -> None: assert not expressions_equal(lhs, rhs) -def test_tolerance(): +def test_tolerance() -> None: assert expressions_equal(literal(10), literal(10.09), abs_tol=0.1) assert not expressions_equal(literal(10), literal(10.11), abs_tol=0.1) assert expressions_equal(literal(10), literal(10.9), rel_tol=0.1) diff --git a/tests/unittests/expressions/test_linear_expressions_efficient.py b/tests/unittests/expressions/test_linear_expressions_efficient.py index f894dd6c..36fdad96 100644 --- a/tests/unittests/expressions/test_linear_expressions_efficient.py +++ b/tests/unittests/expressions/test_linear_expressions_efficient.py @@ -14,10 +14,20 @@ import pytest +from andromede.expression.expression_efficient import ( + TimeAggregatorNode, + expression_range, + param, +) from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, + PortFieldId, + PortFieldTerm, TermEfficient, + _copy_expression, linear_expressions_equal, + var, + wrap_in_linear_expr, ) from andromede.expression.scenario_operator import Expectation from andromede.expression.time_operator import TimeShift, TimeSum @@ -40,6 +50,26 @@ def test_affine_expression_printing_should_reflect_required_formatting( assert str(expr) == expec_str +@pytest.mark.parametrize( + "expr", + [ + var("x"), + wrap_in_linear_expr(param("p")), + var("x") + 1, + var("x") - 1, + var("x") / 2, + var("x") * 3, + var("x").sum(shift=expression_range(1, 10, 2)), + var("x").sum(shift=expression_range(1, param("p"))), + var("x").expec(), + ], +) +def test_linear_expressions_equal(expr: LinearExpressionEfficient) -> None: + copy = LinearExpressionEfficient() + _copy_expression(expr, copy) + assert linear_expressions_equal(expr, copy) + + @pytest.mark.parametrize( "lhs, rhs", [ @@ -85,6 +115,26 @@ def test_instantiate_linear_expression_from_dict( assert expr.constant == exp_constant +@pytest.mark.parametrize( + "expr, expected", + [ + (LinearExpressionEfficient(), True), + (LinearExpressionEfficient([]), True), + (LinearExpressionEfficient([], 0, {}), True), + (LinearExpressionEfficient([TermEfficient(1, "c", "x")], 0, {}), False), + (LinearExpressionEfficient([], 1, {}), False), + ( + LinearExpressionEfficient( + [], 1, {PortFieldId("p", "f"): PortFieldTerm(1, "p", "f")} + ), + False, + ), + ], +) +def test_is_zero(expr: LinearExpressionEfficient, expected: bool) -> None: + assert expr.is_zero() == expected + + @pytest.mark.parametrize( "e1, e2, expected", [ @@ -169,12 +219,6 @@ def test_addition( assert linear_expressions_equal(e1 + e2, expected) -def test_addition_of_linear_expressions_with_different_number_of_instances_should_raise_value_error() -> ( - None -): - pass - - def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_from_terms() -> ( None ): diff --git a/tests/unittests/expressions/test_port_resolver.py b/tests/unittests/expressions/test_port_resolver.py index e38f20cb..64855220 100644 --- a/tests/unittests/expressions/test_port_resolver.py +++ b/tests/unittests/expressions/test_port_resolver.py @@ -12,37 +12,54 @@ from typing import Dict, List -from andromede.expression import ExpressionNode, var -from andromede.expression.equality import expressions_equal -from andromede.expression.expression import port_field -from andromede.expression.port_resolver import PortFieldKey, resolve_port -from andromede.model.model import PortFieldId - +import pytest -def test_port_field_resolution(): - ports_expressions: Dict[PortFieldKey, List[ExpressionNode]] = {} +from andromede.expression.equality import expressions_equal +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + PortFieldId, + PortFieldKey, + linear_expressions_equal, + port_field, + var, +) +@pytest.mark.parametrize( + "port_expr, expected", + [ + (port_field("port", "field") + 2, var("flow") + 2), + (port_field("port", "field") - 2, var("flow") - 2), + (port_field("port", "field") * 2, 2 * var("flow")), + (port_field("port", "field") / 2, var("flow") / 2), + (port_field("port", "field") * 0, LinearExpressionEfficient()), + ] +) +def test_port_field_resolution(port_expr: LinearExpressionEfficient, expected: LinearExpressionEfficient) -> None: + ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]] = {} + key = PortFieldKey("com_id", PortFieldId(field_name="field", port_name="port")) expression = var("flow") ports_expressions[key] = [expression] - expression_2 = port_field("port", "field") + 2 + print() + print(port_expr.resolve_port("com_id", ports_expressions)) + print(expected) - assert expressions_equal( - resolve_port(expression_2, "com_id", ports_expressions), var("flow") + 2 + assert linear_expressions_equal( + port_expr.resolve_port("com_id", ports_expressions), expected ) -def test_port_field_resolution_sum(): - ports_expressions: Dict[PortFieldKey, List[ExpressionNode]] = {} +def test_port_field_resolution_sum() -> None: + ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]] = {} key = PortFieldKey("com_id", PortFieldId(field_name="field", port_name="port")) ports_expressions[key] = [var("flow1"), var("flow2")] expression_2 = port_field("port", "field").sum_connections() - assert expressions_equal( - resolve_port(expression_2, "com_id", ports_expressions), + assert linear_expressions_equal( + expression_2.resolve_port("com_id", ports_expressions), var("flow1") + var("flow2"), ) From 13ea7de665d010624a8fe08ee62994fef157a3bf Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Thu, 25 Jul 2024 11:48:58 +0200 Subject: [PATCH 19/51] Fix circular imports --- src/andromede/expression/copy.py | 57 ++++----- src/andromede/expression/degree.py | 38 +++--- src/andromede/expression/equality.py | 10 +- src/andromede/expression/evaluate.py | 32 ++--- .../expression/expression_efficient.py | 115 +++++++++--------- .../expression/linear_expression_efficient.py | 34 ++++-- src/andromede/expression/port_operator.py | 1 - src/andromede/expression/port_resolver.py | 2 +- src/andromede/expression/visitor.py | 6 +- src/andromede/libs/standard_sc.py | 3 +- tests/models/test_electrolyzer.py | 3 +- .../parsing/test_expression_parsing.py | 31 ++--- tests/unittests/expressions/test_equality.py | 54 ++++---- .../expressions/test_expressions_efficient.py | 27 ++-- .../expressions/test_port_resolver.py | 23 ++-- tests/unittests/model/test_model_parsing.py | 3 +- 16 files changed, 227 insertions(+), 212 deletions(-) diff --git a/src/andromede/expression/copy.py b/src/andromede/expression/copy.py index c135ee59..97f29a09 100644 --- a/src/andromede/expression/copy.py +++ b/src/andromede/expression/copy.py @@ -11,56 +11,49 @@ # This file is part of the Antares project. from dataclasses import dataclass -from typing import List, Union, cast +from typing import List, cast -from .expression import ( - AdditionNode, +from .expression_efficient import ( ComparisonNode, ComponentParameterNode, - ComponentVariableNode, - DivisionNode, - ExpressionNode, + ExpressionNodeEfficient, ExpressionRange, InstancesTimeIndex, LiteralNode, - MultiplicationNode, - NegationNode, ParameterNode, PortFieldAggregatorNode, PortFieldNode, ScenarioOperatorNode, - SubstractionNode, TimeAggregatorNode, TimeOperatorNode, - VariableNode, ) -from .visitor import ExpressionVisitor, ExpressionVisitorOperations, T, visit +from .visitor import ExpressionVisitorOperations, visit @dataclass(frozen=True) -class CopyVisitor(ExpressionVisitorOperations[ExpressionNode]): +class CopyVisitor(ExpressionVisitorOperations[ExpressionNodeEfficient]): """ Simply copies the whole AST. """ - def literal(self, node: LiteralNode) -> ExpressionNode: + def literal(self, node: LiteralNode) -> ExpressionNodeEfficient: return LiteralNode(node.value) - def comparison(self, node: ComparisonNode) -> ExpressionNode: + def comparison(self, node: ComparisonNode) -> ExpressionNodeEfficient: return ComparisonNode( visit(node.left, self), visit(node.right, self), node.comparator ) - def variable(self, node: VariableNode) -> ExpressionNode: - return VariableNode(node.name) + # def variable(self, node: VariableNode) -> ExpressionNodeEfficient: + # return VariableNode(node.name) - def parameter(self, node: ParameterNode) -> ExpressionNode: + def parameter(self, node: ParameterNode) -> ExpressionNodeEfficient: return ParameterNode(node.name) - def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode: - return ComponentVariableNode(node.component_id, node.name) + # def comp_variable(self, node: ComponentVariableNode) -> ExpressionNodeEfficient: + # return ComponentVariableNode(node.component_id, node.name) - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: + def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNodeEfficient: return ComponentParameterNode(node.component_id, node.name) def copy_expression_range( @@ -69,9 +62,11 @@ def copy_expression_range( return ExpressionRange( start=visit(expression_range.start, self), stop=visit(expression_range.stop, self), - step=visit(expression_range.step, self) - if expression_range.step is not None - else None, + step=( + visit(expression_range.step, self) + if expression_range.step is not None + else None + ), ) def copy_instances_index( @@ -81,30 +76,32 @@ def copy_instances_index( if isinstance(expressions, ExpressionRange): return InstancesTimeIndex(self.copy_expression_range(expressions)) if isinstance(expressions, list): - expressions_list = cast(List[ExpressionNode], expressions) + expressions_list = cast(List[ExpressionNodeEfficient], expressions) copy = [visit(e, self) for e in expressions_list] return InstancesTimeIndex(copy) raise ValueError("Unexpected type in instances index") - def time_operator(self, node: TimeOperatorNode) -> ExpressionNode: + def time_operator(self, node: TimeOperatorNode) -> ExpressionNodeEfficient: return TimeOperatorNode( visit(node.operand, self), node.name, self.copy_instances_index(node.instances_index), ) - def time_aggregator(self, node: TimeAggregatorNode) -> ExpressionNode: + def time_aggregator(self, node: TimeAggregatorNode) -> ExpressionNodeEfficient: return TimeAggregatorNode(visit(node.operand, self), node.name, node.stay_roll) - def scenario_operator(self, node: ScenarioOperatorNode) -> ExpressionNode: + def scenario_operator(self, node: ScenarioOperatorNode) -> ExpressionNodeEfficient: return ScenarioOperatorNode(visit(node.operand, self), node.name) - def port_field(self, node: PortFieldNode) -> ExpressionNode: + def port_field(self, node: PortFieldNode) -> ExpressionNodeEfficient: return PortFieldNode(node.port_name, node.field_name) - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode: + def port_field_aggregator( + self, node: PortFieldAggregatorNode + ) -> ExpressionNodeEfficient: return PortFieldAggregatorNode(visit(node.operand, self), node.aggregator) -def copy_expression(expression: ExpressionNode) -> ExpressionNode: +def copy_expression(expression: ExpressionNodeEfficient) -> ExpressionNodeEfficient: return visit(expression, CopyVisitor()) diff --git a/src/andromede/expression/degree.py b/src/andromede/expression/degree.py index cfd175cd..572a58b6 100644 --- a/src/andromede/expression/degree.py +++ b/src/andromede/expression/degree.py @@ -11,28 +11,26 @@ # This file is part of the Antares project. import andromede.expression.scenario_operator -from andromede.expression.expression import ( - ComponentParameterNode, - ComponentVariableNode, - PortFieldAggregatorNode, - PortFieldNode, - TimeOperatorNode, -) - -from .expression import ( +from andromede.expression.expression_efficient import ( AdditionNode, ComparisonNode, + ComponentParameterNode, DivisionNode, - ExpressionNode, + ExpressionNodeEfficient, LiteralNode, MultiplicationNode, NegationNode, ParameterNode, + PortFieldAggregatorNode, + PortFieldNode, ScenarioOperatorNode, SubstractionNode, + TimeAggregatorName, TimeAggregatorNode, - VariableNode, + TimeOperatorName, + TimeOperatorNode, ) + from .visitor import ExpressionVisitor, T, visit @@ -66,26 +64,26 @@ def division(self, node: DivisionNode) -> int: def comparison(self, node: ComparisonNode) -> int: return max(visit(node.left, self), visit(node.right, self)) - def variable(self, node: VariableNode) -> int: - return 1 + # def variable(self, node: VariableNode) -> int: + # return 1 def parameter(self, node: ParameterNode) -> int: return 0 - def comp_variable(self, node: ComponentVariableNode) -> int: - return 1 + # def comp_variable(self, node: ComponentVariableNode) -> int: + # return 1 def comp_parameter(self, node: ComponentParameterNode) -> int: return 0 def time_operator(self, node: TimeOperatorNode) -> int: - if node.name in ["TimeShift", "TimeEvaluation"]: + if node.name in [TimeOperatorName.SHIFT, TimeOperatorName.EVALUATION]: return visit(node.operand, self) else: return NotImplemented def time_aggregator(self, node: TimeAggregatorNode) -> int: - if node.name in ["TimeSum"]: + if node.name in [TimeAggregatorName.TIME_SUM]: return visit(node.operand, self) else: return NotImplemented @@ -104,18 +102,18 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> int: return visit(node.operand, self) -def compute_degree(expression: ExpressionNode) -> int: +def compute_degree(expression: ExpressionNodeEfficient) -> int: return visit(expression, ExpressionDegreeVisitor()) -def is_constant(expr: ExpressionNode) -> bool: +def is_constant(expr: ExpressionNodeEfficient) -> bool: """ True if the expression has no variable. """ return compute_degree(expr) == 0 -def is_linear(expr: ExpressionNode) -> bool: +def is_linear(expr: ExpressionNodeEfficient) -> bool: """ True if the expression is linear with respect to variables. """ diff --git a/src/andromede/expression/equality.py b/src/andromede/expression/equality.py index c558f346..42b8155f 100644 --- a/src/andromede/expression/equality.py +++ b/src/andromede/expression/equality.py @@ -99,7 +99,9 @@ def visit( # return self.variable(left, right) if isinstance(left, ParameterNode) and isinstance(right, ParameterNode): return self.parameter(left, right) - if isinstance(left, ComponentParameterNode) and isinstance(right, ComponentParameterNode): + if isinstance(left, ComponentParameterNode) and isinstance( + right, ComponentParameterNode + ): return self.comp_parameter(left, right) if isinstance(left, TimeOperatorNode) and isinstance(right, TimeOperatorNode): return self.time_operator(left, right) @@ -154,8 +156,10 @@ def comparison(self, left: ComparisonNode, right: ComparisonNode) -> bool: def parameter(self, left: ParameterNode, right: ParameterNode) -> bool: return left.name == right.name - - def comp_parameter(self, left: ComponentParameterNode, right: ComponentParameterNode) -> bool: + + def comp_parameter( + self, left: ComponentParameterNode, right: ComponentParameterNode + ) -> bool: return left.component_id == right.component_id and left.name == right.name def expression_range(self, left: ExpressionRange, right: ExpressionRange) -> bool: diff --git a/src/andromede/expression/evaluate.py b/src/andromede/expression/evaluate.py index b51c0e86..389e4197 100644 --- a/src/andromede/expression/evaluate.py +++ b/src/andromede/expression/evaluate.py @@ -14,28 +14,20 @@ from dataclasses import dataclass, field from typing import Dict -from andromede.expression.expression import ( - ComponentParameterNode, - ComponentVariableNode, - PortFieldAggregatorNode, - PortFieldNode, - TimeOperatorNode, -) - -from .expression import ( - AdditionNode, +from andromede.expression.expression import VariableNode +from andromede.expression.expression_efficient import ( ComparisonNode, - DivisionNode, - ExpressionNode, + ComponentParameterNode, + ExpressionNodeEfficient, LiteralNode, - MultiplicationNode, - NegationNode, ParameterNode, + PortFieldAggregatorNode, + PortFieldNode, ScenarioOperatorNode, - SubstractionNode, TimeAggregatorNode, - VariableNode, + TimeOperatorNode, ) + from .visitor import ExpressionVisitor, ExpressionVisitorOperations, T, visit @@ -117,8 +109,8 @@ def parameter(self, node: ParameterNode) -> float: def comp_parameter(self, node: ComponentParameterNode) -> float: return self.context.get_component_parameter_value(node.component_id, node.name) - def comp_variable(self, node: ComponentVariableNode) -> float: - return self.context.get_component_variable_value(node.component_id, node.name) + # def comp_variable(self, node: ComponentVariableNode) -> float: + # return self.context.get_component_variable_value(node.component_id, node.name) def time_operator(self, node: TimeOperatorNode) -> float: raise NotImplementedError() @@ -136,7 +128,9 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> float: raise NotImplementedError() -def evaluate(expression: ExpressionNode, value_provider: ValueProvider) -> float: +def evaluate( + expression: ExpressionNodeEfficient, value_provider: ValueProvider +) -> float: return visit(expression, EvaluationVisitor(value_provider)) diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index ffd95795..7d516467 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -14,15 +14,10 @@ Defines the model for generic expressions. """ import enum -import inspect import math from dataclasses import dataclass from typing import Any, Callable, List, Optional, Union -import andromede.expression.port_operator -import andromede.expression.scenario_operator -import andromede.expression.time_operator - EPS = 10 ** (-16) @@ -80,15 +75,20 @@ def __eq__(self, rhs: Any) -> "ExpressionNodeEfficient": # type: ignore def sum(self) -> "ExpressionNodeEfficient": if isinstance(self, TimeOperatorNode): - return TimeAggregatorNode(self, "TimeSum", stay_roll=True) + return TimeAggregatorNode(self, TimeAggregatorName.TIME_SUM, stay_roll=True) else: return _apply_if_node( - self, lambda x: TimeAggregatorNode(x, "TimeSum", stay_roll=False) + self, + lambda x: TimeAggregatorNode( + x, TimeAggregatorName.TIME_SUM, stay_roll=False + ), ) def sum_connections(self) -> "ExpressionNodeEfficient": if isinstance(self, PortFieldNode): - return PortFieldAggregatorNode(self, aggregator="PortSum") + return PortFieldAggregatorNode( + self, aggregator=PortFieldAggregatorName.PORT_SUM + ) raise ValueError( f"sum_connections() applies only for PortFieldNode, whereas the current node is of type {type(self)}." ) @@ -104,7 +104,9 @@ def shift( ) -> "ExpressionNodeEfficient": return _apply_if_node( self, - lambda x: TimeOperatorNode(x, "TimeShift", InstancesTimeIndex(expressions)), + lambda x: TimeOperatorNode( + x, TimeOperatorName.SHIFT, InstancesTimeIndex(expressions) + ), ) def eval( @@ -119,15 +121,19 @@ def eval( return _apply_if_node( self, lambda x: TimeOperatorNode( - x, "TimeEvaluation", InstancesTimeIndex(expressions) + x, TimeOperatorName.EVALUATION, InstancesTimeIndex(expressions) ), ) def expec(self) -> "ExpressionNodeEfficient": - return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) + return _apply_if_node( + self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.EXPECTATION) + ) def variance(self) -> "ExpressionNodeEfficient": - return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) + return _apply_if_node( + self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.Variance) + ) def wrap_in_node(obj: Any) -> ExpressionNodeEfficient: @@ -163,7 +169,7 @@ def is_minus_one(node: ExpressionNodeEfficient) -> bool: return isinstance(node, LiteralNode) and math.isclose(node.value, -1) -def _negate_node(node: ExpressionNodeEfficient): +def _negate_node(node: ExpressionNodeEfficient) -> ExpressionNodeEfficient: if isinstance(node, LiteralNode): return LiteralNode(-node.value) elif isinstance(node, NegationNode): @@ -380,21 +386,19 @@ class UnaryOperatorNode(ExpressionNodeEfficient): operand: ExpressionNodeEfficient +class PortFieldAggregatorName(enum.Enum): + # String value of enum must match the name of the PortAggregator class in port_operator.py + PORT_SUM = "PortSum" + + @dataclass(frozen=True, eq=False) class PortFieldAggregatorNode(UnaryOperatorNode): - aggregator: str + aggregator: PortFieldAggregatorName def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.port_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.port_operator.PortAggregator) - ] - if self.aggregator not in valid_names: - raise NotImplementedError( - f"{self.aggregator} is not a valid port aggregator, valid port aggregators are {valid_names}" + if not isinstance(self.aggregator, PortFieldAggregatorName): + raise TypeError( + f"PortFieldAggregatorNode.name should of class PortFieldAggregatorName, but {self.aggregator} of type {type(self.aggregator)} was given" ) @@ -442,7 +446,6 @@ class DivisionNode(BinaryOperatorNode): @dataclass(frozen=True, eq=False) class ExpressionRange: - start: ExpressionNodeEfficient stop: ExpressionNodeEfficient step: Optional[ExpressionNodeEfficient] = None @@ -520,59 +523,55 @@ def is_simple(self) -> bool: return False +class TimeOperatorName(enum.Enum): + # String value of enum must match the name of the TimeOperator class in time_operator.py + SHIFT = "TimeShift" + EVALUATION = "TimeEvaluation" + + +class TimeAggregatorName(enum.Enum): + # String value of enum must match the name of the TimeAggregator class in time_operator.py + TIME_SUM = "TimeSum" + + @dataclass(frozen=True, eq=False) class TimeOperatorNode(UnaryOperatorNode): - name: str + name: TimeOperatorName instances_index: InstancesTimeIndex def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.time_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.time_operator.TimeOperator) - ] - if self.name not in valid_names: - raise ValueError( - f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" + if not isinstance(self.name, TimeOperatorName): + raise TypeError( + f"TimeOperatorNode.name should of class TimeOperatorName, but {self.name} of type {type(self.name)} was given" ) @dataclass(frozen=True, eq=False) class TimeAggregatorNode(UnaryOperatorNode): - name: str + name: TimeAggregatorName stay_roll: bool def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.time_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.time_operator.TimeAggregator) - ] - if self.name not in valid_names: - raise ValueError( - f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" + if not isinstance(self.name, TimeAggregatorName): + raise TypeError( + f"TimeAggregatorNode.name should of class TimeAggregatorName, but {self.name} of type {type(self.name)} was given" ) +class ScenarioOperatorName(enum.Enum): + # String value of enum must match the name of the ScenarioOperator class in scenario_operator.py + EXPECTATION = "Expectation" + VARIANCE = "Variance" + + @dataclass(frozen=True, eq=False) class ScenarioOperatorNode(UnaryOperatorNode): - name: str + name: ScenarioOperatorName def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.scenario_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.scenario_operator.ScenarioOperator) - ] - if self.name not in valid_names: - raise ValueError( - f"{self.name} is not a valid scenario operator, valid scenario operators are {valid_names}" + if not isinstance(self.name, ScenarioOperatorName): + raise TypeError( + f"ScenarioOperatorNode.name should of class ScenarioOperatorName, but {self.name} of type {type(self.name)} was given" ) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 29c8924c..823a9914 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -35,8 +35,11 @@ ExpressionRange, InstancesTimeIndex, LiteralNode, + ScenarioOperatorName, ScenarioOperatorNode, + TimeAggregatorName, TimeAggregatorNode, + TimeOperatorName, TimeOperatorNode, is_minus_one, is_one, @@ -145,7 +148,6 @@ def evaluate(self, context: ValueProvider) -> float: def compute_indexation( self, provider: IndexingStructureProvider ) -> IndexingStructure: - return IndexingStructure( self._compute_time_indexing(provider), self._compute_scenario_indexing(provider), @@ -206,7 +208,7 @@ def sum( return dataclasses.replace( self, coefficient=TimeOperatorNode( - self.coefficient, "TimeShift", InstancesTimeIndex(shift) + self.coefficient, TimeOperatorName.SHIFT, InstancesTimeIndex(shift) ), time_operator=TimeShift(InstancesTimeIndex(shift)), time_aggregator=TimeSum(stay_roll=True), @@ -215,7 +217,9 @@ def sum( return dataclasses.replace( self, coefficient=TimeOperatorNode( - self.coefficient, "TimeEvaluation", InstancesTimeIndex(eval) + self.coefficient, + TimeOperatorName.EVALUATION, + InstancesTimeIndex(eval), ), time_operator=TimeEvaluation(InstancesTimeIndex(eval)), time_aggregator=TimeSum(stay_roll=True), @@ -346,7 +350,8 @@ def _merge_dicts( rhs: Dict[TermKeyEfficient, TermEfficient], merge_func: Callable[[TermEfficient, TermEfficient], TermEfficient], neutral: float, -) -> Dict[TermKeyEfficient, TermEfficient]: ... +) -> Dict[TermKeyEfficient, TermEfficient]: + ... @overload @@ -355,7 +360,8 @@ def _merge_dicts( rhs: Dict[PortFieldId, PortFieldTerm], merge_func: Callable[[PortFieldTerm, PortFieldTerm], PortFieldTerm], neutral: float, -) -> Dict[PortFieldId, PortFieldTerm]: ... +) -> Dict[PortFieldId, PortFieldTerm]: + ... def _get_neutral_term(term: T_val, neutral: float) -> T_val: @@ -746,10 +752,10 @@ def sum( result_constant = TimeAggregatorNode( TimeOperatorNode( self.constant, - "TimeShift", + TimeOperatorName.SHIFT, InstancesTimeIndex(shift), ), - "TimeSum", + TimeAggregatorName.TIME_SUM, stay_roll=True, ) elif eval is not None: @@ -758,10 +764,10 @@ def sum( result_constant = TimeAggregatorNode( TimeOperatorNode( self.constant, - "TimeEvaluation", + TimeOperatorName.EVALUATION, InstancesTimeIndex(eval), ), - "TimeSum", + TimeAggregatorName.TIME_SUM, stay_roll=True, ) else: # x.sum() -> Sum over all time block @@ -769,7 +775,7 @@ def sum( result_constant = TimeAggregatorNode( self.constant, - "TimeSum", + TimeAggregatorName.TIME_SUM, stay_roll=False, ) @@ -799,7 +805,7 @@ def _apply_operator( # def sum_connections(self) -> "ExpressionNode": # if isinstance(self, PortFieldNode): - # return PortFieldAggregatorNode(self, aggregator="PortSum") + # return PortFieldAggregatorNode(self, aggregator=PortFieldAggregatorName.PORT_SUM) # raise ValueError( # f"sum_connections() applies only for PortFieldNode, whereas the current node is of type {type(self)}." # ) @@ -866,12 +872,14 @@ def expec(self) -> "LinearExpressionEfficient": term_with_operator = term.expec() result_terms[generate_key(term_with_operator)] = term_with_operator - result_constant = ScenarioOperatorNode(self.constant, "Expectation") + result_constant = ScenarioOperatorNode( + self.constant, ScenarioOperatorName.EXPECTATION + ) result_expr = LinearExpressionEfficient(result_terms, result_constant) return result_expr # def variance(self) -> "ExpressionNode": - # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) + # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.Variance)) def sum_connections(self) -> "LinearExpressionEfficient": if not self.is_zero(): diff --git a/src/andromede/expression/port_operator.py b/src/andromede/expression/port_operator.py index c15245a9..56f18322 100644 --- a/src/andromede/expression/port_operator.py +++ b/src/andromede/expression/port_operator.py @@ -30,6 +30,5 @@ class PortAggregator: @dataclass(frozen=True) class PortSum(PortAggregator): - def __str__(self): return "PortSum" diff --git a/src/andromede/expression/port_resolver.py b/src/andromede/expression/port_resolver.py index 806626f6..54432748 100644 --- a/src/andromede/expression/port_resolver.py +++ b/src/andromede/expression/port_resolver.py @@ -51,7 +51,7 @@ class PortFieldKey: # def port_field_aggregator( # self, node: PortFieldAggregatorNode # ) -> LinearExpressionEfficient: -# if node.aggregator != "PortSum": +# if node.aggregator != PortFieldAggregatorName.PORT_SUM: # raise NotImplementedError("Only PortSum is supported.") # port_field_node = node.operand # if not isinstance(port_field_node, PortFieldNode): diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index 54b5f16f..29e95cee 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -213,9 +213,9 @@ def addition(self, node: AdditionNode) -> T_op: return left_value + right_value def substraction(self, node: SubstractionNode) -> T_op: - left_value = visit(node.left, self) - right_value = visit(node.right, self) - return left_value - right_value + left_value = visit(node.left, self) + right_value = visit(node.right, self) + return left_value - right_value def multiplication(self, node: MultiplicationNode) -> T_op: left_value = visit(node.left, self) diff --git a/src/andromede/libs/standard_sc.py b/src/andromede/libs/standard_sc.py index 7864adf8..b9f9a38d 100644 --- a/src/andromede/libs/standard_sc.py +++ b/src/andromede/libs/standard_sc.py @@ -11,9 +11,8 @@ # This file is part of the Antares project. -from andromede.expression.expression import port_field from andromede.expression.expression_efficient import literal, param -from andromede.expression.linear_expression_efficient import var +from andromede.expression.linear_expression_efficient import port_field, var from andromede.libs.standard import BALANCE_PORT_TYPE, CONSTANT from andromede.model import ( Constraint, diff --git a/tests/models/test_electrolyzer.py b/tests/models/test_electrolyzer.py index afea65fc..bc7d20c1 100644 --- a/tests/models/test_electrolyzer.py +++ b/tests/models/test_electrolyzer.py @@ -10,9 +10,8 @@ # # This file is part of the Antares project. -from andromede.expression.expression import port_field from andromede.expression.expression_efficient import literal, param -from andromede.expression.linear_expression_efficient import var +from andromede.expression.linear_expression_efficient import port_field, var from andromede.libs.standard import CONSTANT, TIME_AND_SCENARIO_FREE from andromede.model import ( Constraint, diff --git a/tests/unittests/expressions/parsing/test_expression_parsing.py b/tests/unittests/expressions/parsing/test_expression_parsing.py index eb31ab3f..95b1b93d 100644 --- a/tests/unittests/expressions/parsing/test_expression_parsing.py +++ b/tests/unittests/expressions/parsing/test_expression_parsing.py @@ -14,9 +14,12 @@ import pytest from andromede.expression.equality import expressions_equal -from andromede.expression.expression import ExpressionRange, port_field -from andromede.expression.expression_efficient import literal, param -from andromede.expression.linear_expression_efficient import LinearExpressionEfficient, var +from andromede.expression.expression_efficient import ExpressionRange, literal, param +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + port_field, + var, +) from andromede.expression.parsing.parse_expression import ( AntaresParseException, ModelIdentifiers, @@ -49,7 +52,7 @@ {"x"}, {}, "x[-1..5]", - var("x").eval(ExpressionRange(-literal(1), literal(5))), + var("x").sum(eval=ExpressionRange(-literal(1), literal(5))), ), ({"x"}, {}, "x[1]", var("x").eval(1)), ({"x"}, {}, "x[t-1]", var("x").shift(-literal(1))), @@ -57,13 +60,13 @@ {"x"}, {}, "x[t-1, t+4]", - var("x").shift([-literal(1), literal(4)]), + var("x").sum(shift=[-literal(1), literal(4)]), ), ( {"x"}, {}, "x[t-1+1]", - var("x").shift(-literal(1) + literal(1)), + var("x").sum(shift=-literal(1) + literal(1)), ), ( {"x"}, @@ -93,25 +96,25 @@ {"x"}, {}, "x[t-1, t, t+4]", - var("x").shift([-literal(1), literal(0), literal(4)]), + var("x").sum(shift=[-literal(1), literal(0), literal(4)]), ), ( {"x"}, {}, "x[t-1..t+5]", - var("x").shift(ExpressionRange(-literal(1), literal(5))), + var("x").sum(shift=ExpressionRange(-literal(1), literal(5))), ), ( {"x"}, {}, "x[t-1..t]", - var("x").shift(ExpressionRange(-literal(1), literal(0))), + var("x").sum(shift=ExpressionRange(-literal(1), literal(0))), ), ( {"x"}, {}, "x[t..t+5]", - var("x").shift(ExpressionRange(literal(0), literal(5))), + var("x").sum(shift=ExpressionRange(literal(0), literal(5))), ), ({"x"}, {}, "x[t]", var("x")), ({"x"}, {"p"}, "x[t+p]", var("x").shift(param("p"))), @@ -119,7 +122,7 @@ {"x"}, {}, "sum(x[-1..5])", - var("x").eval(ExpressionRange(-literal(1), literal(5))).sum(), + var("x").sum(eval=ExpressionRange(-literal(1), literal(5))).sum(), ), ({}, {}, "sum_connections(port.f)", port_field("port", "f").sum_connections()), ( @@ -136,9 +139,9 @@ {"nb_start", "nb_on"}, {"d_min_up"}, "sum(nb_start[-d_min_up + 1 .. 0]) <= nb_on", - var("nb_start") - .eval(ExpressionRange(-param("d_min_up") + 1, literal(0))) - .sum() + var("nb_start").sum( + eval=(ExpressionRange(-param("d_min_up") + 1, literal(0))) + ) <= var("nb_on"), ), ( diff --git a/tests/unittests/expressions/test_equality.py b/tests/unittests/expressions/test_equality.py index 1b738999..5952f28b 100644 --- a/tests/unittests/expressions/test_equality.py +++ b/tests/unittests/expressions/test_equality.py @@ -18,7 +18,9 @@ from andromede.expression.expression_efficient import ( ExpressionNodeEfficient, InstancesTimeIndex, + TimeAggregatorName, TimeAggregatorNode, + TimeOperatorName, TimeOperatorNode, expression_range, literal, @@ -28,7 +30,7 @@ def shifted_param() -> ExpressionNodeEfficient: return TimeOperatorNode( - param("q"), "TimeShift", InstancesTimeIndex(expression_range(0, 2)) + param("q"), TimeOperatorName.SHIFT, InstancesTimeIndex(expression_range(0, 2)) ) @@ -43,22 +45,28 @@ def shifted_param() -> ExpressionNodeEfficient: param("q") * 3, TimeAggregatorNode( TimeOperatorNode( - param("q"), "TimeShift", InstancesTimeIndex(expression_range(1, 10, 2)) + param("q"), + TimeOperatorName.SHIFT, + InstancesTimeIndex(expression_range(1, 10, 2)), ), - "TimeSum", + TimeAggregatorName.TIME_SUM, stay_roll=True, ), TimeAggregatorNode( TimeOperatorNode( param("q"), - "TimeShift", + TimeOperatorName.SHIFT, InstancesTimeIndex(expression_range(1, param("p"))), ), - "TimeSum", + TimeAggregatorName.TIME_SUM, stay_roll=True, ), - TimeAggregatorNode(shifted_param(), name="TimeSum", stay_roll=True), - TimeAggregatorNode(shifted_param(), name="TimeAggregator", stay_roll=True), + TimeAggregatorNode( + shifted_param(), name=TimeAggregatorName.TIME_SUM, stay_roll=True + ), + TimeAggregatorNode( + shifted_param(), name=TimeAggregatorName.TIME_SUM, stay_roll=True + ), param("q") + 5 <= 2, param("q").expec(), ], @@ -78,19 +86,19 @@ def test_equals(expr: ExpressionNodeEfficient) -> None: TimeAggregatorNode( TimeOperatorNode( param("q"), - "TimeShift", + TimeOperatorName.SHIFT, InstancesTimeIndex(expression_range(1, param("p"))), ), - "TimeSum", + TimeAggregatorName.TIME_SUM, stay_roll=True, ), TimeAggregatorNode( TimeOperatorNode( param("q"), - "TimeShift", + TimeOperatorName.SHIFT, InstancesTimeIndex(expression_range(1, param("q"))), ), - "TimeSum", + TimeAggregatorName.TIME_SUM, stay_roll=True, ), ), @@ -98,36 +106,34 @@ def test_equals(expr: ExpressionNodeEfficient) -> None: TimeAggregatorNode( TimeOperatorNode( param("q"), - "TimeShift", + TimeOperatorName.SHIFT, InstancesTimeIndex(expression_range(1, 10, 2)), ), - "TimeSum", + TimeAggregatorName.TIME_SUM, stay_roll=True, ), TimeAggregatorNode( TimeOperatorNode( param("q"), - "TimeShift", + TimeOperatorName.SHIFT, InstancesTimeIndex(expression_range(1, 10, 3)), ), - "TimeSum", + TimeAggregatorName.TIME_SUM, stay_roll=True, ), ), ( - TimeAggregatorNode(shifted_param(), name="TimeSum", stay_roll=True), - TimeAggregatorNode(shifted_param(), name="TimeSum", stay_roll=False), - ), - ( - TimeAggregatorNode(shifted_param(), name="TimeSum", stay_roll=True), - TimeAggregatorNode(shifted_param(), name="TimeAggregator", stay_roll=True), + TimeAggregatorNode( + shifted_param(), name=TimeAggregatorName.TIME_SUM, stay_roll=True + ), + TimeAggregatorNode( + shifted_param(), name=TimeAggregatorName.TIME_SUM, stay_roll=False + ), ), (param("q").expec(), param("y").expec()), ], ) -def test_not_equals( - lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient -) -> None: +def test_not_equals(lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient) -> None: assert not expressions_equal(lhs, rhs) diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index cdf99703..a1304f61 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -25,7 +25,9 @@ InstancesTimeIndex, LiteralNode, ParameterNode, + TimeAggregatorName, TimeAggregatorNode, + TimeOperatorName, TimeOperatorNode, comp_param, literal, @@ -394,7 +396,7 @@ def test_comparison() -> None: scenario_operator=None, ): TermEfficient( TimeOperatorNode( - LiteralNode(1), "TimeShift", InstancesTimeIndex(1) + LiteralNode(1), TimeOperatorName.SHIFT, InstancesTimeIndex(1) ), "", "x", @@ -413,7 +415,7 @@ def test_comparison() -> None: scenario_operator=None, ): TermEfficient( TimeOperatorNode( - LiteralNode(1), "TimeShift", InstancesTimeIndex(1) + LiteralNode(1), TimeOperatorName.SHIFT, InstancesTimeIndex(1) ), "", "y", @@ -422,8 +424,10 @@ def test_comparison() -> None: ), }, TimeAggregatorNode( - TimeOperatorNode(LiteralNode(1), "TimeShift", InstancesTimeIndex(1)), - "TimeSum", + TimeOperatorNode( + LiteralNode(1), TimeOperatorName.SHIFT, InstancesTimeIndex(1) + ), + TimeAggregatorName.TIME_SUM, stay_roll=True, ), # TODO: Could it be simplified online ? ), @@ -440,7 +444,9 @@ def test_comparison() -> None: scenario_operator=None, ): TermEfficient( TimeOperatorNode( - LiteralNode(1), "TimeEvaluation", InstancesTimeIndex(1) + LiteralNode(1), + TimeOperatorName.EVALUATION, + InstancesTimeIndex(1), ), "", "x", @@ -459,7 +465,9 @@ def test_comparison() -> None: scenario_operator=None, ): TermEfficient( TimeOperatorNode( - LiteralNode(1), "TimeEvaluation", InstancesTimeIndex(1) + LiteralNode(1), + TimeOperatorName.EVALUATION, + InstancesTimeIndex(1), ), "", "y", @@ -469,9 +477,9 @@ def test_comparison() -> None: }, TimeAggregatorNode( TimeOperatorNode( - LiteralNode(1), "TimeEvaluation", InstancesTimeIndex(1) + LiteralNode(1), TimeOperatorName.EVALUATION, InstancesTimeIndex(1) ), - "TimeSum", + TimeAggregatorName.TIME_SUM, stay_roll=True, ), # TODO: Could it be simplified online ? ), @@ -506,7 +514,7 @@ def test_comparison() -> None: ), }, TimeAggregatorNode( - LiteralNode(1), "TimeSum", stay_roll=False + LiteralNode(1), TimeAggregatorName.TIME_SUM, stay_roll=False ), # TODO: Could it be simplified online ? ), ], @@ -668,7 +676,6 @@ def get_variable_structure(self, name: str) -> IndexingStructure: def test_sum_expressions( sum_expr: LinearExpressionEfficient, expected: LinearExpressionEfficient ) -> None: - assert linear_expressions_equal(sum_expr, wrap_in_linear_expr(expected)) diff --git a/tests/unittests/expressions/test_port_resolver.py b/tests/unittests/expressions/test_port_resolver.py index 64855220..5ca982d4 100644 --- a/tests/unittests/expressions/test_port_resolver.py +++ b/tests/unittests/expressions/test_port_resolver.py @@ -24,19 +24,22 @@ var, ) + @pytest.mark.parametrize( - "port_expr, expected", - [ - (port_field("port", "field") + 2, var("flow") + 2), - (port_field("port", "field") - 2, var("flow") - 2), - (port_field("port", "field") * 2, 2 * var("flow")), - (port_field("port", "field") / 2, var("flow") / 2), - (port_field("port", "field") * 0, LinearExpressionEfficient()), - ] + "port_expr, expected", + [ + (port_field("port", "field") + 2, var("flow") + 2), + (port_field("port", "field") - 2, var("flow") - 2), + (port_field("port", "field") * 2, 2 * var("flow")), + (port_field("port", "field") / 2, var("flow") / 2), + (port_field("port", "field") * 0, LinearExpressionEfficient()), + ], ) -def test_port_field_resolution(port_expr: LinearExpressionEfficient, expected: LinearExpressionEfficient) -> None: +def test_port_field_resolution( + port_expr: LinearExpressionEfficient, expected: LinearExpressionEfficient +) -> None: ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]] = {} - + key = PortFieldKey("com_id", PortFieldId(field_name="field", port_name="port")) expression = var("flow") diff --git a/tests/unittests/model/test_model_parsing.py b/tests/unittests/model/test_model_parsing.py index 931deaee..947c182f 100644 --- a/tests/unittests/model/test_model_parsing.py +++ b/tests/unittests/model/test_model_parsing.py @@ -13,9 +13,8 @@ import pytest -from andromede.expression.expression import port_field from andromede.expression.expression_efficient import literal, param -from andromede.expression.linear_expression_efficient import var +from andromede.expression.linear_expression_efficient import port_field, var from andromede.expression.parsing.parse_expression import AntaresParseException from andromede.libs.standard import CONSTANT from andromede.model import ( From d092b1ba2b5a7d83accf92a9135c55d47308eabe Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Fri, 26 Jul 2024 10:47:38 +0200 Subject: [PATCH 20/51] Fix syntax --- src/andromede/expression/context_adder.py | 26 ++++++------- .../expression/linear_expression_efficient.py | 25 ++++++++++-- src/andromede/simulation/optimization.py | 39 +++++++++---------- 3 files changed, 53 insertions(+), 37 deletions(-) diff --git a/src/andromede/expression/context_adder.py b/src/andromede/expression/context_adder.py index 812e95f7..397197da 100644 --- a/src/andromede/expression/context_adder.py +++ b/src/andromede/expression/context_adder.py @@ -13,12 +13,10 @@ from dataclasses import dataclass from . import CopyVisitor -from .expression import ( +from .expression_efficient import ( ComponentParameterNode, - ComponentVariableNode, - ExpressionNode, + ExpressionNodeEfficient, ParameterNode, - VariableNode, ) from .visitor import visit @@ -32,22 +30,24 @@ class ContextAdder(CopyVisitor): component_id: str - def variable(self, node: VariableNode) -> ExpressionNode: - return ComponentVariableNode(self.component_id, node.name) + # def variable(self, node: VariableNode) -> ExpressionNodeEfficient: + # return ComponentVariableNode(self.component_id, node.name) - def parameter(self, node: ParameterNode) -> ExpressionNode: + def parameter(self, node: ParameterNode) -> ExpressionNodeEfficient: return ComponentParameterNode(self.component_id, node.name) - def comp_variable(self, node: ComponentVariableNode) -> ExpressionNode: - raise ValueError( - "This expression has already been associated to another component." - ) + # def comp_variable(self, node: ComponentVariableNode) -> ExpressionNodeEfficient: + # raise ValueError( + # "This expression has already been associated to another component." + # ) - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: + def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNodeEfficient: raise ValueError( "This expression has already been associated to another component." ) -def add_component_context(id: str, expression: ExpressionNode) -> ExpressionNode: +def add_component_context( + id: str, expression: ExpressionNodeEfficient +) -> ExpressionNodeEfficient: return visit(expression, ContextAdder(id)) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 823a9914..ca48e73a 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -28,6 +28,7 @@ overload, ) +from andromede.expression.context_adder import add_component_context from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import ValueProvider, evaluate from andromede.expression.expression_efficient import ( @@ -350,8 +351,7 @@ def _merge_dicts( rhs: Dict[TermKeyEfficient, TermEfficient], merge_func: Callable[[TermEfficient, TermEfficient], TermEfficient], neutral: float, -) -> Dict[TermKeyEfficient, TermEfficient]: - ... +) -> Dict[TermKeyEfficient, TermEfficient]: ... @overload @@ -360,8 +360,7 @@ def _merge_dicts( rhs: Dict[PortFieldId, PortFieldTerm], merge_func: Callable[[PortFieldTerm, PortFieldTerm], PortFieldTerm], neutral: float, -) -> Dict[PortFieldId, PortFieldTerm]: - ... +) -> Dict[PortFieldId, PortFieldTerm]: ... def _get_neutral_term(term: T_val, neutral: float) -> T_val: @@ -919,6 +918,24 @@ def resolve_port( ) return self + port_expr + def add_component_context(self, component_id: str) -> "LinearExpressionEfficient": + result_terms = {} + for term in self.terms.values(): + if term.component_id: + raise ValueError( + "This expression has already been associated to another component." + ) + result_term = dataclasses.replace( + term, + component_id=component_id, + coefficient=add_component_context(component_id, term.coefficient), + ) + result_terms[generate_key(result_term)] = result_term + result_constant = add_component_context(component_id, self.constant) + return LinearExpressionEfficient( + result_terms, result_constant, self.port_field_terms + ) + def linear_expressions_equal( lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index ce45b388..b8b6f033 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -18,22 +18,21 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Dict, Iterable, List, Optional, Type +from typing import Dict, Iterable, List, Optional import ortools.linear_solver.pywraplp as lp from andromede.expression import ( # ExpressionNode, - EvaluationVisitor, ParameterValueProvider, ValueProvider, resolve_parameters, - visit, ) -from andromede.expression.context_adder import add_component_context -from andromede.expression.indexing import IndexingStructureProvider, compute_indexation +from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import LinearExpressionEfficient -from andromede.expression.port_resolver import PortFieldKey, resolve_port +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + PortFieldKey, +) from andromede.expression.scenario_operator import Expectation from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum from andromede.model.common import ValueType @@ -149,8 +148,7 @@ def get_value(self, block_timestep: int, scenario: int) -> float: param_value_provider = _make_value_provider( self.context, block_timestep, scenario, self.component ) - visitor = EvaluationVisitor(param_value_provider) - return visit(self.expression, visitor) + return self.expression.evaluate(param_value_provider) def _make_parameter_value_provider( @@ -421,9 +419,9 @@ def _get_indexing( constraint: Constraint, provider: IndexingStructureProvider ) -> IndexingStructure: return ( - compute_indexation(constraint.expression, provider) - or compute_indexation(constraint.lower_bound, provider) - or compute_indexation(constraint.upper_bound, provider) + constraint.expression.compute_indexation(provider) + or constraint.lower_bound.compute_indexation(provider) + or constraint.upper_bound.compute_indexation(provider) ) @@ -447,9 +445,9 @@ def _instantiate_model_expression( 1. add component ID for variables and parameters of THIS component 2. replace port fields by their definition """ - with_component = add_component_context(component_id, model_expression) - with_component_and_ports = resolve_port( - with_component, component_id, optimization_context.connection_fields_expressions + with_component = model_expression.add_component_context(component_id) + with_component_and_ports = with_component.resolve_port( + component_id, optimization_context.connection_fields_expressions ) return with_component_and_ports @@ -465,9 +463,10 @@ def _create_constraint( constraint_indexing = _compute_indexing_structure(context, constraint) # Perf: Perform linearization (tree traversing) without timesteps so that we can get the number of instances for the expression (from the time_ids of operators) - linear_expr = context.linearize_expression(0, 0, constraint.expression) - # Will there be cases where instances > 1 ? If not, maybe just a check that get_number_of_instances == 1 is sufficient ? Anyway, the function should be implemented - instances_per_time_step = linear_expr.number_of_instances() + # linear_expr = context.linearize_expression(0, 0, constraint.expression) + # # Will there be cases where instances > 1 ? If not, maybe just a check that get_number_of_instances == 1 is sufficient ? Anyway, the function should be implemented + # instances_per_time_step = linear_expr.number_of_instances() + instances_per_time_step = 1 for block_timestep in context.opt_context.get_time_indices(constraint_indexing): for scenario in context.opt_context.get_scenario_indices(constraint_indexing): @@ -703,8 +702,8 @@ def _register_connection_fields_definitions(self) -> None: ) ) expression_node = port_definition.definition # type: ignore - instantiated_expression = add_component_context( - master_port.component.id, expression_node + instantiated_expression = expression_node.add_component_context( + master_port.component.id ) self.context.register_connection_fields_expressions( component_id=cnx.port1.component.id, From 48c96ffce45cb0dc19a33a3d434840d4a9790b60 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Fri, 26 Jul 2024 17:12:02 +0200 Subject: [PATCH 21/51] Resolve linear expression --- src/andromede/expression/evaluate.py | 12 +- .../evaluate_parameters_efficient.py | 201 ++++++++++++++++++ .../expression/linear_expression_efficient.py | 26 +++ src/andromede/simulation/optimization.py | 120 ++++++----- .../simulation/resolved_linear_expression.py | 41 ++++ 5 files changed, 342 insertions(+), 58 deletions(-) create mode 100644 src/andromede/expression/evaluate_parameters_efficient.py create mode 100644 src/andromede/simulation/resolved_linear_expression.py diff --git a/src/andromede/expression/evaluate.py b/src/andromede/expression/evaluate.py index 389e4197..f53d7d53 100644 --- a/src/andromede/expression/evaluate.py +++ b/src/andromede/expression/evaluate.py @@ -37,17 +37,17 @@ class ValueProvider(ABC): Depending on the implementation, evaluation may require a component id or not. """ - @abstractmethod - def get_variable_value(self, name: str) -> float: - ... + # @abstractmethod + # def get_variable_value(self, name: str) -> float: + # ... @abstractmethod def get_parameter_value(self, name: str) -> float: ... - @abstractmethod - def get_component_variable_value(self, component_id: str, name: str) -> float: - ... + # @abstractmethod + # def get_component_variable_value(self, component_id: str, name: str) -> float: + # ... @abstractmethod def get_component_parameter_value(self, component_id: str, name: str) -> float: diff --git a/src/andromede/expression/evaluate_parameters_efficient.py b/src/andromede/expression/evaluate_parameters_efficient.py new file mode 100644 index 00000000..9bb42cf1 --- /dev/null +++ b/src/andromede/expression/evaluate_parameters_efficient.py @@ -0,0 +1,201 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Dict, List + +from andromede.expression.expression import VariableNode +from andromede.expression.expression_efficient import ( + ComparisonNode, + ComponentParameterNode, + ExpressionNodeEfficient, + ExpressionRange, + InstancesTimeIndex, + LiteralNode, + ParameterNode, + PortFieldAggregatorNode, + PortFieldNode, + ScenarioOperatorName, + ScenarioOperatorNode, + TimeAggregatorName, + TimeAggregatorNode, + TimeOperatorName, + TimeOperatorNode, +) +from andromede.expression.linear_expression_efficient import RowIndex + +from .visitor import ExpressionVisitor, ExpressionVisitorOperations, T, visit + + +@dataclass +class TimeScenarioIndices: + time_indices: List[int] + scenario_indices: List[int] + + +class ValueProvider(ABC): + """ + Implementations are in charge of mapping parameters and variables to their values. + Depending on the implementation, evaluation may require a component id or not. + """ + + # @abstractmethod + # def get_variable_value(self, name: str) -> float: ... + + @abstractmethod + def get_parameter_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> List[float]: ... + + # @abstractmethod + # def get_component_variable_value(self, component_id: str, name: str) -> float: ... + + @abstractmethod + def get_component_parameter_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> List[float]: ... + + # TODO: Should this really be an abstract method ? Or maybe, only the Provider in _make_value_provider should implement it. And the context attribute in the InstancesIndexVisitor is a ValueProvider that implements the parameter_is_constant_over_time method. Maybe create a child class of ValueProvider like TimeValueProvider ? + @abstractmethod + def parameter_is_constant_over_time(self, name: str) -> bool: ... + + +@dataclass(frozen=True) +class ParameterEvaluationVisitor(ExpressionVisitorOperations[float]): + """ + Evaluates the expression with respect to the provided context + (variables and parameters values). + """ + + context: ValueProvider + row_id: RowIndex # TODO to be included in ValueProvider ? + time_scenario_indices: TimeScenarioIndices + + def literal(self, node: LiteralNode) -> float: + return [node.value] + + # def comparison(self, node: ComparisonNode) -> float: + # raise ValueError("Cannot evaluate comparison operator.") + + # def variable(self, node: VariableNode) -> float: + # return self.context.get_variable_value(node.name) + + def parameter(self, node: ParameterNode) -> float: + return self.context.get_parameter_value(node.name, self.time_scenario_indices) + + def comp_parameter(self, node: ComponentParameterNode) -> float: + return self.context.get_component_parameter_value( + node.component_id, node.name, self.time_scenario_indices + ) + + # def comp_variable(self, node: ComponentVariableNode) -> float: + # return self.context.get_component_variable_value(node.component_id, node.name) + + def time_operator(self, node: TimeOperatorNode) -> float: + self.time_scenario_indices.time_indices = get_time_ids_from_instances_index( + node.instances_index, self.context + ) + if node.name == TimeOperatorName.SHIFT: + self.time_scenario_indices.time_indices = [ + self.row_id.time + op_id + for op_id in self.time_scenario_indices.time_indices + ] + elif node.name != TimeOperatorName.EVALUATION: + return NotImplemented + return visit(node.operand, self) + + def time_aggregator(self, node: TimeAggregatorNode) -> float: + if node.name in [TimeAggregatorName.SUM]: + # TODO: Where is the sum ? + return visit(node.operand, self) + else: + return NotImplemented + + def scenario_operator(self, node: ScenarioOperatorNode) -> float: + if node.name in [ScenarioOperatorName.EXPECTATION]: + return visit(node.operand, self) + else: + return NotImplemented + + def port_field(self, node: PortFieldNode) -> float: + raise ValueError("Port fields must be resolved before evaluating parameters") + + def port_field_aggregator(self, node: PortFieldAggregatorNode) -> float: + raise ValueError("Port fields must be resolved before evaluating parameters") + + +def resolve_coefficient( + expression: ExpressionNodeEfficient, value_provider: ValueProvider, row_id: RowIndex +) -> float: + return visit(expression, ParameterEvaluationVisitor(value_provider, row_id)) + + +@dataclass(frozen=True) +class InstancesIndexVisitor(ParameterEvaluationVisitor): + """ + Evaluates an expression given as instances index which should have no variable and constant parameter values. + """ + + # def variable(self, node: VariableNode) -> float: + # raise ValueError("An instance index expression cannot contain variable") + + def parameter(self, node: ParameterNode) -> float: + if not self.context.parameter_is_constant_over_time(node.name): + raise ValueError( + "Parameter given in an instance index expression must be constant over time" + ) + return self.context.get_parameter_value(node.name) + + def time_operator(self, node: TimeOperatorNode) -> float: + raise ValueError("An instance index expression cannot contain time operator") + + def time_aggregator(self, node: TimeAggregatorNode) -> float: + raise ValueError("An instance index expression cannot contain time aggregator") + + +def float_to_int(value: float) -> int: + if isinstance(value, int) or value.is_integer(): + return int(value) + else: + raise ValueError(f"{value} is not an integer.") + + +def evaluate_time_id( + expr: ExpressionNodeEfficient, value_provider: ValueProvider +) -> int: + float_time_id = visit(expr, InstancesIndexVisitor(value_provider)) + try: + time_id = float_to_int(float_time_id) + except ValueError: + print(f"{expr} does not represent an integer time index.") + return time_id + + +def get_time_ids_from_instances_index( + instances_index: InstancesTimeIndex, value_provider: ValueProvider +) -> List[int]: + time_ids = [] + if isinstance(instances_index.expressions, list): # List[ExpressionNode] + for expr in instances_index.expressions: + time_ids.append(evaluate_time_id(expr, value_provider)) + + elif isinstance(instances_index.expressions, ExpressionRange): # ExpressionRange + start_id = evaluate_time_id(instances_index.expressions.start, value_provider) + stop_id = evaluate_time_id(instances_index.expressions.stop, value_provider) + step_id = 1 + if instances_index.expressions.step is not None: + step_id = evaluate_time_id(instances_index.expressions.step, value_provider) + # ExpressionRange includes stop_id whereas range excludes it + time_ids = list(range(start_id, stop_id + 1, step_id)) + + return time_ids diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index ca48e73a..8fae646b 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -31,6 +31,7 @@ from andromede.expression.context_adder import add_component_context from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import ValueProvider, evaluate +from andromede.expression.evaluate_parameters_efficient import resolve_coefficient from andromede.expression.expression_efficient import ( ExpressionNodeEfficient, ExpressionRange, @@ -61,6 +62,10 @@ TimeShift, TimeSum, ) +from andromede.simulation.resolved_linear_expression import ( + ResolvedLinearExpression, + ResolvedTerm, +) @dataclass(frozen=True) @@ -416,6 +421,12 @@ def _substract_terms(lhs: T_val, rhs: T_val) -> T_val: return dataclasses.replace(lhs, coefficient=lhs.coefficient - rhs.coefficient) +@dataclass(frozen=True) +class RowIndex: + time: int + scenario: int + + class LinearExpressionEfficient: """ Represents a linear expression with respect to variable names, for example 10x + 5y + 2. @@ -936,6 +947,21 @@ def add_component_context(self, component_id: str) -> "LinearExpressionEfficient result_terms, result_constant, self.port_field_terms ) + def resolve_coefficient( + self, value_provider: ValueProvider, row_id: RowIndex + ) -> ResolvedLinearExpression: + + resolved_terms = [] + for term in self.terms.values(): + resolved_coeff = resolve_coefficient( + term.coefficient, value_provider, row_id + ) + resolved_variable = ... + resolved_terms.append(ResolvedTerm(resolved_coeff, resolved_variable)) + + resolved_constant = resolve_coefficient(self.constant, value_provider, row_id) + return ResolvedLinearExpression(resolved_terms, resolved_constant) + def linear_expressions_equal( lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index b8b6f033..2774c122 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -32,6 +32,7 @@ from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, PortFieldKey, + RowIndex, ) from andromede.expression.scenario_operator import Expectation from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum @@ -40,6 +41,7 @@ from andromede.model.model import PortFieldId from andromede.simulation.linear_expression import LinearExpression, Term from andromede.simulation.linearize import linearize_expression +from andromede.simulation.resolved_linear_expression import ResolvedLinearExpression from andromede.simulation.strategy import MergedProblemStrategy, ModelSelectionStrategy from andromede.simulation.time_block import TimeBlock from andromede.study.data import DataBase @@ -71,21 +73,6 @@ def _get_parameter_value( return data.get_value(absolute_timestep, scenario) -# TODO: Maybe add the notion of constant parameter in the model -# TODO : And constant over scenarios ? -def _parameter_is_constant_over_time( - component: Component, - name: str, - context: "OptimizationContext", - block_timestep: int, - scenario: int, -) -> bool: - data = context.database.get_data(component.id, name) - return data.get_value(block_timestep, scenario) == IndexingStructure( - time=False, scenario=False - ) - - class TimestepValueProvider(ABC): """ Interface which provides numerical values for individual timesteps. @@ -98,8 +85,6 @@ def get_value(self, block_timestep: int, scenario: int) -> float: def _make_value_provider( context: "OptimizationContext", - block_timestep: int, - scenario: int, component: Component, ) -> ValueProvider: """ @@ -110,22 +95,34 @@ def _make_value_provider( """ class Provider(ValueProvider): - def get_component_variable_value(self, component_id: str, name: str) -> float: - raise NotImplementedError( - "Cannot provide variable value at problem build time." - ) - - def get_component_parameter_value(self, component_id: str, name: str) -> float: - return _get_parameter_value( - context, block_timestep, scenario, component_id, name - ) - - def get_variable_value(self, name: str) -> float: - raise NotImplementedError( - "Cannot provide variable value at problem build time." - ) - - def get_parameter_value(self, name: str) -> float: + # def get_component_variable_value(self, component_id: str, name: str) -> float: + # raise NotImplementedError( + # "Cannot provide variable value at problem build time." + # ) + + def get_component_parameter_value( + self, + component_id: str, + name: str, + time_ids: List[int], + scenario_ids: List[int], + ) -> List[float]: + return [ + _get_parameter_value( + context, block_timestep, scenario, component_id, name + ) + for block_timestep in time_ids + for scenario in scenario_ids + ] + + # def get_variable_value(self, name: str) -> float: + # raise NotImplementedError( + # "Cannot provide variable value at problem build time." + # ) + + def get_parameter_value( + self, name: str, time_ids: List[int], scenario_ids: List[int] + ) -> List[float]: raise ValueError( "Parameter must be associated to its component before resolution." ) @@ -468,12 +465,20 @@ def _create_constraint( # instances_per_time_step = linear_expr.number_of_instances() instances_per_time_step = 1 + value_provider = _make_value_provider(context.opt_context, context.component) + for block_timestep in context.opt_context.get_time_indices(constraint_indexing): for scenario in context.opt_context.get_scenario_indices(constraint_indexing): - linear_expr_at_t = context.linearize_expression( - block_timestep, scenario, constraint.expression - ) - # What happens if there is some time_operator in the bounds ? + # linear_expr_at_t = context.linearize_expression( + # block_timestep, scenario, constraint.expression + # ) + row_id = RowIndex(block_timestep, scenario) + + resolved_expr = constraint.expression.resolve_coefficient(value_provider, row_id) + resolved_lb = constraint.lower_bound.resolve_coefficient(value_provider, row_id) + resolved_ub = constraint.upper_bound.resolve_coefficient(value_provider, row_id) + + # What happens if there is some time_operator in the bounds ? -> Pb réglé avec le nouveau design ! constraint_data = ConstraintData( name=constraint.name, lower_bound=context.get_values(constraint.lower_bound).get_value( @@ -540,9 +545,9 @@ def _create_objective( @dataclass class ConstraintData: name: str - lower_bound: float - upper_bound: float - expression: LinearExpression + lower_bound: ResolvedLinearExpression # Or a float ? + upper_bound: ResolvedLinearExpression # Or a float ? + expression: ResolvedLinearExpression def _get_solver_vars( @@ -636,25 +641,36 @@ def make_constraint( """ solver_constraints = {} constraint_name = f"{data.name}_t{block_timestep}_s{scenario}" + + # TODO : Check if instance can be removed for instance in range(instances): if instances > 1: constraint_name += f"_{instance}" solver_constraint: lp.Constraint = solver.Constraint(constraint_name) constant: float = 0 - for term in data.expression.terms.values(): - solver_vars = _get_solver_vars( - term, - context, - block_timestep, - scenario, - instance, + + for term in data.expression.terms: + solver_constraint.SetCoefficient( + term.variable, + term.coefficient + solver_constraint.GetCoefficient(term.variable), ) - for solver_var in solver_vars: - coefficient = term.coefficient + solver_constraint.GetCoefficient( - solver_var - ) - solver_constraint.SetCoefficient(solver_var, coefficient) + + # TODO : To be done in linear expression resolution coeff + # for term in data.expression.terms.values(): + # # Move this to resolve coefficient + # solver_vars = _get_solver_vars( + # term, + # context, + # block_timestep, + # scenario, + # instance, + # ) + # for solver_var in solver_vars: + # coefficient = term.coefficient + solver_constraint.GetCoefficient( + # solver_var + # ) + # solver_constraint.SetCoefficient(solver_var, coefficient) # TODO: On pourrait aussi faire que l'objet Constraint n'ait pas de terme constant dans son expression et que les constantes soit déjà prises en compte dans les bornes, ça simplifierait le traitement ici constant += data.expression.constant diff --git a/src/andromede/simulation/resolved_linear_expression.py b/src/andromede/simulation/resolved_linear_expression.py new file mode 100644 index 00000000..d3ed9fac --- /dev/null +++ b/src/andromede/simulation/resolved_linear_expression.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +""" +Specific modelling for "resolved" linear expressions, +with only variables and literal coefficients. +""" + +from dataclasses import dataclass, field +from typing import List + +import ortools.linear_solver.pywraplp as lp + + +@dataclass(frozen=True) +class ResolvedTerm: + """ + Represents a term where parameters and variables id have been resolved, in the form of couple (coefficient, variable_id) + """ + + coefficient: float + variable: lp.Variable + + +@dataclass +class ResolvedLinearExpression: + """ + Represents a linear expression where parameters and variables id have been resolved, in the form of couple (coefficient, variable_id) and a constant + """ + + terms: List[ResolvedTerm] = field(default_factory=list) + constant: float = field(default=0) From 9de9cae508c6c28d4fa22e35cf2fefdc8225ece9 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 14 Aug 2024 11:58:55 +0200 Subject: [PATCH 22/51] Fix circular imports --- src/andromede/expression/__init__.py | 16 ++++++++-------- .../expression/evaluate_parameters_efficient.py | 2 +- src/andromede/expression/indexing_structure.py | 11 +++++++++++ .../expression/linear_expression_efficient.py | 14 ++++---------- .../resolved_linear_expression.py | 0 src/andromede/simulation/optimization.py | 4 ++-- 6 files changed, 26 insertions(+), 21 deletions(-) rename src/andromede/{simulation => expression}/resolved_linear_expression.py (100%) diff --git a/src/andromede/expression/__init__.py b/src/andromede/expression/__init__.py index 07825dee..8e481cce 100644 --- a/src/andromede/expression/__init__.py +++ b/src/andromede/expression/__init__.py @@ -19,23 +19,23 @@ resolve_parameters, ) -# from .expression import ( +from .expression import ( # AdditionNode, # Comparator, # ComparisonNode, # DivisionNode, -# ExpressionNode, + ExpressionNode, # LiteralNode, # MultiplicationNode, # NegationNode, # ParameterNode, # SubstractionNode, -# VariableNode, -# literal, -# param, -# sum_expressions, -# var, -# ) + VariableNode, + literal, + param, + sum_expressions, + var, +) from .expression_efficient import ( AdditionNode, Comparator, diff --git a/src/andromede/expression/evaluate_parameters_efficient.py b/src/andromede/expression/evaluate_parameters_efficient.py index 9bb42cf1..7588e58a 100644 --- a/src/andromede/expression/evaluate_parameters_efficient.py +++ b/src/andromede/expression/evaluate_parameters_efficient.py @@ -32,7 +32,7 @@ TimeOperatorName, TimeOperatorNode, ) -from andromede.expression.linear_expression_efficient import RowIndex +from andromede.expression.indexing_structure import RowIndex from .visitor import ExpressionVisitor, ExpressionVisitorOperations, T, visit diff --git a/src/andromede/expression/indexing_structure.py b/src/andromede/expression/indexing_structure.py index e3edc0f5..fa733e79 100644 --- a/src/andromede/expression/indexing_structure.py +++ b/src/andromede/expression/indexing_structure.py @@ -26,3 +26,14 @@ def __or__(self, other: "IndexingStructure") -> "IndexingStructure": time = self.time or other.time scenario = self.scenario or other.scenario return IndexingStructure(time, scenario) + + +# Contrary to IndexingStructure, time and scenario are integers to "count/identify" the constraint whereas IndexingStructure is used used to know whether or not an expression is indexed by time or scenario. +@dataclass(frozen=True) +class RowIndex: + """ + Indexing of rows in a problem. + """ + + time: int + scenario: int diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 8fae646b..915e6795 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -30,8 +30,8 @@ from andromede.expression.context_adder import add_component_context from andromede.expression.equality import expressions_equal -from andromede.expression.evaluate import ValueProvider, evaluate -from andromede.expression.evaluate_parameters_efficient import resolve_coefficient +from andromede.expression.evaluate import evaluate +from andromede.expression.evaluate_parameters_efficient import ValueProvider, resolve_coefficient from andromede.expression.expression_efficient import ( ExpressionNodeEfficient, ExpressionRange, @@ -51,7 +51,7 @@ wrap_in_node, ) from andromede.expression.indexing import IndexingStructureProvider -from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.indexing_structure import IndexingStructure, RowIndex from andromede.expression.port_operator import PortAggregator, PortSum from andromede.expression.print import print_expr from andromede.expression.scenario_operator import Expectation, ScenarioOperator @@ -62,7 +62,7 @@ TimeShift, TimeSum, ) -from andromede.simulation.resolved_linear_expression import ( +from andromede.expression.resolved_linear_expression import ( ResolvedLinearExpression, ResolvedTerm, ) @@ -421,12 +421,6 @@ def _substract_terms(lhs: T_val, rhs: T_val) -> T_val: return dataclasses.replace(lhs, coefficient=lhs.coefficient - rhs.coefficient) -@dataclass(frozen=True) -class RowIndex: - time: int - scenario: int - - class LinearExpressionEfficient: """ Represents a linear expression with respect to variable names, for example 10x + 5y + 2. diff --git a/src/andromede/simulation/resolved_linear_expression.py b/src/andromede/expression/resolved_linear_expression.py similarity index 100% rename from src/andromede/simulation/resolved_linear_expression.py rename to src/andromede/expression/resolved_linear_expression.py diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 2774c122..f17bfa60 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -24,9 +24,9 @@ from andromede.expression import ( # ExpressionNode, ParameterValueProvider, - ValueProvider, resolve_parameters, ) +from andromede.expression.evaluate_parameters_efficient import ValueProvider from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.linear_expression_efficient import ( @@ -41,7 +41,7 @@ from andromede.model.model import PortFieldId from andromede.simulation.linear_expression import LinearExpression, Term from andromede.simulation.linearize import linearize_expression -from andromede.simulation.resolved_linear_expression import ResolvedLinearExpression +from andromede.expression.resolved_linear_expression import ResolvedLinearExpression from andromede.simulation.strategy import MergedProblemStrategy, ModelSelectionStrategy from andromede.simulation.time_block import TimeBlock from andromede.study.data import DataBase From 3a4004c6589cb08e9c9e5ca208007aea680b7bd9 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 14 Aug 2024 18:41:38 +0200 Subject: [PATCH 23/51] Parameter evaluation visitor implemented --- src/andromede/expression/evaluate.py | 32 +-- .../evaluate_parameters_efficient.py | 201 ++++++++++++------ .../expression/expression_efficient.py | 2 +- .../expression/indexing_structure.py | 3 + .../expression/linear_expression_efficient.py | 3 +- .../expression/resolved_linear_expression.py | 4 + src/andromede/expression/value_provider.py | 69 ++++++ src/andromede/simulation/optimization.py | 170 ++++++++------- 8 files changed, 309 insertions(+), 175 deletions(-) create mode 100644 src/andromede/expression/value_provider.py diff --git a/src/andromede/expression/evaluate.py b/src/andromede/expression/evaluate.py index f53d7d53..bcb66a4d 100644 --- a/src/andromede/expression/evaluate.py +++ b/src/andromede/expression/evaluate.py @@ -10,7 +10,6 @@ # # This file is part of the Antares project. -from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Dict @@ -27,36 +26,9 @@ TimeAggregatorNode, TimeOperatorNode, ) +from andromede.expression.value_provider import ValueProvider -from .visitor import ExpressionVisitor, ExpressionVisitorOperations, T, visit - - -class ValueProvider(ABC): - """ - Implementations are in charge of mapping parameters and variables to their values. - Depending on the implementation, evaluation may require a component id or not. - """ - - # @abstractmethod - # def get_variable_value(self, name: str) -> float: - # ... - - @abstractmethod - def get_parameter_value(self, name: str) -> float: - ... - - # @abstractmethod - # def get_component_variable_value(self, component_id: str, name: str) -> float: - # ... - - @abstractmethod - def get_component_parameter_value(self, component_id: str, name: str) -> float: - ... - - # TODO: Should this really be an abstract method ? Or maybe, only the Provider in _make_value_provider should implement it. And the context attribute in the InstancesIndexVisitor is a ValueProvider that implements the parameter_is_constant_over_time method. Maybe create a child class of ValueProvider like TimeValueProvider ? - @abstractmethod - def parameter_is_constant_over_time(self, name: str) -> bool: - ... +from .visitor import ExpressionVisitorOperations, visit @dataclass(frozen=True) diff --git a/src/andromede/expression/evaluate_parameters_efficient.py b/src/andromede/expression/evaluate_parameters_efficient.py index 7588e58a..053ec33e 100644 --- a/src/andromede/expression/evaluate_parameters_efficient.py +++ b/src/andromede/expression/evaluate_parameters_efficient.py @@ -10,98 +10,129 @@ # # This file is part of the Antares project. -from abc import ABC, abstractmethod +import operator from dataclasses import dataclass, field -from typing import Dict, List +from typing import Callable, Dict, List -from andromede.expression.expression import VariableNode from andromede.expression.expression_efficient import ( + AdditionNode, ComparisonNode, ComponentParameterNode, + DivisionNode, ExpressionNodeEfficient, ExpressionRange, InstancesTimeIndex, LiteralNode, + MultiplicationNode, + NegationNode, ParameterNode, PortFieldAggregatorNode, PortFieldNode, ScenarioOperatorName, ScenarioOperatorNode, + SubstractionNode, TimeAggregatorName, TimeAggregatorNode, TimeOperatorName, TimeOperatorNode, ) from andromede.expression.indexing_structure import RowIndex +from andromede.expression.value_provider import ( + TimeScenarioIndex, + TimeScenarioIndices, + ValueProvider, +) -from .visitor import ExpressionVisitor, ExpressionVisitorOperations, T, visit - - -@dataclass -class TimeScenarioIndices: - time_indices: List[int] - scenario_indices: List[int] - - -class ValueProvider(ABC): - """ - Implementations are in charge of mapping parameters and variables to their values. - Depending on the implementation, evaluation may require a component id or not. - """ - - # @abstractmethod - # def get_variable_value(self, name: str) -> float: ... - - @abstractmethod - def get_parameter_value( - self, name: str, time_scenarios_indices: TimeScenarioIndices - ) -> List[float]: ... - - # @abstractmethod - # def get_component_variable_value(self, component_id: str, name: str) -> float: ... +from .visitor import ExpressionVisitor, visit - @abstractmethod - def get_component_parameter_value( - self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices - ) -> List[float]: ... - # TODO: Should this really be an abstract method ? Or maybe, only the Provider in _make_value_provider should implement it. And the context attribute in the InstancesIndexVisitor is a ValueProvider that implements the parameter_is_constant_over_time method. Maybe create a child class of ValueProvider like TimeValueProvider ? - @abstractmethod - def parameter_is_constant_over_time(self, name: str) -> bool: ... +# TODO: (almost) same function as in linear_expression _merge_dicts +def _merge_dicts( + lhs: Dict[TimeScenarioIndex, float], + rhs: Dict[TimeScenarioIndex, float], + merge_func: Callable[[float, float], float], + neutral: float, +) -> Dict[TimeScenarioIndex, float]: + res = {} + for k, v in lhs.items(): + res[k] = merge_func(v, rhs.get(k, neutral)) + for k, v in rhs.items(): + if k not in lhs: + res[k] = merge_func(neutral, v) + return res +# It is better to return a list of float than a single float to minimize the number of calls to the visit function. We access values of the parameters at different time steps with a single visit of the tree @dataclass(frozen=True) -class ParameterEvaluationVisitor(ExpressionVisitorOperations[float]): +class ParameterEvaluationVisitor(ExpressionVisitor[Dict[TimeScenarioIndex, float]]): """ Evaluates the expression with respect to the provided context (variables and parameters values). """ context: ValueProvider + # Useful to keep track of row id for time shift row_id: RowIndex # TODO to be included in ValueProvider ? - time_scenario_indices: TimeScenarioIndices - - def literal(self, node: LiteralNode) -> float: - return [node.value] - - # def comparison(self, node: ComparisonNode) -> float: - # raise ValueError("Cannot evaluate comparison operator.") + time_scenario_indices: TimeScenarioIndices = field(init=False) - # def variable(self, node: VariableNode) -> float: - # return self.context.get_variable_value(node.name) + def __post_init__(self) -> None: + object.__setattr__( + self, + "time_scenario_indices", + TimeScenarioIndices([self.row_id.time], [self.row_id.scenario]), + ) - def parameter(self, node: ParameterNode) -> float: + # Redefine common operations so that it works as expected on lists (i.e. summing element wise rather than appending to it) + def negation(self, node: NegationNode) -> Dict[TimeScenarioIndex, float]: + return {k: -v for k, v in visit(node.operand, self).items()} + + def addition(self, node: AdditionNode) -> Dict[TimeScenarioIndex, float]: + left_value = visit(node.left, self) + right_value = visit(node.right, self) + result = _merge_dicts(left_value, right_value, operator.add, 0) + return result + + def substraction(self, node: SubstractionNode) -> Dict[TimeScenarioIndex, float]: + left_value = visit(node.left, self) + right_value = visit(node.right, self) + result = _merge_dicts(left_value, right_value, operator.sub, 0) + return result + + def multiplication( + self, node: MultiplicationNode + ) -> Dict[TimeScenarioIndex, float]: + left_value = visit(node.left, self) + right_value = visit(node.right, self) + result = _merge_dicts(left_value, right_value, operator.mul, 1) + return result + + def division(self, node: DivisionNode) -> Dict[TimeScenarioIndex, float]: + left_value = visit(node.left, self) + right_value = visit(node.right, self) + result = _merge_dicts(left_value, right_value, operator.truediv, 1) + return result + + def literal(self, node: LiteralNode) -> Dict[TimeScenarioIndex, float]: + result = {} + for time in self.time_scenario_indices.time_indices: + for scenario in self.time_scenario_indices.scenario_indices: + result[TimeScenarioIndex(time, scenario)] = node.value + return result + + def comparison(self, node: ComparisonNode) -> Dict[TimeScenarioIndex, float]: + raise ValueError("Cannot evaluate comparison operator.") + + def parameter(self, node: ParameterNode) -> Dict[TimeScenarioIndex, float]: return self.context.get_parameter_value(node.name, self.time_scenario_indices) - def comp_parameter(self, node: ComponentParameterNode) -> float: + def comp_parameter( + self, node: ComponentParameterNode + ) -> Dict[TimeScenarioIndex, float]: return self.context.get_component_parameter_value( node.component_id, node.name, self.time_scenario_indices ) - # def comp_variable(self, node: ComponentVariableNode) -> float: - # return self.context.get_component_variable_value(node.component_id, node.name) - - def time_operator(self, node: TimeOperatorNode) -> float: + def time_operator(self, node: TimeOperatorNode) -> Dict[TimeScenarioIndex, float]: self.time_scenario_indices.time_indices = get_time_ids_from_instances_index( node.instances_index, self.context ) @@ -114,30 +145,78 @@ def time_operator(self, node: TimeOperatorNode) -> float: return NotImplemented return visit(node.operand, self) - def time_aggregator(self, node: TimeAggregatorNode) -> float: - if node.name in [TimeAggregatorName.SUM]: - # TODO: Where is the sum ? - return visit(node.operand, self) + def time_aggregator( + self, node: TimeAggregatorNode + ) -> Dict[TimeScenarioIndex, float]: + if node.name == TimeAggregatorName.TIME_SUM: + if not isinstance(node.operand, TimeOperatorNode): + # Sum over all time block + self.time_scenario_indices.time_indices = list( + range(self.context.block_length()) + ) + # Time indices for the case where node.operand is a TimeOperatorNode are handled in time_operator function directly + operand_dict = visit(node.operand, self) + result = {} + for scenario in self.time_scenario_indices.scenario_indices: + result[TimeScenarioIndex(self.row_id.time, scenario)] = sum( + operand_dict[k] + for k in operand_dict.keys() + if k.scenario == scenario + ) + return result else: return NotImplemented - def scenario_operator(self, node: ScenarioOperatorNode) -> float: - if node.name in [ScenarioOperatorName.EXPECTATION]: - return visit(node.operand, self) + def scenario_operator( + self, node: ScenarioOperatorNode + ) -> Dict[TimeScenarioIndex, float]: + if node.name == ScenarioOperatorName.EXPECTATION: + self.time_scenario_indices.scenario_indices = list( + range(self.context.scenarios()) + ) + operand_dict = visit(node.operand, self) + result = {} + for time in self.time_scenario_indices.time_indices: + # TODO: Make this more general to consider weighted expectations + result[TimeScenarioIndex(time, self.row_id.scenario)] = ( + 1 + / self.context.scenarios() + * sum( + operand_dict[k] for k in operand_dict.keys() if k.time == time + ) + ) + return result + else: return NotImplemented - def port_field(self, node: PortFieldNode) -> float: + def port_field(self, node: PortFieldNode) -> Dict[TimeScenarioIndex, float]: raise ValueError("Port fields must be resolved before evaluating parameters") - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> float: + def port_field_aggregator( + self, node: PortFieldAggregatorNode + ) -> Dict[TimeScenarioIndex, float]: raise ValueError("Port fields must be resolved before evaluating parameters") +def is_valid_resolved_expr( + resolved_expr: Dict[TimeScenarioIndex, float], row_id: RowIndex +) -> bool: + # Check that the resolved expression has been correctly time and scenario aggregated so that only a float is left + return ( + len(resolved_expr) == 1 + and TimeScenarioIndex(row_id.time, row_id.scenario) in resolved_expr + ) + + def resolve_coefficient( expression: ExpressionNodeEfficient, value_provider: ValueProvider, row_id: RowIndex ) -> float: - return visit(expression, ParameterEvaluationVisitor(value_provider, row_id)) + result = visit(expression, ParameterEvaluationVisitor(value_provider, row_id)) + if is_valid_resolved_expr(result, row_id): + return result[TimeScenarioIndex(row_id.time, row_id.scenario)] + else: + raise ValueError("Evaluation of expression cannot be reduced to a float value") @dataclass(frozen=True) diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index 7d516467..575e4494 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -549,7 +549,7 @@ def __post_init__(self) -> None: @dataclass(frozen=True, eq=False) class TimeAggregatorNode(UnaryOperatorNode): name: TimeAggregatorName - stay_roll: bool + stay_roll: bool # TODO: Is it still useful ? def __post_init__(self) -> None: if not isinstance(self.name, TimeAggregatorName): diff --git a/src/andromede/expression/indexing_structure.py b/src/andromede/expression/indexing_structure.py index fa733e79..97184b00 100644 --- a/src/andromede/expression/indexing_structure.py +++ b/src/andromede/expression/indexing_structure.py @@ -37,3 +37,6 @@ class RowIndex: time: int scenario: int + + def __str__(self) -> str: + return f"t{self.time}_s{self.scenario}" diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 915e6795..308c3d65 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -31,7 +31,7 @@ from andromede.expression.context_adder import add_component_context from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import evaluate -from andromede.expression.evaluate_parameters_efficient import ValueProvider, resolve_coefficient +from andromede.expression.evaluate_parameters_efficient import resolve_coefficient from andromede.expression.expression_efficient import ( ExpressionNodeEfficient, ExpressionRange, @@ -66,6 +66,7 @@ ResolvedLinearExpression, ResolvedTerm, ) +from andromede.expression.value_provider import ValueProvider @dataclass(frozen=True) diff --git a/src/andromede/expression/resolved_linear_expression.py b/src/andromede/expression/resolved_linear_expression.py index d3ed9fac..acc82d7d 100644 --- a/src/andromede/expression/resolved_linear_expression.py +++ b/src/andromede/expression/resolved_linear_expression.py @@ -39,3 +39,7 @@ class ResolvedLinearExpression: terms: List[ResolvedTerm] = field(default_factory=list) constant: float = field(default=0) + + def is_constant(self) -> bool: + # Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree... + return not self.terms diff --git a/src/andromede/expression/value_provider.py b/src/andromede/expression/value_provider.py new file mode 100644 index 00000000..cde69aab --- /dev/null +++ b/src/andromede/expression/value_provider.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, List + + +@dataclass +class TimeScenarioIndices: + time_indices: List[int] + scenario_indices: List[int] + +# TODO: Already define in study module, factorize this +@dataclass(frozen=True) +class TimeScenarioIndex: + time: int + scenario: int + +# Given a list of time_indices and of scenario_indices, the value provider will get the parameter value for all couple (time, scenario) for time in time_indices and scenario in scenario_indices +class ValueProvider(ABC): + """ + Implementations are in charge of mapping parameters and variables to their values. + Depending on the implementation, evaluation may require a component id or not. + """ + + @abstractmethod + def get_variable_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: ... + + # Need to have time_scenarios_indices as function argument as we do not want to create a Provider each time we have to get the value of a parameter at a different (time, scenario) index + @abstractmethod + def get_parameter_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: ... + + @abstractmethod + def get_component_variable_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: ... + + @abstractmethod + def get_component_parameter_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: ... + + # TODO: Should this really be an abstract method ? Or maybe, only the Provider in _make_value_provider should implement it. And the context attribute in the InstancesIndexVisitor is a ValueProvider that implements the parameter_is_constant_over_time method. Maybe create a child class of ValueProvider like TimeValueProvider ? + @abstractmethod + def parameter_is_constant_over_time(self, name: str) -> bool: ... + + # There is probably a better place to put this.. + # Which is useful when evaluating the TimeSum operator over the whole block, to know which time steps to look for to get the parameter values + @staticmethod + @abstractmethod + def block_length() -> int: ... + + @staticmethod + @abstractmethod + def scenarios() -> int: ... diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index f17bfa60..11e2dbe1 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -34,14 +34,15 @@ PortFieldKey, RowIndex, ) +from andromede.expression.resolved_linear_expression import ResolvedLinearExpression from andromede.expression.scenario_operator import Expectation from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum +from andromede.expression.value_provider import TimeScenarioIndex, TimeScenarioIndices from andromede.model.common import ValueType from andromede.model.constraint import Constraint from andromede.model.model import PortFieldId from andromede.simulation.linear_expression import LinearExpression, Term from andromede.simulation.linearize import linearize_expression -from andromede.expression.resolved_linear_expression import ResolvedLinearExpression from andromede.simulation.strategy import MergedProblemStrategy, ModelSelectionStrategy from andromede.simulation.time_block import TimeBlock from andromede.study.data import DataBase @@ -95,34 +96,46 @@ def _make_value_provider( """ class Provider(ValueProvider): - # def get_component_variable_value(self, component_id: str, name: str) -> float: - # raise NotImplementedError( - # "Cannot provide variable value at problem build time." - # ) + def get_component_variable_value( + self, + component_id: str, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError( + "Cannot provide variable value at problem build time." + ) def get_component_parameter_value( self, component_id: str, name: str, - time_ids: List[int], - scenario_ids: List[int], - ) -> List[float]: - return [ - _get_parameter_value( - context, block_timestep, scenario, component_id, name - ) - for block_timestep in time_ids - for scenario in scenario_ids - ] + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + result = {} + for block_timestep in time_scenarios_indices.time_indices: + for scenario in time_scenarios_indices.scenario_indices: + result[TimeScenarioIndex(block_timestep, scenario)] = ( + _get_parameter_value( + context, block_timestep, scenario, component_id, name + ) + ) + return result - # def get_variable_value(self, name: str) -> float: - # raise NotImplementedError( - # "Cannot provide variable value at problem build time." - # ) + def get_variable_value( + self, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError( + "Cannot provide variable value at problem build time." + ) def get_parameter_value( - self, name: str, time_ids: List[int], scenario_ids: List[int] - ) -> List[float]: + self, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: raise ValueError( "Parameter must be associated to its component before resolution." ) @@ -130,6 +143,14 @@ def get_parameter_value( def parameter_is_constant_over_time(self, name: str) -> bool: return not component.model.parameters[name].structure.time + @staticmethod + def block_length() -> int: + return context.block_length() + + @staticmethod + def scenarios() -> int: + return context.scenarios + return Provider() @@ -463,7 +484,7 @@ def _create_constraint( # linear_expr = context.linearize_expression(0, 0, constraint.expression) # # Will there be cases where instances > 1 ? If not, maybe just a check that get_number_of_instances == 1 is sufficient ? Anyway, the function should be implemented # instances_per_time_step = linear_expr.number_of_instances() - instances_per_time_step = 1 + # instances_per_time_step = 1 value_provider = _make_value_provider(context.opt_context, context.component) @@ -474,29 +495,24 @@ def _create_constraint( # ) row_id = RowIndex(block_timestep, scenario) - resolved_expr = constraint.expression.resolve_coefficient(value_provider, row_id) - resolved_lb = constraint.lower_bound.resolve_coefficient(value_provider, row_id) - resolved_ub = constraint.upper_bound.resolve_coefficient(value_provider, row_id) + resolved_expr = constraint.expression.resolve_coefficient( + value_provider, row_id + ) + resolved_lb = constraint.lower_bound.resolve_coefficient( + value_provider, row_id + ) + resolved_ub = constraint.upper_bound.resolve_coefficient( + value_provider, row_id + ) # What happens if there is some time_operator in the bounds ? -> Pb réglé avec le nouveau design ! constraint_data = ConstraintData( name=constraint.name, - lower_bound=context.get_values(constraint.lower_bound).get_value( - block_timestep, scenario - ), - upper_bound=context.get_values(constraint.upper_bound).get_value( - block_timestep, scenario - ), - expression=linear_expr_at_t, - ) - make_constraint( - solver, - context.opt_context, - block_timestep, - scenario, - constraint_data, - instances_per_time_step, + lower_bound=resolved_lb, + upper_bound=resolved_ub, + expression=resolved_expr, ) + make_constraint(solver, row_id, constraint_data) def _create_objective( @@ -630,57 +646,47 @@ def _get_solver_vars( def make_constraint( solver: lp.Solver, - context: OptimizationContext, - block_timestep: int, - scenario: int, + row_id: RowIndex, data: ConstraintData, - instances: int, -) -> Dict[str, lp.Constraint]: +) -> lp.Constraint: """ Adds constraint to the solver. """ - solver_constraints = {} - constraint_name = f"{data.name}_t{block_timestep}_s{scenario}" - # TODO : Check if instance can be removed - for instance in range(instances): - if instances > 1: - constraint_name += f"_{instance}" + constraint_name = f"{data.name}_{str(row_id)}" - solver_constraint: lp.Constraint = solver.Constraint(constraint_name) - constant: float = 0 + solver_constraint: lp.Constraint = solver.Constraint(constraint_name) + constant: float = 0 - for term in data.expression.terms: - solver_constraint.SetCoefficient( - term.variable, - term.coefficient + solver_constraint.GetCoefficient(term.variable), - ) - - # TODO : To be done in linear expression resolution coeff - # for term in data.expression.terms.values(): - # # Move this to resolve coefficient - # solver_vars = _get_solver_vars( - # term, - # context, - # block_timestep, - # scenario, - # instance, - # ) - # for solver_var in solver_vars: - # coefficient = term.coefficient + solver_constraint.GetCoefficient( - # solver_var - # ) - # solver_constraint.SetCoefficient(solver_var, coefficient) - # TODO: On pourrait aussi faire que l'objet Constraint n'ait pas de terme constant dans son expression et que les constantes soit déjà prises en compte dans les bornes, ça simplifierait le traitement ici - constant += data.expression.constant - - solver_constraint.SetBounds( - data.lower_bound - constant, data.upper_bound - constant + for term in data.expression.terms: + solver_constraint.SetCoefficient( + term.variable, + term.coefficient + solver_constraint.GetCoefficient(term.variable), ) - # TODO: this dictionary does not make sense, we override the content when there are multiple instances - solver_constraints[constraint_name] = solver_constraint - return solver_constraints + # TODO : To be done in linear expression resolution coeff + # for term in data.expression.terms.values(): + # # Move this to resolve coefficient + # solver_vars = _get_solver_vars( + # term, + # context, + # block_timestep, + # scenario, + # instance, + # ) + # for solver_var in solver_vars: + # coefficient = term.coefficient + solver_constraint.GetCoefficient( + # solver_var + # ) + # solver_constraint.SetCoefficient(solver_var, coefficient) + # TODO: On pourrait aussi faire que l'objet Constraint n'ait pas de terme constant dans son expression et que les constantes soit déjà prises en compte dans les bornes, ça simplifierait le traitement ici + constant += data.expression.constant + + solver_constraint.SetBounds( + data.lower_bound - constant, data.upper_bound - constant + ) + + return solver_constraint class OptimizationProblem: From 80073067e673f86c577377b50c9ee37ea6fb1486 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 14 Aug 2024 18:58:30 +0200 Subject: [PATCH 24/51] Be able to create variables --- src/andromede/simulation/optimization.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 11e2dbe1..4e13a3db 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -745,6 +745,8 @@ def _create_variables(self) -> None: component_context = self.context.get_component_context(component) model = component.model + value_provider = _make_value_provider(self.context, component) + for model_var in self.strategy.get_variables(model): var_indexing = IndexingStructure( model_var.structure.time, model_var.structure.scenario @@ -764,13 +766,22 @@ def _create_variables(self) -> None: lower_bound = -self.solver.infinity() upper_bound = self.solver.infinity() if instantiated_lb_expr: - lower_bound = component_context.get_values( - instantiated_lb_expr - ).get_value(block_timestep, scenario) + if instantiated_lb_expr.is_constant(): + # TODO: Improve API + lower_bound = instantiated_lb_expr.resolve_coefficient( + value_provider, RowIndex(block_timestep, scenario) + ).constant + # lower_bound = component_context.get_values( + # instantiated_lb_expr + # ).get_value(block_timestep, scenario) if instantiated_ub_expr: - upper_bound = component_context.get_values( - instantiated_ub_expr - ).get_value(block_timestep, scenario) + if instantiated_ub_expr.is_constant(): + upper_bound = instantiated_ub_expr.resolve_coefficient( + value_provider, RowIndex(block_timestep, scenario) + ).constant + # upper_bound = component_context.get_values( + # instantiated_ub_expr + # ).get_value(block_timestep, scenario) # TODO: Add BoolVar or IntVar if the variable is specified to be integer or bool # Externally, for the Solver, this variable will have a full name From 1e1ac18e203102dedaaa088af7f38dde93e0c05d Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Fri, 16 Aug 2024 18:21:13 +0200 Subject: [PATCH 25/51] Test resolve coefficients, update test evaluation context --- src/andromede/expression/evaluate.py | 36 +- .../evaluate_parameters_efficient.py | 58 ++-- .../expression/linear_expression_efficient.py | 49 ++- src/andromede/expression/value_provider.py | 3 +- src/andromede/simulation/optimization.py | 19 +- .../expressions/test_expressions_efficient.py | 41 ++- .../expressions/test_resolve_coefficients.py | 314 ++++++++++++++++++ 7 files changed, 467 insertions(+), 53 deletions(-) create mode 100644 tests/unittests/expressions/test_resolve_coefficients.py diff --git a/src/andromede/expression/evaluate.py b/src/andromede/expression/evaluate.py index bcb66a4d..a56f912d 100644 --- a/src/andromede/expression/evaluate.py +++ b/src/andromede/expression/evaluate.py @@ -26,11 +26,15 @@ TimeAggregatorNode, TimeOperatorNode, ) -from andromede.expression.value_provider import ValueProvider +from andromede.expression.value_provider import ( + TimeScenarioIndex, + TimeScenarioIndices, + ValueProvider, +) from .visitor import ExpressionVisitorOperations, visit - +# Used only for tests @dataclass(frozen=True) class EvaluationContext(ValueProvider): """ @@ -41,21 +45,37 @@ class EvaluationContext(ValueProvider): variables: Dict[str, float] = field(default_factory=dict) parameters: Dict[str, float] = field(default_factory=dict) - def get_variable_value(self, name: str) -> float: - return self.variables[name] + def get_variable_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + return {TimeScenarioIndex(0, 0): self.variables[name]} - def get_parameter_value(self, name: str) -> float: - return self.parameters[name] + def get_parameter_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + return {TimeScenarioIndex(0, 0): self.parameters[name]} - def get_component_variable_value(self, component_id: str, name: str) -> float: + def get_component_variable_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: raise NotImplementedError() - def get_component_parameter_value(self, component_id: str, name: str) -> float: + def get_component_parameter_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: raise NotImplementedError() def parameter_is_constant_over_time(self, name: str) -> bool: raise NotImplementedError() + @staticmethod + def block_length() -> int: + raise NotImplementedError() + + @staticmethod + def scenarios() -> int: + raise NotImplementedError() + @dataclass(frozen=True) class EvaluationVisitor(ExpressionVisitorOperations[float]): diff --git a/src/andromede/expression/evaluate_parameters_efficient.py b/src/andromede/expression/evaluate_parameters_efficient.py index 053ec33e..c3eb5402 100644 --- a/src/andromede/expression/evaluate_parameters_efficient.py +++ b/src/andromede/expression/evaluate_parameters_efficient.py @@ -134,7 +134,7 @@ def comp_parameter( def time_operator(self, node: TimeOperatorNode) -> Dict[TimeScenarioIndex, float]: self.time_scenario_indices.time_indices = get_time_ids_from_instances_index( - node.instances_index, self.context + node.instances_index, self.context, self.row_id ) if node.name == TimeOperatorName.SHIFT: self.time_scenario_indices.time_indices = [ @@ -199,24 +199,22 @@ def port_field_aggregator( raise ValueError("Port fields must be resolved before evaluating parameters") -def is_valid_resolved_expr( +def check_resolved_expr( resolved_expr: Dict[TimeScenarioIndex, float], row_id: RowIndex ) -> bool: # Check that the resolved expression has been correctly time and scenario aggregated so that only a float is left - return ( - len(resolved_expr) == 1 - and TimeScenarioIndex(row_id.time, row_id.scenario) in resolved_expr - ) + if len(resolved_expr) != 1: + raise ValueError("Evaluation of expression cannot be reduced to a float value") + if TimeScenarioIndex(row_id.time, row_id.scenario) not in resolved_expr: + raise ValueError("Expression has a time operator but not time aggregator, maybe you are missing a sum(), necessary even on one element") def resolve_coefficient( expression: ExpressionNodeEfficient, value_provider: ValueProvider, row_id: RowIndex ) -> float: result = visit(expression, ParameterEvaluationVisitor(value_provider, row_id)) - if is_valid_resolved_expr(result, row_id): - return result[TimeScenarioIndex(row_id.time, row_id.scenario)] - else: - raise ValueError("Evaluation of expression cannot be reduced to a float value") + check_resolved_expr(result, row_id) + return result[TimeScenarioIndex(row_id.time, row_id.scenario)] @dataclass(frozen=True) @@ -228,12 +226,25 @@ class InstancesIndexVisitor(ParameterEvaluationVisitor): # def variable(self, node: VariableNode) -> float: # raise ValueError("An instance index expression cannot contain variable") + # Probably useless as parameter nodes should have already be replaced by component parameter nodes ? def parameter(self, node: ParameterNode) -> float: if not self.context.parameter_is_constant_over_time(node.name): raise ValueError( "Parameter given in an instance index expression must be constant over time" ) - return self.context.get_parameter_value(node.name) + + return self.context.get_parameter_value(node.name, self.time_scenario_indices) + + def comp_parameter( + self, node: ComponentParameterNode + ) -> Dict[TimeScenarioIndex, float]: + if not self.context.parameter_is_constant_over_time(node.name): + raise ValueError( + "Parameter given in an instance index expression must be constant over time" + ) + return self.context.get_component_parameter_value( + node.component_id, node.name, self.time_scenario_indices + ) def time_operator(self, node: TimeOperatorNode) -> float: raise ValueError("An instance index expression cannot contain time operator") @@ -250,30 +261,39 @@ def float_to_int(value: float) -> int: def evaluate_time_id( - expr: ExpressionNodeEfficient, value_provider: ValueProvider + expr: ExpressionNodeEfficient, value_provider: ValueProvider, row_id: RowIndex ) -> int: - float_time_id = visit(expr, InstancesIndexVisitor(value_provider)) + float_time_id_in_list = visit(expr, InstancesIndexVisitor(value_provider, row_id)) + check_resolved_expr(float_time_id_in_list, row_id) try: - time_id = float_to_int(float_time_id) + time_id = float_to_int( + float_time_id_in_list[TimeScenarioIndex(row_id.time, row_id.scenario)] + ) except ValueError: print(f"{expr} does not represent an integer time index.") return time_id def get_time_ids_from_instances_index( - instances_index: InstancesTimeIndex, value_provider: ValueProvider + instances_index: InstancesTimeIndex, value_provider: ValueProvider, row_id: RowIndex ) -> List[int]: time_ids = [] if isinstance(instances_index.expressions, list): # List[ExpressionNode] for expr in instances_index.expressions: - time_ids.append(evaluate_time_id(expr, value_provider)) + time_ids.append(evaluate_time_id(expr, value_provider, row_id)) elif isinstance(instances_index.expressions, ExpressionRange): # ExpressionRange - start_id = evaluate_time_id(instances_index.expressions.start, value_provider) - stop_id = evaluate_time_id(instances_index.expressions.stop, value_provider) + start_id = evaluate_time_id( + instances_index.expressions.start, value_provider, row_id + ) + stop_id = evaluate_time_id( + instances_index.expressions.stop, value_provider, row_id + ) step_id = 1 if instances_index.expressions.step is not None: - step_id = evaluate_time_id(instances_index.expressions.step, value_provider) + step_id = evaluate_time_id( + instances_index.expressions.step, value_provider, row_id + ) # ExpressionRange includes stop_id whereas range excludes it time_ids = list(range(start_id, stop_id + 1, step_id)) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 308c3d65..9c7634f1 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -31,7 +31,10 @@ from andromede.expression.context_adder import add_component_context from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import evaluate -from andromede.expression.evaluate_parameters_efficient import resolve_coefficient +from andromede.expression.evaluate_parameters_efficient import ( + check_resolved_expr, + resolve_coefficient, +) from andromede.expression.expression_efficient import ( ExpressionNodeEfficient, ExpressionRange, @@ -54,6 +57,10 @@ from andromede.expression.indexing_structure import IndexingStructure, RowIndex from andromede.expression.port_operator import PortAggregator, PortSum from andromede.expression.print import print_expr +from andromede.expression.resolved_linear_expression import ( + ResolvedLinearExpression, + ResolvedTerm, +) from andromede.expression.scenario_operator import Expectation, ScenarioOperator from andromede.expression.time_operator import ( TimeAggregator, @@ -62,11 +69,11 @@ TimeShift, TimeSum, ) -from andromede.expression.resolved_linear_expression import ( - ResolvedLinearExpression, - ResolvedTerm, +from andromede.expression.value_provider import ( + TimeScenarioIndex, + TimeScenarioIndices, + ValueProvider, ) -from andromede.expression.value_provider import ValueProvider @dataclass(frozen=True) @@ -141,16 +148,29 @@ def __str__(self) -> str: result += f".{str(self.scenario_operator)}" return result - def evaluate(self, context: ValueProvider) -> float: + def evaluate(self, context: ValueProvider, time_scenario_index: RowIndex) -> float: # TODO: Take care of component variables, multiple time scenarios, operators, etc + time_scenario_indices = TimeScenarioIndices( + [time_scenario_index.time], [time_scenario_index.scenario] + ) # Probably very error prone if self.component_id: variable_value = context.get_component_variable_value( - self.component_id, self.variable_name + self.component_id, self.variable_name, time_scenario_indices ) else: - variable_value = context.get_variable_value(self.variable_name) - return evaluate(self.coefficient, context) * variable_value + variable_value = context.get_variable_value( + self.variable_name, time_scenario_indices + ) + check_resolved_expr(variable_value, time_scenario_index) + return ( + resolve_coefficient(self.coefficient, context, time_scenario_index) + * variable_value[ + TimeScenarioIndex( + time_scenario_index.time, time_scenario_index.scenario + ) + ] + ) def compute_indexation( self, provider: IndexingStructureProvider @@ -693,10 +713,13 @@ def remove_zeros_from_terms(self) -> None: if is_zero(port_term.coefficient): del self.port_field_terms[port_term_key] - def evaluate(self, context: ValueProvider) -> float: - return sum([term.evaluate(context) for term in self.terms.values()]) + evaluate( - self.constant, context - ) + def evaluate(self, context: ValueProvider, time_scenario_index: RowIndex) -> float: + return sum( + [ + term.evaluate(context, time_scenario_index) + for term in self.terms.values() + ] + ) + resolve_coefficient(self.constant, context, time_scenario_index) def is_constant(self) -> bool: # Constant expr like x-x could be seen as non constant as we do not simplify coefficient tree... diff --git a/src/andromede/expression/value_provider.py b/src/andromede/expression/value_provider.py index cde69aab..e555c4ed 100644 --- a/src/andromede/expression/value_provider.py +++ b/src/andromede/expression/value_provider.py @@ -32,7 +32,7 @@ class ValueProvider(ABC): Implementations are in charge of mapping parameters and variables to their values. Depending on the implementation, evaluation may require a component id or not. """ - + #TODO: To be removed, or should we keep it to evaluate solutions ? @abstractmethod def get_variable_value( self, name: str, time_scenarios_indices: TimeScenarioIndices @@ -44,6 +44,7 @@ def get_parameter_value( self, name: str, time_scenarios_indices: TimeScenarioIndices ) -> Dict[TimeScenarioIndex, float]: ... + #TODO: To be removed ? @abstractmethod def get_component_variable_value( self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index c4a63522..d44977f3 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -85,6 +85,14 @@ def get_value(self, block_timestep: int, scenario: int) -> float: raise NotImplementedError() +def _get_data_time_key(block_timestep: int, data_indexing: IndexingStructure) -> int: + return block_timestep if data_indexing.time else 0 + + +def _get_data_scenario_key(scenario: int, data_indexing: IndexingStructure) -> int: + return scenario if data_indexing.scenario else 0 + + def _make_value_provider( context: "OptimizationContext", component: Component, @@ -114,11 +122,20 @@ def get_component_parameter_value( time_scenarios_indices: TimeScenarioIndices, ) -> Dict[TimeScenarioIndex, float]: result = {} + param_index = ( + context.network.get_component(component_id) + .model.parameters[name] + .structure + ) for block_timestep in time_scenarios_indices.time_indices: for scenario in time_scenarios_indices.scenario_indices: result[TimeScenarioIndex(block_timestep, scenario)] = ( _get_parameter_value( - context, block_timestep, scenario, component_id, name + context, + _get_data_time_key(block_timestep, param_index), + _get_data_scenario_key(scenario, param_index), + component_id, + name, ) ) return result diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index a1304f61..75fba9fe 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -34,7 +34,7 @@ param, ) from andromede.expression.indexing import IndexingStructureProvider -from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.indexing_structure import IndexingStructure, RowIndex from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, StandaloneConstraint, @@ -47,6 +47,7 @@ wrap_in_linear_expr, ) from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum +from andromede.expression.value_provider import TimeScenarioIndex, TimeScenarioIndices @dataclass(frozen=True) @@ -69,21 +70,37 @@ class ComponentEvaluationContext(ValueProvider): variables: Dict[ComponentValueKey, float] = field(default_factory=dict) parameters: Dict[ComponentValueKey, float] = field(default_factory=dict) - def get_variable_value(self, name: str) -> float: + def get_variable_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: raise NotImplementedError() - def get_parameter_value(self, name: str) -> float: + def get_parameter_value( + self, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: raise NotImplementedError() - def get_component_variable_value(self, component_id: str, name: str) -> float: - return self.variables[comp_key(component_id, name)] + def get_component_variable_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + return {TimeScenarioIndex(0, 0): self.variables[comp_key(component_id, name)]} - def get_component_parameter_value(self, component_id: str, name: str) -> float: - return self.parameters[comp_key(component_id, name)] + def get_component_parameter_value( + self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices + ) -> Dict[TimeScenarioIndex, float]: + return {TimeScenarioIndex(0, 0): self.parameters[comp_key(component_id, name)]} def parameter_is_constant_over_time(self, name: str) -> bool: raise NotImplementedError() + @staticmethod + def block_length() -> int: + raise NotImplementedError() + + @staticmethod + def scenarios() -> int: + raise NotImplementedError() + # TODO: Redundant with add tests in test_linear_expressions_efficient ? def test_comp_parameter() -> None: @@ -98,7 +115,8 @@ def test_comp_parameter() -> None: context = ComponentEvaluationContext( variables={comp_key("comp1", "x"): 3}, parameters={comp_key("comp1", "p"): 4} ) - assert expr2.evaluate(context) == 1 + # Need to specify at which (t, w) to evaluate as the information is not contained anymore within the value provider + assert expr2.evaluate(context, RowIndex(0, 0)) == 1 # TODO: Find a better name @@ -111,7 +129,7 @@ def test_ast() -> None: assert str(expr2) == "(1.0 / p)x + (1.0 / p)" context = EvaluationContext(variables={"x": 3}, parameters={"p": 4}) - assert expr2.evaluate(context) == 1 + assert expr2.evaluate(context, RowIndex(0, 0)) == 1 def test_operators() -> None: @@ -122,9 +140,9 @@ def test_operators() -> None: assert str(expr) == "(5.0 / p)x + ((3.0 / p) - 2.0)" context = EvaluationContext(variables={"x": 3}, parameters={"p": 4}) - assert expr.evaluate(context) == pytest.approx(2.5, 1e-16) + assert expr.evaluate(context, RowIndex(0, 0)) == pytest.approx(2.5, 1e-16) - assert -expr.evaluate(context) == pytest.approx(-2.5, 1e-16) + assert -expr.evaluate(context, RowIndex(0, 0)) == pytest.approx(-2.5, 1e-16) # def test_degree() -> None: @@ -568,6 +586,7 @@ def test_eval_on_time_step_list_raises_value_error() -> None: _ = x.eval(ExpressionRange(1, 4)) +# TODO: Shoudl be moved to test_linear_expression_efficient @pytest.mark.parametrize( "linear_expr, expected_indexation", [ diff --git a/tests/unittests/expressions/test_resolve_coefficients.py b/tests/unittests/expressions/test_resolve_coefficients.py new file mode 100644 index 00000000..914703a7 --- /dev/null +++ b/tests/unittests/expressions/test_resolve_coefficients.py @@ -0,0 +1,314 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +import math +import re +from typing import Dict + +import pytest + +from andromede.expression.evaluate_parameters_efficient import resolve_coefficient +from andromede.expression.expression_efficient import ( + Comparator, + ComparisonNode, + ExpressionNodeEfficient, + ExpressionRange, + InstancesTimeIndex, + LiteralNode, + PortFieldAggregatorName, + PortFieldAggregatorNode, + PortFieldNode, + TimeOperatorName, + TimeOperatorNode, + comp_param, + param, +) +from andromede.expression.indexing_structure import IndexingStructure, RowIndex +from andromede.expression.value_provider import ( + TimeScenarioIndex, + TimeScenarioIndices, + ValueProvider, +) + +p_values = { + TimeScenarioIndex(0, 0): 1.0, + TimeScenarioIndex(1, 0): 2.0, + TimeScenarioIndex(2, 0): 3.0, + TimeScenarioIndex(3, 0): 7.0, + TimeScenarioIndex(0, 1): 4.0, + TimeScenarioIndex(1, 1): 5.0, + TimeScenarioIndex(2, 1): 6.0, + TimeScenarioIndex(3, 1): 8.0, +} + +# A time constant parameter that can be put as TimeShift arg +comp_q_values = { + TimeScenarioIndex(0, 0): 2.0, + TimeScenarioIndex(0, 1): 1.0, +} + + +def _get_data_time_key(block_timestep: int, data_indexing: IndexingStructure) -> int: + return block_timestep if data_indexing.time else 0 + + +def _get_data_scenario_key(scenario: int, data_indexing: IndexingStructure) -> int: + return scenario if data_indexing.scenario else 0 + + +class CustomValueProvider(ValueProvider): + def get_component_variable_value( + self, + component_id: str, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError( + "Cannot provide variable value at problem build time." + ) + + def get_component_parameter_value( + self, + component_id: str, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + result = {} + param_indexing = IndexingStructure(False, True) + for block_timestep in time_scenarios_indices.time_indices: + for scenario in time_scenarios_indices.scenario_indices: + result[TimeScenarioIndex(block_timestep, scenario)] = comp_q_values[ + TimeScenarioIndex( + _get_data_time_key(block_timestep, param_indexing), + _get_data_scenario_key(scenario, param_indexing), + ) + ] + return result + + def get_variable_value( + self, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError( + "Cannot provide variable value at problem build time." + ) + + def get_parameter_value( + self, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + result = {} + param_indexing = IndexingStructure(True, True) + for block_timestep in time_scenarios_indices.time_indices: + for scenario in time_scenarios_indices.scenario_indices: + result[TimeScenarioIndex(block_timestep, scenario)] = p_values[ + TimeScenarioIndex( + _get_data_time_key(block_timestep, param_indexing), + _get_data_scenario_key(scenario, param_indexing), + ) + ] + return result + + def parameter_is_constant_over_time(self, name: str) -> bool: + return True + + @staticmethod + def block_length() -> int: + return 4 + + @staticmethod + def scenarios() -> int: + return 2 + + +@pytest.fixture +def provider() -> CustomValueProvider: + return CustomValueProvider() + + +@pytest.mark.parametrize( + "port_node", + [ + (PortFieldNode("port", "field")), + ( + PortFieldAggregatorNode( + PortFieldNode("port", "field"), PortFieldAggregatorName.PORT_SUM + ) + ), + ], +) +def test_resolve_coefficient_raises_value_error_on_port_field_node( + port_node: ExpressionNodeEfficient, provider: CustomValueProvider +) -> None: + with pytest.raises( + ValueError, match="Port fields must be resolved before evaluating parameters" + ): + resolve_coefficient(port_node, provider, RowIndex(0, 0)) + + +def test_resolve_coefficient_raises_value_error_on_comparison_node( + provider: CustomValueProvider, +) -> None: + expr = ComparisonNode(LiteralNode(0), param("p"), Comparator.EQUAL) + with pytest.raises(ValueError, match="Cannot evaluate comparison operator."): + resolve_coefficient(expr, provider, RowIndex(0, 0)) + + +@pytest.mark.parametrize( + "expr", + [ + ( + TimeOperatorNode( + param("p"), + TimeOperatorName.SHIFT, + InstancesTimeIndex([LiteralNode(1), LiteralNode(2)]), + ) + ), + ( + TimeOperatorNode( + param("p"), + TimeOperatorName.EVALUATION, + InstancesTimeIndex([LiteralNode(1), LiteralNode(2)]), + ) + ), + ], +) +def test_resolve_coefficient_raises_value_error_on_expressions_that_are_not_aggregated_on_a_single_time_and_scenario( + expr: ExpressionNodeEfficient, provider: CustomValueProvider +) -> None: + with pytest.raises( + ValueError, match="Evaluation of expression cannot be reduced to a float value" + ): + resolve_coefficient(expr, provider, RowIndex(0, 0)) + + +@pytest.mark.parametrize( + "expr", + [ + (param("p").shift(2)), + (param("p").eval(2)), + param("p").shift(comp_param("c", "q")), + ], +) +def test_resolve_coefficient_on_expression_with_shift_but_without_sum_raises_value_error( + expr: ExpressionNodeEfficient, + provider: CustomValueProvider, +) -> None: + with pytest.raises( + ValueError, + match=re.escape( + "Expression has a time operator but not time aggregator, maybe you are missing a sum(), necessary even on one element" + ), + ): + resolve_coefficient(expr, provider, RowIndex(0, 0)) + + +@pytest.mark.parametrize( + "expr", + [ + ( + TimeOperatorNode( + param("p"), + TimeOperatorName.EVALUATION, + InstancesTimeIndex([comp_param("c", "q")]), + ) + ), + ( + TimeOperatorNode( + param("p"), + TimeOperatorName.SHIFT, + InstancesTimeIndex([param("q")]), + ) + ), + ], +) +def test_resolve_coefficient_with_no_time_varying_parameter_in_time_operator_argument_raises_value_error( + expr: ExpressionNodeEfficient, +) -> None: + + class TimeVaryingParameterValueProvider(CustomValueProvider): + def parameter_is_constant_over_time(self, name: str) -> bool: + return False + + provider = TimeVaryingParameterValueProvider() + + with pytest.raises( + ValueError, + match="Parameter given in an instance index expression must be constant over time", + ): + resolve_coefficient(expr, provider, RowIndex(0, 0)) + + +@pytest.mark.parametrize( + "expr, row_id, expected", + [ + (param("p"), RowIndex(0, 0), 1.0), + (comp_param("c", "q"), RowIndex(0, 0), 2.0), + (-comp_param("c", "q"), RowIndex(0, 0), -2.0), + (param("p") + comp_param("c", "q"), RowIndex(0, 0), 3.0), + (param("p") - comp_param("c", "q"), RowIndex(0, 0), -1.0), + (param("p") * LiteralNode(2), RowIndex(0, 0), 2.0), + (param("p") / LiteralNode(2), RowIndex(0, 0), 0.5), + ], +) +def test_resolve_coefficient_on_elementary_operations( + expr: ExpressionNodeEfficient, + row_id: RowIndex, + expected: float, + provider: CustomValueProvider, +) -> None: + assert math.isclose(resolve_coefficient(expr, provider, row_id), expected) + + +@pytest.mark.parametrize( + "expr, row_id, expected", + [ + (param("p").shift(2).sum(), RowIndex(0, 0), 3.0), + (param("p").shift(-1).sum(), RowIndex(2, 1), 5.0), + (param("p").eval(2).sum(), RowIndex(0, 0), 3.0), + (param("p").eval(2).sum(), RowIndex(2, 0), 3.0), + (param("p").shift(ExpressionRange(0, 3)).sum(), RowIndex(0, 0), 13.0), + (param("p").eval(ExpressionRange(1, 2)).sum(), RowIndex(0, 0), 5.0), + (param("p").eval(ExpressionRange(0, 3, 2)).sum(), RowIndex(0, 0), 4.0), + (param("p").shift(comp_param("c", "q")).sum(), RowIndex(1, 0), 7.0), + (param("p").shift(comp_param("c", "q")).sum(), RowIndex(1, 1), 6.0), + (param("p").sum(), RowIndex(0, 0), 13.0), + (param("p").sum(), RowIndex(2, 1), 23.0), + (comp_param("c", "q").sum(), RowIndex(0, 0), 2 * 4.0), + ], +) +def test_resolve_coefficient_on_time_shift_and_sum( + expr: ExpressionNodeEfficient, + row_id: RowIndex, + expected: float, + provider: CustomValueProvider, +) -> None: + assert math.isclose(resolve_coefficient(expr, provider, row_id), expected) + + +@pytest.mark.parametrize( + "expr, row_id, expected", + [ + (param("p").expec(), RowIndex(0, 0), 2.5), + (param("p").expec(), RowIndex(1, 1), 3.5), + (comp_param("c", "q").expec(), RowIndex(1, 1), 1.5), + ], +) +def test_resolve_coefficient_on_expectation( + expr: ExpressionNodeEfficient, + row_id: RowIndex, + expected: float, + provider: CustomValueProvider, +) -> None: + assert math.isclose(resolve_coefficient(expr, provider, row_id), expected) From 1be9d8496c1078cb06cc16d3b54c1829dd460d09 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Fri, 16 Aug 2024 19:01:22 +0200 Subject: [PATCH 26/51] Improve resolve coefficient API --- .../expression/indexing_structure.py | 23 ++++++++++--------- .../expression/linear_expression_efficient.py | 8 +++++++ src/andromede/simulation/optimization.py | 23 ++++++++----------- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/andromede/expression/indexing_structure.py b/src/andromede/expression/indexing_structure.py index 28147377..2c81e708 100644 --- a/src/andromede/expression/indexing_structure.py +++ b/src/andromede/expression/indexing_structure.py @@ -27,6 +27,18 @@ def __or__(self, other: "IndexingStructure") -> "IndexingStructure": scenario = self.scenario or other.scenario return IndexingStructure(time, scenario) + def is_time_varying(self) -> bool: + return self.time + + def is_scenario_varying(self) -> bool: + return self.scenario + + def is_time_scenario_varying(self) -> bool: + return self.is_time_varying() and self.is_scenario_varying() + + def is_constant(self) -> bool: + return (not self.is_time_varying()) and (not self.is_scenario_varying()) + # Contrary to IndexingStructure, time and scenario are integers to "count/identify" the constraint whereas IndexingStructure is used used to know whether or not an expression is indexed by time or scenario. @dataclass(frozen=True) @@ -40,14 +52,3 @@ class RowIndex: def __str__(self) -> str: return f"t{self.time}_s{self.scenario}" - def is_time_varying(self) -> bool: - return self.time - - def is_scenario_varying(self) -> bool: - return self.scenario - - def is_time_scenario_varying(self) -> bool: - return self.is_time_varying() and self.is_scenario_varying() - - def is_constant(self) -> bool: - return (not self.is_time_varying()) and (not self.is_scenario_varying()) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 9c7634f1..f568ba60 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -980,6 +980,14 @@ def resolve_coefficient( resolved_constant = resolve_coefficient(self.constant, value_provider, row_id) return ResolvedLinearExpression(resolved_terms, resolved_constant) + def resolve_constant_expr( + self, value_provider: ValueProvider, row_id: RowIndex + ) -> float: + if not self.is_constant(): + raise ValueError(f"{str(self)} is not a constant expression") + resolved_expr = self.resolve_coefficient(value_provider, row_id) + return resolved_expr.constant + def linear_expressions_equal( lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index d44977f3..a9db5226 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -516,10 +516,10 @@ def _create_constraint( resolved_expr = constraint.expression.resolve_coefficient( value_provider, row_id ) - resolved_lb = constraint.lower_bound.resolve_coefficient( + resolved_lb = constraint.lower_bound.resolve_constant_expr( value_provider, row_id ) - resolved_ub = constraint.upper_bound.resolve_coefficient( + resolved_ub = constraint.upper_bound.resolve_constant_expr( value_provider, row_id ) @@ -579,8 +579,8 @@ def _create_objective( @dataclass class ConstraintData: name: str - lower_bound: ResolvedLinearExpression # Or a float ? - upper_bound: ResolvedLinearExpression # Or a float ? + lower_bound: float + upper_bound: float expression: ResolvedLinearExpression @@ -794,19 +794,16 @@ def _create_variables(self) -> None: lower_bound = -self.solver.infinity() upper_bound = self.solver.infinity() if instantiated_lb_expr: - if instantiated_lb_expr.is_constant(): - # TODO: Improve API - lower_bound = instantiated_lb_expr.resolve_coefficient( - value_provider, RowIndex(block_timestep, scenario) - ).constant + lower_bound = instantiated_lb_expr.resolve_constant_expr( + value_provider, RowIndex(block_timestep, scenario) + ) # lower_bound = component_context.get_values( # instantiated_lb_expr # ).get_value(block_timestep, scenario) if instantiated_ub_expr: - if instantiated_ub_expr.is_constant(): - upper_bound = instantiated_ub_expr.resolve_coefficient( - value_provider, RowIndex(block_timestep, scenario) - ).constant + upper_bound = instantiated_ub_expr.resolve_constant_expr( + value_provider, RowIndex(block_timestep, scenario) + ) # upper_bound = component_context.get_values( # instantiated_ub_expr # ).get_value(block_timestep, scenario) From b532a4173ae0a4fd07e48560d73a1aded12ba610 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Mon, 19 Aug 2024 17:10:45 +0200 Subject: [PATCH 27/51] Start resolve variables, separate optimization context from optimization, rename scenario operator in scenario aggregator --- src/andromede/expression/__init__.py | 18 +- src/andromede/expression/evaluate.py | 1 + .../evaluate_parameters_efficient.py | 16 +- src/andromede/expression/expression.py | 4 +- .../expression/expression_efficient.py | 2 +- .../expression/linear_expression_efficient.py | 52 +- src/andromede/expression/scenario_operator.py | 6 +- src/andromede/expression/value_provider.py | 28 +- src/andromede/simulation/__init__.py | 3 +- .../simulation/benders_decomposed.py | 7 +- src/andromede/simulation/linear_expression.py | 26 +- .../simulation/linear_expression_resolver.py | 122 +++++ src/andromede/simulation/linearize.py | 2 +- src/andromede/simulation/optimization.py | 453 +----------------- .../simulation/optimization_context.py | 435 +++++++++++++++++ tests/functional/test_andromede.py | 7 +- tests/functional/test_andromede_yml.py | 8 +- .../functional/test_performance_efficient.py | 16 +- .../models/test_short_term_storage_complex.py | 3 +- .../expressions/test_expressions_efficient.py | 12 +- .../expressions/test_linear_expressions.py | 24 +- .../test_linear_expressions_efficient.py | 20 +- .../expressions/test_resolve_coefficients.py | 1 - .../expressions/test_term_efficient.py | 19 +- .../study/test_components_parsing.py | 3 +- tests/unittests/test_output_values.py | 7 +- 26 files changed, 719 insertions(+), 576 deletions(-) create mode 100644 src/andromede/simulation/linear_expression_resolver.py create mode 100644 src/andromede/simulation/optimization_context.py diff --git a/src/andromede/expression/__init__.py b/src/andromede/expression/__init__.py index 8e481cce..55b51967 100644 --- a/src/andromede/expression/__init__.py +++ b/src/andromede/expression/__init__.py @@ -20,16 +20,16 @@ ) from .expression import ( -# AdditionNode, -# Comparator, -# ComparisonNode, -# DivisionNode, + # AdditionNode, + # Comparator, + # ComparisonNode, + # DivisionNode, ExpressionNode, -# LiteralNode, -# MultiplicationNode, -# NegationNode, -# ParameterNode, -# SubstractionNode, + # LiteralNode, + # MultiplicationNode, + # NegationNode, + # ParameterNode, + # SubstractionNode, VariableNode, literal, param, diff --git a/src/andromede/expression/evaluate.py b/src/andromede/expression/evaluate.py index a56f912d..08477070 100644 --- a/src/andromede/expression/evaluate.py +++ b/src/andromede/expression/evaluate.py @@ -34,6 +34,7 @@ from .visitor import ExpressionVisitorOperations, visit + # Used only for tests @dataclass(frozen=True) class EvaluationContext(ValueProvider): diff --git a/src/andromede/expression/evaluate_parameters_efficient.py b/src/andromede/expression/evaluate_parameters_efficient.py index c3eb5402..3c5d3393 100644 --- a/src/andromede/expression/evaluate_parameters_efficient.py +++ b/src/andromede/expression/evaluate_parameters_efficient.py @@ -201,12 +201,14 @@ def port_field_aggregator( def check_resolved_expr( resolved_expr: Dict[TimeScenarioIndex, float], row_id: RowIndex -) -> bool: +) -> None: # Check that the resolved expression has been correctly time and scenario aggregated so that only a float is left if len(resolved_expr) != 1: raise ValueError("Evaluation of expression cannot be reduced to a float value") if TimeScenarioIndex(row_id.time, row_id.scenario) not in resolved_expr: - raise ValueError("Expression has a time operator but not time aggregator, maybe you are missing a sum(), necessary even on one element") + raise ValueError( + "Expression has a time operator but not time aggregator, maybe you are missing a sum(), necessary even on one element" + ) def resolve_coefficient( @@ -227,12 +229,12 @@ class InstancesIndexVisitor(ParameterEvaluationVisitor): # raise ValueError("An instance index expression cannot contain variable") # Probably useless as parameter nodes should have already be replaced by component parameter nodes ? - def parameter(self, node: ParameterNode) -> float: + def parameter(self, node: ParameterNode) -> Dict[TimeScenarioIndex, float]: if not self.context.parameter_is_constant_over_time(node.name): raise ValueError( "Parameter given in an instance index expression must be constant over time" ) - + return self.context.get_parameter_value(node.name, self.time_scenario_indices) def comp_parameter( @@ -246,10 +248,12 @@ def comp_parameter( node.component_id, node.name, self.time_scenario_indices ) - def time_operator(self, node: TimeOperatorNode) -> float: + def time_operator(self, node: TimeOperatorNode) -> Dict[TimeScenarioIndex, float]: raise ValueError("An instance index expression cannot contain time operator") - def time_aggregator(self, node: TimeAggregatorNode) -> float: + def time_aggregator( + self, node: TimeAggregatorNode + ) -> Dict[TimeScenarioIndex, float]: raise ValueError("An instance index expression cannot contain time aggregator") diff --git a/src/andromede/expression/expression.py b/src/andromede/expression/expression.py index eb03dd5b..01e8136b 100644 --- a/src/andromede/expression/expression.py +++ b/src/andromede/expression/expression.py @@ -435,7 +435,9 @@ def __post_init__(self) -> None: for _, cls in inspect.getmembers( andromede.expression.scenario_operator, inspect.isclass ) - if issubclass(cls, andromede.expression.scenario_operator.ScenarioOperator) + if issubclass( + cls, andromede.expression.scenario_operator.ScenarioAggregator + ) ] if self.name not in valid_names: raise ValueError( diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index 575e4494..a0fd86e7 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -549,7 +549,7 @@ def __post_init__(self) -> None: @dataclass(frozen=True, eq=False) class TimeAggregatorNode(UnaryOperatorNode): name: TimeAggregatorName - stay_roll: bool # TODO: Is it still useful ? + stay_roll: bool # TODO: Is it still useful ? def __post_init__(self) -> None: if not isinstance(self.name, TimeAggregatorName): diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index f568ba60..1802ddd3 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -28,11 +28,14 @@ overload, ) +import ortools.linear_solver.pywraplp as lp + from andromede.expression.context_adder import add_component_context from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import evaluate from andromede.expression.evaluate_parameters_efficient import ( check_resolved_expr, + get_time_ids_from_instances_index, resolve_coefficient, ) from andromede.expression.expression_efficient import ( @@ -61,7 +64,7 @@ ResolvedLinearExpression, ResolvedTerm, ) -from andromede.expression.scenario_operator import Expectation, ScenarioOperator +from andromede.expression.scenario_operator import Expectation, ScenarioAggregator from andromede.expression.time_operator import ( TimeAggregator, TimeEvaluation, @@ -86,7 +89,7 @@ class TermKeyEfficient: variable_name: str time_operator: Optional[TimeOperator] time_aggregator: Optional[TimeAggregator] - scenario_operator: Optional[ScenarioOperator] + scenario_aggregator: Optional[ScenarioAggregator] @dataclass(frozen=True) @@ -107,7 +110,7 @@ class TermEfficient: ) time_operator: Optional[TimeOperator] = None time_aggregator: Optional[TimeAggregator] = None - scenario_operator: Optional[ScenarioOperator] = None + scenario_aggregator: Optional[ScenarioAggregator] = None def __post_init__(self) -> None: object.__setattr__(self, "coefficient", wrap_in_node(self.coefficient)) @@ -121,7 +124,7 @@ def __eq__(self, other: object) -> bool: and self.structure == other.structure and self.time_operator == other.time_operator and self.time_aggregator == other.time_aggregator - and self.scenario_operator == other.scenario_operator + and self.scenario_aggregator == other.scenario_aggregator ) def is_zero(self) -> bool: @@ -144,8 +147,8 @@ def __str__(self) -> str: result += f".{str(self.time_operator)}" if self.time_aggregator is not None: result += f".{str(self.time_aggregator)}" - if self.scenario_operator is not None: - result += f".{str(self.scenario_operator)}" + if self.scenario_aggregator is not None: + result += f".{str(self.scenario_aggregator)}" return result def evaluate(self, context: ValueProvider, time_scenario_index: RowIndex) -> float: @@ -195,7 +198,7 @@ def _compute_time_indexing(self, provider: IndexingStructureProvider) -> bool: return time def _compute_scenario_indexing(self, provider: IndexingStructureProvider) -> bool: - if self.scenario_operator: + if self.scenario_aggregator: scenario = False else: # TODO: Improve this if/else structure, probably simplify IndexingStructureProvider @@ -320,7 +323,7 @@ def eval( def expec(self) -> "TermEfficient": # TODO: Do we need checks, in case a scenario operator is already specified ? - return dataclasses.replace(self, scenario_operator=Expectation()) + return dataclasses.replace(self, scenario_aggregator=Expectation()) def generate_key(term: TermEfficient) -> TermKeyEfficient: @@ -329,7 +332,7 @@ def generate_key(term: TermEfficient) -> TermKeyEfficient: term.variable_name, term.time_operator, term.time_aggregator, - term.scenario_operator, + term.scenario_aggregator, ) @@ -377,7 +380,8 @@ def _merge_dicts( rhs: Dict[TermKeyEfficient, TermEfficient], merge_func: Callable[[TermEfficient, TermEfficient], TermEfficient], neutral: float, -) -> Dict[TermKeyEfficient, TermEfficient]: ... +) -> Dict[TermKeyEfficient, TermEfficient]: + ... @overload @@ -386,7 +390,8 @@ def _merge_dicts( rhs: Dict[PortFieldId, PortFieldTerm], merge_func: Callable[[PortFieldTerm, PortFieldTerm], PortFieldTerm], neutral: float, -) -> Dict[PortFieldId, PortFieldTerm]: ... +) -> Dict[PortFieldId, PortFieldTerm]: + ... def _get_neutral_term(term: T_val, neutral: float) -> T_val: @@ -418,7 +423,7 @@ def _merge_term_is_possible(lhs: TermEfficient, rhs: TermEfficient) -> None: if ( lhs.time_operator != rhs.time_operator or lhs.time_aggregator != rhs.time_aggregator - or lhs.scenario_operator != rhs.scenario_operator + or lhs.scenario_aggregator != rhs.scenario_aggregator ): raise ValueError("Cannot merge terms with different operators") if lhs.structure != rhs.structure: @@ -965,29 +970,6 @@ def add_component_context(self, component_id: str) -> "LinearExpressionEfficient result_terms, result_constant, self.port_field_terms ) - def resolve_coefficient( - self, value_provider: ValueProvider, row_id: RowIndex - ) -> ResolvedLinearExpression: - - resolved_terms = [] - for term in self.terms.values(): - resolved_coeff = resolve_coefficient( - term.coefficient, value_provider, row_id - ) - resolved_variable = ... - resolved_terms.append(ResolvedTerm(resolved_coeff, resolved_variable)) - - resolved_constant = resolve_coefficient(self.constant, value_provider, row_id) - return ResolvedLinearExpression(resolved_terms, resolved_constant) - - def resolve_constant_expr( - self, value_provider: ValueProvider, row_id: RowIndex - ) -> float: - if not self.is_constant(): - raise ValueError(f"{str(self)} is not a constant expression") - resolved_expr = self.resolve_coefficient(value_provider, row_id) - return resolved_expr.constant - def linear_expressions_equal( lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient diff --git a/src/andromede/expression/scenario_operator.py b/src/andromede/expression/scenario_operator.py index 77f164e1..6eee08df 100644 --- a/src/andromede/expression/scenario_operator.py +++ b/src/andromede/expression/scenario_operator.py @@ -19,7 +19,7 @@ @dataclass(frozen=True) -class ScenarioOperator(ABC): +class ScenarioAggregator(ABC): def __str__(self) -> str: return NotImplemented @@ -30,7 +30,7 @@ def degree(cls) -> int: @dataclass(frozen=True) -class Expectation(ScenarioOperator): +class Expectation(ScenarioAggregator): def __str__(self) -> str: return "expec()" @@ -40,7 +40,7 @@ def degree(cls) -> int: @dataclass(frozen=True) -class Variance(ScenarioOperator): +class Variance(ScenarioAggregator): def __str__(self) -> str: return "variance()" diff --git a/src/andromede/expression/value_provider.py b/src/andromede/expression/value_provider.py index e555c4ed..f15b943a 100644 --- a/src/andromede/expression/value_provider.py +++ b/src/andromede/expression/value_provider.py @@ -20,51 +20,61 @@ class TimeScenarioIndices: time_indices: List[int] scenario_indices: List[int] + # TODO: Already define in study module, factorize this @dataclass(frozen=True) class TimeScenarioIndex: time: int scenario: int + # Given a list of time_indices and of scenario_indices, the value provider will get the parameter value for all couple (time, scenario) for time in time_indices and scenario in scenario_indices class ValueProvider(ABC): """ Implementations are in charge of mapping parameters and variables to their values. Depending on the implementation, evaluation may require a component id or not. """ - #TODO: To be removed, or should we keep it to evaluate solutions ? + + # TODO: To be removed, or should we keep it to evaluate solutions ? @abstractmethod def get_variable_value( self, name: str, time_scenarios_indices: TimeScenarioIndices - ) -> Dict[TimeScenarioIndex, float]: ... + ) -> Dict[TimeScenarioIndex, float]: + ... # Need to have time_scenarios_indices as function argument as we do not want to create a Provider each time we have to get the value of a parameter at a different (time, scenario) index @abstractmethod def get_parameter_value( self, name: str, time_scenarios_indices: TimeScenarioIndices - ) -> Dict[TimeScenarioIndex, float]: ... + ) -> Dict[TimeScenarioIndex, float]: + ... - #TODO: To be removed ? + # TODO: To be removed ? @abstractmethod def get_component_variable_value( self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices - ) -> Dict[TimeScenarioIndex, float]: ... + ) -> Dict[TimeScenarioIndex, float]: + ... @abstractmethod def get_component_parameter_value( self, component_id: str, name: str, time_scenarios_indices: TimeScenarioIndices - ) -> Dict[TimeScenarioIndex, float]: ... + ) -> Dict[TimeScenarioIndex, float]: + ... # TODO: Should this really be an abstract method ? Or maybe, only the Provider in _make_value_provider should implement it. And the context attribute in the InstancesIndexVisitor is a ValueProvider that implements the parameter_is_constant_over_time method. Maybe create a child class of ValueProvider like TimeValueProvider ? @abstractmethod - def parameter_is_constant_over_time(self, name: str) -> bool: ... + def parameter_is_constant_over_time(self, name: str) -> bool: + ... # There is probably a better place to put this.. # Which is useful when evaluating the TimeSum operator over the whole block, to know which time steps to look for to get the parameter values @staticmethod @abstractmethod - def block_length() -> int: ... + def block_length() -> int: + ... @staticmethod @abstractmethod - def scenarios() -> int: ... + def scenarios() -> int: + ... diff --git a/src/andromede/simulation/__init__.py b/src/andromede/simulation/__init__.py index 23e0a855..94cac7d3 100644 --- a/src/andromede/simulation/__init__.py +++ b/src/andromede/simulation/__init__.py @@ -14,7 +14,8 @@ BendersDecomposedProblem, build_benders_decomposed_problem, ) -from .optimization import BlockBorderManagement, OptimizationProblem, build_problem +from .optimization import OptimizationProblem, build_problem +from .optimization_context import BlockBorderManagement from .output_values import BendersSolution, OutputValues from .runner import BendersRunner, MergeMPSRunner from .strategy import MergedProblemStrategy, ModelSelectionStrategy diff --git a/src/andromede/simulation/benders_decomposed.py b/src/andromede/simulation/benders_decomposed.py index 08d718f1..9ab44139 100644 --- a/src/andromede/simulation/benders_decomposed.py +++ b/src/andromede/simulation/benders_decomposed.py @@ -18,11 +18,8 @@ import pathlib from typing import Any, Dict, List, Optional -from andromede.simulation.optimization import ( - BlockBorderManagement, - OptimizationProblem, - build_problem, -) +from andromede.simulation.optimization import OptimizationProblem, build_problem +from andromede.simulation.optimization_context import BlockBorderManagement from andromede.simulation.output_values import ( BendersDecomposedSolution, BendersMergedSolution, diff --git a/src/andromede/simulation/linear_expression.py b/src/andromede/simulation/linear_expression.py index 8167420d..1f3c9359 100644 --- a/src/andromede/simulation/linear_expression.py +++ b/src/andromede/simulation/linear_expression.py @@ -18,7 +18,7 @@ from typing import Callable, Dict, List, Optional, TypeVar, Union from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.scenario_operator import ScenarioOperator +from andromede.expression.scenario_operator import ScenarioAggregator from andromede.expression.time_operator import TimeAggregator, TimeOperator from andromede.model.model import PortFieldId @@ -53,7 +53,7 @@ class TermKey: variable_name: str time_operator: Optional[TimeOperator] time_aggregator: Optional[TimeAggregator] - scenario_operator: Optional[ScenarioOperator] + scenario_aggregator: Optional[ScenarioAggregator] @dataclass(frozen=True) @@ -74,7 +74,7 @@ class Term: ) time_operator: Optional[TimeOperator] = None time_aggregator: Optional[TimeAggregator] = None - scenario_operator: Optional[ScenarioOperator] = None + scenario_aggregator: Optional[ScenarioAggregator] = None # TODO: It may be useful to define __add__, __sub__, etc on terms, which should return a linear expression ? @@ -98,8 +98,8 @@ def __str__(self) -> str: result += f".{str(self.time_operator)}" if self.time_aggregator is not None: result += f".{str(self.time_aggregator)}" - if self.scenario_operator is not None: - result += f".{str(self.scenario_operator)}" + if self.scenario_aggregator is not None: + result += f".{str(self.scenario_aggregator)}" return result def number_of_instances(self) -> int: @@ -118,7 +118,7 @@ def generate_key(term: Term) -> TermKey: term.variable_name, term.time_operator, term.time_aggregator, - term.scenario_operator, + term.scenario_aggregator, ) @@ -141,7 +141,7 @@ def _merge_dicts( v.structure, v.time_operator, v.time_aggregator, - v.scenario_operator, + v.scenario_aggregator, ), ), ) @@ -155,7 +155,7 @@ def _merge_dicts( v.structure, v.time_operator, v.time_aggregator, - v.scenario_operator, + v.scenario_aggregator, ), v, ) @@ -168,7 +168,7 @@ def _merge_is_possible(lhs: Term, rhs: Term) -> None: if ( lhs.time_operator != rhs.time_operator or lhs.time_aggregator != rhs.time_aggregator - or lhs.scenario_operator != rhs.scenario_operator + or lhs.scenario_aggregator != rhs.scenario_aggregator ): raise ValueError("Cannot merge terms with different operators") if lhs.structure != rhs.structure: @@ -184,7 +184,7 @@ def _add_terms(lhs: Term, rhs: Term) -> Term: lhs.structure, lhs.time_operator, lhs.time_aggregator, - lhs.scenario_operator, + lhs.scenario_aggregator, ) @@ -197,7 +197,7 @@ def _substract_terms(lhs: Term, rhs: Term) -> Term: lhs.structure, lhs.time_operator, lhs.time_aggregator, - lhs.scenario_operator, + lhs.scenario_aggregator, ) @@ -342,7 +342,7 @@ def __imul__(self, rhs: "LinearExpression") -> "LinearExpression": term.structure, term.time_operator, term.time_aggregator, - term.scenario_operator, + term.scenario_aggregator, ) _copy_expression(left_expr, self) return self @@ -374,7 +374,7 @@ def __itruediv__(self, rhs: "LinearExpression") -> "LinearExpression": term.structure, term.time_operator, term.time_aggregator, - term.scenario_operator, + term.scenario_aggregator, ) return self diff --git a/src/andromede/simulation/linear_expression_resolver.py b/src/andromede/simulation/linear_expression_resolver.py new file mode 100644 index 00000000..bc4c1074 --- /dev/null +++ b/src/andromede/simulation/linear_expression_resolver.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from dataclasses import dataclass +from typing import Dict, List + +import ortools.linear_solver.pywraplp as lp + +from andromede.expression.evaluate_parameters_efficient import ( + get_time_ids_from_instances_index, + resolve_coefficient, +) +from andromede.expression.indexing_structure import RowIndex +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + TermEfficient, +) +from andromede.expression.resolved_linear_expression import ( + ResolvedLinearExpression, + ResolvedTerm, +) +from andromede.expression.time_operator import TimeShift +from andromede.expression.value_provider import ( + TimeScenarioIndex, + TimeScenarioIndices, + ValueProvider, +) +from andromede.simulation.optimization_context import OptimizationContext + + +@dataclass +class LinearExpressionResolver: + context: OptimizationContext + value_provider: ValueProvider + + def resolve( + self, expression: LinearExpressionEfficient, row_id: RowIndex + ) -> ResolvedLinearExpression: + resolved_terms = [] + for term in expression.terms.values(): + # Here, the value provide is used only to evaluate possible time operator args if the term has one + resolved_variables = self.resolve_variables(term, row_id) + + for ts_id, lp_variable in resolved_variables.items(): + # TODO: Where is key going to play a role ? + resolved_coeff = resolve_coefficient( + term.coefficient, self.value_provider, row_id + ) + resolved_terms.append(ResolvedTerm(resolved_coeff, lp_variable)) + + resolved_constant = resolve_coefficient( + expression.constant, self.value_provider, row_id + ) + return ResolvedLinearExpression(resolved_terms, resolved_constant) + + def resolve_constant_expr( + self, expression: LinearExpressionEfficient, row_id: RowIndex + ) -> float: + if not expression.is_constant(): + raise ValueError(f"{str(self)} is not a constant expression") + return resolve_coefficient(expression.constant, self.value_provider, row_id) + + def resolve_variables( + self, term: TermEfficient, row_id: RowIndex + ) -> Dict[TimeScenarioIndex, lp.Variable]: + solver_vars = {} + operator_ts_ids = self._row_id_to_term_time_scenario_id(term, row_id) + for time in operator_ts_ids.time_indices: + for scenario in operator_ts_ids.scenario_indices: + solver_vars[ + TimeScenarioIndex(time, scenario) + ] = self.context.get_component_variable( + time, + scenario, + term.component_id, + term.variable_name, + term.structure, + ) + return solver_vars + + def _row_id_to_term_time_scenario_id( + self, term: TermEfficient, row_id: RowIndex + ) -> TimeScenarioIndices: + operator_time_ids = self._compute_operator_time_ids(term, row_id) + + operator_scenario_ids = self._compute_operator_scenario_ids(term, row_id) + return TimeScenarioIndices(operator_time_ids, operator_scenario_ids) + + def _compute_operator_scenario_ids( + self, term: TermEfficient, row_id: RowIndex + ) -> List[int]: + if term.scenario_aggregator: + operator_scenario_ids = list(range(self.context.scenarios)) + else: + operator_scenario_ids = [row_id.scenario] + return operator_scenario_ids + + def _compute_operator_time_ids( + self, term: TermEfficient, row_id: RowIndex + ) -> List[int]: + if not term.time_operator and not term.time_aggregator: + operator_time_ids = [row_id.time] + elif term.time_operator: + operator_time_ids = get_time_ids_from_instances_index( + term.time_operator.time_ids, self.value_provider, row_id + ) + if isinstance(term.time_operator, TimeShift): + operator_time_ids = [ + row_id.time + time_id for time_id in operator_time_ids + ] + else: # Case time_aggregator but no time_operator i.e. sum over whole block + operator_time_ids = list(range(self.context.block_length())) + return operator_time_ids diff --git a/src/andromede/simulation/linearize.py b/src/andromede/simulation/linearize.py index 9fe6738a..dc3cb2ef 100644 --- a/src/andromede/simulation/linearize.py +++ b/src/andromede/simulation/linearize.py @@ -130,7 +130,7 @@ def scenario_operator(self, node: ScenarioOperatorNode) -> LinearExpression: result_terms = {} for term in operand_expr.terms.values(): term_with_operator = dataclasses.replace( - term, scenario_operator=scenario_operator_cls() + term, scenario_aggregator=scenario_operator_cls() ) result_terms[generate_key(term_with_operator)] = term_with_operator diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index a9db5226..011b27cd 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -16,439 +16,36 @@ """ import math -from abc import ABC, abstractmethod from dataclasses import dataclass -from enum import Enum -from typing import Dict, Iterable, List, Optional +from typing import List, Optional import ortools.linear_solver.pywraplp as lp -from andromede.expression import ( # ExpressionNode, - ParameterValueProvider, - resolve_parameters, -) -from andromede.expression.evaluate_parameters_efficient import ValueProvider from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, - PortFieldKey, RowIndex, ) from andromede.expression.resolved_linear_expression import ResolvedLinearExpression from andromede.expression.scenario_operator import Expectation from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum -from andromede.expression.value_provider import TimeScenarioIndex, TimeScenarioIndices from andromede.model.common import ValueType from andromede.model.constraint import Constraint from andromede.model.model import PortFieldId -from andromede.simulation.linear_expression import LinearExpression, Term -from andromede.simulation.linearize import linearize_expression +from andromede.simulation.linear_expression import Term +from andromede.simulation.linear_expression_resolver import LinearExpressionResolver +from andromede.simulation.optimization_context import ( + BlockBorderManagement, + ComponentContext, + OptimizationContext, + _make_data_structure_provider, + _make_value_provider, +) from andromede.simulation.strategy import MergedProblemStrategy, ModelSelectionStrategy from andromede.simulation.time_block import TimeBlock from andromede.study.data import DataBase from andromede.study.network import Component, Network -from andromede.utils import get_or_add - - -@dataclass(eq=True, frozen=True) -class TimestepComponentVariableKey: - """ - Identifies the solver variable for one timestep and one component variable. - """ - - component_id: str - variable_name: str - block_timestep: Optional[int] = None - scenario: Optional[int] = None - - -def _get_parameter_value( - context: "OptimizationContext", - block_timestep: int, - scenario: int, - component_id: str, - name: str, -) -> float: - data = context.database.get_data(component_id, name) - absolute_timestep = context.block_timestep_to_absolute_timestep(block_timestep) - return data.get_value(absolute_timestep, scenario) - - -class TimestepValueProvider(ABC): - """ - Interface which provides numerical values for individual timesteps. - """ - - @abstractmethod - def get_value(self, block_timestep: int, scenario: int) -> float: - raise NotImplementedError() - - -def _get_data_time_key(block_timestep: int, data_indexing: IndexingStructure) -> int: - return block_timestep if data_indexing.time else 0 - - -def _get_data_scenario_key(scenario: int, data_indexing: IndexingStructure) -> int: - return scenario if data_indexing.scenario else 0 - - -def _make_value_provider( - context: "OptimizationContext", - component: Component, -) -> ValueProvider: - """ - Create a value provider which takes its values from - the parameter values as defined in the network data. - - Cannot evaluate expressions which contain variables. - """ - - class Provider(ValueProvider): - def get_component_variable_value( - self, - component_id: str, - name: str, - time_scenarios_indices: TimeScenarioIndices, - ) -> Dict[TimeScenarioIndex, float]: - raise NotImplementedError( - "Cannot provide variable value at problem build time." - ) - - def get_component_parameter_value( - self, - component_id: str, - name: str, - time_scenarios_indices: TimeScenarioIndices, - ) -> Dict[TimeScenarioIndex, float]: - result = {} - param_index = ( - context.network.get_component(component_id) - .model.parameters[name] - .structure - ) - for block_timestep in time_scenarios_indices.time_indices: - for scenario in time_scenarios_indices.scenario_indices: - result[TimeScenarioIndex(block_timestep, scenario)] = ( - _get_parameter_value( - context, - _get_data_time_key(block_timestep, param_index), - _get_data_scenario_key(scenario, param_index), - component_id, - name, - ) - ) - return result - - def get_variable_value( - self, - name: str, - time_scenarios_indices: TimeScenarioIndices, - ) -> Dict[TimeScenarioIndex, float]: - raise NotImplementedError( - "Cannot provide variable value at problem build time." - ) - - def get_parameter_value( - self, - name: str, - time_scenarios_indices: TimeScenarioIndices, - ) -> Dict[TimeScenarioIndex, float]: - raise ValueError( - "Parameter must be associated to its component before resolution." - ) - - def parameter_is_constant_over_time(self, name: str) -> bool: - return not component.model.parameters[name].structure.time - - @staticmethod - def block_length() -> int: - return context.block_length() - - @staticmethod - def scenarios() -> int: - return context.scenarios - - return Provider() - - -@dataclass(frozen=True) -class ExpressionTimestepValueProvider(TimestepValueProvider): - context: "OptimizationContext" - component: Component - expression: LinearExpressionEfficient - - # OptimizationContext has knowledge of the block, so that get_value only needs block_timestep and scenario to get the correct data value - - def get_value(self, block_timestep: int, scenario: int) -> float: - param_value_provider = _make_value_provider( - self.context, block_timestep, scenario, self.component - ) - return self.expression.evaluate(param_value_provider) - - -def _make_parameter_value_provider( - context: "OptimizationContext", - block_timestep: int, - scenario: int, -) -> ParameterValueProvider: - """ - A value provider which takes its values from - the parameter values as defined in the network data. - - Cannot evaluate expressions which contain variables. - """ - - class Provider(ParameterValueProvider): - def get_component_parameter_value(self, component_id: str, name: str) -> float: - return _get_parameter_value( - context, block_timestep, scenario, component_id, name - ) - - def get_parameter_value(self, name: str) -> float: - raise ValueError( - "Parameters should have been associated with their component before resolution." - ) - - return Provider() - - -def _make_data_structure_provider( - network: Network, component: Component -) -> IndexingStructureProvider: - """ - Retrieve information in data structure (parameter and variable) from the model - """ - - class Provider(IndexingStructureProvider): - def get_component_variable_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - return network.get_component(component_id).model.variables[name].structure - - def get_component_parameter_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - return network.get_component(component_id).model.parameters[name].structure - - def get_parameter_structure(self, name: str) -> IndexingStructure: - return component.model.parameters[name].structure - - def get_variable_structure(self, name: str) -> IndexingStructure: - return component.model.variables[name].structure - - return Provider() - - -@dataclass(frozen=True) -class ComponentContext: - """ - Helper class to fill the optimization problem with component-related equations and variables. - """ - - opt_context: "OptimizationContext" - component: Component - - def get_values( - self, expression: LinearExpressionEfficient - ) -> TimestepValueProvider: - """ - The returned value provider will evaluate the provided expression. - """ - return ExpressionTimestepValueProvider( - self.opt_context, self.component, expression - ) - - def add_variable( - self, - block_timestep: int, - scenario: int, - model_var_name: str, - variable: lp.Variable, - ) -> None: - self.opt_context.register_component_variable( - block_timestep, scenario, self.component.id, model_var_name, variable - ) - - def get_variable( - self, block_timestep: int, scenario: int, variable_name: str - ) -> lp.Variable: - return self.opt_context.get_component_variable( - block_timestep, - scenario, - self.component.id, - variable_name, - self.component.model.variables[variable_name].structure, - ) - - def linearize_expression( - self, - block_timestep: int, - scenario: int, - expression: LinearExpressionEfficient, - ) -> LinearExpression: - parameters_valued_provider = _make_parameter_value_provider( - self.opt_context, block_timestep, scenario - ) - evaluated_expr = resolve_parameters(expression, parameters_valued_provider) - - value_provider = _make_value_provider( - self.opt_context, block_timestep, scenario, self.component - ) - structure_provider = _make_data_structure_provider( - self.opt_context.network, self.component - ) - - return linearize_expression(evaluated_expr, structure_provider, value_provider) - - -class BlockBorderManagement(Enum): - """ - Class to specify the way of handling the time horizon (or time block) border. - - IGNORE_OUT_OF_FRAME: Ignore terms in constraints that lead to out of horizon data - - CYCLE: Consider all timesteps to be specified modulo the horizon length, this is the actual functioning of Antares - """ - - IGNORE_OUT_OF_FRAME = "IGNORE" - CYCLE = "CYCLE" - - -@dataclass -class SolverVariableInfo: - """ - Helper class for constructing the structure file - for Benders solver. It keeps track of the corresponding - column of the variable in the MPS format as well as if it is - present in the objective function or not - """ - - name: str - column_id: int - is_in_objective: bool - - -class OptimizationContext: - """ - Helper class to build the optimization problem. - Maintains some mappings between model and solver objects. - Also provides navigation method in the model (components by node ...). - """ - - def __init__( - self, - network: Network, - database: DataBase, - block: TimeBlock, - scenarios: int, - border_management: BlockBorderManagement, - ): - self._network = network - self._database = database - self._block = block - self._scenarios = scenarios - self._border_management = border_management - self._component_variables: Dict[TimestepComponentVariableKey, lp.Variable] = {} - self._solver_variables: Dict[lp.Variable, SolverVariableInfo] = {} - self._connection_fields_expressions: Dict[ - PortFieldKey, List[LinearExpressionEfficient] - ] = {} - - @property - def network(self) -> Network: - return self._network - - @property - def scenarios(self) -> int: - return self._scenarios - - def block_length(self) -> int: - return len(self._block.timesteps) - - @property - def connection_fields_expressions( - self, - ) -> Dict[PortFieldKey, List[LinearExpressionEfficient]]: - return self._connection_fields_expressions - - # TODO: Need to think about data processing when creating blocks with varying or inequal time steps length (aggregation, sum ?, mean of data ?) - def block_timestep_to_absolute_timestep(self, block_timestep: int) -> int: - return self._block.timesteps[block_timestep] - - @property - def database(self) -> DataBase: - return self._database - - def _manage_border_timesteps(self, timestep: int) -> int: - if self._border_management == BlockBorderManagement.CYCLE: - return timestep % self.block_length() - else: - raise NotImplementedError - - def get_time_indices(self, index_structure: IndexingStructure) -> Iterable[int]: - return range(self.block_length()) if index_structure.time else range(1) - - def get_scenario_indices(self, index_structure: IndexingStructure) -> Iterable[int]: - return range(self.scenarios) if index_structure.scenario else range(1) - - # TODO: API to improve, variable_structure guides which of the indices block_timestep and scenario should be used - def get_component_variable( - self, - block_timestep: int, - scenario: int, - component_id: str, - variable_name: str, - variable_structure: IndexingStructure, - ) -> lp.Variable: - block_timestep = self._manage_border_timesteps(block_timestep) - - # TODO: Improve design, variable_structure defines indexing - if variable_structure.time == False: - block_timestep = 0 - if variable_structure.scenario == False: - scenario = 0 - - return self._component_variables[ - TimestepComponentVariableKey( - component_id, variable_name, block_timestep, scenario - ) - ] - - def get_all_component_variables( - self, - ) -> Dict[TimestepComponentVariableKey, lp.Variable]: - return self._component_variables - - def register_component_variable( - self, - block_timestep: int, - scenario: int, - component_id: str, - variable_name: str, - variable: lp.Variable, - ) -> None: - key = TimestepComponentVariableKey( - component_id, variable_name, block_timestep, scenario - ) - if key not in self._component_variables: - self._solver_variables[variable] = SolverVariableInfo( - variable.name(), len(self._solver_variables), False - ) - self._component_variables[key] = variable - - def get_component_context(self, component: Component) -> ComponentContext: - return ComponentContext(self, component) - - def register_connection_fields_expressions( - self, - component_id: str, - port_name: str, - field_name: str, - expression: LinearExpressionEfficient, - ) -> None: - key = PortFieldKey(component_id, PortFieldId(port_name, field_name)) - get_or_add(self._connection_fields_expressions, key, lambda: []).append( - expression - ) def _get_indexing( @@ -505,6 +102,7 @@ def _create_constraint( # instances_per_time_step = 1 value_provider = _make_value_provider(context.opt_context, context.component) + expression_resolver = LinearExpressionResolver(context.opt_context, value_provider) for block_timestep in context.opt_context.get_time_indices(constraint_indexing): for scenario in context.opt_context.get_scenario_indices(constraint_indexing): @@ -513,14 +111,12 @@ def _create_constraint( # ) row_id = RowIndex(block_timestep, scenario) - resolved_expr = constraint.expression.resolve_coefficient( - value_provider, row_id - ) - resolved_lb = constraint.lower_bound.resolve_constant_expr( - value_provider, row_id + resolved_expr = expression_resolver.resolve(constraint.expression, row_id) + resolved_lb = expression_resolver.resolve_constant_expr( + constraint.lower_bound, row_id ) - resolved_ub = constraint.upper_bound.resolve_constant_expr( - value_provider, row_id + resolved_ub = expression_resolver.resolve_constant_expr( + constraint.upper_bound, row_id ) # What happens if there is some time_operator in the bounds ? -> Pb réglé avec le nouveau design ! @@ -549,7 +145,7 @@ def _create_objective( obj: lp.Objective = solver.Objective() for term in linear_expr.terms.values(): # TODO : How to handle the scenario operator in a general manner ? - if isinstance(term.scenario_operator, Expectation): + if isinstance(term.scenario_aggregator, Expectation): weight = 1 / opt_context.scenarios scenario_ids = range(opt_context.scenarios) else: @@ -764,6 +360,7 @@ def _create_variables(self) -> None: model = component.model value_provider = _make_value_provider(self.context, component) + expression_resolver = LinearExpressionResolver(self.context, value_provider) for model_var in self.strategy.get_variables(model): var_indexing = model_var.structure @@ -794,19 +391,13 @@ def _create_variables(self) -> None: lower_bound = -self.solver.infinity() upper_bound = self.solver.infinity() if instantiated_lb_expr: - lower_bound = instantiated_lb_expr.resolve_constant_expr( - value_provider, RowIndex(block_timestep, scenario) + lower_bound = expression_resolver.resolve_constant_expr( + instantiated_lb_expr, RowIndex(block_timestep, scenario) ) - # lower_bound = component_context.get_values( - # instantiated_lb_expr - # ).get_value(block_timestep, scenario) if instantiated_ub_expr: - upper_bound = instantiated_ub_expr.resolve_constant_expr( - value_provider, RowIndex(block_timestep, scenario) + upper_bound = expression_resolver.resolve_constant_expr( + instantiated_ub_expr, RowIndex(block_timestep, scenario) ) - # upper_bound = component_context.get_values( - # instantiated_ub_expr - # ).get_value(block_timestep, scenario) scenario_suffix = ( f"_s{scenario}" diff --git a/src/andromede/simulation/optimization_context.py b/src/andromede/simulation/optimization_context.py new file mode 100644 index 00000000..03b11ccc --- /dev/null +++ b/src/andromede/simulation/optimization_context.py @@ -0,0 +1,435 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# +# See AUTHORS.txt +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# SPDX-License-Identifier: MPL-2.0 +# +# This file is part of the Antares project. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Iterable, List, Optional + +import ortools.linear_solver.pywraplp as lp + +from andromede.expression import ParameterValueProvider, resolve_parameters +from andromede.expression.evaluate_parameters_efficient import ValueProvider +from andromede.expression.indexing import IndexingStructureProvider +from andromede.expression.indexing_structure import IndexingStructure +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + PortFieldId, + PortFieldKey, +) +from andromede.expression.value_provider import TimeScenarioIndex, TimeScenarioIndices +from andromede.simulation.linear_expression import LinearExpression +from andromede.simulation.linearize import linearize_expression +from andromede.simulation.time_block import TimeBlock +from andromede.study.data import DataBase +from andromede.study.network import Component, Network +from andromede.utils import get_or_add + + +@dataclass(eq=True, frozen=True) +class TimestepComponentVariableKey: + """ + Identifies the solver variable for one timestep and one component variable. + """ + + component_id: str + variable_name: str + block_timestep: Optional[int] = None + scenario: Optional[int] = None + + +@dataclass +class SolverVariableInfo: + """ + Helper class for constructing the structure file + for Benders solver. It keeps track of the corresponding + column of the variable in the MPS format as well as if it is + present in the objective function or not + """ + + name: str + column_id: int + is_in_objective: bool + + +class BlockBorderManagement(Enum): + """ + Class to specify the way of handling the time horizon (or time block) border. + - IGNORE_OUT_OF_FRAME: Ignore terms in constraints that lead to out of horizon data + - CYCLE: Consider all timesteps to be specified modulo the horizon length, this is the actual functioning of Antares + """ + + IGNORE_OUT_OF_FRAME = "IGNORE" + CYCLE = "CYCLE" + + +class OptimizationContext: + """ + Helper class to build the optimization problem. + Maintains some mappings between model and solver objects. + Also provides navigation method in the model (components by node ...). + """ + + def __init__( + self, + network: Network, + database: DataBase, + block: TimeBlock, + scenarios: int, + border_management: BlockBorderManagement, + ): + self._network = network + self._database = database + self._block = block + self._scenarios = scenarios + self._border_management = border_management + self._component_variables: Dict[TimestepComponentVariableKey, lp.Variable] = {} + self._solver_variables: Dict[lp.Variable, SolverVariableInfo] = {} + self._connection_fields_expressions: Dict[ + PortFieldKey, List[LinearExpressionEfficient] + ] = {} + + @property + def network(self) -> Network: + return self._network + + @property + def scenarios(self) -> int: + return self._scenarios + + def block_length(self) -> int: + return len(self._block.timesteps) + + @property + def connection_fields_expressions( + self, + ) -> Dict[PortFieldKey, List[LinearExpressionEfficient]]: + return self._connection_fields_expressions + + # TODO: Need to think about data processing when creating blocks with varying or inequal time steps length (aggregation, sum ?, mean of data ?) + def block_timestep_to_absolute_timestep(self, block_timestep: int) -> int: + return self._block.timesteps[block_timestep] + + @property + def database(self) -> DataBase: + return self._database + + def _manage_border_timesteps(self, timestep: int) -> int: + if self._border_management == BlockBorderManagement.CYCLE: + return timestep % self.block_length() + else: + raise NotImplementedError + + def get_time_indices(self, index_structure: IndexingStructure) -> Iterable[int]: + return range(self.block_length()) if index_structure.time else range(1) + + def get_scenario_indices(self, index_structure: IndexingStructure) -> Iterable[int]: + return range(self.scenarios) if index_structure.scenario else range(1) + + # TODO: API to improve, variable_structure guides which of the indices block_timestep and scenario should be used + def get_component_variable( + self, + block_timestep: int, + scenario: int, + component_id: str, + variable_name: str, + variable_structure: IndexingStructure, + ) -> lp.Variable: + block_timestep = self._manage_border_timesteps(block_timestep) + + # TODO: Improve design, variable_structure defines indexing + if variable_structure.time == False: + block_timestep = 0 + if variable_structure.scenario == False: + scenario = 0 + + return self._component_variables[ + TimestepComponentVariableKey( + component_id, variable_name, block_timestep, scenario + ) + ] + + def get_all_component_variables( + self, + ) -> Dict[TimestepComponentVariableKey, lp.Variable]: + return self._component_variables + + def register_component_variable( + self, + block_timestep: int, + scenario: int, + component_id: str, + variable_name: str, + variable: lp.Variable, + ) -> None: + key = TimestepComponentVariableKey( + component_id, variable_name, block_timestep, scenario + ) + if key not in self._component_variables: + self._solver_variables[variable] = SolverVariableInfo( + variable.name(), len(self._solver_variables), False + ) + self._component_variables[key] = variable + + def get_component_context(self, component: Component) -> "ComponentContext": + return ComponentContext(self, component) + + def register_connection_fields_expressions( + self, + component_id: str, + port_name: str, + field_name: str, + expression: LinearExpressionEfficient, + ) -> None: + key = PortFieldKey(component_id, PortFieldId(port_name, field_name)) + get_or_add(self._connection_fields_expressions, key, lambda: []).append( + expression + ) + + +class TimestepValueProvider(ABC): + """ + Interface which provides numerical values for individual timesteps. + """ + + @abstractmethod + def get_value(self, block_timestep: int, scenario: int) -> float: + raise NotImplementedError() + + +def _get_parameter_value( + context: OptimizationContext, + block_timestep: int, + scenario: int, + component_id: str, + name: str, +) -> float: + data = context.database.get_data(component_id, name) + absolute_timestep = context.block_timestep_to_absolute_timestep(block_timestep) + return data.get_value(absolute_timestep, scenario) + + +def _make_value_provider( + context: "OptimizationContext", + component: Component, +) -> ValueProvider: + """ + Create a value provider which takes its values from + the parameter values as defined in the network data. + + Cannot evaluate expressions which contain variables. + """ + + class Provider(ValueProvider): + def get_component_variable_value( + self, + component_id: str, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError( + "Cannot provide variable value at problem build time." + ) + + def get_component_parameter_value( + self, + component_id: str, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + result = {} + param_index = ( + context.network.get_component(component_id) + .model.parameters[name] + .structure + ) + for block_timestep in time_scenarios_indices.time_indices: + for scenario in time_scenarios_indices.scenario_indices: + result[ + TimeScenarioIndex(block_timestep, scenario) + ] = _get_parameter_value( + context, + _get_data_time_key(block_timestep, param_index), + _get_data_scenario_key(scenario, param_index), + component_id, + name, + ) + return result + + def get_variable_value( + self, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise NotImplementedError( + "Cannot provide variable value at problem build time." + ) + + def get_parameter_value( + self, + name: str, + time_scenarios_indices: TimeScenarioIndices, + ) -> Dict[TimeScenarioIndex, float]: + raise ValueError( + "Parameter must be associated to its component before resolution." + ) + + def parameter_is_constant_over_time(self, name: str) -> bool: + return not component.model.parameters[name].structure.time + + @staticmethod + def block_length() -> int: + return context.block_length() + + @staticmethod + def scenarios() -> int: + return context.scenarios + + return Provider() + + +@dataclass(frozen=True) +class ExpressionTimestepValueProvider(TimestepValueProvider): + context: "OptimizationContext" + component: Component + expression: LinearExpressionEfficient + + # OptimizationContext has knowledge of the block, so that get_value only needs block_timestep and scenario to get the correct data value + + def get_value(self, block_timestep: int, scenario: int) -> float: + param_value_provider = _make_value_provider( + self.context, block_timestep, scenario, self.component + ) + return self.expression.evaluate(param_value_provider) + + +def _make_parameter_value_provider( + context: "OptimizationContext", + block_timestep: int, + scenario: int, +) -> ParameterValueProvider: + """ + A value provider which takes its values from + the parameter values as defined in the network data. + + Cannot evaluate expressions which contain variables. + """ + + class Provider(ParameterValueProvider): + def get_component_parameter_value(self, component_id: str, name: str) -> float: + return _get_parameter_value( + context, block_timestep, scenario, component_id, name + ) + + def get_parameter_value(self, name: str) -> float: + raise ValueError( + "Parameters should have been associated with their component before resolution." + ) + + return Provider() + + +def _make_data_structure_provider( + network: Network, component: Component +) -> IndexingStructureProvider: + """ + Retrieve information in data structure (parameter and variable) from the model + """ + + class Provider(IndexingStructureProvider): + def get_component_variable_structure( + self, component_id: str, name: str + ) -> IndexingStructure: + return network.get_component(component_id).model.variables[name].structure + + def get_component_parameter_structure( + self, component_id: str, name: str + ) -> IndexingStructure: + return network.get_component(component_id).model.parameters[name].structure + + def get_parameter_structure(self, name: str) -> IndexingStructure: + return component.model.parameters[name].structure + + def get_variable_structure(self, name: str) -> IndexingStructure: + return component.model.variables[name].structure + + return Provider() + + +@dataclass(frozen=True) +class ComponentContext: + """ + Helper class to fill the optimization problem with component-related equations and variables. + """ + + opt_context: OptimizationContext + component: Component + + def get_values( + self, expression: LinearExpressionEfficient + ) -> TimestepValueProvider: + """ + The returned value provider will evaluate the provided expression. + """ + return ExpressionTimestepValueProvider( + self.opt_context, self.component, expression + ) + + def add_variable( + self, + block_timestep: int, + scenario: int, + model_var_name: str, + variable: lp.Variable, + ) -> None: + self.opt_context.register_component_variable( + block_timestep, scenario, self.component.id, model_var_name, variable + ) + + def get_variable( + self, block_timestep: int, scenario: int, variable_name: str + ) -> lp.Variable: + return self.opt_context.get_component_variable( + block_timestep, + scenario, + self.component.id, + variable_name, + self.component.model.variables[variable_name].structure, + ) + + def linearize_expression( + self, + block_timestep: int, + scenario: int, + expression: LinearExpressionEfficient, + ) -> LinearExpression: + parameters_valued_provider = _make_parameter_value_provider( + self.opt_context, block_timestep, scenario + ) + evaluated_expr = resolve_parameters(expression, parameters_valued_provider) + + value_provider = _make_value_provider( + self.opt_context, block_timestep, scenario, self.component + ) + structure_provider = _make_data_structure_provider( + self.opt_context.network, self.component + ) + + return linearize_expression(evaluated_expr, structure_provider, value_provider) + + +def _get_data_time_key(block_timestep: int, data_indexing: IndexingStructure) -> int: + return block_timestep if data_indexing.time else 0 + + +def _get_data_scenario_key(scenario: int, data_indexing: IndexingStructure) -> int: + return scenario if data_indexing.scenario else 0 diff --git a/tests/functional/test_andromede.py b/tests/functional/test_andromede.py index 38a72354..099c7706 100644 --- a/tests/functional/test_andromede.py +++ b/tests/functional/test_andromede.py @@ -27,11 +27,8 @@ ) from andromede.model import Model, ModelPort, float_parameter, float_variable, model from andromede.model.model import PortFieldDefinition, PortFieldId -from andromede.simulation import ( - BlockBorderManagement, - TimeBlock, - build_problem, -) +from andromede.simulation import TimeBlock, build_problem +from andromede.simulation.optimization_context import BlockBorderManagement from andromede.study import ( ConstantData, DataBase, diff --git a/tests/functional/test_andromede_yml.py b/tests/functional/test_andromede_yml.py index 36dc2e13..c9326d65 100644 --- a/tests/functional/test_andromede_yml.py +++ b/tests/functional/test_andromede_yml.py @@ -2,12 +2,8 @@ import pytest from andromede.model.library import Library -from andromede.simulation import ( - BlockBorderManagement, - OutputValues, - TimeBlock, - build_problem, -) +from andromede.simulation import OutputValues, TimeBlock, build_problem +from andromede.simulation.optimization_context import BlockBorderManagement from andromede.study import ( ConstantData, DataBase, diff --git a/tests/functional/test_performance_efficient.py b/tests/functional/test_performance_efficient.py index 6b0a1b4a..55ef9913 100644 --- a/tests/functional/test_performance_efficient.py +++ b/tests/functional/test_performance_efficient.py @@ -15,6 +15,7 @@ from andromede.expression.evaluate import EvaluationContext from andromede.expression.expression_efficient import param +from andromede.expression.indexing_structure import RowIndex from andromede.expression.linear_expression_efficient import ( literal, var, @@ -37,7 +38,7 @@ def test_large_number_of_parameters_sum() -> None: # Still the recursion depth error with parameters with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): expr = sum(wrap_in_linear_expr(param(f"cost_{i}")) for i in range(1, nb_terms)) - expr.evaluate(EvaluationContext(parameters=parameters_value)) + expr.evaluate(EvaluationContext(parameters=parameters_value), RowIndex(0, 0)) def test_large_number_of_identical_parameters_sum() -> None: @@ -51,7 +52,10 @@ def test_large_number_of_identical_parameters_sum() -> None: # Still the recursion depth error with parameters # with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): expr = sum(wrap_in_linear_expr(param("cost")) for _ in range(nb_terms)) - assert expr.evaluate(EvaluationContext(parameters=parameters_value)) == nb_terms + assert ( + expr.evaluate(EvaluationContext(parameters=parameters_value), RowIndex(0, 0)) + == nb_terms + ) def test_large_number_of_literal_sum() -> None: @@ -63,7 +67,7 @@ def test_large_number_of_literal_sum() -> None: # # Still the recursion depth error with parameters # with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): expr = sum(wrap_in_linear_expr(literal(1)) for _ in range(nb_terms)) - assert expr.evaluate(EvaluationContext()) == nb_terms + assert expr.evaluate(EvaluationContext(), RowIndex(0, 0)) == nb_terms def test_large_number_of_variables_sum() -> None: @@ -77,6 +81,6 @@ def test_large_number_of_variables_sum() -> None: variables_value[f"cost_{i}"] = 1 / i expr = sum(var(f"cost_{i}") for i in range(1, nb_terms)) - assert expr.evaluate(EvaluationContext(variables=variables_value)) == sum( - 1 / i for i in range(1, nb_terms) - ) + assert expr.evaluate( + EvaluationContext(variables=variables_value), RowIndex(0, 0) + ) == sum(1 / i for i in range(1, nb_terms)) diff --git a/tests/models/test_short_term_storage_complex.py b/tests/models/test_short_term_storage_complex.py index 0d23c050..ef28c7e3 100644 --- a/tests/models/test_short_term_storage_complex.py +++ b/tests/models/test_short_term_storage_complex.py @@ -14,7 +14,8 @@ UNSUPPLIED_ENERGY_MODEL, ) from andromede.libs.standard_sc import SHORT_TERM_STORAGE_COMPLEX -from andromede.simulation import BlockBorderManagement, TimeBlock, build_problem +from andromede.simulation import TimeBlock, build_problem +from andromede.simulation.optimization_context import BlockBorderManagement from andromede.study import ( ConstantData, DataBase, diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index 75fba9fe..36d7075a 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -411,7 +411,7 @@ def test_comparison() -> None: time_aggregator=TimeSum( stay_roll=True ), # The internal representation of shift(1) is sum(shift=1) - scenario_operator=None, + scenario_aggregator=None, ): TermEfficient( TimeOperatorNode( LiteralNode(1), TimeOperatorName.SHIFT, InstancesTimeIndex(1) @@ -430,7 +430,7 @@ def test_comparison() -> None: InstancesTimeIndex(1), ), time_aggregator=TimeSum(stay_roll=True), - scenario_operator=None, + scenario_aggregator=None, ): TermEfficient( TimeOperatorNode( LiteralNode(1), TimeOperatorName.SHIFT, InstancesTimeIndex(1) @@ -459,7 +459,7 @@ def test_comparison() -> None: time_aggregator=TimeSum( stay_roll=True ), # The internal representation of eval(1) is sum(eval=1) - scenario_operator=None, + scenario_aggregator=None, ): TermEfficient( TimeOperatorNode( LiteralNode(1), @@ -480,7 +480,7 @@ def test_comparison() -> None: InstancesTimeIndex(1), ), time_aggregator=TimeSum(stay_roll=True), - scenario_operator=None, + scenario_aggregator=None, ): TermEfficient( TimeOperatorNode( LiteralNode(1), @@ -509,7 +509,7 @@ def test_comparison() -> None: "x", time_operator=None, time_aggregator=TimeSum(stay_roll=False), - scenario_operator=None, + scenario_aggregator=None, ): TermEfficient( LiteralNode(1), # Sum is not distributed to coeff "", @@ -522,7 +522,7 @@ def test_comparison() -> None: "y", time_operator=None, time_aggregator=TimeSum(stay_roll=False), - scenario_operator=None, + scenario_aggregator=None, ): TermEfficient( LiteralNode(1), # Sum is not distributed to coeff "", diff --git a/tests/unittests/expressions/test_linear_expressions.py b/tests/unittests/expressions/test_linear_expressions.py index dcedc28d..21f916a7 100644 --- a/tests/unittests/expressions/test_linear_expressions.py +++ b/tests/unittests/expressions/test_linear_expressions.py @@ -38,14 +38,14 @@ ), "-3x.shift([2, 3]).sum(False)", ), - (Term(-3, "c", "x", scenario_operator=Expectation()), "-3x.expec()"), + (Term(-3, "c", "x", scenario_aggregator=Expectation()), "-3x.expec()"), ( Term( -3, "c", "x", time_aggregator=TimeSum(True), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ), "-3x.sum(True).expec()", ), @@ -151,7 +151,7 @@ def test_instantiate_linear_expression_from_dict( "c", "x", time_operator=TimeShift(-1), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ] ), @@ -163,7 +163,7 @@ def test_instantiate_linear_expression_from_dict( "c", "x", time_operator=TimeShift(-1), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ), ] ), @@ -217,7 +217,7 @@ def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_fr "c", "x", time_operator=TimeShift(-1), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 3, @@ -230,7 +230,7 @@ def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_fr "c", "x", time_operator=TimeShift(-1), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 6, @@ -269,7 +269,7 @@ def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 5, @@ -282,7 +282,7 @@ def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], -5, @@ -343,7 +343,7 @@ def test_negation(e1: LinearExpression, expected: LinearExpression) -> None: "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ] ), @@ -356,7 +356,7 @@ def test_negation(e1: LinearExpression, expected: LinearExpression) -> None: "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ), ] ), @@ -391,7 +391,7 @@ def test_substraction( "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 15, @@ -405,7 +405,7 @@ def test_substraction( "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 3, diff --git a/tests/unittests/expressions/test_linear_expressions_efficient.py b/tests/unittests/expressions/test_linear_expressions_efficient.py index 36fdad96..4d703500 100644 --- a/tests/unittests/expressions/test_linear_expressions_efficient.py +++ b/tests/unittests/expressions/test_linear_expressions_efficient.py @@ -192,7 +192,7 @@ def test_is_zero(expr: LinearExpressionEfficient, expected: bool) -> None: "c", "x", time_operator=TimeShift(-1), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ] ), @@ -204,7 +204,7 @@ def test_is_zero(expr: LinearExpressionEfficient, expected: bool) -> None: "c", "x", time_operator=TimeShift(-1), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ), ] ), @@ -254,7 +254,7 @@ def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_fr "c", "x", time_operator=TimeShift(-1), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 3, @@ -267,7 +267,7 @@ def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_fr "c", "x", time_operator=TimeShift(-1), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 6, @@ -308,7 +308,7 @@ def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 5, @@ -321,7 +321,7 @@ def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], -5, @@ -397,7 +397,7 @@ def test_negation( "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ] ), @@ -410,7 +410,7 @@ def test_negation( "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ), ] ), @@ -447,7 +447,7 @@ def test_substraction( "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 15, @@ -461,7 +461,7 @@ def test_substraction( "x", time_operator=TimeShift(-1), time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ) ], 3, diff --git a/tests/unittests/expressions/test_resolve_coefficients.py b/tests/unittests/expressions/test_resolve_coefficients.py index 914703a7..9eb31619 100644 --- a/tests/unittests/expressions/test_resolve_coefficients.py +++ b/tests/unittests/expressions/test_resolve_coefficients.py @@ -236,7 +236,6 @@ def test_resolve_coefficient_on_expression_with_shift_but_without_sum_raises_val def test_resolve_coefficient_with_no_time_varying_parameter_in_time_operator_argument_raises_value_error( expr: ExpressionNodeEfficient, ) -> None: - class TimeVaryingParameterValueProvider(CustomValueProvider): def parameter_is_constant_over_time(self, name: str) -> bool: return False diff --git a/tests/unittests/expressions/test_term_efficient.py b/tests/unittests/expressions/test_term_efficient.py index d30f9310..96b25262 100644 --- a/tests/unittests/expressions/test_term_efficient.py +++ b/tests/unittests/expressions/test_term_efficient.py @@ -37,14 +37,17 @@ ), "-3.0x.shift([2, 3]).sum(False)", ), - (TermEfficient(-3, "c", "x", scenario_operator=Expectation()), "-3.0x.expec()"), + ( + TermEfficient(-3, "c", "x", scenario_aggregator=Expectation()), + "-3.0x.expec()", + ), ( TermEfficient( -3, "c", "x", time_aggregator=TimeSum(True), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ), "-3.0x.sum(True).expec()", ), @@ -105,13 +108,13 @@ def test_printing_term(term: TermEfficient, expected: str) -> None: False, ), ( - TermEfficient(-3, "c", "x", scenario_operator=Expectation()), - TermEfficient(-3, "c", "x", scenario_operator=Expectation()), + TermEfficient(-3, "c", "x", scenario_aggregator=Expectation()), + TermEfficient(-3, "c", "x", scenario_aggregator=Expectation()), True, ), ( - TermEfficient(-3, "c", "x", scenario_operator=Expectation()), - TermEfficient(-3, "c", "x", scenario_operator=Variance()), + TermEfficient(-3, "c", "x", scenario_aggregator=Expectation()), + TermEfficient(-3, "c", "x", scenario_aggregator=Variance()), False, ), ( @@ -120,14 +123,14 @@ def test_printing_term(term: TermEfficient, expected: str) -> None: "c", "x", time_aggregator=TimeSum(True), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ), TermEfficient( -3, "c", "x", time_aggregator=TimeSum(False), - scenario_operator=Expectation(), + scenario_aggregator=Expectation(), ), False, ), diff --git a/tests/unittests/study/test_components_parsing.py b/tests/unittests/study/test_components_parsing.py index f72c5c2f..f5ad10c6 100644 --- a/tests/unittests/study/test_components_parsing.py +++ b/tests/unittests/study/test_components_parsing.py @@ -5,7 +5,8 @@ from andromede.model.parsing import InputLibrary, parse_yaml_library from andromede.model.resolve_library import resolve_library -from andromede.simulation import BlockBorderManagement, TimeBlock, build_problem +from andromede.simulation import TimeBlock, build_problem +from andromede.simulation.optimization_context import BlockBorderManagement from andromede.study import TimeScenarioIndex, TimeScenarioSeriesData from andromede.study.parsing import InputComponents, parse_yaml_components from andromede.study.resolve_components import ( diff --git a/tests/unittests/test_output_values.py b/tests/unittests/test_output_values.py index 89b3637e..f3d01c13 100644 --- a/tests/unittests/test_output_values.py +++ b/tests/unittests/test_output_values.py @@ -15,11 +15,8 @@ import ortools.linear_solver.pywraplp as lp from andromede.simulation import OutputValues -from andromede.simulation.optimization import ( - OptimizationContext, - OptimizationProblem, - TimestepComponentVariableKey, -) +from andromede.simulation.optimization import OptimizationContext, OptimizationProblem +from andromede.simulation.optimization_context import TimestepComponentVariableKey def test_component_and_flow_output_object() -> None: From b42978266788374804a6096ef280d732d9a567f4 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Mon, 19 Aug 2024 17:37:54 +0200 Subject: [PATCH 28/51] Set objective --- .../simulation/linear_expression_resolver.py | 12 +- src/andromede/simulation/optimization.py | 139 +++--------------- 2 files changed, 25 insertions(+), 126 deletions(-) diff --git a/src/andromede/simulation/linear_expression_resolver.py b/src/andromede/simulation/linear_expression_resolver.py index bc4c1074..d104dc41 100644 --- a/src/andromede/simulation/linear_expression_resolver.py +++ b/src/andromede/simulation/linear_expression_resolver.py @@ -50,11 +50,15 @@ def resolve( # Here, the value provide is used only to evaluate possible time operator args if the term has one resolved_variables = self.resolve_variables(term, row_id) + # TODO: For now all coefficients are the same for a given "variable_name", we are not able to represent things like sum(a_t * x_t)... but everything else is ok: + # sum(a_t') * x_t + # a_t * sum(x_t') + # a_t * x_t + # TODO: Next line is to be moved inside the for loop once we have figured out how to represent sum(a_t * x_t) + resolved_coeff = resolve_coefficient( + term.coefficient, self.value_provider, row_id + ) for ts_id, lp_variable in resolved_variables.items(): - # TODO: Where is key going to play a role ? - resolved_coeff = resolve_coefficient( - term.coefficient, self.value_provider, row_id - ) resolved_terms.append(ResolvedTerm(resolved_coeff, lp_variable)) resolved_constant = resolve_coefficient( diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 011b27cd..7538b9e7 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -133,43 +133,34 @@ def _create_objective( solver: lp.Solver, opt_context: OptimizationContext, component: Component, - component_context: ComponentContext, objective_contribution: LinearExpressionEfficient, ) -> None: instantiated_expr = _instantiate_model_expression( objective_contribution, component.id, opt_context ) # We have already checked in the model creation that the objective contribution is neither indexed by time nor by scenario - linear_expr = component_context.linearize_expression(0, 0, instantiated_expr) + + value_provider = _make_value_provider(opt_context, component) + expression_resolver = LinearExpressionResolver(opt_context, value_provider) + resolved_expr = expression_resolver.resolve(instantiated_expr, RowIndex(0, 0)) obj: lp.Objective = solver.Objective() - for term in linear_expr.terms.values(): + for term in resolved_expr.terms: # TODO : How to handle the scenario operator in a general manner ? - if isinstance(term.scenario_aggregator, Expectation): - weight = 1 / opt_context.scenarios - scenario_ids = range(opt_context.scenarios) - else: - weight = 1 - scenario_ids = range(1) - - for scenario in scenario_ids: - solver_vars = _get_solver_vars( - term, - opt_context, - 0, - scenario, - 0, - ) - - for solver_var in solver_vars: - opt_context._solver_variables[solver_var].is_in_objective = True - obj.SetCoefficient( - solver_var, - obj.GetCoefficient(solver_var) + weight * term.coefficient, - ) + # if isinstance(term.scenario_aggregator, Expectation): + # weight = 1 / opt_context.scenarios + # scenario_ids = range(opt_context.scenarios) + # else: + # weight = 1 + # scenario_ids = range(1) + opt_context._solver_variables[term.variable].is_in_objective = True + obj.SetCoefficient( + term.variable, + obj.GetCoefficient(term.variable) + term.coefficient, + ) # This should have no effect on the optimization - obj.SetOffset(linear_expr.constant + obj.offset()) + obj.SetOffset(resolved_expr.constant + obj.offset()) @dataclass @@ -180,84 +171,6 @@ class ConstraintData: expression: ResolvedLinearExpression -def _get_solver_vars( - term: Term, - context: OptimizationContext, - block_timestep: int, - scenario: int, - instance: int, -) -> List[lp.Variable]: - solver_vars = [] - if isinstance(term.time_aggregator, TimeSum): - if isinstance(term.time_operator, TimeShift): - for time_id in term.time_operator.time_ids: - solver_vars.append( - context.get_component_variable( - block_timestep + time_id, - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - elif isinstance(term.time_operator, TimeEvaluation): - for time_id in term.time_operator.time_ids: - solver_vars.append( - context.get_component_variable( - time_id, - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - else: # time_operator is None, retrieve variable for each time step of the block. What happens if we do x.sum() with x not being indexed by time ? Is there a check that it is a valid expression ? - for time_id in range(context.block_length()): - solver_vars.append( - context.get_component_variable( - block_timestep + time_id, - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - - else: # time_aggregator is None - if isinstance(term.time_operator, TimeShift): - solver_vars.append( - context.get_component_variable( - block_timestep + term.time_operator.time_ids[instance], - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - elif isinstance(term.time_operator, TimeEvaluation): - solver_vars.append( - context.get_component_variable( - term.time_operator.time_ids[instance], - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - else: # time_operator is None - # TODO: horrible tous ces if/else - solver_vars.append( - context.get_component_variable( - block_timestep, - scenario, - term.component_id, - term.variable_name, - term.structure, - ) - ) - return solver_vars - - def make_constraint( solver: lp.Solver, row_id: RowIndex, @@ -277,23 +190,6 @@ def make_constraint( term.variable, term.coefficient + solver_constraint.GetCoefficient(term.variable), ) - - # TODO : To be done in linear expression resolution coeff - # for term in data.expression.terms.values(): - # # Move this to resolve coefficient - # solver_vars = _get_solver_vars( - # term, - # context, - # block_timestep, - # scenario, - # instance, - # ) - # for solver_var in solver_vars: - # coefficient = term.coefficient + solver_constraint.GetCoefficient( - # solver_var - # ) - # solver_constraint.SetCoefficient(solver_var, coefficient) - # TODO: On pourrait aussi faire que l'objet Constraint n'ait pas de terme constant dans son expression et que les constantes soit déjà prises en compte dans les bornes, ça simplifierait le traitement ici constant += data.expression.constant solver_constraint.SetBounds( @@ -479,7 +375,6 @@ def _create_objectives(self) -> None: self.solver, self.context, component, - component_context, objective, ) From bda088c2f7b563723759dc9f95b08c7e4c8eb8fa Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 20 Aug 2024 20:15:50 +0200 Subject: [PATCH 29/51] Fix shift distribution and add component context for time operators --- .../evaluate_parameters_efficient.py | 4 ++ .../expression/linear_expression_efficient.py | 66 +++++++++++++++---- .../simulation/linear_expression_resolver.py | 26 ++++---- .../expressions/test_expressions_efficient.py | 20 ++---- .../expressions/test_resolve_coefficients.py | 21 ++++++ 5 files changed, 96 insertions(+), 41 deletions(-) diff --git a/src/andromede/expression/evaluate_parameters_efficient.py b/src/andromede/expression/evaluate_parameters_efficient.py index 3c5d3393..70fa2716 100644 --- a/src/andromede/expression/evaluate_parameters_efficient.py +++ b/src/andromede/expression/evaluate_parameters_efficient.py @@ -163,6 +163,8 @@ def time_aggregator( for k in operand_dict.keys() if k.scenario == scenario ) + # As the sum aggregates on time, time indices on which to evaluate parent expression collapses on row_id.time + self.time_scenario_indices.time_indices = [self.row_id.time] return result else: return NotImplemented @@ -185,6 +187,8 @@ def scenario_operator( operand_dict[k] for k in operand_dict.keys() if k.time == time ) ) + # As the expectation aggregates on scenario, scenario indices on which to evaluate parent expression collapses on row_id.scenario + self.time_scenario_indices.scenario_indices = [self.row_id.scenario] return result else: diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 1802ddd3..059920d6 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -25,6 +25,7 @@ Sequence, TypeVar, Union, + cast, overload, ) @@ -231,26 +232,25 @@ def sum( if shift is not None and eval is not None: raise ValueError("Only shift or eval arguments should specified, not both.") - # The shift or eval operators distribute over the coefficients whereas the sum only applies to the whole as (param("a") * var("x")).shift([1,5]) represents: a[t+1]x[t+1] + ... + a[t+5]x[t+5] - # And (param("a") * var("x")).eval([1,5]) represents: a[1]x[1] + ... + a[5]x[5] + # The shift or eval operators applies on the variable, then it will define at which time step the term coefficient * variable will be evaluated if shift is not None: return dataclasses.replace( self, - coefficient=TimeOperatorNode( - self.coefficient, TimeOperatorName.SHIFT, InstancesTimeIndex(shift) - ), + # coefficient=TimeOperatorNode( + # self.coefficient, TimeOperatorName.SHIFT, InstancesTimeIndex(shift) + # ), time_operator=TimeShift(InstancesTimeIndex(shift)), time_aggregator=TimeSum(stay_roll=True), ) elif eval is not None: return dataclasses.replace( self, - coefficient=TimeOperatorNode( - self.coefficient, - TimeOperatorName.EVALUATION, - InstancesTimeIndex(eval), - ), + # coefficient=TimeOperatorNode( + # self.coefficient, + # TimeOperatorName.EVALUATION, + # InstancesTimeIndex(eval), + # ), time_operator=TimeEvaluation(InstancesTimeIndex(eval)), time_aggregator=TimeSum(stay_roll=True), ) @@ -380,8 +380,7 @@ def _merge_dicts( rhs: Dict[TermKeyEfficient, TermEfficient], merge_func: Callable[[TermEfficient, TermEfficient], TermEfficient], neutral: float, -) -> Dict[TermKeyEfficient, TermEfficient]: - ... +) -> Dict[TermKeyEfficient, TermEfficient]: ... @overload @@ -390,8 +389,7 @@ def _merge_dicts( rhs: Dict[PortFieldId, PortFieldTerm], merge_func: Callable[[PortFieldTerm, PortFieldTerm], PortFieldTerm], neutral: float, -) -> Dict[PortFieldId, PortFieldTerm]: - ... +) -> Dict[PortFieldId, PortFieldTerm]: ... def _get_neutral_term(term: T_val, neutral: float) -> T_val: @@ -959,10 +957,21 @@ def add_component_context(self, component_id: str) -> "LinearExpressionEfficient raise ValueError( "This expression has already been associated to another component." ) + result_term = dataclasses.replace( term, component_id=component_id, coefficient=add_component_context(component_id, term.coefficient), + time_operator=( + dataclasses.replace( + term.time_operator, + time_ids=_add_component_context_to_instances_index( + component_id, term.time_operator.time_ids + ), + ) + if term.time_operator + else None + ), ) result_terms[generate_key(result_term)] = result_term result_constant = add_component_context(component_id, self.constant) @@ -971,6 +980,35 @@ def add_component_context(self, component_id: str) -> "LinearExpressionEfficient ) +def _add_component_context_to_expression_range( + component_id: str, expression_range: ExpressionRange +) -> ExpressionRange: + return ExpressionRange( + start=add_component_context(component_id, expression_range.start), + stop=add_component_context(component_id, expression_range.stop), + step=( + add_component_context(component_id, expression_range.step) + if expression_range.step is not None + else None + ), + ) + + +def _add_component_context_to_instances_index( + component_id: str, instances_index: InstancesTimeIndex +) -> InstancesTimeIndex: + expressions = instances_index.expressions + if isinstance(expressions, ExpressionRange): + return InstancesTimeIndex( + _add_component_context_to_expression_range(component_id, expressions) + ) + if isinstance(expressions, list): + expressions_list = cast(List[ExpressionNodeEfficient], expressions) + copy = [add_component_context(component_id, e) for e in expressions_list] + return InstancesTimeIndex(copy) + raise ValueError("Unexpected type in instances index") + + def linear_expressions_equal( lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient ) -> bool: diff --git a/src/andromede/simulation/linear_expression_resolver.py b/src/andromede/simulation/linear_expression_resolver.py index d104dc41..7691b53d 100644 --- a/src/andromede/simulation/linear_expression_resolver.py +++ b/src/andromede/simulation/linear_expression_resolver.py @@ -55,10 +55,14 @@ def resolve( # a_t * sum(x_t') # a_t * x_t # TODO: Next line is to be moved inside the for loop once we have figured out how to represent sum(a_t * x_t) - resolved_coeff = resolve_coefficient( - term.coefficient, self.value_provider, row_id - ) + for ts_id, lp_variable in resolved_variables.items(): + # TODO: Could we check in which case coeff resolution leads to the same result for each element in the for loop ? When there is only a literal, etc, etc ? + resolved_coeff = resolve_coefficient( + term.coefficient, + self.value_provider, + RowIndex(ts_id.time, ts_id.scenario), + ) resolved_terms.append(ResolvedTerm(resolved_coeff, lp_variable)) resolved_constant = resolve_coefficient( @@ -80,14 +84,14 @@ def resolve_variables( operator_ts_ids = self._row_id_to_term_time_scenario_id(term, row_id) for time in operator_ts_ids.time_indices: for scenario in operator_ts_ids.scenario_indices: - solver_vars[ - TimeScenarioIndex(time, scenario) - ] = self.context.get_component_variable( - time, - scenario, - term.component_id, - term.variable_name, - term.structure, + solver_vars[TimeScenarioIndex(time, scenario)] = ( + self.context.get_component_variable( + time, + scenario, + term.component_id, + term.variable_name, + term.structure, + ) ) return solver_vars diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index 36d7075a..a2f73401 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -413,9 +413,7 @@ def test_comparison() -> None: ), # The internal representation of shift(1) is sum(shift=1) scenario_aggregator=None, ): TermEfficient( - TimeOperatorNode( - LiteralNode(1), TimeOperatorName.SHIFT, InstancesTimeIndex(1) - ), + LiteralNode(1), "", "x", time_operator=TimeShift( @@ -432,9 +430,7 @@ def test_comparison() -> None: time_aggregator=TimeSum(stay_roll=True), scenario_aggregator=None, ): TermEfficient( - TimeOperatorNode( - LiteralNode(1), TimeOperatorName.SHIFT, InstancesTimeIndex(1) - ), + LiteralNode(1), "", "y", time_operator=TimeShift(InstancesTimeIndex(1)), @@ -461,11 +457,7 @@ def test_comparison() -> None: ), # The internal representation of eval(1) is sum(eval=1) scenario_aggregator=None, ): TermEfficient( - TimeOperatorNode( - LiteralNode(1), - TimeOperatorName.EVALUATION, - InstancesTimeIndex(1), - ), + LiteralNode(1), "", "x", time_operator=TimeEvaluation( @@ -482,11 +474,7 @@ def test_comparison() -> None: time_aggregator=TimeSum(stay_roll=True), scenario_aggregator=None, ): TermEfficient( - TimeOperatorNode( - LiteralNode(1), - TimeOperatorName.EVALUATION, - InstancesTimeIndex(1), - ), + LiteralNode(1), "", "y", time_operator=TimeEvaluation(InstancesTimeIndex(1)), diff --git a/tests/unittests/expressions/test_resolve_coefficients.py b/tests/unittests/expressions/test_resolve_coefficients.py index 9eb31619..2211a617 100644 --- a/tests/unittests/expressions/test_resolve_coefficients.py +++ b/tests/unittests/expressions/test_resolve_coefficients.py @@ -30,6 +30,7 @@ TimeOperatorName, TimeOperatorNode, comp_param, + literal, param, ) from andromede.expression.indexing_structure import IndexingStructure, RowIndex @@ -275,6 +276,7 @@ def test_resolve_coefficient_on_elementary_operations( [ (param("p").shift(2).sum(), RowIndex(0, 0), 3.0), (param("p").shift(-1).sum(), RowIndex(2, 1), 5.0), + (literal(0).shift(-1).sum(), RowIndex(0, 0), 0.0), (param("p").eval(2).sum(), RowIndex(0, 0), 3.0), (param("p").eval(2).sum(), RowIndex(2, 0), 3.0), (param("p").shift(ExpressionRange(0, 3)).sum(), RowIndex(0, 0), 13.0), @@ -311,3 +313,22 @@ def test_resolve_coefficient_on_expectation( provider: CustomValueProvider, ) -> None: assert math.isclose(resolve_coefficient(expr, provider, row_id), expected) + + +@pytest.mark.parametrize( + "expr, row_id, expected", + [ + (param("p").expec().sum(), RowIndex(0, 0), 18.0), + (param("p").sum().expec(), RowIndex(0, 0), 18.0), + (param("p").shift(comp_param("c", "q")).sum().expec(), RowIndex(1, 0), 6.5), + (param("p").expec().shift(comp_param("c", "q")).sum(), RowIndex(1, 0), 7.5), + (param("p").shift(comp_param("c", "q")).expec().sum(), RowIndex(1, 0), 6.5), + ], +) +def test_resolve_coefficient_on_sum_and_expectation( + expr: ExpressionNodeEfficient, + row_id: RowIndex, + expected: float, + provider: CustomValueProvider, +) -> None: + assert math.isclose(resolve_coefficient(expr, provider, row_id), expected) From 2c723cd9162f8e38fbff381a89c7ac2e700add54 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 21 Aug 2024 09:40:19 +0200 Subject: [PATCH 30/51] Temporary API for single shift over ExpressionNodeEfficient --- src/andromede/libs/standard.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/andromede/libs/standard.py b/src/andromede/libs/standard.py index e5cdaf1f..d476aab4 100644 --- a/src/andromede/libs/standard.py +++ b/src/andromede/libs/standard.py @@ -260,12 +260,13 @@ ) <= var("nb_on"), ), + # TODO : Improve API so that we are not forced to use sum() on one shifted element for ExpressionNodeEfficient Constraint( "Min down time", var("nb_stop").sum( shift=ExpressionRange(-param("d_min_down") + 1, literal(0)) ) - <= param("nb_units_max").shift(-param("d_min_down")) - var("nb_on"), + <= param("nb_units_max").shift(-param("d_min_down")).sum() - var("nb_on"), ), # It also works by writing ExpressionRange(-param("d_min_down") + 1, 0) as ExpressionRange's __post_init__ wraps integers to literal nodes. However, MyPy does not seem to infer that ExpressionRange's attributes are necessarily of ExpressionNode type and raises an error if the arguments in the constructor are integer (whereas it runs correctly), this why we specify it here with literal(0) instead of 0. ], From 27f8ad8c063c99ff1ed931382ee901a52f6f9b2c Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 21 Aug 2024 10:15:53 +0200 Subject: [PATCH 31/51] Fix variable get structure --- src/andromede/expression/linear_expression_efficient.py | 1 + src/andromede/libs/standard.py | 2 +- src/andromede/simulation/linear_expression_resolver.py | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 059920d6..4ea7a2e6 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -1079,6 +1079,7 @@ def _copy_expression( # TODO : Define shortcuts for "x", is_one etc .... def var(name: str) -> LinearExpressionEfficient: + # TODO: At term build time, no information on the variable structure is known, we use a default time, scenario varying, maybe discard structure as term attribute ? return LinearExpressionEfficient( [ TermEfficient( diff --git a/src/andromede/libs/standard.py b/src/andromede/libs/standard.py index d476aab4..a28ba905 100644 --- a/src/andromede/libs/standard.py +++ b/src/andromede/libs/standard.py @@ -343,7 +343,7 @@ var("nb_stop").sum( shift=ExpressionRange(-param("d_min_down") + 1, literal(0)) ) - <= param("nb_units_max").shift(-param("d_min_down")) - var("nb_on"), + <= param("nb_units_max").shift(-param("d_min_down")).sum() - var("nb_on"), ), ], objective_operational_contribution=(param("cost") * var("generation")) diff --git a/src/andromede/simulation/linear_expression_resolver.py b/src/andromede/simulation/linear_expression_resolver.py index 7691b53d..fa479191 100644 --- a/src/andromede/simulation/linear_expression_resolver.py +++ b/src/andromede/simulation/linear_expression_resolver.py @@ -90,7 +90,10 @@ def resolve_variables( scenario, term.component_id, term.variable_name, - term.structure, + # At term build time, no information on the variable structure is known, we use it now + self.context.network.get_component(term.component_id) + .model.variables[term.variable_name] + .structure, ) ) return solver_vars From bcafb95b1c0bde9952e51e74d92bd5ac69e45674 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 21 Aug 2024 10:52:14 +0200 Subject: [PATCH 32/51] Fix expectation computation --- .../expression/linear_expression_efficient.py | 6 ++-- .../simulation/linear_expression_resolver.py | 36 ++++++++++--------- tests/functional/test_xpansion.py | 8 ++--- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 4ea7a2e6..2472c909 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -380,7 +380,8 @@ def _merge_dicts( rhs: Dict[TermKeyEfficient, TermEfficient], merge_func: Callable[[TermEfficient, TermEfficient], TermEfficient], neutral: float, -) -> Dict[TermKeyEfficient, TermEfficient]: ... +) -> Dict[TermKeyEfficient, TermEfficient]: + ... @overload @@ -389,7 +390,8 @@ def _merge_dicts( rhs: Dict[PortFieldId, PortFieldTerm], merge_func: Callable[[PortFieldTerm, PortFieldTerm], PortFieldTerm], neutral: float, -) -> Dict[PortFieldId, PortFieldTerm]: ... +) -> Dict[PortFieldId, PortFieldTerm]: + ... def _get_neutral_term(term: T_val, neutral: float) -> T_val: diff --git a/src/andromede/simulation/linear_expression_resolver.py b/src/andromede/simulation/linear_expression_resolver.py index fa479191..892c701f 100644 --- a/src/andromede/simulation/linear_expression_resolver.py +++ b/src/andromede/simulation/linear_expression_resolver.py @@ -28,6 +28,7 @@ ResolvedLinearExpression, ResolvedTerm, ) +from andromede.expression.scenario_operator import Expectation from andromede.expression.time_operator import TimeShift from andromede.expression.value_provider import ( TimeScenarioIndex, @@ -50,11 +51,10 @@ def resolve( # Here, the value provide is used only to evaluate possible time operator args if the term has one resolved_variables = self.resolve_variables(term, row_id) - # TODO: For now all coefficients are the same for a given "variable_name", we are not able to represent things like sum(a_t * x_t)... but everything else is ok: - # sum(a_t') * x_t - # a_t * sum(x_t') - # a_t * x_t - # TODO: Next line is to be moved inside the for loop once we have figured out how to represent sum(a_t * x_t) + # TODO: Contrary to the time aggregator that does a sum which is the default behaviour when append resolved terms, expectation performs an averaging, so weights must be included in coefficients. We feel here that we could generalize time and scenario aggregation over variables with more general operators, the following lines are very specific to expectation with same weights over all scenarios + weight = 1 + if isinstance(term.scenario_aggregator, Expectation): + weight = 1 / self.value_provider.scenarios() for ts_id, lp_variable in resolved_variables.items(): # TODO: Could we check in which case coeff resolution leads to the same result for each element in the for loop ? When there is only a literal, etc, etc ? @@ -63,7 +63,9 @@ def resolve( self.value_provider, RowIndex(ts_id.time, ts_id.scenario), ) - resolved_terms.append(ResolvedTerm(resolved_coeff, lp_variable)) + resolved_terms.append( + ResolvedTerm(weight * resolved_coeff, lp_variable) + ) resolved_constant = resolve_coefficient( expression.constant, self.value_provider, row_id @@ -84,17 +86,17 @@ def resolve_variables( operator_ts_ids = self._row_id_to_term_time_scenario_id(term, row_id) for time in operator_ts_ids.time_indices: for scenario in operator_ts_ids.scenario_indices: - solver_vars[TimeScenarioIndex(time, scenario)] = ( - self.context.get_component_variable( - time, - scenario, - term.component_id, - term.variable_name, - # At term build time, no information on the variable structure is known, we use it now - self.context.network.get_component(term.component_id) - .model.variables[term.variable_name] - .structure, - ) + solver_vars[ + TimeScenarioIndex(time, scenario) + ] = self.context.get_component_variable( + time, + scenario, + term.component_id, + term.variable_name, + # At term build time, no information on the variable structure is known, we use it now + self.context.network.get_component(term.component_id) + .model.variables[term.variable_name] + .structure, ) return solver_vars diff --git a/tests/functional/test_xpansion.py b/tests/functional/test_xpansion.py index 32545eca..316ede8f 100644 --- a/tests/functional/test_xpansion.py +++ b/tests/functional/test_xpansion.py @@ -375,10 +375,10 @@ def test_generation_xpansion_two_time_steps_two_scenarios( status = problem.solver.Solve() assert status == problem.solver.OPTIMAL - # assert problem.solver.NumVariables() == 2 * scenarios * horizon + 1 - # assert ( - # problem.solver.NumConstraints() == 3 * scenarios * horizon - # ) # Flow balance, Max generation for each cluster + assert problem.solver.NumVariables() == 2 * scenarios * horizon + 1 + assert ( + problem.solver.NumConstraints() == 3 * scenarios * horizon + ) # Flow balance, Max generation for each cluster assert problem.solver.Objective().Value() == pytest.approx( 490 * 300 + 0.5 * (10 * 300 + 10 * 300 + 40 * 200) From b0197da0e579634b648da613bba648871738cc14 Mon Sep 17 00:00:00 2001 From: tbittar Date: Wed, 21 Aug 2024 18:40:39 +0200 Subject: [PATCH 33/51] Feature/update yaml parsing (#51) * Fix literal parsing * Parse models * Fix shift and eval parsing (temporary) * Remove useless code (#50) --- src/andromede/expression/__init__.py | 37 +- src/andromede/expression/evaluate.py | 10 - .../expression/evaluate_parameters.py | 99 ---- src/andromede/expression/expression.py | 454 ------------------ .../expression/expression_efficient.py | 33 +- src/andromede/expression/indexing.py | 105 ---- .../expression/linear_expression_efficient.py | 38 +- .../expression/parsing/parse_expression.py | 31 +- src/andromede/model/model.py | 44 +- src/andromede/model/resolve_library.py | 32 +- src/andromede/model/variable.py | 6 +- src/andromede/simulation/optimization.py | 16 +- .../simulation/optimization_context.py | 55 +-- tests/functional/test_performance.py | 280 ----------- .../functional/test_performance_efficient.py | 132 +++++ .../parsing/test_expression_parsing.py | 41 +- .../unittests/expressions/test_expressions.py | 314 ------------ 17 files changed, 287 insertions(+), 1440 deletions(-) delete mode 100644 src/andromede/expression/evaluate_parameters.py delete mode 100644 src/andromede/expression/expression.py delete mode 100644 tests/functional/test_performance.py delete mode 100644 tests/unittests/expressions/test_expressions.py diff --git a/src/andromede/expression/__init__.py b/src/andromede/expression/__init__.py index 55b51967..70c8c8ac 100644 --- a/src/andromede/expression/__init__.py +++ b/src/andromede/expression/__init__.py @@ -12,41 +12,28 @@ from .copy import CopyVisitor, copy_expression from .degree import ExpressionDegreeVisitor, compute_degree -from .evaluate import EvaluationContext, EvaluationVisitor, ValueProvider, evaluate -from .evaluate_parameters import ( - ParameterResolver, - ParameterValueProvider, - resolve_parameters, -) - -from .expression import ( - # AdditionNode, - # Comparator, - # ComparisonNode, - # DivisionNode, - ExpressionNode, - # LiteralNode, - # MultiplicationNode, - # NegationNode, - # ParameterNode, - # SubstractionNode, - VariableNode, - literal, - param, - sum_expressions, - var, -) +from .evaluate_parameters_efficient import ValueProvider from .expression_efficient import ( AdditionNode, - Comparator, ComparisonNode, + ComponentParameterNode, DivisionNode, ExpressionNodeEfficient, + ExpressionRange, + InstancesTimeIndex, LiteralNode, MultiplicationNode, NegationNode, ParameterNode, + PortFieldAggregatorNode, + PortFieldNode, + ScenarioOperatorName, + ScenarioOperatorNode, SubstractionNode, + TimeAggregatorName, + TimeAggregatorNode, + TimeOperatorName, + TimeOperatorNode, ) from .print import PrinterVisitor, print_expr from .visitor import ExpressionVisitor, visit diff --git a/src/andromede/expression/evaluate.py b/src/andromede/expression/evaluate.py index 08477070..e09f033c 100644 --- a/src/andromede/expression/evaluate.py +++ b/src/andromede/expression/evaluate.py @@ -13,7 +13,6 @@ from dataclasses import dataclass, field from typing import Dict -from andromede.expression.expression import VariableNode from andromede.expression.expression_efficient import ( ComparisonNode, ComponentParameterNode, @@ -93,18 +92,12 @@ def literal(self, node: LiteralNode) -> float: def comparison(self, node: ComparisonNode) -> float: raise ValueError("Cannot evaluate comparison operator.") - def variable(self, node: VariableNode) -> float: - return self.context.get_variable_value(node.name) - def parameter(self, node: ParameterNode) -> float: return self.context.get_parameter_value(node.name) def comp_parameter(self, node: ComponentParameterNode) -> float: return self.context.get_component_parameter_value(node.component_id, node.name) - # def comp_variable(self, node: ComponentVariableNode) -> float: - # return self.context.get_component_variable_value(node.component_id, node.name) - def time_operator(self, node: TimeOperatorNode) -> float: raise NotImplementedError() @@ -133,9 +126,6 @@ class InstancesIndexVisitor(EvaluationVisitor): Evaluates an expression given as instances index which should have no variable and constant parameter values. """ - def variable(self, node: VariableNode) -> float: - raise ValueError("An instance index expression cannot contain variable") - def parameter(self, node: ParameterNode) -> float: if not self.context.parameter_is_constant_over_time(node.name): raise ValueError( diff --git a/src/andromede/expression/evaluate_parameters.py b/src/andromede/expression/evaluate_parameters.py deleted file mode 100644 index 7c734260..00000000 --- a/src/andromede/expression/evaluate_parameters.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import List - -from andromede.expression.evaluate import InstancesIndexVisitor, ValueProvider - -from .copy import CopyVisitor -from .expression import ( - ComponentParameterNode, - ExpressionNode, - ExpressionRange, - InstancesTimeIndex, - LiteralNode, - ParameterNode, -) -from .visitor import visit - - -class ParameterValueProvider(ABC): - @abstractmethod - def get_parameter_value(self, name: str) -> float: - ... - - @abstractmethod - def get_component_parameter_value(self, component_id: str, name: str) -> float: - ... - - -@dataclass(frozen=True) -class ParameterResolver(CopyVisitor): - """ - Duplicates the AST with replacement of parameter nodes by literal nodes. - """ - - context: ParameterValueProvider - - def parameter(self, node: ParameterNode) -> ExpressionNode: - value: float = self.context.get_parameter_value(node.name) - return LiteralNode(value) - - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: - value: float = self.context.get_component_parameter_value( - node.component_id, node.name - ) - return LiteralNode(value) - - -def resolve_parameters( - expression: ExpressionNode, parameter_provider: ParameterValueProvider -) -> ExpressionNode: - return visit(expression, ParameterResolver(parameter_provider)) - - -def float_to_int(value: float) -> int: - if isinstance(value, int) or value.is_integer(): - return int(value) - else: - raise ValueError(f"{value} is not an integer.") - - -def evaluate_time_id(expr: ExpressionNode, value_provider: ValueProvider) -> int: - float_time_id = visit(expr, InstancesIndexVisitor(value_provider)) - try: - time_id = float_to_int(float_time_id) - except ValueError: - print(f"{expr} does not represent an integer time index.") - return time_id - - -def get_time_ids_from_instances_index( - instances_index: InstancesTimeIndex, value_provider: ValueProvider -) -> List[int]: - time_ids = [] - if isinstance(instances_index.expressions, list): # List[ExpressionNode] - for expr in instances_index.expressions: - time_ids.append(evaluate_time_id(expr, value_provider)) - - elif isinstance(instances_index.expressions, ExpressionRange): # ExpressionRange - start_id = evaluate_time_id(instances_index.expressions.start, value_provider) - stop_id = evaluate_time_id(instances_index.expressions.stop, value_provider) - step_id = 1 - if instances_index.expressions.step is not None: - step_id = evaluate_time_id(instances_index.expressions.step, value_provider) - # ExpressionRange includes stop_id whereas range excludes it - time_ids = list(range(start_id, stop_id + 1, step_id)) - - return time_ids diff --git a/src/andromede/expression/expression.py b/src/andromede/expression/expression.py deleted file mode 100644 index 01e8136b..00000000 --- a/src/andromede/expression/expression.py +++ /dev/null @@ -1,454 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -""" -Defines the model for generic expressions. -""" -import enum -import inspect -from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Sequence, Union - -import andromede.expression.port_operator -import andromede.expression.scenario_operator -import andromede.expression.time_operator - - -class Instances(enum.Enum): - SIMPLE = "SIMPLE" - MULTIPLE = "MULTIPLE" - - -@dataclass(frozen=True) -class ExpressionNode: - """ - Base class for all nodes of the expression AST. - - Operators overloading is provided to help create expressions - programmatically. - - Examples - >>> expr = -var('x') + 5 / param('p') - """ - - instances: Instances = field(init=False, default=Instances.SIMPLE) - - def __neg__(self) -> "ExpressionNode": - return NegationNode(self) - - def __add__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: AdditionNode(self, x)) - - def __radd__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: AdditionNode(x, self)) - - def __sub__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: SubstractionNode(self, x)) - - def __rsub__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: SubstractionNode(x, self)) - - def __mul__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: MultiplicationNode(self, x)) - - def __rmul__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: MultiplicationNode(x, self)) - - def __truediv__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: DivisionNode(self, x)) - - def __rtruediv__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: DivisionNode(x, self)) - - def __le__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node( - rhs, lambda x: ComparisonNode(self, x, Comparator.LESS_THAN) - ) - - def __ge__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node( - rhs, lambda x: ComparisonNode(self, x, Comparator.GREATER_THAN) - ) - - def __eq__(self, rhs: Any) -> "ExpressionNode": # type: ignore - return _apply_if_node(rhs, lambda x: ComparisonNode(self, x, Comparator.EQUAL)) - - def sum(self) -> "ExpressionNode": - if isinstance(self, TimeOperatorNode): - return TimeAggregatorNode(self, "TimeSum", stay_roll=True) - else: - return _apply_if_node( - self, lambda x: TimeAggregatorNode(x, "TimeSum", stay_roll=False) - ) - - def sum_connections(self) -> "ExpressionNode": - if isinstance(self, PortFieldNode): - return PortFieldAggregatorNode(self, aggregator="PortSum") - raise ValueError( - f"sum_connections() applies only for PortFieldNode, whereas the current node is of type {type(self)}." - ) - - def shift( - self, - expressions: Union[ - int, "ExpressionNode", List["ExpressionNode"], "ExpressionRange" - ], - ) -> "ExpressionNode": - return _apply_if_node( - self, - lambda x: TimeOperatorNode(x, "TimeShift", InstancesTimeIndex(expressions)), - ) - - def eval( - self, - expressions: Union[ - int, "ExpressionNode", List["ExpressionNode"], "ExpressionRange" - ], - ) -> "ExpressionNode": - return _apply_if_node( - self, - lambda x: TimeOperatorNode( - x, "TimeEvaluation", InstancesTimeIndex(expressions) - ), - ) - - def expec(self) -> "ExpressionNode": - return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) - - def variance(self) -> "ExpressionNode": - return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) - - -def _wrap_in_node(obj: Any) -> ExpressionNode: - if isinstance(obj, ExpressionNode): - return obj - elif isinstance(obj, float) or isinstance(obj, int): - return LiteralNode(float(obj)) - raise TypeError(f"Unable to wrap {obj} into an expression node") - - -def _apply_if_node( - obj: Any, func: Callable[["ExpressionNode"], "ExpressionNode"] -) -> "ExpressionNode": - if as_node := _wrap_in_node(obj): - return func(as_node) - else: - return NotImplemented - - -@dataclass(frozen=True, eq=False) -class VariableNode(ExpressionNode): - name: str - - -def var(name: str) -> VariableNode: - return VariableNode(name) - - -@dataclass(frozen=True, eq=False) -class PortFieldNode(ExpressionNode): - """ - References a port field. - """ - - port_name: str - field_name: str - - -def port_field(port_name: str, field_name: str) -> PortFieldNode: - return PortFieldNode(port_name, field_name) - - -@dataclass(frozen=True, eq=False) -class ParameterNode(ExpressionNode): - name: str - - -def param(name: str) -> ParameterNode: - return ParameterNode(name) - - -@dataclass(frozen=True, eq=False) -class ComponentParameterNode(ExpressionNode): - """ - Represents one parameter of one component. - - When building actual equations for a system, - we need to associated each parameter to its - actual component, at some point. - """ - - component_id: str - name: str - - -def comp_param(component_id: str, name: str) -> ComponentParameterNode: - return ComponentParameterNode(component_id, name) - - -@dataclass(frozen=True, eq=False) -class ComponentVariableNode(ExpressionNode): - """ - Represents one variable of one component. - - When building actual equations for a system, - we need to associated each variable to its - actual component, at some point. - """ - - component_id: str - name: str - - -def comp_var(component_id: str, name: str) -> ComponentVariableNode: - return ComponentVariableNode(component_id, name) - - -@dataclass(frozen=True, eq=False) -class LiteralNode(ExpressionNode): - value: float - - -def literal(value: float) -> LiteralNode: - return LiteralNode(value) - - -@dataclass(frozen=True, eq=False) -class UnaryOperatorNode(ExpressionNode): - operand: ExpressionNode - - def __post_init__(self) -> None: - object.__setattr__(self, "instances", self.operand.instances) - - -@dataclass(frozen=True, eq=False) -class PortFieldAggregatorNode(UnaryOperatorNode): - aggregator: str - - def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.port_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.port_operator.PortAggregator) - ] - if self.aggregator not in valid_names: - raise NotImplementedError( - f"{self.aggregator} is not a valid port aggregator, valid port aggregators are {valid_names}" - ) - - -@dataclass(frozen=True, eq=False) -class NegationNode(UnaryOperatorNode): - pass - - -@dataclass(frozen=True, eq=False) -class BinaryOperatorNode(ExpressionNode): - left: ExpressionNode - right: ExpressionNode - - def __post_init__(self) -> None: - binary_operator_post_init(self, "apply binary operation with") - - -def binary_operator_post_init(node: BinaryOperatorNode, operation: str) -> None: - if node.left.instances != node.right.instances: - raise ValueError( - f"Cannot {operation} {node.left} and {node.right} as they do not have the same number of instances." - ) - else: - object.__setattr__(node, "instances", node.left.instances) - - -class Comparator(enum.Enum): - LESS_THAN = "LESS_THAN" - EQUAL = "EQUAL" - GREATER_THAN = "GREATER_THAN" - - -@dataclass(frozen=True, eq=False) -class ComparisonNode(BinaryOperatorNode): - comparator: Comparator - - def __post_init__(self) -> None: - binary_operator_post_init(self, "compare") - - -@dataclass(frozen=True, eq=False) -class AdditionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "add") - - -@dataclass(frozen=True, eq=False) -class SubstractionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "substract") - - -@dataclass(frozen=True, eq=False) -class MultiplicationNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "multiply") - - -@dataclass(frozen=True, eq=False) -class DivisionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "divide") - - -@dataclass(frozen=True, eq=False) -class ExpressionRange: - start: ExpressionNode - stop: ExpressionNode - step: Optional[ExpressionNode] = None - - def __post_init__(self) -> None: - for attribute in self.__dict__: - value = getattr(self, attribute) - object.__setattr__( - self, attribute, _wrap_in_node(value) if value is not None else value - ) - - -IntOrExpr = Union[int, ExpressionNode] - - -def expression_range( - start: IntOrExpr, stop: IntOrExpr, step: Optional[IntOrExpr] = None -) -> ExpressionRange: - return ExpressionRange( - start=_wrap_in_node(start), - stop=_wrap_in_node(stop), - step=None if step is None else _wrap_in_node(step), - ) - - -@dataclass -class InstancesTimeIndex: - """ - Defines a set of time indices on which a time operator operates. - - In particular, it defines time indices created by the shift operator. - - The actual indices can either be defined as a time range defined by - 2 expression, or as a list of expressions. - """ - - expressions: Union[List[ExpressionNode], ExpressionRange] - - def __init__( - self, - expressions: Union[int, ExpressionNode, List[ExpressionNode], ExpressionRange], - ) -> None: - if not isinstance(expressions, (int, ExpressionNode, list, ExpressionRange)): - raise TypeError( - f"{expressions} must be of type among {{int, ExpressionNode, List[ExpressionNode], ExpressionRange}}" - ) - if isinstance(expressions, list) and not all( - isinstance(x, ExpressionNode) for x in expressions - ): - raise TypeError( - f"All elements of {expressions} must be of type ExpressionNode" - ) - - if isinstance(expressions, (int, ExpressionNode)): - self.expressions = [_wrap_in_node(expressions)] - else: - self.expressions = expressions - - def is_simple(self) -> bool: - if isinstance(self.expressions, list): - return len(self.expressions) == 1 - else: - # TODO: We could also check that if a range only includes literal nodes, compute the length of the range, if it's one return True. This is more complicated, I do not know if we want to do this - return False - - -@dataclass(frozen=True, eq=False) -class TimeOperatorNode(UnaryOperatorNode): - name: str - instances_index: InstancesTimeIndex - - def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.time_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.time_operator.TimeOperator) - ] - if self.name not in valid_names: - raise ValueError( - f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" - ) - if self.operand.instances == Instances.SIMPLE: - if self.instances_index.is_simple(): - object.__setattr__(self, "instances", Instances.SIMPLE) - else: - object.__setattr__(self, "instances", Instances.MULTIPLE) - else: - raise ValueError( - "Cannot apply time operator on an expression that already represents multiple instances" - ) - - -@dataclass(frozen=True, eq=False) -class TimeAggregatorNode(UnaryOperatorNode): - name: str - stay_roll: bool - - def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.time_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.time_operator.TimeAggregator) - ] - if self.name not in valid_names: - raise ValueError( - f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" - ) - object.__setattr__(self, "instances", Instances.SIMPLE) - - -@dataclass(frozen=True, eq=False) -class ScenarioOperatorNode(UnaryOperatorNode): - name: str - - def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.scenario_operator, inspect.isclass - ) - if issubclass( - cls, andromede.expression.scenario_operator.ScenarioAggregator - ) - ] - if self.name not in valid_names: - raise ValueError( - f"{self.name} is not a valid scenario operator, valid scenario operators are {valid_names}" - ) - object.__setattr__(self, "instances", Instances.SIMPLE) - - -def sum_expressions(expressions: Sequence[ExpressionNode]) -> ExpressionNode: - if len(expressions) == 0: - return LiteralNode(0) - if len(expressions) == 1: - return expressions[0] - return expressions[0] + sum_expressions(expressions[1:]) diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index a0fd86e7..29bbe87b 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -444,7 +444,7 @@ class DivisionNode(BinaryOperatorNode): pass -@dataclass(frozen=True, eq=False) +@dataclass(frozen=True) class ExpressionRange: start: ExpressionNodeEfficient stop: ExpressionNodeEfficient @@ -457,6 +457,14 @@ def __post_init__(self) -> None: self, attribute, wrap_in_node(value) if value is not None else value ) + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, ExpressionRange) + and expressions_equal(self.start, other.start) + and expressions_equal(self.stop, other.stop) + and expressions_equal_if_present(self.step, other.step) + ) + IntOrExpr = Union[int, ExpressionNodeEfficient] @@ -515,6 +523,29 @@ def __hash__(self) -> int: else: return hash(self.expressions) + def __eq__(self, other: Any) -> bool: + if isinstance(other, InstancesTimeIndex): + if isinstance(self.expressions, list) and all( + isinstance(x, ExpressionNodeEfficient) for x in self.expressions + ): + return ( + isinstance(other.expressions, list) + and all( + isinstance(x, ExpressionNodeEfficient) + for x in other.expressions + ) + and all( + expressions_equal(left_expr, right_expr) + for left_expr, right_expr in zip( + self.expressions, other.expressions + ) + ) + ) + elif isinstance(self.expressions, ExpressionRange): + return self.expressions == other.expressions + else: + return False + def is_simple(self) -> bool: if isinstance(self.expressions, list): return len(self.expressions) == 1 diff --git a/src/andromede/expression/indexing.py b/src/andromede/expression/indexing.py index 102f4c45..73d43ff5 100644 --- a/src/andromede/expression/indexing.py +++ b/src/andromede/expression/indexing.py @@ -11,32 +11,9 @@ # This file is part of the Antares project. from abc import ABC, abstractmethod -from dataclasses import dataclass -import andromede.expression.time_operator from andromede.expression.indexing_structure import IndexingStructure -from .expression import ( - AdditionNode, - ComparisonNode, - ComponentParameterNode, - ComponentVariableNode, - DivisionNode, - ExpressionNode, - LiteralNode, - MultiplicationNode, - NegationNode, - ParameterNode, - PortFieldAggregatorNode, - PortFieldNode, - ScenarioOperatorNode, - SubstractionNode, - TimeAggregatorNode, - TimeOperatorNode, - VariableNode, -) -from .visitor import ExpressionVisitor, T, visit - class IndexingStructureProvider(ABC): @abstractmethod @@ -58,85 +35,3 @@ def get_component_parameter_structure( self, component_id: str, name: str ) -> IndexingStructure: ... - - -@dataclass(frozen=True) -class TimeScenarioIndexingVisitor(ExpressionVisitor[IndexingStructure]): - """ - Determines if the expression represents a single expression or an expression that should be instantiated for all time steps. - """ - - context: IndexingStructureProvider - - def literal(self, node: LiteralNode) -> IndexingStructure: - return IndexingStructure(False, False) - - def negation(self, node: NegationNode) -> IndexingStructure: - return visit(node.operand, self) - - def addition(self, node: AdditionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def substraction(self, node: SubstractionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def multiplication(self, node: MultiplicationNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def division(self, node: DivisionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def comparison(self, node: ComparisonNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - # def variable(self, node: VariableNode) -> IndexingStructure: - # time = self.context.get_variable_structure(node.name).time == True - # scenario = self.context.get_variable_structure(node.name).scenario == True - # return IndexingStructure(time, scenario) - - def parameter(self, node: ParameterNode) -> IndexingStructure: - time = self.context.get_parameter_structure(node.name).time == True - scenario = self.context.get_parameter_structure(node.name).scenario == True - return IndexingStructure(time, scenario) - - # def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure: - # return self.context.get_component_variable_structure( - # node.component_id, node.name - # ) - - def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure: - return self.context.get_component_parameter_structure( - node.component_id, node.name - ) - - def time_operator(self, node: TimeOperatorNode) -> IndexingStructure: - time_operator_cls = getattr(andromede.expression.time_operator, node.name) - if time_operator_cls.rolling(): - return visit(node.operand, self) - else: - return IndexingStructure(False, visit(node.operand, self).scenario) - - def time_aggregator(self, node: TimeAggregatorNode) -> IndexingStructure: - if node.stay_roll: - return visit(node.operand, self) - else: - return IndexingStructure(False, visit(node.operand, self).scenario) - - def scenario_operator(self, node: ScenarioOperatorNode) -> IndexingStructure: - return IndexingStructure(visit(node.operand, self).time, False) - - def port_field(self, node: PortFieldNode) -> IndexingStructure: - raise ValueError( - "Port fields must be resolved before computing indexing structure." - ) - - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> IndexingStructure: - raise ValueError( - "Port fields aggregators must be resolved before computing indexing structure." - ) - - -def compute_indexation( - expression: ExpressionNode, provider: IndexingStructureProvider -) -> IndexingStructure: - return visit(expression, TimeScenarioIndexingVisitor(provider)) diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 2472c909..87766bc5 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -29,14 +29,10 @@ overload, ) -import ortools.linear_solver.pywraplp as lp - from andromede.expression.context_adder import add_component_context from andromede.expression.equality import expressions_equal -from andromede.expression.evaluate import evaluate from andromede.expression.evaluate_parameters_efficient import ( check_resolved_expr, - get_time_ids_from_instances_index, resolve_coefficient, ) from andromede.expression.expression_efficient import ( @@ -61,10 +57,6 @@ from andromede.expression.indexing_structure import IndexingStructure, RowIndex from andromede.expression.port_operator import PortAggregator, PortSum from andromede.expression.print import print_expr -from andromede.expression.resolved_linear_expression import ( - ResolvedLinearExpression, - ResolvedTerm, -) from andromede.expression.scenario_operator import Expectation, ScenarioAggregator from andromede.expression.time_operator import ( TimeAggregator, @@ -92,6 +84,17 @@ class TermKeyEfficient: time_aggregator: Optional[TimeAggregator] scenario_aggregator: Optional[ScenarioAggregator] + # Used for test_expression_parsing + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, TermKeyEfficient) + and self.component_id == other.component_id + and self.variable_name == other.variable_name + and self.time_operator == other.time_operator + and self.time_aggregator == other.time_aggregator + and self.scenario_aggregator == other.scenario_aggregator + ) + @dataclass(frozen=True) class TermEfficient: @@ -1022,6 +1025,17 @@ def linear_expressions_equal( ) +def linear_expressions_equal_if_present( + lhs: Optional[LinearExpressionEfficient], rhs: Optional[LinearExpressionEfficient] +) -> bool: + if lhs is None and rhs is None: + return True + elif lhs is None or rhs is None: + return False + else: + return linear_expressions_equal(lhs, rhs) + + # TODO: Is this function useful ? Could we just rely on the sum operator overloading ? Only the case with an empty list may make the function useful def sum_expressions( expressions: Sequence[LinearExpressionEfficient], @@ -1054,6 +1068,14 @@ def __post_init__( def __str__(self) -> str: return f"{str(self.lower_bound)} <= {str(self.expression)} <= {str(self.upper_bound)}" + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, StandaloneConstraint) + and linear_expressions_equal_if_present(self.expression, other.expression) + and linear_expressions_equal_if_present(self.lower_bound, other.lower_bound) + and linear_expressions_equal_if_present(self.upper_bound, other.upper_bound) + ) + def wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient: if isinstance(obj, LinearExpressionEfficient): diff --git a/src/andromede/expression/parsing/parse_expression.py b/src/andromede/expression/parsing/parse_expression.py index b7b79704..bb20c66c 100644 --- a/src/andromede/expression/parsing/parse_expression.py +++ b/src/andromede/expression/parsing/parse_expression.py @@ -12,22 +12,22 @@ from dataclasses import dataclass from typing import Set -from antlr4 import CommonTokenStream, DiagnosticErrorListener, InputStream +from antlr4 import CommonTokenStream, InputStream from antlr4.error.ErrorStrategy import BailErrorStrategy from andromede.expression.equality import expressions_equal from andromede.expression.expression_efficient import ( Comparator, ComparisonNode, - ExpressionNodeEfficient, ExpressionRange, - PortFieldNode, + literal, param, ) from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, - literal, + port_field, var, + wrap_in_linear_expr, ) from andromede.expression.parsing.antlr.ExprLexer import ExprLexer from andromede.expression.parsing.antlr.ExprParser import ExprParser @@ -124,7 +124,7 @@ def _convert_identifier(self, identifier: str) -> LinearExpressionEfficient: def visitPortField( self, ctx: ExprParser.PortFieldContext ) -> LinearExpressionEfficient: - return PortFieldNode( + return port_field( port_name=ctx.IDENTIFIER(0).getText(), # type: ignore field_name=ctx.IDENTIFIER(1).getText(), # type: ignore ) @@ -137,11 +137,11 @@ def visitComparison( exp1 = ctx.expr(0).accept(self) # type: ignore exp2 = ctx.expr(1).accept(self) # type: ignore comp = { - "=": Comparator.EQUAL, - "<=": Comparator.LESS_THAN, - ">=": Comparator.GREATER_THAN, + "=": LinearExpressionEfficient.__eq__, + "<=": LinearExpressionEfficient.__le__, + ">=": LinearExpressionEfficient.__ge__, }[op] - return ComparisonNode(exp1, exp2, comp) + return comp(exp1, exp2) # Visit a parse tree produced by ExprParser#timeShift. def visitTimeIndex( @@ -157,7 +157,8 @@ def visitTimeRange( ) -> LinearExpressionEfficient: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore expressions = [e.accept(self) for e in ctx.expr()] # type: ignore - return shifted_expr.eval(ExpressionRange(expressions[0], expressions[1])) + # TODO: Is there a visitSum somewhere that is not needed ? Are the correct symbol parsed (sum(...) ?) ? + return shifted_expr.sum(eval=ExpressionRange(expressions[0], expressions[1])) def visitTimeShift( self, ctx: ExprParser.TimeShiftContext @@ -167,7 +168,7 @@ def visitTimeShift( # specifics for x[t] ... if len(time_shifts) == 1 and expressions_equal(time_shifts[0], literal(0)): return shifted_expr - return shifted_expr.shift(time_shifts) + return shifted_expr.sum(shift=time_shifts) def visitTimeShiftRange( self, ctx: ExprParser.TimeShiftRangeContext @@ -175,7 +176,7 @@ def visitTimeShiftRange( shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore shift1 = ctx.shift1.accept(self) # type: ignore shift2 = ctx.shift2.accept(self) # type: ignore - return shifted_expr.shift(ExpressionRange(shift1, shift2)) + return shifted_expr.sum(shift=ExpressionRange(shift1, shift2)) # Visit a parse tree produced by ExprParser#function. def visitFunction( @@ -266,9 +267,9 @@ def visitRightAtom( _FUNCTIONS = { - "sum": ExpressionNodeEfficient.sum, - "sum_connections": ExpressionNodeEfficient.sum_connections, - "expec": ExpressionNodeEfficient.expec, + "sum": LinearExpressionEfficient.sum, + "sum_connections": LinearExpressionEfficient.sum_connections, + "expec": LinearExpressionEfficient.expec, } diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index 6a856e05..e151825c 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -19,16 +19,6 @@ from dataclasses import dataclass, field from typing import Dict, Iterable, Optional -# from andromede.expression.expression import ( -# BinaryOperatorNode, -# ComponentParameterNode, -# ComponentVariableNode, -# PortFieldAggregatorNode, -# PortFieldNode, -# ScenarioOperatorNode, -# TimeAggregatorNode, -# TimeOperatorNode, -# ) from andromede.expression.expression_efficient import ( AdditionNode, BinaryOperatorNode, @@ -46,7 +36,7 @@ TimeAggregatorNode, TimeOperatorNode, ) -from andromede.expression.indexing import IndexingStructureProvider, compute_indexation +from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, @@ -61,38 +51,6 @@ from andromede.model.port import PortType from andromede.model.variable import Variable -# from andromede.expression import ( -# AdditionNode, -# ComparisonNode, -# DivisionNode, -# ExpressionNode, -# ExpressionVisitor, -# LiteralNode, -# MultiplicationNode, -# NegationNode, -# ParameterNode, -# SubstractionNode, -# VariableNode, -# ) -# from andromede.expression.expression_efficient import ( -# AdditionNode, -# BinaryOperatorNode, -# ComparisonNode, -# ComponentParameterNode, -# DivisionNode, -# ExpressionNodeEfficient, -# LiteralNode, -# MultiplicationNode, -# NegationNode, -# ParameterNode, -# PortFieldAggregatorNode, -# PortFieldNode, -# ScenarioOperatorNode, -# SubstractionNode, -# TimeAggregatorNode, -# TimeOperatorNode, -# ) - # TODO: Introduce bool_variable ? def _make_structure_provider(model: "Model") -> IndexingStructureProvider: diff --git a/src/andromede/model/resolve_library.py b/src/andromede/model/resolve_library.py index ef117283..5cf2cc89 100644 --- a/src/andromede/model/resolve_library.py +++ b/src/andromede/model/resolve_library.py @@ -13,7 +13,10 @@ # from andromede.expression import ExpressionNode from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import LinearExpressionEfficient +from andromede.expression.linear_expression_efficient import ( + LinearExpressionEfficient, + wrap_in_linear_expr_if_present, +) from andromede.expression.parsing.parse_expression import ( ModelIdentifiers, parse_expression, @@ -137,8 +140,12 @@ def _to_variable(var: InputVariable, identifiers: ModelIdentifiers) -> Variable: var.variable_type ], structure=IndexingStructure(var.time_dependent, var.scenario_dependent), - lower_bound=_to_expression_if_present(var.lower_bound, identifiers), - upper_bound=_to_expression_if_present(var.upper_bound, identifiers), + lower_bound=wrap_in_linear_expr_if_present( + _to_expression_if_present(var.lower_bound, identifiers) + ), + upper_bound=wrap_in_linear_expr_if_present( + _to_expression_if_present(var.upper_bound, identifiers) + ), context=ProblemContext.OPERATIONAL, ) @@ -146,9 +153,18 @@ def _to_variable(var: InputVariable, identifiers: ModelIdentifiers) -> Variable: def _to_constraint( constraint: InputConstraint, identifiers: ModelIdentifiers ) -> Constraint: - return Constraint( - name=constraint.name, - expression=parse_expression(constraint.expression, identifiers), - lower_bound=_to_expression_if_present(constraint.lower_bound, identifiers), - upper_bound=_to_expression_if_present(constraint.upper_bound, identifiers), + kwargs = { + "name": constraint.name, + "expression": parse_expression(constraint.expression, identifiers), + } + lb = wrap_in_linear_expr_if_present( + _to_expression_if_present(constraint.lower_bound, identifiers) + ) + ub = wrap_in_linear_expr_if_present( + _to_expression_if_present(constraint.upper_bound, identifiers) ) + if lb is not None: + kwargs["lower_bound"] = lb + if ub is not None: + kwargs["upper_bound"] = ub + return Constraint(**kwargs) diff --git a/src/andromede/model/variable.py b/src/andromede/model/variable.py index f880b902..e418d8a3 100644 --- a/src/andromede/model/variable.py +++ b/src/andromede/model/variable.py @@ -13,11 +13,11 @@ from dataclasses import dataclass from typing import Any, Optional -from andromede.expression.equality import expressions_equal_if_present from andromede.expression.expression_efficient import literal from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, + linear_expressions_equal_if_present, wrap_in_linear_expr_if_present, ) from andromede.model.common import ( @@ -52,8 +52,8 @@ def __eq__(self, other: Any) -> bool: return ( self.name == other.name and self.data_type == other.data_type - and expressions_equal_if_present(self.lower_bound, other.lower_bound) - and expressions_equal_if_present(self.upper_bound, other.upper_bound) + and linear_expressions_equal_if_present(self.lower_bound, other.lower_bound) + and linear_expressions_equal_if_present(self.upper_bound, other.upper_bound) and self.structure == other.structure ) diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 7538b9e7..58e505d1 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -17,7 +17,6 @@ import math from dataclasses import dataclass -from typing import List, Optional import ortools.linear_solver.pywraplp as lp @@ -28,19 +27,16 @@ RowIndex, ) from andromede.expression.resolved_linear_expression import ResolvedLinearExpression -from andromede.expression.scenario_operator import Expectation -from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum from andromede.model.common import ValueType from andromede.model.constraint import Constraint from andromede.model.model import PortFieldId -from andromede.simulation.linear_expression import Term from andromede.simulation.linear_expression_resolver import LinearExpressionResolver from andromede.simulation.optimization_context import ( BlockBorderManagement, ComponentContext, OptimizationContext, - _make_data_structure_provider, - _make_value_provider, + make_data_structure_provider, + make_value_provider, ) from andromede.simulation.strategy import MergedProblemStrategy, ModelSelectionStrategy from andromede.simulation.time_block import TimeBlock @@ -61,7 +57,7 @@ def _get_indexing( def _compute_indexing_structure( context: ComponentContext, constraint: Constraint ) -> IndexingStructure: - data_structure_provider = _make_data_structure_provider( + data_structure_provider = make_data_structure_provider( context.opt_context.network, context.component ) constraint_indexing = _get_indexing(constraint, data_structure_provider) @@ -101,7 +97,7 @@ def _create_constraint( # instances_per_time_step = linear_expr.number_of_instances() # instances_per_time_step = 1 - value_provider = _make_value_provider(context.opt_context, context.component) + value_provider = make_value_provider(context.opt_context, context.component) expression_resolver = LinearExpressionResolver(context.opt_context, value_provider) for block_timestep in context.opt_context.get_time_indices(constraint_indexing): @@ -140,7 +136,7 @@ def _create_objective( ) # We have already checked in the model creation that the objective contribution is neither indexed by time nor by scenario - value_provider = _make_value_provider(opt_context, component) + value_provider = make_value_provider(opt_context, component) expression_resolver = LinearExpressionResolver(opt_context, value_provider) resolved_expr = expression_resolver.resolve(instantiated_expr, RowIndex(0, 0)) @@ -255,7 +251,7 @@ def _create_variables(self) -> None: component_context = self.context.get_component_context(component) model = component.model - value_provider = _make_value_provider(self.context, component) + value_provider = make_value_provider(self.context, component) expression_resolver = LinearExpressionResolver(self.context, value_provider) for model_var in self.strategy.get_variables(model): diff --git a/src/andromede/simulation/optimization_context.py b/src/andromede/simulation/optimization_context.py index 03b11ccc..85be98be 100644 --- a/src/andromede/simulation/optimization_context.py +++ b/src/andromede/simulation/optimization_context.py @@ -17,7 +17,6 @@ import ortools.linear_solver.pywraplp as lp -from andromede.expression import ParameterValueProvider, resolve_parameters from andromede.expression.evaluate_parameters_efficient import ValueProvider from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure @@ -27,8 +26,6 @@ PortFieldKey, ) from andromede.expression.value_provider import TimeScenarioIndex, TimeScenarioIndices -from andromede.simulation.linear_expression import LinearExpression -from andromede.simulation.linearize import linearize_expression from andromede.simulation.time_block import TimeBlock from andromede.study.data import DataBase from andromede.study.network import Component, Network @@ -218,7 +215,7 @@ def _get_parameter_value( return data.get_value(absolute_timestep, scenario) -def _make_value_provider( +def make_value_provider( context: "OptimizationContext", component: Component, ) -> ValueProvider: @@ -306,39 +303,13 @@ class ExpressionTimestepValueProvider(TimestepValueProvider): # OptimizationContext has knowledge of the block, so that get_value only needs block_timestep and scenario to get the correct data value def get_value(self, block_timestep: int, scenario: int) -> float: - param_value_provider = _make_value_provider( + param_value_provider = make_value_provider( self.context, block_timestep, scenario, self.component ) return self.expression.evaluate(param_value_provider) -def _make_parameter_value_provider( - context: "OptimizationContext", - block_timestep: int, - scenario: int, -) -> ParameterValueProvider: - """ - A value provider which takes its values from - the parameter values as defined in the network data. - - Cannot evaluate expressions which contain variables. - """ - - class Provider(ParameterValueProvider): - def get_component_parameter_value(self, component_id: str, name: str) -> float: - return _get_parameter_value( - context, block_timestep, scenario, component_id, name - ) - - def get_parameter_value(self, name: str) -> float: - raise ValueError( - "Parameters should have been associated with their component before resolution." - ) - - return Provider() - - -def _make_data_structure_provider( +def make_data_structure_provider( network: Network, component: Component ) -> IndexingStructureProvider: """ @@ -406,26 +377,6 @@ def get_variable( self.component.model.variables[variable_name].structure, ) - def linearize_expression( - self, - block_timestep: int, - scenario: int, - expression: LinearExpressionEfficient, - ) -> LinearExpression: - parameters_valued_provider = _make_parameter_value_provider( - self.opt_context, block_timestep, scenario - ) - evaluated_expr = resolve_parameters(expression, parameters_valued_provider) - - value_provider = _make_value_provider( - self.opt_context, block_timestep, scenario, self.component - ) - structure_provider = _make_data_structure_provider( - self.opt_context.network, self.component - ) - - return linearize_expression(evaluated_expr, structure_provider, value_provider) - def _get_data_time_key(block_timestep: int, data_indexing: IndexingStructure) -> int: return block_timestep if data_indexing.time else 0 diff --git a/tests/functional/test_performance.py b/tests/functional/test_performance.py deleted file mode 100644 index 1c50af1c..00000000 --- a/tests/functional/test_performance.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -from typing import cast - -import pytest - -from andromede.expression.expression import ExpressionNode, literal, param, var -from andromede.expression.indexing_structure import IndexingStructure -from andromede.libs.standard import ( - BALANCE_PORT_TYPE, - DEMAND_MODEL, - GENERATOR_MODEL, - GENERATOR_MODEL_WITH_STORAGE, - NODE_BALANCE_MODEL, -) -from andromede.model import float_parameter, float_variable, model -from andromede.simulation import TimeBlock, build_problem -from andromede.study import ( - ConstantData, - DataBase, - Network, - Node, - PortRef, - create_component, -) -from tests.unittests.test_utils import generate_scalar_matrix_data - - -def test_large_sum_inside_model_with_loop() -> None: - """ - Test performance when the problem involves an expression with a high number of terms. - Here the objective function is the sum over nb_terms terms on a for-loop inside the model - - This test pass with 476 terms but fails with 477 locally due to recursion depth, - and even less terms are possible with Jenkins... - """ - nb_terms = 500 - - time_blocks = [TimeBlock(0, [0])] - scenarios = 1 - database = DataBase() - - for i in range(1, nb_terms): - database.add_data("simple_cost", f"cost_{i}", ConstantData(1 / i)) - - with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): - SIMPLE_COST_MODEL = model( - id="SIMPLE_COST", - parameters=[ - float_parameter(f"cost_{i}", IndexingStructure(False, False)) - for i in range(1, nb_terms) - ], - objective_operational_contribution=cast( - ExpressionNode, sum(param(f"cost_{i}") for i in range(1, nb_terms)) - ), - ) - - # Won't run because last statement will raise the error - network = Network("test") - cost_model = create_component(model=SIMPLE_COST_MODEL, id="simple_cost") - network.add_component(cost_model) - - problem = build_problem(network, database, time_blocks[0], scenarios) - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == sum( - [1 / i for i in range(1, nb_terms)] - ) - - -def test_large_sum_outside_model_with_loop() -> None: - """ - Test performance when the problem involves an expression with a high number of terms. - Here the objective function is the sum over nb_terms terms on a for-loop outside the model - """ - nb_terms = 10_000 - - time_blocks = [TimeBlock(0, [0])] - scenarios = 1 - database = DataBase() - - obj_coeff = sum([1 / i for i in range(1, nb_terms)]) - - SIMPLE_COST_MODEL = model( - id="SIMPLE_COST", - parameters=[], - objective_operational_contribution=literal(obj_coeff), - ) - - network = Network("test") - - simple_model = create_component( - model=SIMPLE_COST_MODEL, - id="simple_cost", - ) - network.add_component(simple_model) - - problem = build_problem(network, database, time_blocks[0], scenarios) - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == obj_coeff - - -def test_large_sum_inside_model_with_sum_operator() -> None: - """ - Test performance when the problem involves an expression with a high number of terms. - Here the objective function is the sum over nb_terms terms with the sum() operator inside the model - """ - nb_terms = 10_000 - - scenarios = 1 - time_blocks = [TimeBlock(0, list(range(nb_terms)))] - database = DataBase() - - # Weird values when the "cost" varies over time and we use the sum() operator: - # For testing purposes, will use a const value since the problem seems to come when - # we try to linearize nb_terms variables with nb_terms distinct parameters - # TODO check the sum() operator for time-variable parameters - database.add_data("simple_cost", "cost", ConstantData(3)) - - SIMPLE_COST_MODEL = model( - id="SIMPLE_COST", - parameters=[ - float_parameter("cost", IndexingStructure(False, False)), - ], - variables=[ - float_variable( - "var", - lower_bound=literal(1), - upper_bound=literal(2), - structure=IndexingStructure(True, False), - ), - ], - objective_operational_contribution=(param("cost") * var("var")).sum(), - ) - - network = Network("test") - - cost_model = create_component(model=SIMPLE_COST_MODEL, id="simple_cost") - network.add_component(cost_model) - - problem = build_problem(network, database, time_blocks[0], scenarios) - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == 3 * nb_terms - - -def test_large_sum_of_port_connections() -> None: - """ - Test performance when the problem involves a model where several generators are connected to a node. - - This test pass with 470 terms but fails with 471 locally due to recursion depth, - and possibly even less terms are possible with Jenkins... - """ - nb_generators = 500 - - time_block = TimeBlock(0, [0]) - scenarios = 1 - - database = DataBase() - database.add_data("D", "demand", ConstantData(nb_generators)) - - for gen_id in range(nb_generators): - database.add_data(f"G_{gen_id}", "p_max", ConstantData(1)) - database.add_data(f"G_{gen_id}", "cost", ConstantData(5)) - - node = Node(model=NODE_BALANCE_MODEL, id="N") - demand = create_component(model=DEMAND_MODEL, id="D") - generators = [ - create_component(model=GENERATOR_MODEL, id=f"G_{gen_id}") - for gen_id in range(nb_generators) - ] - - network = Network("test") - network.add_node(node) - - network.add_component(demand) - network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) - - for gen_id in range(nb_generators): - network.add_component(generators[gen_id]) - network.connect( - PortRef(generators[gen_id], "balance_port"), PortRef(node, "balance_port") - ) - - with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): - problem = build_problem(network, database, time_block, scenarios) - - # Won't run because last statement will raise the error - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == 5 * nb_generators - - -def test_basic_balance_on_whole_year() -> None: - """ - Balance on one node with one fixed demand and one generation, on 8760 timestep. - """ - - scenarios = 1 - horizon = 8760 - time_block = TimeBlock(1, list(range(horizon))) - - database = DataBase() - database.add_data( - "D", "demand", generate_scalar_matrix_data(100, horizon, scenarios) - ) - - database.add_data("G", "p_max", ConstantData(100)) - database.add_data("G", "cost", ConstantData(30)) - - node = Node(model=NODE_BALANCE_MODEL, id="N") - demand = create_component(model=DEMAND_MODEL, id="D") - - gen = create_component(model=GENERATOR_MODEL, id="G") - - network = Network("test") - network.add_node(node) - network.add_component(demand) - network.add_component(gen) - network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) - network.connect(PortRef(gen, "balance_port"), PortRef(node, "balance_port")) - - problem = build_problem(network, database, time_block, scenarios) - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == 30 * 100 * horizon - - -def test_basic_balance_on_whole_year_with_large_sum() -> None: - """ - Balance on one node with one fixed demand and one generation with storage, on 8760 timestep. - """ - - scenarios = 1 - horizon = 8760 - time_block = TimeBlock(1, list(range(horizon))) - - database = DataBase() - database.add_data( - "D", "demand", generate_scalar_matrix_data(100, horizon, scenarios) - ) - - database.add_data("G", "p_max", ConstantData(100)) - database.add_data("G", "cost", ConstantData(30)) - database.add_data("G", "full_storage", ConstantData(100 * horizon)) - - node = Node(model=NODE_BALANCE_MODEL, id="N") - demand = create_component(model=DEMAND_MODEL, id="D") - gen = create_component( - model=GENERATOR_MODEL_WITH_STORAGE, id="G" - ) # Limits the total generation inside a TimeBlock - - network = Network("test") - network.add_node(node) - network.add_component(demand) - network.add_component(gen) - network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) - network.connect(PortRef(gen, "balance_port"), PortRef(node, "balance_port")) - - problem = build_problem(network, database, time_block, scenarios) - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == 30 * 100 * horizon diff --git a/tests/functional/test_performance_efficient.py b/tests/functional/test_performance_efficient.py index 55ef9913..ed00bd02 100644 --- a/tests/functional/test_performance_efficient.py +++ b/tests/functional/test_performance_efficient.py @@ -21,6 +21,17 @@ var, wrap_in_linear_expr, ) +from andromede.libs.standard import ( + DEMAND_MODEL, + GENERATOR_MODEL, + GENERATOR_MODEL_WITH_STORAGE, + NODE_BALANCE_MODEL, +) +from andromede.simulation.optimization import build_problem +from andromede.simulation.time_block import TimeBlock +from andromede.study.data import ConstantData, DataBase +from andromede.study.network import Network, Node, PortRef, create_component +from tests.unittests.test_utils import generate_scalar_matrix_data def test_large_number_of_parameters_sum() -> None: @@ -84,3 +95,124 @@ def test_large_number_of_variables_sum() -> None: assert expr.evaluate( EvaluationContext(variables=variables_value), RowIndex(0, 0) ) == sum(1 / i for i in range(1, nb_terms)) + + +def test_large_sum_of_port_connections() -> None: + """ + Test performance when the problem involves a model where several generators are connected to a node. + + This test pass with 470 terms but fails with 471 locally due to recursion depth, + and possibly even less terms are possible with Jenkins... + """ + nb_generators = 500 + + time_block = TimeBlock(0, [0]) + scenarios = 1 + + database = DataBase() + database.add_data("D", "demand", ConstantData(nb_generators)) + + for gen_id in range(nb_generators): + database.add_data(f"G_{gen_id}", "p_max", ConstantData(1)) + database.add_data(f"G_{gen_id}", "cost", ConstantData(5)) + + node = Node(model=NODE_BALANCE_MODEL, id="N") + demand = create_component(model=DEMAND_MODEL, id="D") + generators = [ + create_component(model=GENERATOR_MODEL, id=f"G_{gen_id}") + for gen_id in range(nb_generators) + ] + + network = Network("test") + network.add_node(node) + + network.add_component(demand) + network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) + + for gen_id in range(nb_generators): + network.add_component(generators[gen_id]) + network.connect( + PortRef(generators[gen_id], "balance_port"), PortRef(node, "balance_port") + ) + + # Raised recursion error with previous implementation + problem = build_problem(network, database, time_block, scenarios) + + status = problem.solver.Solve() + + assert status == problem.solver.OPTIMAL + assert problem.solver.Objective().Value() == 5 * nb_generators + + +def test_basic_balance_on_whole_year() -> None: + """ + Balance on one node with one fixed demand and one generation, on 8760 timestep. + """ + + scenarios = 1 + horizon = 8760 + time_block = TimeBlock(1, list(range(horizon))) + + database = DataBase() + database.add_data( + "D", "demand", generate_scalar_matrix_data(100, horizon, scenarios) + ) + + database.add_data("G", "p_max", ConstantData(100)) + database.add_data("G", "cost", ConstantData(30)) + + node = Node(model=NODE_BALANCE_MODEL, id="N") + demand = create_component(model=DEMAND_MODEL, id="D") + + gen = create_component(model=GENERATOR_MODEL, id="G") + + network = Network("test") + network.add_node(node) + network.add_component(demand) + network.add_component(gen) + network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) + network.connect(PortRef(gen, "balance_port"), PortRef(node, "balance_port")) + + problem = build_problem(network, database, time_block, scenarios) + status = problem.solver.Solve() + + assert status == problem.solver.OPTIMAL + assert problem.solver.Objective().Value() == 30 * 100 * horizon + + +def test_basic_balance_on_whole_year_with_large_sum() -> None: + """ + Balance on one node with one fixed demand and one generation with storage, on 8760 timestep. + """ + + scenarios = 1 + horizon = 8760 + time_block = TimeBlock(1, list(range(horizon))) + + database = DataBase() + database.add_data( + "D", "demand", generate_scalar_matrix_data(100, horizon, scenarios) + ) + + database.add_data("G", "p_max", ConstantData(100)) + database.add_data("G", "cost", ConstantData(30)) + database.add_data("G", "full_storage", ConstantData(100 * horizon)) + + node = Node(model=NODE_BALANCE_MODEL, id="N") + demand = create_component(model=DEMAND_MODEL, id="D") + gen = create_component( + model=GENERATOR_MODEL_WITH_STORAGE, id="G" + ) # Limits the total generation inside a TimeBlock + + network = Network("test") + network.add_node(node) + network.add_component(demand) + network.add_component(gen) + network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) + network.connect(PortRef(gen, "balance_port"), PortRef(node, "balance_port")) + + problem = build_problem(network, database, time_block, scenarios) + status = problem.solver.Solve() + + assert status == problem.solver.OPTIMAL + assert problem.solver.Objective().Value() == 30 * 100 * horizon diff --git a/tests/unittests/expressions/parsing/test_expression_parsing.py b/tests/unittests/expressions/parsing/test_expression_parsing.py index 79922415..cec5b363 100644 --- a/tests/unittests/expressions/parsing/test_expression_parsing.py +++ b/tests/unittests/expressions/parsing/test_expression_parsing.py @@ -9,14 +9,21 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. -from typing import Set +from typing import Set, Union import pytest from andromede.expression.equality import expressions_equal -from andromede.expression.expression_efficient import ExpressionRange, literal, param +from andromede.expression.expression_efficient import ( + ExpressionNodeEfficient, + ExpressionRange, + literal, + param, +) from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, + StandaloneConstraint, + linear_expressions_equal, port_field, var, ) @@ -59,14 +66,14 @@ ( {"x"}, {}, - "x[t-1, t+4]", + "x[t-1, t+4]", # TODO: Should raise ValueError: shift always with sum var("x").sum(shift=[-literal(1), literal(4)]), ), ( {"x"}, {}, "x[t-1+1]", - var("x").sum(shift=-literal(1) + literal(1)), + var("x"), # Simplifications are applied very early in parsing !!!! ), ( {"x"}, @@ -95,25 +102,25 @@ ( {"x"}, {}, - "x[t-1, t, t+4]", + "x[t-1, t, t+4]", # TODO: Should raise ValueError: shift always with sum var("x").sum(shift=[-literal(1), literal(0), literal(4)]), ), ( {"x"}, {}, - "x[t-1..t+5]", + "x[t-1..t+5]", # TODO: Should raise ValueError: shift always with sum var("x").sum(shift=ExpressionRange(-literal(1), literal(5))), ), ( {"x"}, {}, - "x[t-1..t]", + "x[t-1..t]", # TODO: Should raise ValueError: shift always with sum var("x").sum(shift=ExpressionRange(-literal(1), literal(0))), ), ( {"x"}, {}, - "x[t..t+5]", + "x[t..t+5]", # TODO: Should raise ValueError: shift always with sum var("x").sum(shift=ExpressionRange(literal(0), literal(5))), ), ({"x"}, {}, "x[t]", var("x")), @@ -122,7 +129,7 @@ {"x"}, {}, "sum(x[-1..5])", - var("x").sum(eval=ExpressionRange(-literal(1), literal(5))).sum(), + var("x").sum(eval=ExpressionRange(-literal(1), literal(5))), ), ({}, {}, "sum_connections(port.f)", port_field("port", "f").sum_connections()), ( @@ -156,13 +163,21 @@ def test_parsing_visitor( variables: Set[str], parameters: Set[str], expression_str: str, - expected: LinearExpressionEfficient, -): + expected: Union[ + ExpressionNodeEfficient, LinearExpressionEfficient, StandaloneConstraint + ], +) -> None: identifiers = ModelIdentifiers(variables, parameters) expr = parse_expression(expression_str, identifiers) print() - print(print_expr(expr)) - assert expressions_equal(expr, expected) + print(f"Expected: \n {str(expected)}") + print(f"Parsed: \n {str(expr)}") + if isinstance(expected, ExpressionNodeEfficient): + assert expressions_equal(expr, expected) + elif isinstance(expected, LinearExpressionEfficient): + assert linear_expressions_equal(expr, expected) + elif isinstance(expected, StandaloneConstraint): + assert expected == expr @pytest.mark.parametrize( diff --git a/tests/unittests/expressions/test_expressions.py b/tests/unittests/expressions/test_expressions.py deleted file mode 100644 index 9c415c18..00000000 --- a/tests/unittests/expressions/test_expressions.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -from dataclasses import dataclass, field -from typing import Dict - -import pandas as pd -import pytest - -from andromede.expression import ( - AdditionNode, - DivisionNode, - EvaluationContext, - EvaluationVisitor, - ExpressionDegreeVisitor, - ExpressionNode, - LiteralNode, - ParameterNode, - ParameterValueProvider, - PrinterVisitor, - ValueProvider, - VariableNode, - literal, - param, - resolve_parameters, - sum_expressions, - var, - visit, -) -from andromede.expression.equality import expressions_equal -from andromede.expression.expression import ( - ComponentParameterNode, - ComponentVariableNode, - ExpressionRange, - Instances, - comp_param, - comp_var, - port_field, -) -from andromede.expression.indexing import IndexingStructureProvider, compute_indexation -from andromede.expression.indexing_structure import IndexingStructure -from andromede.model.model import PortFieldId -from andromede.simulation.linear_expression import LinearExpression, Term -from andromede.simulation.linearize import linearize_expression - - -@dataclass(frozen=True) -class ComponentValueKey: - component_id: str - variable_name: str - - -def comp_key(component_id: str, variable_name: str) -> ComponentValueKey: - return ComponentValueKey(component_id, variable_name) - - -@dataclass(frozen=True) -class ComponentEvaluationContext(ValueProvider): - """ - Simple value provider relying on dictionaries. - Does not support component variables/parameters. - """ - - variables: Dict[ComponentValueKey, float] = field(default_factory=dict) - parameters: Dict[ComponentValueKey, float] = field(default_factory=dict) - - def get_variable_value(self, name: str) -> float: - raise NotImplementedError() - - def get_parameter_value(self, name: str) -> float: - raise NotImplementedError() - - def get_component_variable_value(self, component_id: str, name: str) -> float: - return self.variables[comp_key(component_id, name)] - - def get_component_parameter_value(self, component_id: str, name: str) -> float: - return self.parameters[comp_key(component_id, name)] - - def parameter_is_constant_over_time(self, name: str) -> bool: - raise NotImplementedError() - - -def test_comp_parameter() -> None: - add_node = AdditionNode(LiteralNode(1), ComponentVariableNode("comp1", "x")) - expr = DivisionNode(add_node, ComponentParameterNode("comp1", "p")) - - assert visit(expr, PrinterVisitor()) == "((1 + comp1.x) / comp1.p)" - - context = ComponentEvaluationContext( - variables={comp_key("comp1", "x"): 3}, parameters={comp_key("comp1", "p"): 4} - ) - assert visit(expr, EvaluationVisitor(context)) == 1 - - -def test_ast() -> None: - add_node = AdditionNode(LiteralNode(1), VariableNode("x")) - expr = DivisionNode(add_node, ParameterNode("p")) - - assert visit(expr, PrinterVisitor()) == "((1 + x) / p)" - - context = EvaluationContext(variables={"x": 3}, parameters={"p": 4}) - assert visit(expr, EvaluationVisitor(context)) == 1 - - -def test_operators() -> None: - x = var("x") - p = param("p") - expr: ExpressionNode = (5 * x + 3) / p - 2 - - assert visit(expr, PrinterVisitor()) == "((((5.0 * x) + 3.0) / p) - 2.0)" - - context = EvaluationContext(variables={"x": 3}, parameters={"p": 4}) - assert visit(expr, EvaluationVisitor(context)) == pytest.approx(2.5, 1e-16) - - assert visit(-expr, EvaluationVisitor(context)) == pytest.approx(-2.5, 1e-16) - - -def test_degree() -> None: - x = var("x") - p = param("p") - expr = (5 * x + 3) / p - - assert visit(expr, ExpressionDegreeVisitor()) == 1 - - expr = x * expr - assert visit(expr, ExpressionDegreeVisitor()) == 2 - - -@pytest.mark.xfail(reason="Degree simplification not implemented") -def test_degree_computation_should_take_into_account_simplifications() -> None: - x = var("x") - expr = x - x - assert visit(expr, ExpressionDegreeVisitor()) == 0 - - expr = LiteralNode(0) * x - assert visit(expr, ExpressionDegreeVisitor()) == 0 - - -def test_parameters_resolution() -> None: - class TestParamProvider(ParameterValueProvider): - def get_component_parameter_value(self, component_id: str, name: str) -> float: - raise NotImplementedError() - - def get_parameter_value(self, name: str) -> float: - return 2 - - x = var("x") - p = param("p") - expr = (5 * x + 3) / p - assert resolve_parameters(expr, TestParamProvider()) == (5 * x + 3) / 2 - - -def test_linearization() -> None: - x = comp_var("c", "x") - expr = (5 * x + 3) / 2 - provider = StructureProvider() - - assert linearize_expression(expr, provider) == LinearExpression( - [Term(2.5, "c", "x")], 1.5 - ) - - with pytest.raises(ValueError): - linearize_expression(param("p") * x, provider) - - -def test_linearization_of_non_linear_expressions_should_raise_value_error() -> None: - x = var("x") - expr = x.variance() - - provider = StructureProvider() - with pytest.raises(ValueError) as exc: - linearize_expression(expr, provider) - assert ( - str(exc.value) - == "Cannot linearize expression with a non-linear operator: Variance" - ) - - -def test_comparison() -> None: - x = var("x") - p = param("p") - expr: ExpressionNode = (5 * x + 3) >= p - 2 - - assert visit(expr, PrinterVisitor()) == "((5.0 * x) + 3.0) >= (p - 2.0)" - - -class StructureProvider(IndexingStructureProvider): - def get_component_variable_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - return IndexingStructure(True, True) - - def get_component_parameter_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - return IndexingStructure(True, True) - - def get_parameter_structure(self, name: str) -> IndexingStructure: - return IndexingStructure(True, True) - - def get_variable_structure(self, name: str) -> IndexingStructure: - return IndexingStructure(True, True) - - -def test_shift() -> None: - x = var("x") - expr = x.shift(ExpressionRange(literal(1), literal(4))) - - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(True, True) - assert expr.instances == Instances.MULTIPLE - - -def test_shifting_sum() -> None: - x = var("x") - expr = x.shift(ExpressionRange(literal(1), literal(4))).sum() - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(True, True) - assert expr.instances == Instances.SIMPLE - - -def test_eval() -> None: - x = var("x") - expr = x.eval(ExpressionRange(literal(1), literal(4))) - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(False, True) - assert expr.instances == Instances.MULTIPLE - - -def test_eval_sum() -> None: - x = var("x") - expr = x.eval(ExpressionRange(literal(1), literal(4))).sum() - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(False, True) - assert expr.instances == Instances.SIMPLE - - -def test_sum_over_whole_block() -> None: - x = var("x") - expr = x.sum() - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(False, True) - assert expr.instances == Instances.SIMPLE - - -def test_forbidden_composition_should_raise_value_error() -> None: - x = var("x") - with pytest.raises(ValueError): - _ = x.shift(ExpressionRange(literal(1), literal(4))) + var("y") - - -def test_expectation() -> None: - x = var("x") - expr = x.expec() - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(True, False) - assert expr.instances == Instances.SIMPLE - - -def test_indexing_structure_comparison() -> None: - free = IndexingStructure(True, True) - constant = IndexingStructure(False, False) - assert free | constant == IndexingStructure(True, True) - - -def test_multiplication_of_differently_indexed_terms() -> None: - x = var("x") - p = param("p") - expr = p * x - - class CustomStructureProvider(IndexingStructureProvider): - def get_component_variable_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - raise NotImplementedError() - - def get_component_parameter_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - raise NotImplementedError() - - def get_parameter_structure(self, name: str) -> IndexingStructure: - return IndexingStructure(False, False) - - def get_variable_structure(self, name: str) -> IndexingStructure: - return IndexingStructure(True, True) - - provider = CustomStructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(True, True) - - -def test_sum_expressions() -> None: - assert expressions_equal(sum_expressions([]), literal(0)) - assert expressions_equal(sum_expressions([literal(1)]), literal(1)) - assert expressions_equal(sum_expressions([literal(1), var("x")]), 1 + var("x")) - assert expressions_equal( - sum_expressions([literal(1), var("x"), param("p")]), 1 + (var("x") + param("p")) - ) From 1911e296da1fd5eb093aeba1493310ba8a0ce8ca Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 21 Aug 2024 19:22:44 +0200 Subject: [PATCH 34/51] Fix interaction resolve ports / add component context --- src/andromede/expression/context_adder.py | 13 +------------ .../expression/linear_expression_efficient.py | 11 ++++------- src/andromede/simulation/optimization.py | 5 +++-- 3 files changed, 8 insertions(+), 21 deletions(-) diff --git a/src/andromede/expression/context_adder.py b/src/andromede/expression/context_adder.py index 397197da..a5fe4d52 100644 --- a/src/andromede/expression/context_adder.py +++ b/src/andromede/expression/context_adder.py @@ -30,21 +30,10 @@ class ContextAdder(CopyVisitor): component_id: str - # def variable(self, node: VariableNode) -> ExpressionNodeEfficient: - # return ComponentVariableNode(self.component_id, node.name) - def parameter(self, node: ParameterNode) -> ExpressionNodeEfficient: return ComponentParameterNode(self.component_id, node.name) - # def comp_variable(self, node: ComponentVariableNode) -> ExpressionNodeEfficient: - # raise ValueError( - # "This expression has already been associated to another component." - # ) - - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNodeEfficient: - raise ValueError( - "This expression has already been associated to another component." - ) + # Nothing is done is a component parameter node is encountered as it may have been generated from port resolution def add_component_context( diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 87766bc5..af1f76a3 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -953,19 +953,16 @@ def resolve_port( port_expr += sum_expressions( [port_term.coefficient * expression for expression in expressions] ) - return self + port_expr + self_without_ports = LinearExpressionEfficient(self.terms, self.constant) + return self_without_ports + port_expr def add_component_context(self, component_id: str) -> "LinearExpressionEfficient": result_terms = {} for term in self.terms.values(): - if term.component_id: - raise ValueError( - "This expression has already been associated to another component." - ) - + # Some terms may involve variable from other component if they arise from previous port resolution result_term = dataclasses.replace( term, - component_id=component_id, + component_id=term.component_id if term.component_id else component_id, coefficient=add_component_context(component_id, term.coefficient), time_operator=( dataclasses.replace( diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 58e505d1..7037af2e 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -74,10 +74,11 @@ def _instantiate_model_expression( 1. add component ID for variables and parameters of THIS component 2. replace port fields by their definition """ - with_component = model_expression.add_component_context(component_id) - with_component_and_ports = with_component.resolve_port( + # We need to resolve ports before adding component context as binding constraints with ports may involve parameters from the current component + with_ports = model_expression.resolve_port( component_id, optimization_context.connection_fields_expressions ) + with_component_and_ports = with_ports.add_component_context(component_id) return with_component_and_ports From a904d9d73ef373716c546c9ad8f74a36832f9f8f Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 21 Aug 2024 19:37:07 +0200 Subject: [PATCH 35/51] Uniformize imports --- src/andromede/expression/degree.py | 4 +- src/andromede/expression/evaluate.py | 9 +- .../evaluate_parameters_efficient.py | 11 +- src/andromede/expression/indexing.py | 2 +- .../expression/linear_expression_efficient.py | 22 +-- src/andromede/expression/port_resolver.py | 75 --------- src/andromede/expression/visitor.py | 21 +-- .../simulation/benders_decomposed.py | 22 ++- src/andromede/simulation/linear_expression.py | 1 - .../simulation/linear_expression_resolver.py | 4 +- src/andromede/simulation/linearize.py | 156 ------------------ src/andromede/simulation/optimization.py | 10 +- .../simulation/optimization_context.py | 2 +- src/andromede/simulation/output_values.py | 2 +- .../resolved_linear_expression.py | 0 src/andromede/simulation/strategy.py | 1 - src/andromede/study/data.py | 2 +- src/andromede/study/resolve_components.py | 22 +-- 18 files changed, 45 insertions(+), 321 deletions(-) delete mode 100644 src/andromede/expression/port_resolver.py delete mode 100644 src/andromede/simulation/linearize.py rename src/andromede/{expression => simulation}/resolved_linear_expression.py (100%) diff --git a/src/andromede/expression/degree.py b/src/andromede/expression/degree.py index 572a58b6..3a5119ac 100644 --- a/src/andromede/expression/degree.py +++ b/src/andromede/expression/degree.py @@ -11,7 +11,8 @@ # This file is part of the Antares project. import andromede.expression.scenario_operator -from andromede.expression.expression_efficient import ( + +from .expression_efficient import ( AdditionNode, ComparisonNode, ComponentParameterNode, @@ -30,7 +31,6 @@ TimeOperatorName, TimeOperatorNode, ) - from .visitor import ExpressionVisitor, T, visit diff --git a/src/andromede/expression/evaluate.py b/src/andromede/expression/evaluate.py index e09f033c..9f091350 100644 --- a/src/andromede/expression/evaluate.py +++ b/src/andromede/expression/evaluate.py @@ -13,7 +13,7 @@ from dataclasses import dataclass, field from typing import Dict -from andromede.expression.expression_efficient import ( +from .expression_efficient import ( ComparisonNode, ComponentParameterNode, ExpressionNodeEfficient, @@ -25,12 +25,7 @@ TimeAggregatorNode, TimeOperatorNode, ) -from andromede.expression.value_provider import ( - TimeScenarioIndex, - TimeScenarioIndices, - ValueProvider, -) - +from .value_provider import TimeScenarioIndex, TimeScenarioIndices, ValueProvider from .visitor import ExpressionVisitorOperations, visit diff --git a/src/andromede/expression/evaluate_parameters_efficient.py b/src/andromede/expression/evaluate_parameters_efficient.py index 70fa2716..728b7eb3 100644 --- a/src/andromede/expression/evaluate_parameters_efficient.py +++ b/src/andromede/expression/evaluate_parameters_efficient.py @@ -14,7 +14,7 @@ from dataclasses import dataclass, field from typing import Callable, Dict, List -from andromede.expression.expression_efficient import ( +from .expression_efficient import ( AdditionNode, ComparisonNode, ComponentParameterNode, @@ -36,13 +36,8 @@ TimeOperatorName, TimeOperatorNode, ) -from andromede.expression.indexing_structure import RowIndex -from andromede.expression.value_provider import ( - TimeScenarioIndex, - TimeScenarioIndices, - ValueProvider, -) - +from .indexing_structure import RowIndex +from .value_provider import TimeScenarioIndex, TimeScenarioIndices, ValueProvider from .visitor import ExpressionVisitor, visit diff --git a/src/andromede/expression/indexing.py b/src/andromede/expression/indexing.py index 73d43ff5..aaad0881 100644 --- a/src/andromede/expression/indexing.py +++ b/src/andromede/expression/indexing.py @@ -12,7 +12,7 @@ from abc import ABC, abstractmethod -from andromede.expression.indexing_structure import IndexingStructure +from .indexing_structure import IndexingStructure class IndexingStructureProvider(ABC): diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index af1f76a3..e2e267e9 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -29,13 +29,13 @@ overload, ) -from andromede.expression.context_adder import add_component_context -from andromede.expression.equality import expressions_equal -from andromede.expression.evaluate_parameters_efficient import ( +from .context_adder import add_component_context +from .equality import expressions_equal +from .evaluate_parameters_efficient import ( check_resolved_expr, resolve_coefficient, ) -from andromede.expression.expression_efficient import ( +from .expression_efficient import ( ExpressionNodeEfficient, ExpressionRange, InstancesTimeIndex, @@ -53,19 +53,19 @@ literal, wrap_in_node, ) -from andromede.expression.indexing import IndexingStructureProvider -from andromede.expression.indexing_structure import IndexingStructure, RowIndex -from andromede.expression.port_operator import PortAggregator, PortSum -from andromede.expression.print import print_expr -from andromede.expression.scenario_operator import Expectation, ScenarioAggregator -from andromede.expression.time_operator import ( +from .indexing import IndexingStructureProvider +from .indexing_structure import IndexingStructure, RowIndex +from .port_operator import PortAggregator, PortSum +from .print import print_expr +from .scenario_operator import Expectation, ScenarioAggregator +from .time_operator import ( TimeAggregator, TimeEvaluation, TimeOperator, TimeShift, TimeSum, ) -from andromede.expression.value_provider import ( +from .value_provider import ( TimeScenarioIndex, TimeScenarioIndices, ValueProvider, diff --git a/src/andromede/expression/port_resolver.py b/src/andromede/expression/port_resolver.py deleted file mode 100644 index 54432748..00000000 --- a/src/andromede/expression/port_resolver.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -from dataclasses import dataclass - -from andromede.model.model import PortFieldId - - -@dataclass(eq=True, frozen=True) -class PortFieldKey: - """ - Identifies the expression node for one component and one port variable. - """ - - component_id: str - port_variable_id: PortFieldId - - -# @dataclass(frozen=True) -# class PortResolver(CopyVisitor): -# """ -# Duplicates the AST with replacement of port field nodes by -# their corresponding expression. -# """ - -# component_id: str -# ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]] - -# def port_field(self, node: PortFieldNode) -> LinearExpressionEfficient: -# expressions = self.ports_expressions[ -# PortFieldKey( -# self.component_id, PortFieldId(node.port_name, node.field_name) -# ) -# ] -# if len(expressions) != 1: -# raise ValueError( -# f"Invalid number of expression for port : {node.port_name}" -# ) -# else: -# return expressions[0] - -# def port_field_aggregator( -# self, node: PortFieldAggregatorNode -# ) -> LinearExpressionEfficient: -# if node.aggregator != PortFieldAggregatorName.PORT_SUM: -# raise NotImplementedError("Only PortSum is supported.") -# port_field_node = node.operand -# if not isinstance(port_field_node, PortFieldNode): -# raise ValueError(f"Should be a portFieldNode : {port_field_node}") - -# expressions = self.ports_expressions.get( -# PortFieldKey( -# self.component_id, -# PortFieldId(port_field_node.port_name, port_field_node.field_name), -# ), -# [], -# ) -# return sum_expressions(expressions) - - -# def resolve_port( -# expression: LinearExpressionEfficient, -# component_id: str, -# ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]], -# ) -> LinearExpressionEfficient: -# return visit(expression, PortResolver(component_id, ports_expressions)) diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index 29e95cee..3e91a645 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -16,26 +16,7 @@ from abc import ABC, abstractmethod from typing import Generic, Protocol, TypeVar -# from andromede.expression.expression import ( -# AdditionNode, -# ComparisonNode, -# ComponentParameterNode, -# ComponentVariableNode, -# DivisionNode, -# ExpressionNode, -# LiteralNode, -# MultiplicationNode, -# NegationNode, -# ParameterNode, -# PortFieldAggregatorNode, -# PortFieldNode, -# ScenarioOperatorNode, -# SubstractionNode, -# TimeAggregatorNode, -# TimeOperatorNode, -# VariableNode, -# ) -from andromede.expression.expression_efficient import ( +from .expression_efficient import ( AdditionNode, ComparisonNode, ComponentParameterNode, diff --git a/src/andromede/simulation/benders_decomposed.py b/src/andromede/simulation/benders_decomposed.py index 9ab44139..f5ff23a0 100644 --- a/src/andromede/simulation/benders_decomposed.py +++ b/src/andromede/simulation/benders_decomposed.py @@ -18,22 +18,20 @@ import pathlib from typing import Any, Dict, List, Optional -from andromede.simulation.optimization import OptimizationProblem, build_problem -from andromede.simulation.optimization_context import BlockBorderManagement -from andromede.simulation.output_values import ( +from andromede.study.data import DataBase +from andromede.study.network import Network +from andromede.utils import read_json, serialize, serialize_json + +from .optimization import OptimizationProblem, build_problem +from .optimization_context import BlockBorderManagement +from .output_values import ( BendersDecomposedSolution, BendersMergedSolution, BendersSolution, ) -from andromede.simulation.runner import BendersRunner, MergeMPSRunner -from andromede.simulation.strategy import ( - InvestmentProblemStrategy, - OperationalProblemStrategy, -) -from andromede.simulation.time_block import TimeBlock -from andromede.study.data import DataBase -from andromede.study.network import Network -from andromede.utils import read_json, serialize, serialize_json +from .runner import BendersRunner, MergeMPSRunner +from .strategy import InvestmentProblemStrategy, OperationalProblemStrategy +from .time_block import TimeBlock class BendersDecomposedProblem: diff --git a/src/andromede/simulation/linear_expression.py b/src/andromede/simulation/linear_expression.py index 1f3c9359..c491f6ce 100644 --- a/src/andromede/simulation/linear_expression.py +++ b/src/andromede/simulation/linear_expression.py @@ -20,7 +20,6 @@ from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.scenario_operator import ScenarioAggregator from andromede.expression.time_operator import TimeAggregator, TimeOperator -from andromede.model.model import PortFieldId T = TypeVar("T") diff --git a/src/andromede/simulation/linear_expression_resolver.py b/src/andromede/simulation/linear_expression_resolver.py index 892c701f..ed28339f 100644 --- a/src/andromede/simulation/linear_expression_resolver.py +++ b/src/andromede/simulation/linear_expression_resolver.py @@ -24,7 +24,7 @@ LinearExpressionEfficient, TermEfficient, ) -from andromede.expression.resolved_linear_expression import ( +from .resolved_linear_expression import ( ResolvedLinearExpression, ResolvedTerm, ) @@ -35,7 +35,7 @@ TimeScenarioIndices, ValueProvider, ) -from andromede.simulation.optimization_context import OptimizationContext +from .optimization_context import OptimizationContext @dataclass diff --git a/src/andromede/simulation/linearize.py b/src/andromede/simulation/linearize.py deleted file mode 100644 index dc3cb2ef..00000000 --- a/src/andromede/simulation/linearize.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -import dataclasses -from dataclasses import dataclass -from typing import Optional - -import andromede.expression.scenario_operator -import andromede.expression.time_operator -from andromede.expression.evaluate import ValueProvider -from andromede.expression.evaluate_parameters import get_time_ids_from_instances_index -from andromede.expression.expression import ( - ComparisonNode, - ComponentParameterNode, - ComponentVariableNode, - ExpressionNode, - LiteralNode, - ParameterNode, - PortFieldAggregatorNode, - PortFieldNode, - ScenarioOperatorNode, - TimeAggregatorNode, - TimeOperatorNode, - VariableNode, -) -from andromede.expression.indexing import IndexingStructureProvider -from andromede.expression.visitor import ExpressionVisitorOperations, T, visit -from andromede.simulation.linear_expression import LinearExpression, Term, generate_key - - -@dataclass(frozen=True) -class LinearExpressionBuilder(ExpressionVisitorOperations[LinearExpression]): - """ - Reduces a generic expression to a linear expression. - - Parameters should have been evaluated first. - """ - - structure_provider: IndexingStructureProvider - value_provider: Optional[ValueProvider] = None - - def literal(self, node: LiteralNode) -> LinearExpression: - return LinearExpression([], node.value) - - def comparison(self, node: ComparisonNode) -> LinearExpression: - raise ValueError("Linear expression cannot contain a comparison operator.") - - def variable(self, node: VariableNode) -> LinearExpression: - raise ValueError( - "Variables need to be associated with their component ID before linearization." - ) - - def parameter(self, node: ParameterNode) -> LinearExpression: - raise ValueError("Parameters must be evaluated before linearization.") - - def comp_variable(self, node: ComponentVariableNode) -> LinearExpression: - return LinearExpression( - [ - Term( - 1, - node.component_id, - node.name, - self.structure_provider.get_component_variable_structure( - node.component_id, node.name - ), - ) - ], - 0, - ) - - def comp_parameter(self, node: ComponentParameterNode) -> LinearExpression: - raise ValueError("Parameters must be evaluated before linearization.") - - def time_operator(self, node: TimeOperatorNode) -> LinearExpression: - if self.value_provider is None: - raise ValueError( - "A value provider must be specified to linearize a time operator node. This is required in order to evaluate the value of potential parameters used to specified the time ids on which the time operator applies." - ) - - operand_expr = visit(node.operand, self) - time_operator_cls = getattr(andromede.expression.time_operator, node.name) - time_ids = get_time_ids_from_instances_index( - node.instances_index, self.value_provider - ) - - result_terms = {} - for term in operand_expr.terms.values(): - term_with_operator = dataclasses.replace( - term, time_operator=time_operator_cls(time_ids) - ) - result_terms[generate_key(term_with_operator)] = term_with_operator - - # TODO: How can we apply a shift on a parameter ? It seems impossible for now as parameters must already be evaluated... - result_expr = LinearExpression(result_terms, operand_expr.constant) - return result_expr - - def time_aggregator(self, node: TimeAggregatorNode) -> LinearExpression: - # TODO: Very similar to time_operator, may be factorized - operand_expr = visit(node.operand, self) - time_aggregator_cls = getattr(andromede.expression.time_operator, node.name) - result_terms = {} - for term in operand_expr.terms.values(): - term_with_operator = dataclasses.replace( - term, time_aggregator=time_aggregator_cls(node.stay_roll) - ) - result_terms[generate_key(term_with_operator)] = term_with_operator - - result_expr = LinearExpression(result_terms, operand_expr.constant) - return result_expr - - def scenario_operator(self, node: ScenarioOperatorNode) -> LinearExpression: - scenario_operator_cls = getattr( - andromede.expression.scenario_operator, node.name - ) - if scenario_operator_cls.degree() > 1: - raise ValueError( - f"Cannot linearize expression with a non-linear operator: {scenario_operator_cls.__name__}" - ) - - operand_expr = visit(node.operand, self) - result_terms = {} - for term in operand_expr.terms.values(): - term_with_operator = dataclasses.replace( - term, scenario_aggregator=scenario_operator_cls() - ) - result_terms[generate_key(term_with_operator)] = term_with_operator - - result_expr = LinearExpression(result_terms, operand_expr.constant) - return result_expr - - def port_field(self, node: PortFieldNode) -> LinearExpression: - raise ValueError("Port fields must be replaced before linearization.") - - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> LinearExpression: - raise ValueError( - "Port fields aggregators must be replaced before linearization." - ) - - -def linearize_expression( - expression: ExpressionNode, - structure_provider: IndexingStructureProvider, - value_provider: Optional[ValueProvider] = None, -) -> LinearExpression: - return visit( - expression, LinearExpressionBuilder(structure_provider, value_provider) - ) diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 7037af2e..9f90378b 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -26,20 +26,20 @@ LinearExpressionEfficient, RowIndex, ) -from andromede.expression.resolved_linear_expression import ResolvedLinearExpression +from .resolved_linear_expression import ResolvedLinearExpression from andromede.model.common import ValueType from andromede.model.constraint import Constraint from andromede.model.model import PortFieldId -from andromede.simulation.linear_expression_resolver import LinearExpressionResolver -from andromede.simulation.optimization_context import ( +from .linear_expression_resolver import LinearExpressionResolver +from .optimization_context import ( BlockBorderManagement, ComponentContext, OptimizationContext, make_data_structure_provider, make_value_provider, ) -from andromede.simulation.strategy import MergedProblemStrategy, ModelSelectionStrategy -from andromede.simulation.time_block import TimeBlock +from .strategy import MergedProblemStrategy, ModelSelectionStrategy +from .time_block import TimeBlock from andromede.study.data import DataBase from andromede.study.network import Component, Network diff --git a/src/andromede/simulation/optimization_context.py b/src/andromede/simulation/optimization_context.py index 85be98be..be7893cb 100644 --- a/src/andromede/simulation/optimization_context.py +++ b/src/andromede/simulation/optimization_context.py @@ -26,7 +26,7 @@ PortFieldKey, ) from andromede.expression.value_provider import TimeScenarioIndex, TimeScenarioIndices -from andromede.simulation.time_block import TimeBlock +from .time_block import TimeBlock from andromede.study.data import DataBase from andromede.study.network import Component, Network from andromede.utils import get_or_add diff --git a/src/andromede/simulation/output_values.py b/src/andromede/simulation/output_values.py index cf94e7c0..d7977358 100644 --- a/src/andromede/simulation/output_values.py +++ b/src/andromede/simulation/output_values.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union, cast -from andromede.simulation.optimization import OptimizationProblem +from .optimization import OptimizationProblem from andromede.study.data import TimeScenarioIndex diff --git a/src/andromede/expression/resolved_linear_expression.py b/src/andromede/simulation/resolved_linear_expression.py similarity index 100% rename from src/andromede/expression/resolved_linear_expression.py rename to src/andromede/simulation/resolved_linear_expression.py diff --git a/src/andromede/simulation/strategy.py b/src/andromede/simulation/strategy.py index cb30f96a..0f148bda 100644 --- a/src/andromede/simulation/strategy.py +++ b/src/andromede/simulation/strategy.py @@ -13,7 +13,6 @@ from abc import ABC, abstractmethod from typing import Generator, Optional -# from andromede.expression import ExpressionNode from andromede.expression.linear_expression_efficient import LinearExpressionEfficient from andromede.model import Constraint, Model, ProblemContext, Variable diff --git a/src/andromede/study/data.py b/src/andromede/study/data.py index 50d4acda..7fa7b384 100644 --- a/src/andromede/study/data.py +++ b/src/andromede/study/data.py @@ -17,7 +17,7 @@ import pandas as pd -from andromede.study.network import Network +from .network import Network @dataclass(frozen=True) diff --git a/src/andromede/study/resolve_components.py b/src/andromede/study/resolve_components.py index d6650e78..327d2c45 100644 --- a/src/andromede/study/resolve_components.py +++ b/src/andromede/study/resolve_components.py @@ -13,30 +13,18 @@ from pathlib import Path from typing import Dict, Iterable, List, Optional -import pandas as pd - from andromede.model import Model from andromede.model.library import Library -from andromede.study import ( - Component, + +from .data import ( + AbstractDataStructure, ConstantData, DataBase, - Network, - Node, - PortRef, - PortsConnection, -) -from andromede.study.data import ( - AbstractDataStructure, - TimeScenarioIndex, TimeScenarioSeriesData, load_ts_from_txt, ) -from andromede.study.parsing import ( - InputComponent, - InputComponents, - InputPortConnections, -) +from .network import Component, Network, Node, PortRef, PortsConnection +from .parsing import InputComponent, InputComponents, InputPortConnections @dataclass(frozen=True) From 8b7a2449793f3f63eddf4a0a414bd6d868777c62 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 21 Aug 2024 20:32:02 +0200 Subject: [PATCH 36/51] Fix some type checking issues, remove useless code --- requirements-dev.txt | 1 + src/andromede/expression/__init__.py | 1 - src/andromede/expression/degree.py | 120 ----- src/andromede/expression/evaluate.py | 76 --- .../expression/expression_efficient.py | 4 +- .../expression/linear_expression_efficient.py | 36 +- src/andromede/expression/port_operator.py | 2 +- src/andromede/expression/time_operator.py | 5 +- src/andromede/model/probability_law.py | 55 --- src/andromede/model/variable.py | 10 +- src/andromede/simulation/linear_expression.py | 416 ---------------- .../simulation/linear_expression_resolver.py | 2 +- .../simulation/optimization_context.py | 35 -- .../expressions/test_linear_expressions.py | 446 ------------------ 14 files changed, 28 insertions(+), 1181 deletions(-) delete mode 100644 src/andromede/expression/degree.py delete mode 100644 src/andromede/model/probability_law.py delete mode 100644 src/andromede/simulation/linear_expression.py delete mode 100644 tests/unittests/expressions/test_linear_expressions.py diff --git a/requirements-dev.txt b/requirements-dev.txt index 647280ff..4f090b4c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,3 +8,4 @@ types-PyYAML~=6.0.12.12 antlr4-tools~=0.2.1 pandas~=2.0.3 pandas-stubs<=2.0.3 +types-PyYAML~=6.0.12 diff --git a/src/andromede/expression/__init__.py b/src/andromede/expression/__init__.py index 70c8c8ac..5b62b67e 100644 --- a/src/andromede/expression/__init__.py +++ b/src/andromede/expression/__init__.py @@ -11,7 +11,6 @@ # This file is part of the Antares project. from .copy import CopyVisitor, copy_expression -from .degree import ExpressionDegreeVisitor, compute_degree from .evaluate_parameters_efficient import ValueProvider from .expression_efficient import ( AdditionNode, diff --git a/src/andromede/expression/degree.py b/src/andromede/expression/degree.py deleted file mode 100644 index 3a5119ac..00000000 --- a/src/andromede/expression/degree.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -import andromede.expression.scenario_operator - -from .expression_efficient import ( - AdditionNode, - ComparisonNode, - ComponentParameterNode, - DivisionNode, - ExpressionNodeEfficient, - LiteralNode, - MultiplicationNode, - NegationNode, - ParameterNode, - PortFieldAggregatorNode, - PortFieldNode, - ScenarioOperatorNode, - SubstractionNode, - TimeAggregatorName, - TimeAggregatorNode, - TimeOperatorName, - TimeOperatorNode, -) -from .visitor import ExpressionVisitor, T, visit - - -class ExpressionDegreeVisitor(ExpressionVisitor[int]): - """ - Computes degree of expression with respect to variables. - """ - - def literal(self, node: LiteralNode) -> int: - return 0 - - def negation(self, node: NegationNode) -> int: - return visit(node.operand, self) - - # TODO: Take into account simplification that can occur with literal coefficient for add, sub, mult, div - def addition(self, node: AdditionNode) -> int: - return max(visit(node.left, self), visit(node.right, self)) - - def substraction(self, node: SubstractionNode) -> int: - return max(visit(node.left, self), visit(node.right, self)) - - def multiplication(self, node: MultiplicationNode) -> int: - return visit(node.left, self) + visit(node.right, self) - - def division(self, node: DivisionNode) -> int: - right_degree = visit(node.right, self) - if right_degree != 0: - raise ValueError("Degree computation not implemented for divisions.") - return visit(node.left, self) - - def comparison(self, node: ComparisonNode) -> int: - return max(visit(node.left, self), visit(node.right, self)) - - # def variable(self, node: VariableNode) -> int: - # return 1 - - def parameter(self, node: ParameterNode) -> int: - return 0 - - # def comp_variable(self, node: ComponentVariableNode) -> int: - # return 1 - - def comp_parameter(self, node: ComponentParameterNode) -> int: - return 0 - - def time_operator(self, node: TimeOperatorNode) -> int: - if node.name in [TimeOperatorName.SHIFT, TimeOperatorName.EVALUATION]: - return visit(node.operand, self) - else: - return NotImplemented - - def time_aggregator(self, node: TimeAggregatorNode) -> int: - if node.name in [TimeAggregatorName.TIME_SUM]: - return visit(node.operand, self) - else: - return NotImplemented - - def scenario_operator(self, node: ScenarioOperatorNode) -> int: - scenario_operator_cls = getattr( - andromede.expression.scenario_operator, node.name - ) - # TODO: Carefully check if this formula is correct - return scenario_operator_cls.degree() * visit(node.operand, self) - - def port_field(self, node: PortFieldNode) -> int: - return 1 - - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> int: - return visit(node.operand, self) - - -def compute_degree(expression: ExpressionNodeEfficient) -> int: - return visit(expression, ExpressionDegreeVisitor()) - - -def is_constant(expr: ExpressionNodeEfficient) -> bool: - """ - True if the expression has no variable. - """ - return compute_degree(expr) == 0 - - -def is_linear(expr: ExpressionNodeEfficient) -> bool: - """ - True if the expression is linear with respect to variables. - """ - return compute_degree(expr) <= 1 diff --git a/src/andromede/expression/evaluate.py b/src/andromede/expression/evaluate.py index 9f091350..94c9eb1d 100644 --- a/src/andromede/expression/evaluate.py +++ b/src/andromede/expression/evaluate.py @@ -13,20 +13,7 @@ from dataclasses import dataclass, field from typing import Dict -from .expression_efficient import ( - ComparisonNode, - ComponentParameterNode, - ExpressionNodeEfficient, - LiteralNode, - ParameterNode, - PortFieldAggregatorNode, - PortFieldNode, - ScenarioOperatorNode, - TimeAggregatorNode, - TimeOperatorNode, -) from .value_provider import TimeScenarioIndex, TimeScenarioIndices, ValueProvider -from .visitor import ExpressionVisitorOperations, visit # Used only for tests @@ -70,66 +57,3 @@ def block_length() -> int: @staticmethod def scenarios() -> int: raise NotImplementedError() - - -@dataclass(frozen=True) -class EvaluationVisitor(ExpressionVisitorOperations[float]): - """ - Evaluates the expression with respect to the provided context - (variables and parameters values). - """ - - context: ValueProvider - - def literal(self, node: LiteralNode) -> float: - return node.value - - def comparison(self, node: ComparisonNode) -> float: - raise ValueError("Cannot evaluate comparison operator.") - - def parameter(self, node: ParameterNode) -> float: - return self.context.get_parameter_value(node.name) - - def comp_parameter(self, node: ComponentParameterNode) -> float: - return self.context.get_component_parameter_value(node.component_id, node.name) - - def time_operator(self, node: TimeOperatorNode) -> float: - raise NotImplementedError() - - def time_aggregator(self, node: TimeAggregatorNode) -> float: - raise NotImplementedError() - - def scenario_operator(self, node: ScenarioOperatorNode) -> float: - raise NotImplementedError() - - def port_field(self, node: PortFieldNode) -> float: - raise NotImplementedError() - - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> float: - raise NotImplementedError() - - -def evaluate( - expression: ExpressionNodeEfficient, value_provider: ValueProvider -) -> float: - return visit(expression, EvaluationVisitor(value_provider)) - - -@dataclass(frozen=True) -class InstancesIndexVisitor(EvaluationVisitor): - """ - Evaluates an expression given as instances index which should have no variable and constant parameter values. - """ - - def parameter(self, node: ParameterNode) -> float: - if not self.context.parameter_is_constant_over_time(node.name): - raise ValueError( - "Parameter given in an instance index expression must be constant over time" - ) - return self.context.get_parameter_value(node.name) - - def time_operator(self, node: TimeOperatorNode) -> float: - raise ValueError("An instance index expression cannot contain time operator") - - def time_aggregator(self, node: TimeAggregatorNode) -> float: - raise ValueError("An instance index expression cannot contain time aggregator") diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index 29bbe87b..4386b6d6 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -132,7 +132,7 @@ def expec(self) -> "ExpressionNodeEfficient": def variance(self) -> "ExpressionNodeEfficient": return _apply_if_node( - self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.Variance) + self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.VARIANCE) ) @@ -141,6 +141,8 @@ def wrap_in_node(obj: Any) -> ExpressionNodeEfficient: return obj elif isinstance(obj, float) or isinstance(obj, int): return LiteralNode(float(obj)) + # else: + # return None # Do not raise excpetion so that we can return NotImplemented in _apply_if_node # raise TypeError(f"Unable to wrap {obj} into an expression node") diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index e2e267e9..3ecc136c 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -21,6 +21,8 @@ Callable, Dict, List, + Literal, + Mapping, Optional, Sequence, TypeVar, @@ -31,10 +33,7 @@ from .context_adder import add_component_context from .equality import expressions_equal -from .evaluate_parameters_efficient import ( - check_resolved_expr, - resolve_coefficient, -) +from .evaluate_parameters_efficient import check_resolved_expr, resolve_coefficient from .expression_efficient import ( ExpressionNodeEfficient, ExpressionRange, @@ -65,11 +64,7 @@ TimeShift, TimeSum, ) -from .value_provider import ( - TimeScenarioIndex, - TimeScenarioIndices, - ValueProvider, -) +from .value_provider import TimeScenarioIndex, TimeScenarioIndices, ValueProvider @dataclass(frozen=True) @@ -368,7 +363,7 @@ def __str__(self) -> str: result += f".{str(self.aggregator)}" return result - def sum_connections(self) -> "LinearExpressionEfficient": + def sum_connections(self) -> "PortFieldTerm": if self.aggregator is not None: raise ValueError(f"Port field {str(self)} already has a port aggregator") return dataclasses.replace(self, aggregator=PortSum()) @@ -377,6 +372,10 @@ def sum_connections(self) -> "LinearExpressionEfficient": T_val = TypeVar("T_val", bound=Union[TermEfficient, PortFieldTerm]) +def _get_neutral_term(term: T_val, neutral: float) -> T_val: + return dataclasses.replace(term, coefficient=wrap_in_node(neutral)) + + @overload def _merge_dicts( lhs: Dict[TermKeyEfficient, TermEfficient], @@ -397,10 +396,6 @@ def _merge_dicts( ... -def _get_neutral_term(term: T_val, neutral: float) -> T_val: - return dataclasses.replace(term, coefficient=neutral) - - def _merge_dicts(lhs, rhs, merge_func, neutral): res = {} for k, v in lhs.items(): @@ -821,7 +816,7 @@ def sum( def _apply_operator( self, - sum_args: Dict[ + sum_args: Mapping[ str, Union[ int, @@ -839,13 +834,6 @@ def _apply_operator( return result_terms - # def sum_connections(self) -> "ExpressionNode": - # if isinstance(self, PortFieldNode): - # return PortFieldAggregatorNode(self, aggregator=PortFieldAggregatorName.PORT_SUM) - # raise ValueError( - # f"sum_connections() applies only for PortFieldNode, whereas the current node is of type {type(self)}." - # ) - def shift( self, expressions: Union[ @@ -1036,7 +1024,7 @@ def linear_expressions_equal_if_present( # TODO: Is this function useful ? Could we just rely on the sum operator overloading ? Only the case with an empty list may make the function useful def sum_expressions( expressions: Sequence[LinearExpressionEfficient], -) -> LinearExpressionEfficient: +) -> Union[LinearExpressionEfficient, Literal[0]]: if len(expressions) == 0: return wrap_in_linear_expr(literal(0)) else: @@ -1059,7 +1047,7 @@ def __post_init__( for bound in [self.lower_bound, self.upper_bound]: if not bound.is_constant(): raise ValueError( - f"The bounds of a constraint should not contain variables, {print_expr(bound)} was given." + f"The bounds of a constraint should not contain variables, {str(bound)} was given." ) def __str__(self) -> str: diff --git a/src/andromede/expression/port_operator.py b/src/andromede/expression/port_operator.py index 56f18322..845ae693 100644 --- a/src/andromede/expression/port_operator.py +++ b/src/andromede/expression/port_operator.py @@ -30,5 +30,5 @@ class PortAggregator: @dataclass(frozen=True) class PortSum(PortAggregator): - def __str__(self): + def __str__(self) -> str: return "PortSum" diff --git a/src/andromede/expression/time_operator.py b/src/andromede/expression/time_operator.py index 3b8e7bce..1332b2dc 100644 --- a/src/andromede/expression/time_operator.py +++ b/src/andromede/expression/time_operator.py @@ -36,12 +36,9 @@ class TimeOperator(ABC): def rolling(cls) -> bool: raise NotImplementedError - def key(self) -> Tuple[int, ...]: + def key(self) -> InstancesTimeIndex: return self.time_ids - def size(self) -> int: - return len(self.time_ids.expressions) - @dataclass(frozen=True) class TimeShift(TimeOperator): diff --git a/src/andromede/model/probability_law.py b/src/andromede/model/probability_law.py deleted file mode 100644 index 62e3dc6a..00000000 --- a/src/andromede/model/probability_law.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -""" -Describes probability distributions used in the models -""" - -from abc import ABC -from dataclasses import dataclass -from typing import List - -import numpy as np - -from andromede.expression.expression import ExpressionNode - - -class AbstractProbabilityLaw(ABC): - def get_sample(self, size: int) -> List[float]: - return NotImplemented - - -@dataclass(frozen=True) -class Normal(AbstractProbabilityLaw): - mean: ExpressionNode - standard_deviation: ExpressionNode - - def get_sample(self, size: int) -> List[float]: - return NotImplemented - - -@dataclass(frozen=True) -class Uniform(AbstractProbabilityLaw): - lower_bound: ExpressionNode - upper_bound: ExpressionNode - - def get_sample(self, size: int) -> List[float]: - return NotImplemented - - -@dataclass(frozen=True) -class UniformIntegers(AbstractProbabilityLaw): - lower_bound: ExpressionNode - upper_bound: ExpressionNode - - def get_sample(self, size: int) -> List[float]: - return NotImplemented diff --git a/src/andromede/model/variable.py b/src/andromede/model/variable.py index e418d8a3..28343e42 100644 --- a/src/andromede/model/variable.py +++ b/src/andromede/model/variable.py @@ -18,6 +18,7 @@ from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, linear_expressions_equal_if_present, + wrap_in_linear_expr, wrap_in_linear_expr_if_present, ) from andromede.model.common import ( @@ -80,7 +81,14 @@ def bool_var( structure: IndexingStructure = IndexingStructure(True, True), context: ProblemContext = ProblemContext.OPERATIONAL, ) -> Variable: - return Variable(name, ValueType.BOOL, literal(0), literal(1), structure, context) + return Variable( + name, + ValueType.BOOL, + wrap_in_linear_expr(literal(0)), + wrap_in_linear_expr(literal(1)), + structure, + context, + ) def float_variable( diff --git a/src/andromede/simulation/linear_expression.py b/src/andromede/simulation/linear_expression.py deleted file mode 100644 index c491f6ce..00000000 --- a/src/andromede/simulation/linear_expression.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -""" -Specific modelling for "instantiated" linear expressions, -with only variables and literal coefficients. -""" -from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, TypeVar, Union - -from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.scenario_operator import ScenarioAggregator -from andromede.expression.time_operator import TimeAggregator, TimeOperator - -T = TypeVar("T") - -EPS = 10 ** (-16) - - -def is_close_abs(value: float, other_value: float, eps: float) -> bool: - return abs(value - other_value) < eps - - -def is_zero(value: float) -> bool: - return is_close_abs(value, 0, EPS) - - -def is_one(value: float) -> bool: - return is_close_abs(value, 1, EPS) - - -def is_minus_one(value: float) -> bool: - return is_close_abs(value, -1, EPS) - - -@dataclass(frozen=True) -class TermKey: - """ - Utility class to provide key for a term that contains all term information except coefficient - """ - - component_id: str - variable_name: str - time_operator: Optional[TimeOperator] - time_aggregator: Optional[TimeAggregator] - scenario_aggregator: Optional[ScenarioAggregator] - - -@dataclass(frozen=True) -class Term: - """ - One term in a linear expression: for example the "10x" par in "10x + 5y + 5" - - Args: - coefficient: the coefficient for that term, for example "10" in "10x" - variable_name: the name of the variable, for example "x" in "10x" - """ - - coefficient: float - component_id: str - variable_name: str - structure: IndexingStructure = field( - default=IndexingStructure(time=True, scenario=True) - ) - time_operator: Optional[TimeOperator] = None - time_aggregator: Optional[TimeAggregator] = None - scenario_aggregator: Optional[ScenarioAggregator] = None - - # TODO: It may be useful to define __add__, __sub__, etc on terms, which should return a linear expression ? - - def is_zero(self) -> bool: - return is_zero(self.coefficient) - - def str_for_coeff(self) -> str: - str_for_coeff = "" - if is_one(self.coefficient): - str_for_coeff = "+" - elif is_minus_one(self.coefficient): - str_for_coeff = "-" - else: - str_for_coeff = "{:+g}".format(self.coefficient) - return str_for_coeff - - def __str__(self) -> str: - # Useful for debugging tests - result = self.str_for_coeff() + str(self.variable_name) - if self.time_operator is not None: - result += f".{str(self.time_operator)}" - if self.time_aggregator is not None: - result += f".{str(self.time_aggregator)}" - if self.scenario_aggregator is not None: - result += f".{str(self.scenario_aggregator)}" - return result - - def number_of_instances(self) -> int: - if self.time_aggregator is not None: - return self.time_aggregator.size() - else: - if self.time_operator is not None: - return self.time_operator.size() - else: - return 1 - - -def generate_key(term: Term) -> TermKey: - return TermKey( - term.component_id, - term.variable_name, - term.time_operator, - term.time_aggregator, - term.scenario_aggregator, - ) - - -def _merge_dicts( - lhs: Dict[TermKey, Term], - rhs: Dict[TermKey, Term], - merge_func: Callable[[Term, Term], Term], - neutral: float, -) -> Dict[TermKey, Term]: - res = {} - for k, v in lhs.items(): - res[k] = merge_func( - v, - rhs.get( - k, - Term( - neutral, - v.component_id, - v.variable_name, - v.structure, - v.time_operator, - v.time_aggregator, - v.scenario_aggregator, - ), - ), - ) - for k, v in rhs.items(): - if k not in lhs: - res[k] = merge_func( - Term( - neutral, - v.component_id, - v.variable_name, - v.structure, - v.time_operator, - v.time_aggregator, - v.scenario_aggregator, - ), - v, - ) - return res - - -def _merge_is_possible(lhs: Term, rhs: Term) -> None: - if lhs.component_id != rhs.component_id or lhs.variable_name != rhs.variable_name: - raise ValueError("Cannot merge terms for different variables") - if ( - lhs.time_operator != rhs.time_operator - or lhs.time_aggregator != rhs.time_aggregator - or lhs.scenario_aggregator != rhs.scenario_aggregator - ): - raise ValueError("Cannot merge terms with different operators") - if lhs.structure != rhs.structure: - raise ValueError("Cannot merge terms with different structures") - - -def _add_terms(lhs: Term, rhs: Term) -> Term: - _merge_is_possible(lhs, rhs) - return Term( - lhs.coefficient + rhs.coefficient, - lhs.component_id, - lhs.variable_name, - lhs.structure, - lhs.time_operator, - lhs.time_aggregator, - lhs.scenario_aggregator, - ) - - -def _substract_terms(lhs: Term, rhs: Term) -> Term: - _merge_is_possible(lhs, rhs) - return Term( - lhs.coefficient - rhs.coefficient, - lhs.component_id, - lhs.variable_name, - lhs.structure, - lhs.time_operator, - lhs.time_aggregator, - lhs.scenario_aggregator, - ) - - -class LinearExpression: - """ - Represents a linear expression with respect to variable names, for example 10x + 5y + 2. - - Operators may be used for construction. - - Args: - terms: the list of variable terms, for example 10x and 5y in "10x + 5y + 2". - constant: the constant term, for example 2 in "10x + 5y + 2" - - Examples: - Operators may be used for construction: - - >>> LinearExpression([], 10) + LinearExpression([Term(10, "x")], 0) - LinearExpression([Term(10, "x")], 10) - """ - - terms: Dict[TermKey, Term] - constant: float - - def __init__( - self, - terms: Optional[Union[Dict[TermKey, Term], List[Term]]] = None, - constant: Optional[float] = None, - ) -> None: - self.constant = 0 - self.terms = {} - - if constant is not None: - # += b - self.constant = constant - if terms is not None: - # Allows to give two different syntax in the constructor: - # - List[Term] is natural - # - Dict[str, Term] is useful when constructing a linear expression from the terms of another expression - if isinstance(terms, dict): - for term_key, term in terms.items(): - if not term.is_zero(): - self.terms[term_key] = term - elif isinstance(terms, list): - for term in terms: - if not term.is_zero(): - self.terms[generate_key(term)] = term - else: - raise TypeError( - f"Terms must be either of type Dict[str, Term] or List[Term], whereas {terms} is of type {type(terms)}" - ) - - def is_zero(self) -> bool: - return len(self.terms) == 0 and is_zero(self.constant) - - def str_for_constant(self) -> str: - if is_zero(self.constant): - return "" - else: - return "{:+g}".format(self.constant) - - def __str__(self) -> str: - # Useful for debugging tests - result = "" - if self.is_zero(): - result += "0" - else: - for term in self.terms.values(): - result += str(term) - - result += self.str_for_constant() - - return result - - def __eq__(self, rhs: object) -> bool: - return ( - isinstance(rhs, LinearExpression) - and is_close_abs(self.constant, rhs.constant, EPS) - and self.terms - == rhs.terms # /!\ There may be float equality comparison in the terms values - ) - - def __iadd__(self, rhs: "LinearExpression") -> "LinearExpression": - if not isinstance(rhs, LinearExpression): - return NotImplemented - self.constant += rhs.constant - aggregated_terms = _merge_dicts(self.terms, rhs.terms, _add_terms, 0) - self.terms = aggregated_terms - self.remove_zeros_from_terms() - return self - - def __add__(self, rhs: "LinearExpression") -> "LinearExpression": - result = LinearExpression() - result += self - result += rhs - return result - - def __isub__(self, rhs: "LinearExpression") -> "LinearExpression": - if not isinstance(rhs, LinearExpression): - return NotImplemented - self.constant -= rhs.constant - aggregated_terms = _merge_dicts(self.terms, rhs.terms, _substract_terms, 0) - self.terms = aggregated_terms - self.remove_zeros_from_terms() - return self - - def __sub__(self, rhs: "LinearExpression") -> "LinearExpression": - result = LinearExpression() - result += self - result -= rhs - return result - - def __neg__(self) -> "LinearExpression": - result = LinearExpression() - result -= self - return result - - def __imul__(self, rhs: "LinearExpression") -> "LinearExpression": - if not isinstance(rhs, LinearExpression): - return NotImplemented - - if self.terms and rhs.terms: - raise ValueError("Cannot multiply two non constant expression") - else: - if self.terms: - left_expr = self - const_expr = rhs - else: - # It is possible that both expr are constant - left_expr = rhs - const_expr = self - if is_close_abs(const_expr.constant, 0, EPS): - return LinearExpression() - elif is_close_abs(const_expr.constant, 1, EPS): - _copy_expression(left_expr, self) - else: - left_expr.constant *= const_expr.constant - for term_key, term in left_expr.terms.items(): - left_expr.terms[term_key] = Term( - term.coefficient * const_expr.constant, - term.component_id, - term.variable_name, - term.structure, - term.time_operator, - term.time_aggregator, - term.scenario_aggregator, - ) - _copy_expression(left_expr, self) - return self - - def __mul__(self, rhs: "LinearExpression") -> "LinearExpression": - result = LinearExpression() - result += self - result *= rhs - return result - - def __itruediv__(self, rhs: "LinearExpression") -> "LinearExpression": - if not isinstance(rhs, LinearExpression): - return NotImplemented - - if rhs.terms: - raise ValueError("Cannot divide by a non constant expression") - else: - if is_close_abs(rhs.constant, 0, EPS): - raise ZeroDivisionError("Cannot divide expression by zero") - elif is_close_abs(rhs.constant, 1, EPS): - return self - else: - self.constant /= rhs.constant - for term_key, term in self.terms.items(): - self.terms[term_key] = Term( - term.coefficient / rhs.constant, - term.component_id, - term.variable_name, - term.structure, - term.time_operator, - term.time_aggregator, - term.scenario_aggregator, - ) - return self - - def __truediv__(self, rhs: "LinearExpression") -> "LinearExpression": - result = LinearExpression() - result += self - result /= rhs - - return result - - def remove_zeros_from_terms(self) -> None: - # TODO: Not optimized, checks could be done directly when doing operations on self.linear_term to avoid copies - for term_key, term in self.terms.copy().items(): - if is_close_abs(term.coefficient, 0, EPS): - del self.terms[term_key] - - def is_valid(self) -> bool: - nb_instances = None - for term in self.terms.values(): - term_instances = term.number_of_instances() - if nb_instances is None: - nb_instances = term_instances - else: - if term_instances != nb_instances: - raise ValueError( - "The terms of the linear expression {self} do not have the same number of instances" - ) - return True - - def number_of_instances(self) -> int: - if self.is_valid(): - # All terms have the same number of instances, just pick one - return self.terms[next(iter(self.terms))].number_of_instances() - else: - raise ValueError(f"{self} is not a valid linear expression") - - -def _copy_expression(src: LinearExpression, dst: LinearExpression) -> None: - dst.terms = src.terms - dst.constant = src.constant diff --git a/src/andromede/simulation/linear_expression_resolver.py b/src/andromede/simulation/linear_expression_resolver.py index ed28339f..0ae9c736 100644 --- a/src/andromede/simulation/linear_expression_resolver.py +++ b/src/andromede/simulation/linear_expression_resolver.py @@ -52,7 +52,7 @@ def resolve( resolved_variables = self.resolve_variables(term, row_id) # TODO: Contrary to the time aggregator that does a sum which is the default behaviour when append resolved terms, expectation performs an averaging, so weights must be included in coefficients. We feel here that we could generalize time and scenario aggregation over variables with more general operators, the following lines are very specific to expectation with same weights over all scenarios - weight = 1 + weight: float = 1 if isinstance(term.scenario_aggregator, Expectation): weight = 1 / self.value_provider.scenarios() diff --git a/src/andromede/simulation/optimization_context.py b/src/andromede/simulation/optimization_context.py index be7893cb..4c046e03 100644 --- a/src/andromede/simulation/optimization_context.py +++ b/src/andromede/simulation/optimization_context.py @@ -193,16 +193,6 @@ def register_connection_fields_expressions( ) -class TimestepValueProvider(ABC): - """ - Interface which provides numerical values for individual timesteps. - """ - - @abstractmethod - def get_value(self, block_timestep: int, scenario: int) -> float: - raise NotImplementedError() - - def _get_parameter_value( context: OptimizationContext, block_timestep: int, @@ -294,21 +284,6 @@ def scenarios() -> int: return Provider() -@dataclass(frozen=True) -class ExpressionTimestepValueProvider(TimestepValueProvider): - context: "OptimizationContext" - component: Component - expression: LinearExpressionEfficient - - # OptimizationContext has knowledge of the block, so that get_value only needs block_timestep and scenario to get the correct data value - - def get_value(self, block_timestep: int, scenario: int) -> float: - param_value_provider = make_value_provider( - self.context, block_timestep, scenario, self.component - ) - return self.expression.evaluate(param_value_provider) - - def make_data_structure_provider( network: Network, component: Component ) -> IndexingStructureProvider: @@ -345,16 +320,6 @@ class ComponentContext: opt_context: OptimizationContext component: Component - def get_values( - self, expression: LinearExpressionEfficient - ) -> TimestepValueProvider: - """ - The returned value provider will evaluate the provided expression. - """ - return ExpressionTimestepValueProvider( - self.opt_context, self.component, expression - ) - def add_variable( self, block_timestep: int, diff --git a/tests/unittests/expressions/test_linear_expressions.py b/tests/unittests/expressions/test_linear_expressions.py deleted file mode 100644 index 21f916a7..00000000 --- a/tests/unittests/expressions/test_linear_expressions.py +++ /dev/null @@ -1,446 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -from typing import Dict - -import pytest - -from andromede.expression.scenario_operator import Expectation -from andromede.expression.time_operator import TimeShift, TimeSum -from andromede.simulation.linear_expression import LinearExpression, Term, TermKey - - -@pytest.mark.parametrize( - "term, expected", - [ - (Term(1, "c", "x"), "+x"), - (Term(-1, "c", "x"), "-x"), - (Term(2.50, "c", "x"), "+2.5x"), - (Term(-3, "c", "x"), "-3x"), - (Term(-3, "c", "x", time_operator=TimeShift(-1)), "-3x.shift(-1)"), - (Term(-3, "c", "x", time_aggregator=TimeSum(True)), "-3x.sum(True)"), - ( - Term( - -3, - "c", - "x", - time_operator=TimeShift([2, 3]), - time_aggregator=TimeSum(False), - ), - "-3x.shift([2, 3]).sum(False)", - ), - (Term(-3, "c", "x", scenario_aggregator=Expectation()), "-3x.expec()"), - ( - Term( - -3, - "c", - "x", - time_aggregator=TimeSum(True), - scenario_aggregator=Expectation(), - ), - "-3x.sum(True).expec()", - ), - ], -) -def test_printing_term(term: Term, expected: str) -> None: - assert str(term) == expected - - -@pytest.mark.parametrize( - "coeff, var_name, constant, expec_str", - [ - (0, "x", 0, "0"), - (1, "x", 0, "+x"), - (1, "x", 1, "+x+1"), - (3.7, "x", 1, "+3.7x+1"), - (0, "x", 1, "+1"), - ], -) -def test_affine_expression_printing_should_reflect_required_formatting( - coeff: float, var_name: str, constant: float, expec_str: str -) -> None: - expr = LinearExpression([Term(coeff, "c", var_name)], constant) - assert str(expr) == expec_str - - -@pytest.mark.parametrize( - "lhs, rhs", - [ - (LinearExpression([], 1) + LinearExpression([], 3), LinearExpression([], 4)), - (LinearExpression([], 4) / LinearExpression([], 2), LinearExpression([], 2)), - (LinearExpression([], 4) * LinearExpression([], 2), LinearExpression([], 8)), - (LinearExpression([], 4) - LinearExpression([], 2), LinearExpression([], 2)), - ], -) -def test_constant_expressions(lhs: LinearExpression, rhs: LinearExpression) -> None: - assert lhs == rhs - - -@pytest.mark.parametrize( - "terms_dict, constant, exp_terms, exp_constant", - [ - ({"x": Term(0, "c", "x")}, 1, {}, 1), - ({"x": Term(1, "c", "x")}, 1, {"x": Term(1, "c", "x")}, 1), - ], -) -def test_instantiate_linear_expression_from_dict( - terms_dict: Dict[TermKey, Term], - constant: float, - exp_terms: Dict[str, Term], - exp_constant: float, -) -> None: - expr = LinearExpression(terms_dict, constant) - assert expr.terms == exp_terms - assert expr.constant == exp_constant - - -@pytest.mark.parametrize( - "e1, e2, expected", - [ - ( - LinearExpression([Term(10, "c", "x")], 1), - LinearExpression([Term(5, "c", "x")], 2), - LinearExpression([Term(15, "c", "x")], 3), - ), - ( - LinearExpression([Term(10, "c1", "x")], 1), - LinearExpression([Term(5, "c2", "x")], 2), - LinearExpression([Term(10, "c1", "x"), Term(5, "c2", "x")], 3), - ), - ( - LinearExpression([Term(10, "c", "x")], 0), - LinearExpression([Term(5, "c", "y")], 0), - LinearExpression([Term(10, "c", "x"), Term(5, "c", "y")], 0), - ), - ( - LinearExpression(), - LinearExpression([Term(10, "c", "x", TimeShift(-1))]), - LinearExpression([Term(10, "c", "x", TimeShift(-1))]), - ), - ( - LinearExpression(), - LinearExpression( - [Term(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] - ), - LinearExpression( - [Term(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] - ), - ), - ( - LinearExpression([Term(10, "c", "x")]), - LinearExpression([Term(10, "c", "x", time_operator=TimeShift(-1))]), - LinearExpression( - [Term(10, "c", "x"), Term(10, "c", "x", time_operator=TimeShift(-1))] - ), - ), - ( - LinearExpression([Term(10, "c", "x")]), - LinearExpression( - [ - Term( - 10, - "c", - "x", - time_operator=TimeShift(-1), - scenario_aggregator=Expectation(), - ) - ] - ), - LinearExpression( - [ - Term(10, "c", "x"), - Term( - 10, - "c", - "x", - time_operator=TimeShift(-1), - scenario_aggregator=Expectation(), - ), - ] - ), - ), - ], -) -def test_addition( - e1: LinearExpression, e2: LinearExpression, expected: LinearExpression -) -> None: - assert e1 + e2 == expected - - -def test_addition_of_linear_expressions_with_different_number_of_instances_should_raise_value_error() -> ( - None -): - pass - - -def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_from_terms() -> ( - None -): - e1 = LinearExpression([Term(10, "c", "x")], 1) - e2 = LinearExpression([Term(10, "c", "x")], 2) - e3 = e2 - e1 - assert e3.terms == {} - - -@pytest.mark.parametrize( - "e1, e2, expected", - [ - ( - LinearExpression([Term(10, "c", "x")], 3), - LinearExpression([], 2), - LinearExpression([Term(20, "c", "x")], 6), - ), - ( - LinearExpression([Term(10, "c", "x")], 3), - LinearExpression([], 1), - LinearExpression([Term(10, "c", "x")], 3), - ), - ( - LinearExpression([Term(10, "c", "x")], 3), - LinearExpression(), - LinearExpression(), - ), - ( - LinearExpression( - [ - Term( - 10, - "c", - "x", - time_operator=TimeShift(-1), - scenario_aggregator=Expectation(), - ) - ], - 3, - ), - LinearExpression([], 2), - LinearExpression( - [ - Term( - 20, - "c", - "x", - time_operator=TimeShift(-1), - scenario_aggregator=Expectation(), - ) - ], - 6, - ), - ), - ], -) -def test_multiplication( - e1: LinearExpression, e2: LinearExpression, expected: LinearExpression -) -> None: - assert e1 * e2 == expected - assert e2 * e1 == expected - - -def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> None: - e1 = LinearExpression([Term(10, "c", "x")], 0) - e2 = LinearExpression([Term(5, "c", "x")], 0) - with pytest.raises(ValueError) as exc: - _ = e1 * e2 - assert str(exc.value) == "Cannot multiply two non constant expression" - - -@pytest.mark.parametrize( - "e1, expected", - [ - ( - LinearExpression([Term(10, "c", "x")], 5), - LinearExpression([Term(-10, "c", "x")], -5), - ), - ( - LinearExpression( - [ - Term( - 10, - "c", - "x", - time_operator=TimeShift(-1), - time_aggregator=TimeSum(False), - scenario_aggregator=Expectation(), - ) - ], - 5, - ), - LinearExpression( - [ - Term( - -10, - "c", - "x", - time_operator=TimeShift(-1), - time_aggregator=TimeSum(False), - scenario_aggregator=Expectation(), - ) - ], - -5, - ), - ), - ], -) -def test_negation(e1: LinearExpression, expected: LinearExpression) -> None: - assert -e1 == expected - - -@pytest.mark.parametrize( - "e1, e2, expected", - [ - ( - LinearExpression([Term(10, "c", "x")], 1), - LinearExpression([Term(5, "c", "x")], 2), - LinearExpression([Term(5, "c", "x")], -1), - ), - ( - LinearExpression([Term(10, "c1", "x")], 1), - LinearExpression([Term(5, "c2", "x")], 2), - LinearExpression([Term(10, "c1", "x"), Term(-5, "c2", "x")], -1), - ), - ( - LinearExpression([Term(10, "c", "x")], 0), - LinearExpression([Term(5, "c", "y")], 0), - LinearExpression([Term(10, "c", "x"), Term(-5, "c", "y")], 0), - ), - ( - LinearExpression(), - LinearExpression([Term(10, "c", "x", time_operator=TimeShift(-1))]), - LinearExpression([Term(-10, "c", "x", time_operator=TimeShift(-1))]), - ), - ( - LinearExpression(), - LinearExpression( - [Term(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] - ), - LinearExpression( - [Term(-10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] - ), - ), - ( - LinearExpression([Term(10, "c", "x")]), - LinearExpression([Term(10, "c", "x", time_operator=TimeShift(-1))]), - LinearExpression( - [Term(10, "c", "x"), Term(-10, "c", "x", time_operator=TimeShift(-1))] - ), - ), - ( - LinearExpression([Term(10, "c", "x")]), - LinearExpression( - [ - Term( - 10, - "c", - "x", - time_operator=TimeShift(-1), - time_aggregator=TimeSum(False), - scenario_aggregator=Expectation(), - ) - ] - ), - LinearExpression( - [ - Term(10, "c", "x"), - Term( - -10, - "c", - "x", - time_operator=TimeShift(-1), - time_aggregator=TimeSum(False), - scenario_aggregator=Expectation(), - ), - ] - ), - ), - ], -) -def test_substraction( - e1: LinearExpression, e2: LinearExpression, expected: LinearExpression -) -> None: - assert e1 - e2 == expected - - -@pytest.mark.parametrize( - "e1, e2, expected", - [ - ( - LinearExpression([Term(10, "c", "x")], 15), - LinearExpression([], 5), - LinearExpression([Term(2, "c", "x")], 3), - ), - ( - LinearExpression([Term(10, "c", "x")], 15), - LinearExpression([], 1), - LinearExpression([Term(10, "c", "x")], 15), - ), - ( - LinearExpression( - [ - Term( - 10, - "c", - "x", - time_operator=TimeShift(-1), - time_aggregator=TimeSum(False), - scenario_aggregator=Expectation(), - ) - ], - 15, - ), - LinearExpression([], 5), - LinearExpression( - [ - Term( - 2, - "c", - "x", - time_operator=TimeShift(-1), - time_aggregator=TimeSum(False), - scenario_aggregator=Expectation(), - ) - ], - 3, - ), - ), - ], -) -def test_division( - e1: LinearExpression, e2: LinearExpression, expected: LinearExpression -) -> None: - assert e1 / e2 == expected - - -def test_division_by_zero_sould_raise_zero_division_error() -> None: - e1 = LinearExpression([Term(10, "c", "x")], 15) - e2 = LinearExpression() - with pytest.raises(ZeroDivisionError) as exc: - _ = e1 / e2 - assert str(exc.value) == "Cannot divide expression by zero" - - -def test_division_by_non_constant_expr_sould_raise_value_error() -> None: - e1 = LinearExpression([Term(10, "c", "x")], 15) - e2 = LinearExpression() - with pytest.raises(ValueError) as exc: - _ = e2 / e1 - assert str(exc.value) == "Cannot divide by a non constant expression" - - -def test_imul_preserve_identity() -> None: - # technical test to check the behaviour of reassigning "self" in imul operator: - # it did not preserve identity, which could lead to weird behaviour - e1 = LinearExpression([], 15) - e2 = e1 - e1 *= LinearExpression([], 2) - assert e1 == LinearExpression([], 30) - assert e2 == e1 - assert e2 is e1 From 85d162839a82262843a4d664e590c51fc1f3336a Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 21 Aug 2024 20:39:59 +0200 Subject: [PATCH 37/51] Remove useless commented code --- src/andromede/expression/equality.py | 28 ------------------- .../evaluate_parameters_efficient.py | 3 -- .../expression/expression_efficient.py | 11 -------- .../expression/linear_expression_efficient.py | 11 -------- src/andromede/expression/visitor.py | 12 -------- src/andromede/model/model.py | 8 ------ src/andromede/simulation/optimization.py | 23 +++------------ 7 files changed, 4 insertions(+), 92 deletions(-) diff --git a/src/andromede/expression/equality.py b/src/andromede/expression/equality.py index 42b8155f..8f46b036 100644 --- a/src/andromede/expression/equality.py +++ b/src/andromede/expression/equality.py @@ -35,29 +35,6 @@ TimeOperatorNode, ) -# from andromede.expression import ( -# AdditionNode, -# ComparisonNode, -# DivisionNode, -# ExpressionNode, -# LiteralNode, -# MultiplicationNode, -# NegationNode, -# ParameterNode, -# SubstractionNode, -# VariableNode, -# ) -# from andromede.expression.expression import ( -# BinaryOperatorNode, -# ExpressionRange, -# InstancesTimeIndex, -# PortFieldAggregatorNode, -# PortFieldNode, -# ScenarioOperatorNode, -# TimeAggregatorNode, -# TimeOperatorNode, -# ) - @dataclass(frozen=True) class EqualityVisitor: @@ -95,8 +72,6 @@ def visit( return self.multiplication(left, right) if isinstance(left, ComparisonNode) and isinstance(right, ComparisonNode): return self.comparison(left, right) - # if isinstance(left, VariableNode) and isinstance(right, VariableNode): - # return self.variable(left, right) if isinstance(left, ParameterNode) and isinstance(right, ParameterNode): return self.parameter(left, right) if isinstance(left, ComponentParameterNode) and isinstance( @@ -151,9 +126,6 @@ def division(self, left: DivisionNode, right: DivisionNode) -> bool: def comparison(self, left: ComparisonNode, right: ComparisonNode) -> bool: return left.comparator == right.comparator and self._visit_operands(left, right) - # def variable(self, left: VariableNode, right: VariableNode) -> bool: - # return left.name == right.name - def parameter(self, left: ParameterNode, right: ParameterNode) -> bool: return left.name == right.name diff --git a/src/andromede/expression/evaluate_parameters_efficient.py b/src/andromede/expression/evaluate_parameters_efficient.py index 728b7eb3..e169c007 100644 --- a/src/andromede/expression/evaluate_parameters_efficient.py +++ b/src/andromede/expression/evaluate_parameters_efficient.py @@ -224,9 +224,6 @@ class InstancesIndexVisitor(ParameterEvaluationVisitor): Evaluates an expression given as instances index which should have no variable and constant parameter values. """ - # def variable(self, node: VariableNode) -> float: - # raise ValueError("An instance index expression cannot contain variable") - # Probably useless as parameter nodes should have already be replaced by component parameter nodes ? def parameter(self, node: ParameterNode) -> Dict[TimeScenarioIndex, float]: if not self.context.parameter_is_constant_over_time(node.name): diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index 4386b6d6..37d89b58 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -141,8 +141,6 @@ def wrap_in_node(obj: Any) -> ExpressionNodeEfficient: return obj elif isinstance(obj, float) or isinstance(obj, int): return LiteralNode(float(obj)) - # else: - # return None # Do not raise excpetion so that we can return NotImplemented in _apply_if_node # raise TypeError(f"Unable to wrap {obj} into an expression node") @@ -238,10 +236,6 @@ def _are_parameter_nodes_equal( ) -# def _is_parameter_multiplication(node: ExpressionNodeEfficient, name: str): -# return isinstance(node, MultiplicationNode) and ((isinstance(node.left, ParameterNode) and node.left.name == name) or - - def _substract_node( lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient ) -> ExpressionNodeEfficient: @@ -644,8 +638,6 @@ def visit( return self.multiplication(left, right) if isinstance(left, ComparisonNode) and isinstance(right, ComparisonNode): return self.comparison(left, right) - # if isinstance(left, VariableNode) and isinstance(right, VariableNode): - # return self.variable(left, right) if isinstance(left, ParameterNode) and isinstance(right, ParameterNode): return self.parameter(left, right) if isinstance(left, ComponentParameterNode) and isinstance( @@ -701,9 +693,6 @@ def division(self, left: DivisionNode, right: DivisionNode) -> bool: def comparison(self, left: ComparisonNode, right: ComparisonNode) -> bool: return left.comparator == right.comparator and self._visit_operands(left, right) - # def variable(self, left: VariableNode, right: VariableNode) -> bool: - # return left.name == right.name - def parameter(self, left: ParameterNode, right: ParameterNode) -> bool: return left.name == right.name diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 3ecc136c..220371e0 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -235,20 +235,12 @@ def sum( if shift is not None: return dataclasses.replace( self, - # coefficient=TimeOperatorNode( - # self.coefficient, TimeOperatorName.SHIFT, InstancesTimeIndex(shift) - # ), time_operator=TimeShift(InstancesTimeIndex(shift)), time_aggregator=TimeSum(stay_roll=True), ) elif eval is not None: return dataclasses.replace( self, - # coefficient=TimeOperatorNode( - # self.coefficient, - # TimeOperatorName.EVALUATION, - # InstancesTimeIndex(eval), - # ), time_operator=TimeEvaluation(InstancesTimeIndex(eval)), time_aggregator=TimeSum(stay_roll=True), ) @@ -902,9 +894,6 @@ def expec(self) -> "LinearExpressionEfficient": result_expr = LinearExpressionEfficient(result_terms, result_constant) return result_expr - # def variance(self) -> "ExpressionNode": - # return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.Variance)) - def sum_connections(self) -> "LinearExpressionEfficient": if not self.is_zero(): raise ValueError( diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index 3e91a645..2267b52d 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -74,10 +74,6 @@ def division(self, node: DivisionNode) -> T: def comparison(self, node: ComparisonNode) -> T: ... - # @abstractmethod - # def variable(self, node: VariableNode) -> T: - # ... - @abstractmethod def parameter(self, node: ParameterNode) -> T: ... @@ -86,10 +82,6 @@ def parameter(self, node: ParameterNode) -> T: def comp_parameter(self, node: ComponentParameterNode) -> T: ... - # @abstractmethod - # def comp_variable(self, node: ComponentVariableNode) -> T: - # ... - @abstractmethod def time_operator(self, node: TimeOperatorNode) -> T: ... @@ -119,14 +111,10 @@ def visit(root: ExpressionNodeEfficient, visitor: ExpressionVisitor[T]) -> T: return visitor.literal(root) elif isinstance(root, NegationNode): return visitor.negation(root) - # elif isinstance(root, VariableNode): - # return visitor.variable(root) elif isinstance(root, ParameterNode): return visitor.parameter(root) elif isinstance(root, ComponentParameterNode): return visitor.comp_parameter(root) - # elif isinstance(root, ComponentVariableNode): - # return visitor.comp_variable(root) elif isinstance(root, AdditionNode): return visitor.addition(root) elif isinstance(root, MultiplicationNode): diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index e151825c..22832105 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -265,9 +265,6 @@ def division(self, node: DivisionNode) -> None: def comparison(self, node: ComparisonNode) -> None: raise ValueError("Port definition cannot contain a comparison operator.") - # def variable(self, node: VariableNode) -> None: - # pass - def parameter(self, node: ParameterNode) -> None: pass @@ -276,11 +273,6 @@ def comp_parameter(self, node: ComponentParameterNode) -> None: "Port definition must not contain a parameter associated to a component." ) - # def comp_variable(self, node: ComponentVariableNode) -> None: - # raise ValueError( - # "Port definition must not contain a variable associated to a component." - # ) - def time_operator(self, node: TimeOperatorNode) -> None: visit(node.operand, self) diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 9f90378b..d6459681 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -26,10 +26,12 @@ LinearExpressionEfficient, RowIndex, ) -from .resolved_linear_expression import ResolvedLinearExpression from andromede.model.common import ValueType from andromede.model.constraint import Constraint from andromede.model.model import PortFieldId +from andromede.study.data import DataBase +from andromede.study.network import Component, Network + from .linear_expression_resolver import LinearExpressionResolver from .optimization_context import ( BlockBorderManagement, @@ -38,10 +40,9 @@ make_data_structure_provider, make_value_provider, ) +from .resolved_linear_expression import ResolvedLinearExpression from .strategy import MergedProblemStrategy, ModelSelectionStrategy from .time_block import TimeBlock -from andromede.study.data import DataBase -from andromede.study.network import Component, Network def _get_indexing( @@ -92,20 +93,11 @@ def _create_constraint( """ constraint_indexing = _compute_indexing_structure(context, constraint) - # Perf: Perform linearization (tree traversing) without timesteps so that we can get the number of instances for the expression (from the time_ids of operators) - # linear_expr = context.linearize_expression(0, 0, constraint.expression) - # # Will there be cases where instances > 1 ? If not, maybe just a check that get_number_of_instances == 1 is sufficient ? Anyway, the function should be implemented - # instances_per_time_step = linear_expr.number_of_instances() - # instances_per_time_step = 1 - value_provider = make_value_provider(context.opt_context, context.component) expression_resolver = LinearExpressionResolver(context.opt_context, value_provider) for block_timestep in context.opt_context.get_time_indices(constraint_indexing): for scenario in context.opt_context.get_scenario_indices(constraint_indexing): - # linear_expr_at_t = context.linearize_expression( - # block_timestep, scenario, constraint.expression - # ) row_id = RowIndex(block_timestep, scenario) resolved_expr = expression_resolver.resolve(constraint.expression, row_id) @@ -143,13 +135,6 @@ def _create_objective( obj: lp.Objective = solver.Objective() for term in resolved_expr.terms: - # TODO : How to handle the scenario operator in a general manner ? - # if isinstance(term.scenario_aggregator, Expectation): - # weight = 1 / opt_context.scenarios - # scenario_ids = range(opt_context.scenarios) - # else: - # weight = 1 - # scenario_ids = range(1) opt_context._solver_variables[term.variable].is_in_objective = True obj.SetCoefficient( term.variable, From b5c3328b6bd574c1933abaf4d7ed660ef9b9c9bb Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 21 Aug 2024 20:41:50 +0200 Subject: [PATCH 38/51] Remove useless commented code --- src/andromede/expression/copy.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/andromede/expression/copy.py b/src/andromede/expression/copy.py index 97f29a09..1f97109f 100644 --- a/src/andromede/expression/copy.py +++ b/src/andromede/expression/copy.py @@ -44,15 +44,9 @@ def comparison(self, node: ComparisonNode) -> ExpressionNodeEfficient: visit(node.left, self), visit(node.right, self), node.comparator ) - # def variable(self, node: VariableNode) -> ExpressionNodeEfficient: - # return VariableNode(node.name) - def parameter(self, node: ParameterNode) -> ExpressionNodeEfficient: return ParameterNode(node.name) - # def comp_variable(self, node: ComponentVariableNode) -> ExpressionNodeEfficient: - # return ComponentVariableNode(node.component_id, node.name) - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNodeEfficient: return ComponentParameterNode(node.component_id, node.name) From 344ac6547a7a491a9c40a62f215fe51d93e8178c Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Thu, 22 Aug 2024 10:32:53 +0200 Subject: [PATCH 39/51] Improve type checking for constraint --- src/andromede/libs/standard.py | 16 ++++++------- src/andromede/libs/standard_sc.py | 10 ++++---- src/andromede/model/constraint.py | 25 +++++++++++++------- src/andromede/model/resolve_library.py | 2 +- src/andromede/simulation/optimization.py | 2 +- tests/functional/test_xpansion.py | 6 ++--- tests/integration/test_benders_decomposed.py | 6 ++--- tests/models/test_electrolyzer.py | 6 ++--- tests/unittests/model/test_model_parsing.py | 4 ++-- tests/unittests/test_data.py | 4 ++-- tests/unittests/test_model.py | 8 +++---- tests/unittests/test_port.py | 2 +- 12 files changed, 49 insertions(+), 42 deletions(-) diff --git a/src/andromede/libs/standard.py b/src/andromede/libs/standard.py index a28ba905..7712ce69 100644 --- a/src/andromede/libs/standard.py +++ b/src/andromede/libs/standard.py @@ -36,7 +36,7 @@ binding_constraints=[ Constraint( name="Balance", - expression=port_field("balance_port", "flow").sum_connections() + expression_init=port_field("balance_port", "flow").sum_connections() == literal(0), ) ], @@ -53,7 +53,7 @@ binding_constraints=[ Constraint( name="Balance", - expression=port_field("balance_port", "flow").sum_connections() + expression_init=port_field("balance_port", "flow").sum_connections() == var("spillage") - var("unsupplied_energy"), ) ], @@ -126,7 +126,7 @@ ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= param("p_max") + name="Max generation", expression_init=var("generation") <= param("p_max") ), ], objective_operational_contribution=(param("cost") * var("generation")) @@ -151,11 +151,11 @@ ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= param("p_max") + name="Max generation", expression_init=var("generation") <= param("p_max") ), Constraint( name="Min generation", - expression=var("generation") - param("p_min"), + expression_init=var("generation") - param("p_min"), lower_bound=literal(0), ), # To test both ways of setting constraints ], @@ -185,11 +185,11 @@ ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= param("p_max") + name="Max generation", expression_init=var("generation") <= param("p_max") ), Constraint( name="Total storage", - expression=var("generation").sum() <= param("full_storage"), + expression_init=var("generation").sum() <= param("full_storage"), ), ], objective_operational_contribution=(param("cost") * var("generation")) @@ -418,7 +418,7 @@ constraints=[ Constraint( name="Level", - expression=var("level") + expression_init=var("level") - var("level").shift(-1) - param("efficiency") * var("injection") + var("withdrawal") diff --git a/src/andromede/libs/standard_sc.py b/src/andromede/libs/standard_sc.py index b9f9a38d..7fb45687 100644 --- a/src/andromede/libs/standard_sc.py +++ b/src/andromede/libs/standard_sc.py @@ -106,7 +106,7 @@ binding_constraints=[ Constraint( name="Conversion", - expression=var("input1") + var("input2") + expression_init=var("input1") + var("input2") == port_field("FlowDO", "flow").sum_connections(), ) ], @@ -131,7 +131,7 @@ binding_constraints=[ Constraint( name="Conversion", - expression=var("input") == port_field("FlowDI", "flow").sum_connections(), + expression_init=var("input") == port_field("FlowDI", "flow").sum_connections(), ) ], ) @@ -185,7 +185,7 @@ binding_constraints=[ Constraint( name="Bound CO2", - expression=port_field("emissionCO2", "Q").sum_connections() + expression_init=port_field("emissionCO2", "Q").sum_connections() <= param("quota"), ) ], @@ -207,7 +207,7 @@ binding_constraints=[ Constraint( name="Balance", - expression=var("p") + expression_init=var("p") == port_field("balance_port_n", "flow").sum_connections(), ) ], @@ -256,7 +256,7 @@ constraints=[ Constraint( name="Level", - expression=var("level") + expression_init=var("level") - var("level").shift(-1) - param("efficiency") * var("injection") + var("withdrawal") diff --git a/src/andromede/model/constraint.py b/src/andromede/model/constraint.py index f1cb27b8..f04ab1f0 100644 --- a/src/andromede/model/constraint.py +++ b/src/andromede/model/constraint.py @@ -9,10 +9,10 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. -from dataclasses import dataclass, field -from typing import Any +from dataclasses import InitVar, dataclass, field +from typing import Any, Union -from andromede.expression.expression_efficient import literal +from andromede.expression.expression_efficient import ExpressionNodeEfficient, literal from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, StandaloneConstraint, @@ -31,7 +31,11 @@ class Constraint: """ name: str - expression: LinearExpressionEfficient + # Used only for mypy type checking, we could have done the same by using only the attribute expression + expression_init: InitVar[ + Union[ExpressionNodeEfficient, LinearExpressionEfficient, StandaloneConstraint] + ] + expression: LinearExpressionEfficient = field(init=False) lower_bound: LinearExpressionEfficient = field( default=wrap_in_linear_expr(literal(-float("inf"))) ) @@ -42,23 +46,26 @@ class Constraint: def __post_init__( self, + expression_init: Union[ + ExpressionNodeEfficient, LinearExpressionEfficient, StandaloneConstraint + ], ) -> None: self.lower_bound = wrap_in_linear_expr(self.lower_bound) self.upper_bound = wrap_in_linear_expr(self.upper_bound) - if isinstance(self.expression, StandaloneConstraint): + if isinstance(expression_init, StandaloneConstraint): # Case where constraint is initialized with something like Constraint(var("x") <= var("y")) if not self.lower_bound.is_unbound() or not self.upper_bound.is_unbound(): raise ValueError( "Both comparison between two expressions and a bound are specfied, set either only a comparison between expressions or a single linear expression with bounds." ) - self.lower_bound = self.expression.lower_bound - self.upper_bound = self.expression.upper_bound - self.expression = self.expression.expression + self.lower_bound = expression_init.lower_bound + self.upper_bound = expression_init.upper_bound + self.expression = expression_init.expression else: - self.expression = wrap_in_linear_expr(self.expression) + self.expression = wrap_in_linear_expr(expression_init) for bound in [self.lower_bound, self.upper_bound]: if not bound.is_constant(): raise ValueError( diff --git a/src/andromede/model/resolve_library.py b/src/andromede/model/resolve_library.py index 5cf2cc89..60e5775d 100644 --- a/src/andromede/model/resolve_library.py +++ b/src/andromede/model/resolve_library.py @@ -155,7 +155,7 @@ def _to_constraint( ) -> Constraint: kwargs = { "name": constraint.name, - "expression": parse_expression(constraint.expression, identifiers), + "expression_init": parse_expression(constraint.expression, identifiers), } lb = wrap_in_linear_expr_if_present( _to_expression_if_present(constraint.lower_bound, identifiers) diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index d6459681..078dfa35 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -336,7 +336,7 @@ def _create_constraints(self) -> None: instantiated_constraint = Constraint( name=f"{component.id}_{constraint.name}", - expression=instantiated_expr, + expression_init=instantiated_expr, lower_bound=instantiated_lb, upper_bound=instantiated_ub, ) diff --git a/tests/functional/test_xpansion.py b/tests/functional/test_xpansion.py index 316ede8f..75d69ce9 100644 --- a/tests/functional/test_xpansion.py +++ b/tests/functional/test_xpansion.py @@ -86,7 +86,7 @@ def thermal_candidate() -> Model: ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= var("p_max") + name="Max generation", expression_init=var("generation") <= var("p_max") ) ], objective_operational_contribution=(param("op_cost") * var("generation")) @@ -131,11 +131,11 @@ def discrete_candidate() -> Model: ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= var("p_max") + name="Max generation", expression_init=var("generation") <= var("p_max") ), Constraint( name="Max investment", - expression=var("p_max") == param("p_max_per_unit") * var("nb_units"), + expression_init=var("p_max") == param("p_max_per_unit") * var("nb_units"), context=INVESTMENT, ), ], diff --git a/tests/integration/test_benders_decomposed.py b/tests/integration/test_benders_decomposed.py index 6123553f..cf599f67 100644 --- a/tests/integration/test_benders_decomposed.py +++ b/tests/integration/test_benders_decomposed.py @@ -88,7 +88,7 @@ def thermal_candidate() -> Model: ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= var("p_max") + name="Max generation", expression_init=var("generation") <= var("p_max") ) ], objective_operational_contribution=(param("op_cost") * var("generation")) @@ -134,11 +134,11 @@ def discrete_candidate() -> Model: ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= var("p_max") + name="Max generation", expression_init=var("generation") <= var("p_max") ), Constraint( name="Max investment", - expression=var("p_max") == param("p_max_per_unit") * var("nb_units"), + expression_init=var("p_max") == param("p_max_per_unit") * var("nb_units"), context=INVESTMENT, ), ], diff --git a/tests/models/test_electrolyzer.py b/tests/models/test_electrolyzer.py index f116921a..9e52b3f6 100644 --- a/tests/models/test_electrolyzer.py +++ b/tests/models/test_electrolyzer.py @@ -41,7 +41,7 @@ binding_constraints=[ Constraint( name="Balance", - expression=port_field("electrical_port", "flow").sum_connections() + expression_init=port_field("electrical_port", "flow").sum_connections() == literal(0), ) ], @@ -76,7 +76,7 @@ binding_constraints=[ Constraint( name="Balance", - expression=port_field("h2_port", "flow").sum_connections() == literal(0), + expression_init=port_field("h2_port", "flow").sum_connections() == literal(0), ) ], ) @@ -121,7 +121,7 @@ constraints=[ Constraint( name="Conversion", - expression=var("h2_output") + expression_init=var("h2_output") == var("electrical_input") * param("efficiency"), ) ], diff --git a/tests/unittests/model/test_model_parsing.py b/tests/unittests/model/test_model_parsing.py index 5665b490..922f9b2c 100644 --- a/tests/unittests/model/test_model_parsing.py +++ b/tests/unittests/model/test_model_parsing.py @@ -106,7 +106,7 @@ def test_library_parsing(data_dir: Path) -> None: constraints=[ Constraint( name="Level equation", - expression=var("level") + expression_init=var("level") - var("level").shift(-literal(1)) - param("efficiency") * var("injection") + var("withdrawal") @@ -162,7 +162,7 @@ def test_library_port_model_ok_parsing(data_dir: Path) -> None: constraints=[ Constraint( name="Level equation", - expression=port_field("injection_port", "flow") == var("withdrawal"), + expression_init=port_field("injection_port", "flow") == var("withdrawal"), ) ], ) diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index b1290b74..873a4109 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -86,7 +86,7 @@ def mock_generator_with_fixed_scenario_time_varying_param() -> Model: ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= param("p_max") + name="Max generation", expression_init=var("generation") <= param("p_max") ) ], objective_operational_contribution=(param("cost") * var("generation")) @@ -114,7 +114,7 @@ def mock_generator_with_scenario_varying_fixed_time_param() -> Model: ], constraints=[ Constraint( - name="Max generation", expression=var("generation") <= param("p_max") + name="Max generation", expression_init=var("generation") <= param("p_max") ) ], objective_operational_contribution=(param("cost") * var("generation")) diff --git a/tests/unittests/test_model.py b/tests/unittests/test_model.py index 43da1f8c..38e66a1c 100644 --- a/tests/unittests/test_model.py +++ b/tests/unittests/test_model.py @@ -262,9 +262,9 @@ def test_invalid_port_field_definition_should_raise( def test_constraint_equals() -> None: # checks in particular that expressions are correctly compared - assert Constraint(name="c", expression=var("x") <= param("p")) == Constraint( - name="c", expression=var("x") <= param("p") + assert Constraint(name="c", expression_init=var("x") <= param("p")) == Constraint( + name="c", expression_init=var("x") <= param("p") ) - assert Constraint(name="c", expression=var("x") <= param("p")) != Constraint( - name="c", expression=var("y") <= param("p") + assert Constraint(name="c", expression_init=var("x") <= param("p")) != Constraint( + name="c", expression_init=var("y") <= param("p") ) diff --git a/tests/unittests/test_port.py b/tests/unittests/test_port.py index 3e443b76..bca3da1f 100644 --- a/tests/unittests/test_port.py +++ b/tests/unittests/test_port.py @@ -31,7 +31,7 @@ def test_port_type_compatibility_ko() -> None: constraints=[ Constraint( name="Balance", - expression=port_field("balance_port", "flow").sum_connections() + expression_init=port_field("balance_port", "flow").sum_connections() == literal(0), ) ], From 600d1bc612385bb2437fe6840e33476f70cf14ac Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Thu, 22 Aug 2024 10:42:38 +0200 Subject: [PATCH 40/51] Improve type checking for port field def --- src/andromede/libs/standard.py | 22 +++++++++--------- src/andromede/libs/standard_sc.py | 24 ++++++++++---------- src/andromede/model/model.py | 21 +++++++++++------ tests/functional/test_andromede.py | 2 +- tests/functional/test_xpansion.py | 4 ++-- tests/integration/test_benders_decomposed.py | 4 ++-- tests/models/test_electrolyzer.py | 8 +++---- tests/unittests/model/test_model_parsing.py | 4 ++-- tests/unittests/test_data.py | 4 ++-- 9 files changed, 50 insertions(+), 43 deletions(-) diff --git a/src/andromede/libs/standard.py b/src/andromede/libs/standard.py index 7712ce69..110002ff 100644 --- a/src/andromede/libs/standard.py +++ b/src/andromede/libs/standard.py @@ -81,11 +81,11 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port_from", "flow"), - definition=-var("flow"), + definition_init=-var("flow"), ), PortFieldDefinition( port_field=PortFieldId("balance_port_to", "flow"), - definition=var("flow"), + definition_init=var("flow"), ), ], ) @@ -102,7 +102,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=-param("demand"), + definition_init=-param("demand"), ) ], ) @@ -121,7 +121,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ @@ -146,7 +146,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ @@ -180,7 +180,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ @@ -237,7 +237,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ @@ -315,7 +315,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ @@ -359,7 +359,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=-var("spillage"), + definition_init=-var("spillage"), ) ], objective_operational_contribution=(param("cost") * var("spillage")).sum().expec(), @@ -373,7 +373,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("unsupplied_energy"), + definition_init=var("unsupplied_energy"), ) ], objective_operational_contribution=(param("cost") * var("unsupplied_energy")) @@ -412,7 +412,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("withdrawal") - var("injection"), + definition_init=var("withdrawal") - var("injection"), ) ], constraints=[ diff --git a/src/andromede/libs/standard_sc.py b/src/andromede/libs/standard_sc.py index 7fb45687..b7025487 100644 --- a/src/andromede/libs/standard_sc.py +++ b/src/andromede/libs/standard_sc.py @@ -41,11 +41,11 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("FlowDI", "flow"), - definition=-var("input"), + definition_init=-var("input"), ), PortFieldDefinition( port_field=PortFieldId("FlowDO", "flow"), - definition=var("input") * param("alpha"), + definition_init=var("input") * param("alpha"), ), ], ) @@ -68,15 +68,15 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("FlowDI1", "flow"), - definition=-var("input1"), + definition_init=-var("input1"), ), PortFieldDefinition( port_field=PortFieldId("FlowDI2", "flow"), - definition=-var("input2"), + definition_init=-var("input2"), ), PortFieldDefinition( port_field=PortFieldId("FlowDO", "flow"), - definition=var("input1") * param("alpha1") + definition_init=var("input1") * param("alpha1") + var("input2") * param("alpha2"), ), ], @@ -96,11 +96,11 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("FlowDI1", "flow"), - definition=var("input1"), + definition_init=var("input1"), ), PortFieldDefinition( port_field=PortFieldId("FlowDI2", "flow"), - definition=var("input2"), + definition_init=var("input2"), ), ], binding_constraints=[ @@ -125,7 +125,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("FlowDO", "flow"), - definition=var("input") * param("alpha"), + definition_init=var("input") * param("alpha"), ), ], binding_constraints=[ @@ -164,11 +164,11 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("FlowP", "flow"), - definition=var("p"), + definition_init=var("p"), ), PortFieldDefinition( port_field=PortFieldId("OutCO2", "Q"), - definition=var("p") * param("emission_rate"), + definition_init=var("p") * param("emission_rate"), ), ], objective_operational_contribution=(param("cost") * var("p")).sum().expec(), @@ -201,7 +201,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port_e", "flow"), - definition=var("p"), + definition_init=var("p"), ) ], binding_constraints=[ @@ -250,7 +250,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("withdrawal") - var("injection"), + definition_init=var("withdrawal") - var("injection"), ) ], constraints=[ diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index 22832105..50e87ef3 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -16,8 +16,8 @@ defining parameters, variables, and equations. """ import itertools -from dataclasses import dataclass, field -from typing import Dict, Iterable, Optional +from dataclasses import InitVar, dataclass, field +from typing import Dict, Iterable, Optional, Union from andromede.expression.expression_efficient import ( AdditionNode, @@ -25,6 +25,7 @@ ComparisonNode, ComponentParameterNode, DivisionNode, + ExpressionNodeEfficient, LiteralNode, MultiplicationNode, NegationNode, @@ -121,15 +122,21 @@ class PortFieldDefinition: """ port_field: PortFieldId - definition: LinearExpressionEfficient - - def __post_init__(self) -> None: - object.__setattr__(self, "definition", wrap_in_linear_expr(self.definition)) + # Used only for type checking... + definition_init: InitVar[Union[ExpressionNodeEfficient, LinearExpressionEfficient]] + definition: LinearExpressionEfficient = field(init=False) + + def __post_init__( + self, definition_init: Union[ExpressionNodeEfficient, LinearExpressionEfficient] + ) -> None: + object.__setattr__(self, "definition", wrap_in_linear_expr(definition_init)) _validate_port_field_expression(self) def port_field_def( - port_name: str, field_name: str, definition: LinearExpressionEfficient + port_name: str, + field_name: str, + definition: Union[ExpressionNodeEfficient, LinearExpressionEfficient], ) -> PortFieldDefinition: return PortFieldDefinition(PortFieldId(port_name, field_name), definition) diff --git a/tests/functional/test_andromede.py b/tests/functional/test_andromede.py index 099c7706..a9f6fd03 100644 --- a/tests/functional/test_andromede.py +++ b/tests/functional/test_andromede.py @@ -145,7 +145,7 @@ def test_variable_bound() -> None: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], objective_operational_contribution=(param("cost") * var("generation")) diff --git a/tests/functional/test_xpansion.py b/tests/functional/test_xpansion.py index 75d69ce9..b32cfe7b 100644 --- a/tests/functional/test_xpansion.py +++ b/tests/functional/test_xpansion.py @@ -81,7 +81,7 @@ def thermal_candidate() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ @@ -126,7 +126,7 @@ def discrete_candidate() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ diff --git a/tests/integration/test_benders_decomposed.py b/tests/integration/test_benders_decomposed.py index cf599f67..484dbf8f 100644 --- a/tests/integration/test_benders_decomposed.py +++ b/tests/integration/test_benders_decomposed.py @@ -83,7 +83,7 @@ def thermal_candidate() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ @@ -129,7 +129,7 @@ def discrete_candidate() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ diff --git a/tests/models/test_electrolyzer.py b/tests/models/test_electrolyzer.py index 9e52b3f6..d2b632de 100644 --- a/tests/models/test_electrolyzer.py +++ b/tests/models/test_electrolyzer.py @@ -60,7 +60,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("electrical_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], objective_operational_contribution=(param("cost") * var("generation")) @@ -90,7 +90,7 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("h2_port", "flow"), - definition=-param("demand"), + definition_init=-param("demand"), ) ], ) @@ -111,11 +111,11 @@ port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("electrical_port", "flow"), - definition=-var("electrical_input"), + definition_init=-var("electrical_input"), ), PortFieldDefinition( port_field=PortFieldId("h2_port", "flow"), - definition=var("h2_output"), + definition_init=var("h2_output"), ), ], constraints=[ diff --git a/tests/unittests/model/test_model_parsing.py b/tests/unittests/model/test_model_parsing.py index 922f9b2c..9878613f 100644 --- a/tests/unittests/model/test_model_parsing.py +++ b/tests/unittests/model/test_model_parsing.py @@ -61,7 +61,7 @@ def test_library_parsing(data_dir: Path) -> None: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId(port_name="injection_port", field_name="flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], objective_operational_contribution=(param("cost") * var("generation")) @@ -100,7 +100,7 @@ def test_library_parsing(data_dir: Path) -> None: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId(port_name="injection_port", field_name="flow"), - definition=var("injection") - var("withdrawal"), + definition_init=var("injection") - var("withdrawal"), ) ], constraints=[ diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 873a4109..0a5e381e 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -81,7 +81,7 @@ def mock_generator_with_fixed_scenario_time_varying_param() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ @@ -109,7 +109,7 @@ def mock_generator_with_scenario_varying_fixed_time_param() -> Model: port_fields_definitions=[ PortFieldDefinition( port_field=PortFieldId("balance_port", "flow"), - definition=var("generation"), + definition_init=var("generation"), ) ], constraints=[ From d3317ef055b47f0db33e600904e4d6cfa9765a9c Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Thu, 22 Aug 2024 11:05:37 +0200 Subject: [PATCH 41/51] Improve type checking for **kwargs --- src/andromede/model/resolve_library.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/andromede/model/resolve_library.py b/src/andromede/model/resolve_library.py index 60e5775d..0d63c924 100644 --- a/src/andromede/model/resolve_library.py +++ b/src/andromede/model/resolve_library.py @@ -9,12 +9,14 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. -from typing import Dict, List, Optional +from typing import Dict, List, Optional, TypedDict, Union # from andromede.expression import ExpressionNode +from andromede.expression.expression_efficient import ExpressionNodeEfficient from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, + StandaloneConstraint, wrap_in_linear_expr_if_present, ) from andromede.expression.parsing.parse_expression import ( @@ -150,10 +152,20 @@ def _to_variable(var: InputVariable, identifiers: ModelIdentifiers) -> Variable: ) +# Used only for mypy +class ConstraintKwargs(TypedDict, total=False): + name: str + expression_init: Union[ + ExpressionNodeEfficient, LinearExpressionEfficient, StandaloneConstraint + ] + lower_bound: LinearExpressionEfficient + upper_bound: LinearExpressionEfficient + + def _to_constraint( constraint: InputConstraint, identifiers: ModelIdentifiers ) -> Constraint: - kwargs = { + kwargs: ConstraintKwargs = { "name": constraint.name, "expression_init": parse_expression(constraint.expression, identifiers), } From 7c9d42767143e509ab75c75de620ed817e56bed8 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Thu, 22 Aug 2024 11:10:17 +0200 Subject: [PATCH 42/51] Type checking and reformatting --- src/andromede/libs/standard.py | 10 ++++++++-- src/andromede/libs/standard_sc.py | 3 ++- tests/functional/test_xpansion.py | 3 ++- tests/integration/test_benders_decomposed.py | 3 ++- tests/models/test_electrolyzer.py | 3 ++- tests/unittests/model/test_model_parsing.py | 3 ++- tests/unittests/test_data.py | 6 ++++-- 7 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/andromede/libs/standard.py b/src/andromede/libs/standard.py index 110002ff..95fc9473 100644 --- a/src/andromede/libs/standard.py +++ b/src/andromede/libs/standard.py @@ -16,7 +16,11 @@ from andromede.expression.expression_efficient import ExpressionRange, literal, param from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import port_field, var +from andromede.expression.linear_expression_efficient import ( + port_field, + var, + wrap_in_linear_expr, +) from andromede.model.constraint import Constraint from andromede.model.model import ModelPort, PortFieldDefinition, PortFieldId, model from andromede.model.parameter import float_parameter, int_parameter @@ -156,7 +160,9 @@ Constraint( name="Min generation", expression_init=var("generation") - param("p_min"), - lower_bound=literal(0), + lower_bound=wrap_in_linear_expr( + literal(0) + ), # wrap_in_linear_expr is not needed as done in __post_init__, it is used here only for type checking... ), # To test both ways of setting constraints ], objective_operational_contribution=(param("cost") * var("generation")) diff --git a/src/andromede/libs/standard_sc.py b/src/andromede/libs/standard_sc.py index b7025487..675b727d 100644 --- a/src/andromede/libs/standard_sc.py +++ b/src/andromede/libs/standard_sc.py @@ -131,7 +131,8 @@ binding_constraints=[ Constraint( name="Conversion", - expression_init=var("input") == port_field("FlowDI", "flow").sum_connections(), + expression_init=var("input") + == port_field("FlowDI", "flow").sum_connections(), ) ], ) diff --git a/tests/functional/test_xpansion.py b/tests/functional/test_xpansion.py index b32cfe7b..96875e7c 100644 --- a/tests/functional/test_xpansion.py +++ b/tests/functional/test_xpansion.py @@ -135,7 +135,8 @@ def discrete_candidate() -> Model: ), Constraint( name="Max investment", - expression_init=var("p_max") == param("p_max_per_unit") * var("nb_units"), + expression_init=var("p_max") + == param("p_max_per_unit") * var("nb_units"), context=INVESTMENT, ), ], diff --git a/tests/integration/test_benders_decomposed.py b/tests/integration/test_benders_decomposed.py index 484dbf8f..3cc8ce32 100644 --- a/tests/integration/test_benders_decomposed.py +++ b/tests/integration/test_benders_decomposed.py @@ -138,7 +138,8 @@ def discrete_candidate() -> Model: ), Constraint( name="Max investment", - expression_init=var("p_max") == param("p_max_per_unit") * var("nb_units"), + expression_init=var("p_max") + == param("p_max_per_unit") * var("nb_units"), context=INVESTMENT, ), ], diff --git a/tests/models/test_electrolyzer.py b/tests/models/test_electrolyzer.py index d2b632de..8e9f5504 100644 --- a/tests/models/test_electrolyzer.py +++ b/tests/models/test_electrolyzer.py @@ -76,7 +76,8 @@ binding_constraints=[ Constraint( name="Balance", - expression_init=port_field("h2_port", "flow").sum_connections() == literal(0), + expression_init=port_field("h2_port", "flow").sum_connections() + == literal(0), ) ], ) diff --git a/tests/unittests/model/test_model_parsing.py b/tests/unittests/model/test_model_parsing.py index 9878613f..4132d964 100644 --- a/tests/unittests/model/test_model_parsing.py +++ b/tests/unittests/model/test_model_parsing.py @@ -162,7 +162,8 @@ def test_library_port_model_ok_parsing(data_dir: Path) -> None: constraints=[ Constraint( name="Level equation", - expression_init=port_field("injection_port", "flow") == var("withdrawal"), + expression_init=port_field("injection_port", "flow") + == var("withdrawal"), ) ], ) diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 0a5e381e..7783fc9c 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -86,7 +86,8 @@ def mock_generator_with_fixed_scenario_time_varying_param() -> Model: ], constraints=[ Constraint( - name="Max generation", expression_init=var("generation") <= param("p_max") + name="Max generation", + expression_init=var("generation") <= param("p_max"), ) ], objective_operational_contribution=(param("cost") * var("generation")) @@ -114,7 +115,8 @@ def mock_generator_with_scenario_varying_fixed_time_param() -> Model: ], constraints=[ Constraint( - name="Max generation", expression_init=var("generation") <= param("p_max") + name="Max generation", + expression_init=var("generation") <= param("p_max"), ) ], objective_operational_contribution=(param("cost") * var("generation")) From 60227811d10d1ac6a213dcdb3f83d52f5a878ee9 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Fri, 23 Aug 2024 14:31:50 +0200 Subject: [PATCH 43/51] Remove useless code --- src/andromede/simulation/optimization.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 078dfa35..55bc5919 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -348,7 +348,6 @@ def _create_constraints(self) -> None: def _create_objectives(self) -> None: for component in self.context.network.all_components: - component_context = self.context.get_component_context(component) model = component.model for objective in self.strategy.get_objectives(model): From 93eda7eb2ae99d8009d6b9714572f504cb779cf3 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 27 Aug 2024 16:00:48 +0200 Subject: [PATCH 44/51] Rename files --- src/andromede/expression/__init__.py | 4 ++-- src/andromede/expression/context_adder.py | 2 +- src/andromede/expression/copy.py | 2 +- src/andromede/expression/equality.py | 2 +- ...valuate_parameters_efficient.py => evaluate_parameters.py} | 2 +- .../expression/{expression_efficient.py => expression.py} | 0 .../{linear_expression_efficient.py => linear_expression.py} | 4 ++-- src/andromede/expression/parsing/parse_expression.py | 4 ++-- src/andromede/expression/print.py | 2 +- src/andromede/expression/time_operator.py | 2 +- src/andromede/expression/visitor.py | 2 +- src/andromede/libs/standard.py | 4 ++-- src/andromede/libs/standard_sc.py | 4 ++-- src/andromede/model/common.py | 4 ++-- src/andromede/model/constraint.py | 4 ++-- src/andromede/model/model.py | 4 ++-- src/andromede/model/resolve_library.py | 4 ++-- src/andromede/model/variable.py | 4 ++-- src/andromede/simulation/linear_expression_resolver.py | 4 ++-- src/andromede/simulation/optimization.py | 2 +- src/andromede/simulation/optimization_context.py | 4 ++-- src/andromede/simulation/strategy.py | 2 +- tests/functional/test_andromede.py | 4 ++-- tests/functional/test_performance_efficient.py | 4 ++-- tests/functional/test_xpansion.py | 4 ++-- tests/integration/test_benders_decomposed.py | 4 ++-- tests/models/test_electrolyzer.py | 4 ++-- .../unittests/expressions/parsing/test_expression_parsing.py | 4 ++-- tests/unittests/expressions/test_equality.py | 2 +- tests/unittests/expressions/test_expressions_efficient.py | 4 ++-- .../expressions/test_linear_expressions_efficient.py | 4 ++-- tests/unittests/expressions/test_port_resolver.py | 2 +- tests/unittests/expressions/test_resolve_coefficients.py | 4 ++-- tests/unittests/expressions/test_term_efficient.py | 4 ++-- tests/unittests/model/test_model_parsing.py | 4 ++-- tests/unittests/test_data.py | 4 ++-- tests/unittests/test_model.py | 4 ++-- tests/unittests/test_port.py | 4 ++-- 38 files changed, 63 insertions(+), 63 deletions(-) rename src/andromede/expression/{evaluate_parameters_efficient.py => evaluate_parameters.py} (99%) rename src/andromede/expression/{expression_efficient.py => expression.py} (100%) rename src/andromede/expression/{linear_expression_efficient.py => linear_expression.py} (99%) diff --git a/src/andromede/expression/__init__.py b/src/andromede/expression/__init__.py index 5b62b67e..e207690d 100644 --- a/src/andromede/expression/__init__.py +++ b/src/andromede/expression/__init__.py @@ -11,8 +11,8 @@ # This file is part of the Antares project. from .copy import CopyVisitor, copy_expression -from .evaluate_parameters_efficient import ValueProvider -from .expression_efficient import ( +from .evaluate_parameters import ValueProvider +from .expression import ( AdditionNode, ComparisonNode, ComponentParameterNode, diff --git a/src/andromede/expression/context_adder.py b/src/andromede/expression/context_adder.py index a5fe4d52..b43d1442 100644 --- a/src/andromede/expression/context_adder.py +++ b/src/andromede/expression/context_adder.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from . import CopyVisitor -from .expression_efficient import ( +from .expression import ( ComponentParameterNode, ExpressionNodeEfficient, ParameterNode, diff --git a/src/andromede/expression/copy.py b/src/andromede/expression/copy.py index 1f97109f..09c64ca4 100644 --- a/src/andromede/expression/copy.py +++ b/src/andromede/expression/copy.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from typing import List, cast -from .expression_efficient import ( +from .expression import ( ComparisonNode, ComponentParameterNode, ExpressionNodeEfficient, diff --git a/src/andromede/expression/equality.py b/src/andromede/expression/equality.py index 8f46b036..760cb51c 100644 --- a/src/andromede/expression/equality.py +++ b/src/andromede/expression/equality.py @@ -14,7 +14,7 @@ from dataclasses import dataclass from typing import Optional -from andromede.expression.expression_efficient import ( +from andromede.expression.expression import ( AdditionNode, BinaryOperatorNode, ComparisonNode, diff --git a/src/andromede/expression/evaluate_parameters_efficient.py b/src/andromede/expression/evaluate_parameters.py similarity index 99% rename from src/andromede/expression/evaluate_parameters_efficient.py rename to src/andromede/expression/evaluate_parameters.py index e169c007..deb84575 100644 --- a/src/andromede/expression/evaluate_parameters_efficient.py +++ b/src/andromede/expression/evaluate_parameters.py @@ -14,7 +14,7 @@ from dataclasses import dataclass, field from typing import Callable, Dict, List -from .expression_efficient import ( +from .expression import ( AdditionNode, ComparisonNode, ComponentParameterNode, diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression.py similarity index 100% rename from src/andromede/expression/expression_efficient.py rename to src/andromede/expression/expression.py diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression.py similarity index 99% rename from src/andromede/expression/linear_expression_efficient.py rename to src/andromede/expression/linear_expression.py index 220371e0..9d81bbb3 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression.py @@ -33,8 +33,8 @@ from .context_adder import add_component_context from .equality import expressions_equal -from .evaluate_parameters_efficient import check_resolved_expr, resolve_coefficient -from .expression_efficient import ( +from .evaluate_parameters import check_resolved_expr, resolve_coefficient +from .expression import ( ExpressionNodeEfficient, ExpressionRange, InstancesTimeIndex, diff --git a/src/andromede/expression/parsing/parse_expression.py b/src/andromede/expression/parsing/parse_expression.py index bb20c66c..fcc37459 100644 --- a/src/andromede/expression/parsing/parse_expression.py +++ b/src/andromede/expression/parsing/parse_expression.py @@ -16,14 +16,14 @@ from antlr4.error.ErrorStrategy import BailErrorStrategy from andromede.expression.equality import expressions_equal -from andromede.expression.expression_efficient import ( +from andromede.expression.expression import ( Comparator, ComparisonNode, ExpressionRange, literal, param, ) -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( LinearExpressionEfficient, port_field, var, diff --git a/src/andromede/expression/print.py b/src/andromede/expression/print.py index 8fd8c53a..7fb5a5de 100644 --- a/src/andromede/expression/print.py +++ b/src/andromede/expression/print.py @@ -13,7 +13,7 @@ from dataclasses import dataclass from typing import Dict -from .expression_efficient import ( +from .expression import ( AdditionNode, Comparator, ComparisonNode, diff --git a/src/andromede/expression/time_operator.py b/src/andromede/expression/time_operator.py index 1332b2dc..3d078920 100644 --- a/src/andromede/expression/time_operator.py +++ b/src/andromede/expression/time_operator.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from typing import Any, List, Tuple -from andromede.expression.expression_efficient import InstancesTimeIndex +from andromede.expression.expression import InstancesTimeIndex @dataclass(frozen=True) diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index 2267b52d..67f56470 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from typing import Generic, Protocol, TypeVar -from .expression_efficient import ( +from .expression import ( AdditionNode, ComparisonNode, ComponentParameterNode, diff --git a/src/andromede/libs/standard.py b/src/andromede/libs/standard.py index 95fc9473..5e6c8c20 100644 --- a/src/andromede/libs/standard.py +++ b/src/andromede/libs/standard.py @@ -14,9 +14,9 @@ The standard module contains the definition of standard models. """ -from andromede.expression.expression_efficient import ExpressionRange, literal, param +from andromede.expression.expression import ExpressionRange, literal, param from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( port_field, var, wrap_in_linear_expr, diff --git a/src/andromede/libs/standard_sc.py b/src/andromede/libs/standard_sc.py index 675b727d..12243bdc 100644 --- a/src/andromede/libs/standard_sc.py +++ b/src/andromede/libs/standard_sc.py @@ -11,8 +11,8 @@ # This file is part of the Antares project. -from andromede.expression.expression_efficient import literal, param -from andromede.expression.linear_expression_efficient import port_field, var +from andromede.expression.expression import literal, param +from andromede.expression.linear_expression import port_field, var from andromede.libs.standard import BALANCE_PORT_TYPE, CONSTANT from andromede.model import ( Constraint, diff --git a/src/andromede/model/common.py b/src/andromede/model/common.py index 444abe69..180db628 100644 --- a/src/andromede/model/common.py +++ b/src/andromede/model/common.py @@ -16,8 +16,8 @@ from enum import Enum from typing import Union -from andromede.expression.expression_efficient import ExpressionNodeEfficient -from andromede.expression.linear_expression_efficient import LinearExpressionEfficient +from andromede.expression.expression import ExpressionNodeEfficient +from andromede.expression.linear_expression import LinearExpressionEfficient ValueOrExprNodeOrLinearExpr = Union[ int, float, ExpressionNodeEfficient, LinearExpressionEfficient diff --git a/src/andromede/model/constraint.py b/src/andromede/model/constraint.py index f04ab1f0..02a03897 100644 --- a/src/andromede/model/constraint.py +++ b/src/andromede/model/constraint.py @@ -12,8 +12,8 @@ from dataclasses import InitVar, dataclass, field from typing import Any, Union -from andromede.expression.expression_efficient import ExpressionNodeEfficient, literal -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.expression import ExpressionNodeEfficient, literal +from andromede.expression.linear_expression import ( LinearExpressionEfficient, StandaloneConstraint, linear_expressions_equal, diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index 50e87ef3..d19eeb84 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -19,7 +19,7 @@ from dataclasses import InitVar, dataclass, field from typing import Dict, Iterable, Optional, Union -from andromede.expression.expression_efficient import ( +from andromede.expression.expression import ( AdditionNode, BinaryOperatorNode, ComparisonNode, @@ -39,7 +39,7 @@ ) from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( LinearExpressionEfficient, is_linear, wrap_in_linear_expr, diff --git a/src/andromede/model/resolve_library.py b/src/andromede/model/resolve_library.py index 0d63c924..fc08feaa 100644 --- a/src/andromede/model/resolve_library.py +++ b/src/andromede/model/resolve_library.py @@ -12,9 +12,9 @@ from typing import Dict, List, Optional, TypedDict, Union # from andromede.expression import ExpressionNode -from andromede.expression.expression_efficient import ExpressionNodeEfficient +from andromede.expression.expression import ExpressionNodeEfficient from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( LinearExpressionEfficient, StandaloneConstraint, wrap_in_linear_expr_if_present, diff --git a/src/andromede/model/variable.py b/src/andromede/model/variable.py index 28343e42..6900cf88 100644 --- a/src/andromede/model/variable.py +++ b/src/andromede/model/variable.py @@ -13,9 +13,9 @@ from dataclasses import dataclass from typing import Any, Optional -from andromede.expression.expression_efficient import literal +from andromede.expression.expression import literal from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( LinearExpressionEfficient, linear_expressions_equal_if_present, wrap_in_linear_expr, diff --git a/src/andromede/simulation/linear_expression_resolver.py b/src/andromede/simulation/linear_expression_resolver.py index 0ae9c736..5e15d7e4 100644 --- a/src/andromede/simulation/linear_expression_resolver.py +++ b/src/andromede/simulation/linear_expression_resolver.py @@ -15,12 +15,12 @@ import ortools.linear_solver.pywraplp as lp -from andromede.expression.evaluate_parameters_efficient import ( +from andromede.expression.evaluate_parameters import ( get_time_ids_from_instances_index, resolve_coefficient, ) from andromede.expression.indexing_structure import RowIndex -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( LinearExpressionEfficient, TermEfficient, ) diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 55bc5919..28e347c3 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -22,7 +22,7 @@ from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( LinearExpressionEfficient, RowIndex, ) diff --git a/src/andromede/simulation/optimization_context.py b/src/andromede/simulation/optimization_context.py index 4c046e03..530cd013 100644 --- a/src/andromede/simulation/optimization_context.py +++ b/src/andromede/simulation/optimization_context.py @@ -17,10 +17,10 @@ import ortools.linear_solver.pywraplp as lp -from andromede.expression.evaluate_parameters_efficient import ValueProvider +from andromede.expression.evaluate_parameters import ValueProvider from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( LinearExpressionEfficient, PortFieldId, PortFieldKey, diff --git a/src/andromede/simulation/strategy.py b/src/andromede/simulation/strategy.py index 0f148bda..6a5326c7 100644 --- a/src/andromede/simulation/strategy.py +++ b/src/andromede/simulation/strategy.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from typing import Generator, Optional -from andromede.expression.linear_expression_efficient import LinearExpressionEfficient +from andromede.expression.linear_expression import LinearExpressionEfficient from andromede.model import Constraint, Model, ProblemContext, Variable diff --git a/tests/functional/test_andromede.py b/tests/functional/test_andromede.py index a9f6fd03..16837352 100644 --- a/tests/functional/test_andromede.py +++ b/tests/functional/test_andromede.py @@ -13,9 +13,9 @@ import pandas as pd import pytest -from andromede.expression.expression_efficient import literal, param +from andromede.expression.expression import literal, param from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import var +from andromede.expression.linear_expression import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, DEMAND_MODEL, diff --git a/tests/functional/test_performance_efficient.py b/tests/functional/test_performance_efficient.py index ed00bd02..d6b74b20 100644 --- a/tests/functional/test_performance_efficient.py +++ b/tests/functional/test_performance_efficient.py @@ -14,9 +14,9 @@ import pytest from andromede.expression.evaluate import EvaluationContext -from andromede.expression.expression_efficient import param +from andromede.expression.expression import param from andromede.expression.indexing_structure import RowIndex -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( literal, var, wrap_in_linear_expr, diff --git a/tests/functional/test_xpansion.py b/tests/functional/test_xpansion.py index 96875e7c..3757b8d2 100644 --- a/tests/functional/test_xpansion.py +++ b/tests/functional/test_xpansion.py @@ -13,9 +13,9 @@ import pandas as pd import pytest -from andromede.expression.expression_efficient import literal, param +from andromede.expression.expression import literal, param from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import var +from andromede.expression.linear_expression import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, CONSTANT, diff --git a/tests/integration/test_benders_decomposed.py b/tests/integration/test_benders_decomposed.py index 3cc8ce32..c9a0db33 100644 --- a/tests/integration/test_benders_decomposed.py +++ b/tests/integration/test_benders_decomposed.py @@ -13,9 +13,9 @@ import pandas as pd import pytest -from andromede.expression.expression_efficient import literal, param +from andromede.expression.expression import literal, param from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import var +from andromede.expression.linear_expression import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, CONSTANT, diff --git a/tests/models/test_electrolyzer.py b/tests/models/test_electrolyzer.py index 8e9f5504..d6ef3808 100644 --- a/tests/models/test_electrolyzer.py +++ b/tests/models/test_electrolyzer.py @@ -10,8 +10,8 @@ # # This file is part of the Antares project. -from andromede.expression.expression_efficient import literal, param -from andromede.expression.linear_expression_efficient import port_field, var +from andromede.expression.expression import literal, param +from andromede.expression.linear_expression import port_field, var from andromede.libs.standard import CONSTANT, TIME_AND_SCENARIO_FREE from andromede.model import ( Constraint, diff --git a/tests/unittests/expressions/parsing/test_expression_parsing.py b/tests/unittests/expressions/parsing/test_expression_parsing.py index cec5b363..1aede958 100644 --- a/tests/unittests/expressions/parsing/test_expression_parsing.py +++ b/tests/unittests/expressions/parsing/test_expression_parsing.py @@ -14,13 +14,13 @@ import pytest from andromede.expression.equality import expressions_equal -from andromede.expression.expression_efficient import ( +from andromede.expression.expression import ( ExpressionNodeEfficient, ExpressionRange, literal, param, ) -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( LinearExpressionEfficient, StandaloneConstraint, linear_expressions_equal, diff --git a/tests/unittests/expressions/test_equality.py b/tests/unittests/expressions/test_equality.py index 5952f28b..d654bdf1 100644 --- a/tests/unittests/expressions/test_equality.py +++ b/tests/unittests/expressions/test_equality.py @@ -15,7 +15,7 @@ from andromede.expression.copy import copy_expression from andromede.expression.equality import expressions_equal -from andromede.expression.expression_efficient import ( +from andromede.expression.expression import ( ExpressionNodeEfficient, InstancesTimeIndex, TimeAggregatorName, diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index a2f73401..0262fc17 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -18,7 +18,7 @@ from andromede.expression.equality import expressions_equal from andromede.expression.evaluate import EvaluationContext, ValueProvider -from andromede.expression.expression_efficient import ( +from andromede.expression.expression import ( ComponentParameterNode, ExpressionNodeEfficient, ExpressionRange, @@ -35,7 +35,7 @@ ) from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure, RowIndex -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( LinearExpressionEfficient, StandaloneConstraint, TermEfficient, diff --git a/tests/unittests/expressions/test_linear_expressions_efficient.py b/tests/unittests/expressions/test_linear_expressions_efficient.py index 4d703500..eba57c14 100644 --- a/tests/unittests/expressions/test_linear_expressions_efficient.py +++ b/tests/unittests/expressions/test_linear_expressions_efficient.py @@ -14,12 +14,12 @@ import pytest -from andromede.expression.expression_efficient import ( +from andromede.expression.expression import ( TimeAggregatorNode, expression_range, param, ) -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( LinearExpressionEfficient, PortFieldId, PortFieldTerm, diff --git a/tests/unittests/expressions/test_port_resolver.py b/tests/unittests/expressions/test_port_resolver.py index 5ca982d4..4008fba9 100644 --- a/tests/unittests/expressions/test_port_resolver.py +++ b/tests/unittests/expressions/test_port_resolver.py @@ -15,7 +15,7 @@ import pytest from andromede.expression.equality import expressions_equal -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.linear_expression import ( LinearExpressionEfficient, PortFieldId, PortFieldKey, diff --git a/tests/unittests/expressions/test_resolve_coefficients.py b/tests/unittests/expressions/test_resolve_coefficients.py index 2211a617..08daa4ab 100644 --- a/tests/unittests/expressions/test_resolve_coefficients.py +++ b/tests/unittests/expressions/test_resolve_coefficients.py @@ -16,8 +16,8 @@ import pytest -from andromede.expression.evaluate_parameters_efficient import resolve_coefficient -from andromede.expression.expression_efficient import ( +from andromede.expression.evaluate_parameters import resolve_coefficient +from andromede.expression.expression import ( Comparator, ComparisonNode, ExpressionNodeEfficient, diff --git a/tests/unittests/expressions/test_term_efficient.py b/tests/unittests/expressions/test_term_efficient.py index 96b25262..45bb66e4 100644 --- a/tests/unittests/expressions/test_term_efficient.py +++ b/tests/unittests/expressions/test_term_efficient.py @@ -12,8 +12,8 @@ import pytest -from andromede.expression.expression_efficient import LiteralNode -from andromede.expression.linear_expression_efficient import TermEfficient +from andromede.expression.expression import LiteralNode +from andromede.expression.linear_expression import TermEfficient from andromede.expression.scenario_operator import Expectation, Variance from andromede.expression.time_operator import TimeShift, TimeSum diff --git a/tests/unittests/model/test_model_parsing.py b/tests/unittests/model/test_model_parsing.py index 4132d964..a79fcec5 100644 --- a/tests/unittests/model/test_model_parsing.py +++ b/tests/unittests/model/test_model_parsing.py @@ -13,8 +13,8 @@ import pytest -from andromede.expression.expression_efficient import literal, param -from andromede.expression.linear_expression_efficient import port_field, var +from andromede.expression.expression import literal, param +from andromede.expression.linear_expression import port_field, var from andromede.expression.parsing.parse_expression import AntaresParseException from andromede.libs.standard import CONSTANT from andromede.model import ( diff --git a/tests/unittests/test_data.py b/tests/unittests/test_data.py index 7783fc9c..d30f4509 100644 --- a/tests/unittests/test_data.py +++ b/tests/unittests/test_data.py @@ -15,9 +15,9 @@ import pandas as pd import pytest -from andromede.expression.expression_efficient import param +from andromede.expression.expression import param from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression_efficient import var +from andromede.expression.linear_expression import var from andromede.libs.standard import ( BALANCE_PORT_TYPE, CONSTANT, diff --git a/tests/unittests/test_model.py b/tests/unittests/test_model.py index 38e66a1c..6d61caae 100644 --- a/tests/unittests/test_model.py +++ b/tests/unittests/test_model.py @@ -15,8 +15,8 @@ import pytest -from andromede.expression.expression_efficient import ExpressionRange, comp_param, param -from andromede.expression.linear_expression_efficient import ( +from andromede.expression.expression import ExpressionRange, comp_param, param +from andromede.expression.linear_expression import ( LinearExpressionEfficient, comp_var, linear_expressions_equal, diff --git a/tests/unittests/test_port.py b/tests/unittests/test_port.py index bca3da1f..6f173dee 100644 --- a/tests/unittests/test_port.py +++ b/tests/unittests/test_port.py @@ -12,8 +12,8 @@ import pytest -from andromede.expression.expression_efficient import literal -from andromede.expression.linear_expression_efficient import port_field +from andromede.expression.expression import literal +from andromede.expression.linear_expression import port_field from andromede.libs.standard import DEMAND_MODEL from andromede.model import Constraint, ModelPort, PortType, model from andromede.model.constraint import Constraint From 952bce4dabd277a45bf7c6b46ff0a09c5d969975 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 27 Aug 2024 16:10:35 +0200 Subject: [PATCH 45/51] Rename file --- tests/functional/test_performance_efficient.py | 8 ++------ tests/unittests/expressions/test_expressions_efficient.py | 2 +- src/andromede/expression/evaluate.py => tests/utils.py | 6 +++++- 3 files changed, 8 insertions(+), 8 deletions(-) rename src/andromede/expression/evaluate.py => tests/utils.py (93%) diff --git a/tests/functional/test_performance_efficient.py b/tests/functional/test_performance_efficient.py index d6b74b20..5bf0a0e5 100644 --- a/tests/functional/test_performance_efficient.py +++ b/tests/functional/test_performance_efficient.py @@ -13,14 +13,9 @@ import pytest -from andromede.expression.evaluate import EvaluationContext from andromede.expression.expression import param from andromede.expression.indexing_structure import RowIndex -from andromede.expression.linear_expression import ( - literal, - var, - wrap_in_linear_expr, -) +from andromede.expression.linear_expression import literal, var, wrap_in_linear_expr from andromede.libs.standard import ( DEMAND_MODEL, GENERATOR_MODEL, @@ -32,6 +27,7 @@ from andromede.study.data import ConstantData, DataBase from andromede.study.network import Network, Node, PortRef, create_component from tests.unittests.test_utils import generate_scalar_matrix_data +from tests.utils import EvaluationContext def test_large_number_of_parameters_sum() -> None: diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index 0262fc17..4a690d4d 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -17,7 +17,6 @@ import pytest from andromede.expression.equality import expressions_equal -from andromede.expression.evaluate import EvaluationContext, ValueProvider from andromede.expression.expression import ( ComponentParameterNode, ExpressionNodeEfficient, @@ -48,6 +47,7 @@ ) from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum from andromede.expression.value_provider import TimeScenarioIndex, TimeScenarioIndices +from tests.utils import EvaluationContext, ValueProvider @dataclass(frozen=True) diff --git a/src/andromede/expression/evaluate.py b/tests/utils.py similarity index 93% rename from src/andromede/expression/evaluate.py rename to tests/utils.py index 94c9eb1d..cf5ef16c 100644 --- a/src/andromede/expression/evaluate.py +++ b/tests/utils.py @@ -13,7 +13,11 @@ from dataclasses import dataclass, field from typing import Dict -from .value_provider import TimeScenarioIndex, TimeScenarioIndices, ValueProvider +from andromede.expression.value_provider import ( + TimeScenarioIndex, + TimeScenarioIndices, + ValueProvider, +) # Used only for tests From 85f3953803da6714454adbae3a07232521ee3dd4 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 27 Aug 2024 16:16:05 +0200 Subject: [PATCH 46/51] Remove 'efficient' suffix --- src/andromede/expression/__init__.py | 2 +- src/andromede/expression/context_adder.py | 12 +- src/andromede/expression/copy.py | 28 +- src/andromede/expression/equality.py | 12 +- .../expression/evaluate_parameters.py | 6 +- src/andromede/expression/expression.py | 151 ++++----- src/andromede/expression/linear_expression.py | 250 +++++++-------- .../expression/parsing/parse_expression.py | 96 ++---- src/andromede/expression/print.py | 4 +- src/andromede/expression/visitor.py | 46 +-- src/andromede/libs/standard.py | 8 +- src/andromede/model/common.py | 8 +- src/andromede/model/constraint.py | 16 +- src/andromede/model/model.py | 20 +- src/andromede/model/resolve_library.py | 14 +- src/andromede/model/variable.py | 6 +- .../simulation/linear_expression_resolver.py | 49 ++- src/andromede/simulation/optimization.py | 11 +- .../simulation/optimization_context.py | 27 +- src/andromede/simulation/strategy.py | 14 +- .../parsing/test_expression_parsing.py | 12 +- tests/unittests/expressions/test_equality.py | 8 +- .../expressions/test_expressions_efficient.py | 94 +++--- .../test_linear_expressions_efficient.py | 300 ++++++++---------- .../expressions/test_port_resolver.py | 10 +- .../expressions/test_resolve_coefficients.py | 18 +- .../expressions/test_term_efficient.py | 70 ++-- tests/unittests/test_model.py | 16 +- 28 files changed, 579 insertions(+), 729 deletions(-) diff --git a/src/andromede/expression/__init__.py b/src/andromede/expression/__init__.py index e207690d..f0723175 100644 --- a/src/andromede/expression/__init__.py +++ b/src/andromede/expression/__init__.py @@ -17,7 +17,7 @@ ComparisonNode, ComponentParameterNode, DivisionNode, - ExpressionNodeEfficient, + ExpressionNode, ExpressionRange, InstancesTimeIndex, LiteralNode, diff --git a/src/andromede/expression/context_adder.py b/src/andromede/expression/context_adder.py index b43d1442..5480a632 100644 --- a/src/andromede/expression/context_adder.py +++ b/src/andromede/expression/context_adder.py @@ -13,11 +13,7 @@ from dataclasses import dataclass from . import CopyVisitor -from .expression import ( - ComponentParameterNode, - ExpressionNodeEfficient, - ParameterNode, -) +from .expression import ComponentParameterNode, ExpressionNode, ParameterNode from .visitor import visit @@ -30,13 +26,11 @@ class ContextAdder(CopyVisitor): component_id: str - def parameter(self, node: ParameterNode) -> ExpressionNodeEfficient: + def parameter(self, node: ParameterNode) -> ExpressionNode: return ComponentParameterNode(self.component_id, node.name) # Nothing is done is a component parameter node is encountered as it may have been generated from port resolution -def add_component_context( - id: str, expression: ExpressionNodeEfficient -) -> ExpressionNodeEfficient: +def add_component_context(id: str, expression: ExpressionNode) -> ExpressionNode: return visit(expression, ContextAdder(id)) diff --git a/src/andromede/expression/copy.py b/src/andromede/expression/copy.py index 09c64ca4..677aaf15 100644 --- a/src/andromede/expression/copy.py +++ b/src/andromede/expression/copy.py @@ -16,7 +16,7 @@ from .expression import ( ComparisonNode, ComponentParameterNode, - ExpressionNodeEfficient, + ExpressionNode, ExpressionRange, InstancesTimeIndex, LiteralNode, @@ -31,23 +31,23 @@ @dataclass(frozen=True) -class CopyVisitor(ExpressionVisitorOperations[ExpressionNodeEfficient]): +class CopyVisitor(ExpressionVisitorOperations[ExpressionNode]): """ Simply copies the whole AST. """ - def literal(self, node: LiteralNode) -> ExpressionNodeEfficient: + def literal(self, node: LiteralNode) -> ExpressionNode: return LiteralNode(node.value) - def comparison(self, node: ComparisonNode) -> ExpressionNodeEfficient: + def comparison(self, node: ComparisonNode) -> ExpressionNode: return ComparisonNode( visit(node.left, self), visit(node.right, self), node.comparator ) - def parameter(self, node: ParameterNode) -> ExpressionNodeEfficient: + def parameter(self, node: ParameterNode) -> ExpressionNode: return ParameterNode(node.name) - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNodeEfficient: + def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: return ComponentParameterNode(node.component_id, node.name) def copy_expression_range( @@ -70,32 +70,30 @@ def copy_instances_index( if isinstance(expressions, ExpressionRange): return InstancesTimeIndex(self.copy_expression_range(expressions)) if isinstance(expressions, list): - expressions_list = cast(List[ExpressionNodeEfficient], expressions) + expressions_list = cast(List[ExpressionNode], expressions) copy = [visit(e, self) for e in expressions_list] return InstancesTimeIndex(copy) raise ValueError("Unexpected type in instances index") - def time_operator(self, node: TimeOperatorNode) -> ExpressionNodeEfficient: + def time_operator(self, node: TimeOperatorNode) -> ExpressionNode: return TimeOperatorNode( visit(node.operand, self), node.name, self.copy_instances_index(node.instances_index), ) - def time_aggregator(self, node: TimeAggregatorNode) -> ExpressionNodeEfficient: + def time_aggregator(self, node: TimeAggregatorNode) -> ExpressionNode: return TimeAggregatorNode(visit(node.operand, self), node.name, node.stay_roll) - def scenario_operator(self, node: ScenarioOperatorNode) -> ExpressionNodeEfficient: + def scenario_operator(self, node: ScenarioOperatorNode) -> ExpressionNode: return ScenarioOperatorNode(visit(node.operand, self), node.name) - def port_field(self, node: PortFieldNode) -> ExpressionNodeEfficient: + def port_field(self, node: PortFieldNode) -> ExpressionNode: return PortFieldNode(node.port_name, node.field_name) - def port_field_aggregator( - self, node: PortFieldAggregatorNode - ) -> ExpressionNodeEfficient: + def port_field_aggregator(self, node: PortFieldAggregatorNode) -> ExpressionNode: return PortFieldAggregatorNode(visit(node.operand, self), node.aggregator) -def copy_expression(expression: ExpressionNodeEfficient) -> ExpressionNodeEfficient: +def copy_expression(expression: ExpressionNode) -> ExpressionNode: return visit(expression, CopyVisitor()) diff --git a/src/andromede/expression/equality.py b/src/andromede/expression/equality.py index 760cb51c..b6efebc0 100644 --- a/src/andromede/expression/equality.py +++ b/src/andromede/expression/equality.py @@ -20,7 +20,7 @@ ComparisonNode, ComponentParameterNode, DivisionNode, - ExpressionNodeEfficient, + ExpressionNode, ExpressionRange, InstancesTimeIndex, LiteralNode, @@ -51,9 +51,7 @@ def __post_init__(self) -> None: f"Relative comparison tolerance must be >= 0, got {self.rel_tol}" ) - def visit( - self, left: ExpressionNodeEfficient, right: ExpressionNodeEfficient - ) -> bool: + def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool: if left.__class__ != right.__class__: return False if isinstance(left, LiteralNode) and isinstance(right, LiteralNode): @@ -187,8 +185,8 @@ def port_field_aggregator( def expressions_equal( - left: ExpressionNodeEfficient, - right: ExpressionNodeEfficient, + left: ExpressionNode, + right: ExpressionNode, abs_tol: float = 0, rel_tol: float = 0, ) -> bool: @@ -199,7 +197,7 @@ def expressions_equal( def expressions_equal_if_present( - lhs: Optional[ExpressionNodeEfficient], rhs: Optional[ExpressionNodeEfficient] + lhs: Optional[ExpressionNode], rhs: Optional[ExpressionNode] ) -> bool: if lhs is None and rhs is None: return True diff --git a/src/andromede/expression/evaluate_parameters.py b/src/andromede/expression/evaluate_parameters.py index deb84575..d202ee2d 100644 --- a/src/andromede/expression/evaluate_parameters.py +++ b/src/andromede/expression/evaluate_parameters.py @@ -19,7 +19,7 @@ ComparisonNode, ComponentParameterNode, DivisionNode, - ExpressionNodeEfficient, + ExpressionNode, ExpressionRange, InstancesTimeIndex, LiteralNode, @@ -211,7 +211,7 @@ def check_resolved_expr( def resolve_coefficient( - expression: ExpressionNodeEfficient, value_provider: ValueProvider, row_id: RowIndex + expression: ExpressionNode, value_provider: ValueProvider, row_id: RowIndex ) -> float: result = visit(expression, ParameterEvaluationVisitor(value_provider, row_id)) check_resolved_expr(result, row_id) @@ -261,7 +261,7 @@ def float_to_int(value: float) -> int: def evaluate_time_id( - expr: ExpressionNodeEfficient, value_provider: ValueProvider, row_id: RowIndex + expr: ExpressionNode, value_provider: ValueProvider, row_id: RowIndex ) -> int: float_time_id_in_list = visit(expr, InstancesIndexVisitor(value_provider, row_id)) check_resolved_expr(float_time_id_in_list, row_id) diff --git a/src/andromede/expression/expression.py b/src/andromede/expression/expression.py index 37d89b58..bbcb2aef 100644 --- a/src/andromede/expression/expression.py +++ b/src/andromede/expression/expression.py @@ -22,7 +22,7 @@ @dataclass(frozen=True) -class ExpressionNodeEfficient: +class ExpressionNode: """ Base class for all nodes of the expression AST. @@ -33,47 +33,47 @@ class ExpressionNodeEfficient: >>> expr = -var('x') + 5 / param('p') """ - def __neg__(self) -> "ExpressionNodeEfficient": + def __neg__(self) -> "ExpressionNode": return _negate_node(self) - def __add__(self, rhs: Any) -> "ExpressionNodeEfficient": + def __add__(self, rhs: Any) -> "ExpressionNode": return _apply_if_node(rhs, lambda x: _add_node(self, x)) - def __radd__(self, lhs: Any) -> "ExpressionNodeEfficient": + def __radd__(self, lhs: Any) -> "ExpressionNode": return _apply_if_node(lhs, lambda x: _add_node(x, self)) - def __sub__(self, rhs: Any) -> "ExpressionNodeEfficient": + def __sub__(self, rhs: Any) -> "ExpressionNode": return _apply_if_node(rhs, lambda x: _substract_node(self, x)) - def __rsub__(self, lhs: Any) -> "ExpressionNodeEfficient": + def __rsub__(self, lhs: Any) -> "ExpressionNode": return _apply_if_node(lhs, lambda x: _substract_node(x, self)) - def __mul__(self, rhs: Any) -> "ExpressionNodeEfficient": + def __mul__(self, rhs: Any) -> "ExpressionNode": return _apply_if_node(rhs, lambda x: _multiply_node(self, x)) - def __rmul__(self, lhs: Any) -> "ExpressionNodeEfficient": + def __rmul__(self, lhs: Any) -> "ExpressionNode": return _apply_if_node(lhs, lambda x: _multiply_node(x, self)) - def __truediv__(self, rhs: Any) -> "ExpressionNodeEfficient": + def __truediv__(self, rhs: Any) -> "ExpressionNode": return _apply_if_node(rhs, lambda x: _divide_node(self, x)) - def __rtruediv__(self, lhs: Any) -> "ExpressionNodeEfficient": + def __rtruediv__(self, lhs: Any) -> "ExpressionNode": return _apply_if_node(lhs, lambda x: _divide_node(x, self)) - def __le__(self, rhs: Any) -> "ExpressionNodeEfficient": + def __le__(self, rhs: Any) -> "ExpressionNode": return _apply_if_node( rhs, lambda x: ComparisonNode(self, x, Comparator.LESS_THAN) ) - def __ge__(self, rhs: Any) -> "ExpressionNodeEfficient": + def __ge__(self, rhs: Any) -> "ExpressionNode": return _apply_if_node( rhs, lambda x: ComparisonNode(self, x, Comparator.GREATER_THAN) ) - def __eq__(self, rhs: Any) -> "ExpressionNodeEfficient": # type: ignore + def __eq__(self, rhs: Any) -> "ExpressionNode": # type: ignore return _apply_if_node(rhs, lambda x: ComparisonNode(self, x, Comparator.EQUAL)) - def sum(self) -> "ExpressionNodeEfficient": + def sum(self) -> "ExpressionNode": if isinstance(self, TimeOperatorNode): return TimeAggregatorNode(self, TimeAggregatorName.TIME_SUM, stay_roll=True) else: @@ -84,7 +84,7 @@ def sum(self) -> "ExpressionNodeEfficient": ), ) - def sum_connections(self) -> "ExpressionNodeEfficient": + def sum_connections(self) -> "ExpressionNode": if isinstance(self, PortFieldNode): return PortFieldAggregatorNode( self, aggregator=PortFieldAggregatorName.PORT_SUM @@ -97,11 +97,11 @@ def shift( self, expressions: Union[ int, - "ExpressionNodeEfficient", - List["ExpressionNodeEfficient"], + "ExpressionNode", + List["ExpressionNode"], "ExpressionRange", ], - ) -> "ExpressionNodeEfficient": + ) -> "ExpressionNode": return _apply_if_node( self, lambda x: TimeOperatorNode( @@ -113,11 +113,11 @@ def eval( self, expressions: Union[ int, - "ExpressionNodeEfficient", - List["ExpressionNodeEfficient"], + "ExpressionNode", + List["ExpressionNode"], "ExpressionRange", ], - ) -> "ExpressionNodeEfficient": + ) -> "ExpressionNode": return _apply_if_node( self, lambda x: TimeOperatorNode( @@ -125,19 +125,19 @@ def eval( ), ) - def expec(self) -> "ExpressionNodeEfficient": + def expec(self) -> "ExpressionNode": return _apply_if_node( self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.EXPECTATION) ) - def variance(self) -> "ExpressionNodeEfficient": + def variance(self) -> "ExpressionNode": return _apply_if_node( self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.VARIANCE) ) -def wrap_in_node(obj: Any) -> ExpressionNodeEfficient: - if isinstance(obj, ExpressionNodeEfficient): +def wrap_in_node(obj: Any) -> ExpressionNode: + if isinstance(obj, ExpressionNode): return obj elif isinstance(obj, float) or isinstance(obj, int): return LiteralNode(float(obj)) @@ -146,30 +146,30 @@ def wrap_in_node(obj: Any) -> ExpressionNodeEfficient: def _apply_if_node( - obj: Any, func: Callable[["ExpressionNodeEfficient"], "ExpressionNodeEfficient"] -) -> "ExpressionNodeEfficient": + obj: Any, func: Callable[["ExpressionNode"], "ExpressionNode"] +) -> "ExpressionNode": if as_node := wrap_in_node(obj): return func(as_node) else: return NotImplemented -def is_zero(node: ExpressionNodeEfficient) -> bool: +def is_zero(node: ExpressionNode) -> bool: # Faster implementation than expressions equal for this particular cases return isinstance(node, LiteralNode) and math.isclose(node.value, 0, abs_tol=EPS) -def is_one(node: ExpressionNodeEfficient) -> bool: +def is_one(node: ExpressionNode) -> bool: # Faster implementation than expressions equal for this particular cases return isinstance(node, LiteralNode) and math.isclose(node.value, 1) -def is_minus_one(node: ExpressionNodeEfficient) -> bool: +def is_minus_one(node: ExpressionNode) -> bool: # Faster implementation than expressions equal for this particular cases return isinstance(node, LiteralNode) and math.isclose(node.value, -1) -def _negate_node(node: ExpressionNodeEfficient) -> ExpressionNodeEfficient: +def _negate_node(node: ExpressionNode) -> ExpressionNode: if isinstance(node, LiteralNode): return LiteralNode(-node.value) elif isinstance(node, NegationNode): @@ -178,9 +178,7 @@ def _negate_node(node: ExpressionNodeEfficient) -> ExpressionNodeEfficient: return NegationNode(node) -def _add_node( - lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient -) -> ExpressionNodeEfficient: +def _add_node(lhs: ExpressionNode, rhs: ExpressionNode) -> ExpressionNode: if is_zero(lhs): return rhs if is_zero(rhs): @@ -226,9 +224,7 @@ def _add_node( # Better if we could use equality visitor -def _are_parameter_nodes_equal( - lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient -) -> bool: +def _are_parameter_nodes_equal(lhs: ExpressionNode, rhs: ExpressionNode) -> bool: return ( isinstance(lhs, ParameterNode) and isinstance(rhs, ParameterNode) @@ -236,9 +232,7 @@ def _are_parameter_nodes_equal( ) -def _substract_node( - lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient -) -> ExpressionNodeEfficient: +def _substract_node(lhs: ExpressionNode, rhs: ExpressionNode) -> ExpressionNode: if is_zero(lhs): return -rhs if is_zero(rhs): @@ -293,9 +287,7 @@ def _substract_node( return SubstractionNode(lhs, rhs) -def _multiply_node( - lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient -) -> ExpressionNodeEfficient: +def _multiply_node(lhs: ExpressionNode, rhs: ExpressionNode) -> ExpressionNode: if is_zero(lhs) or is_zero(rhs): return LiteralNode(0) if is_one(lhs): @@ -312,9 +304,7 @@ def _multiply_node( return MultiplicationNode(lhs, rhs) -def _divide_node( - lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient -) -> ExpressionNodeEfficient: +def _divide_node(lhs: ExpressionNode, rhs: ExpressionNode) -> ExpressionNode: if is_one(rhs): return lhs if is_minus_one(rhs): @@ -328,7 +318,7 @@ def _divide_node( @dataclass(frozen=True, eq=False) -class PortFieldNode(ExpressionNodeEfficient): +class PortFieldNode(ExpressionNode): """ References a port field. """ @@ -338,12 +328,12 @@ class PortFieldNode(ExpressionNodeEfficient): @dataclass(frozen=True, eq=False) -class ParameterNode(ExpressionNodeEfficient): +class ParameterNode(ExpressionNode): name: str @dataclass(frozen=True, eq=False) -class ComponentParameterNode(ExpressionNodeEfficient): +class ComponentParameterNode(ExpressionNode): """ Represents one parameter of one component. @@ -356,30 +346,30 @@ class ComponentParameterNode(ExpressionNodeEfficient): name: str -def param(name: str) -> ExpressionNodeEfficient: +def param(name: str) -> ExpressionNode: return ParameterNode(name) -def comp_param(component_id: str, name: str) -> ExpressionNodeEfficient: +def comp_param(component_id: str, name: str) -> ExpressionNode: return ComponentParameterNode(component_id, name) @dataclass(frozen=True, eq=False) -class LiteralNode(ExpressionNodeEfficient): +class LiteralNode(ExpressionNode): value: float -def literal(value: float) -> ExpressionNodeEfficient: +def literal(value: float) -> ExpressionNode: return LiteralNode(value) -def is_unbound(expr: ExpressionNodeEfficient) -> bool: +def is_unbound(expr: ExpressionNode) -> bool: return isinstance(expr, LiteralNode) and (abs(expr.value) == float("inf")) @dataclass(frozen=True, eq=False) -class UnaryOperatorNode(ExpressionNodeEfficient): - operand: ExpressionNodeEfficient +class UnaryOperatorNode(ExpressionNode): + operand: ExpressionNode class PortFieldAggregatorName(enum.Enum): @@ -404,9 +394,9 @@ class NegationNode(UnaryOperatorNode): @dataclass(frozen=True, eq=False) -class BinaryOperatorNode(ExpressionNodeEfficient): - left: ExpressionNodeEfficient - right: ExpressionNodeEfficient +class BinaryOperatorNode(ExpressionNode): + left: ExpressionNode + right: ExpressionNode class Comparator(enum.Enum): @@ -442,9 +432,9 @@ class DivisionNode(BinaryOperatorNode): @dataclass(frozen=True) class ExpressionRange: - start: ExpressionNodeEfficient - stop: ExpressionNodeEfficient - step: Optional[ExpressionNodeEfficient] = None + start: ExpressionNode + stop: ExpressionNode + step: Optional[ExpressionNode] = None def __post_init__(self) -> None: for attribute in self.__dict__: @@ -462,7 +452,7 @@ def __eq__(self, other: Any) -> bool: ) -IntOrExpr = Union[int, ExpressionNodeEfficient] +IntOrExpr = Union[int, ExpressionNode] def expression_range( @@ -486,28 +476,24 @@ class InstancesTimeIndex: 2 expression, or as a list of expressions. """ - expressions: Union[List[ExpressionNodeEfficient], ExpressionRange] + expressions: Union[List[ExpressionNode], ExpressionRange] def __init__( self, - expressions: Union[ - int, ExpressionNodeEfficient, List[ExpressionNodeEfficient], ExpressionRange - ], + expressions: Union[int, ExpressionNode, List[ExpressionNode], ExpressionRange], ) -> None: - if not isinstance( - expressions, (int, ExpressionNodeEfficient, list, ExpressionRange) - ): + if not isinstance(expressions, (int, ExpressionNode, list, ExpressionRange)): raise TypeError( - f"{expressions} must be of type among {{int, ExpressionNodeEfficient, List[ExpressionNodeEfficient], ExpressionRange}}" + f"{expressions} must be of type among {{int, ExpressionNode, List[ExpressionNode], ExpressionRange}}" ) if isinstance(expressions, list) and not all( - isinstance(x, ExpressionNodeEfficient) for x in expressions + isinstance(x, ExpressionNode) for x in expressions ): raise TypeError( - f"All elements of {expressions} must be of type ExpressionNodeEfficient" + f"All elements of {expressions} must be of type ExpressionNode" ) - if isinstance(expressions, (int, ExpressionNodeEfficient)): + if isinstance(expressions, (int, ExpressionNode)): object.__setattr__(self, "expressions", [wrap_in_node(expressions)]) else: object.__setattr__(self, "expressions", expressions) @@ -522,14 +508,11 @@ def __hash__(self) -> int: def __eq__(self, other: Any) -> bool: if isinstance(other, InstancesTimeIndex): if isinstance(self.expressions, list) and all( - isinstance(x, ExpressionNodeEfficient) for x in self.expressions + isinstance(x, ExpressionNode) for x in self.expressions ): return ( isinstance(other.expressions, list) - and all( - isinstance(x, ExpressionNodeEfficient) - for x in other.expressions - ) + and all(isinstance(x, ExpressionNode) for x in other.expressions) and all( expressions_equal(left_expr, right_expr) for left_expr, right_expr in zip( @@ -617,9 +600,7 @@ def __post_init__(self) -> None: f"Relative comparison tolerance must be >= 0, got {self.rel_tol}" ) - def visit( - self, left: ExpressionNodeEfficient, right: ExpressionNodeEfficient - ) -> bool: + def visit(self, left: ExpressionNode, right: ExpressionNode) -> bool: if left.__class__ != right.__class__: return False if isinstance(left, LiteralNode) and isinstance(right, LiteralNode): @@ -754,8 +735,8 @@ def port_field_aggregator( def expressions_equal( - left: ExpressionNodeEfficient, - right: ExpressionNodeEfficient, + left: ExpressionNode, + right: ExpressionNode, abs_tol: float = 0, rel_tol: float = 0, ) -> bool: @@ -766,7 +747,7 @@ def expressions_equal( def expressions_equal_if_present( - lhs: Optional[ExpressionNodeEfficient], rhs: Optional[ExpressionNodeEfficient] + lhs: Optional[ExpressionNode], rhs: Optional[ExpressionNode] ) -> bool: if lhs is None and rhs is None: return True diff --git a/src/andromede/expression/linear_expression.py b/src/andromede/expression/linear_expression.py index 9d81bbb3..c86774c5 100644 --- a/src/andromede/expression/linear_expression.py +++ b/src/andromede/expression/linear_expression.py @@ -35,7 +35,7 @@ from .equality import expressions_equal from .evaluate_parameters import check_resolved_expr, resolve_coefficient from .expression import ( - ExpressionNodeEfficient, + ExpressionNode, ExpressionRange, InstancesTimeIndex, LiteralNode, @@ -68,7 +68,7 @@ @dataclass(frozen=True) -class TermKeyEfficient: +class TermKey: """ Utility class to provide key for a term that contains all term information except coefficient """ @@ -82,7 +82,7 @@ class TermKeyEfficient: # Used for test_expression_parsing def __eq__(self, other: Any) -> bool: return ( - isinstance(other, TermKeyEfficient) + isinstance(other, TermKey) and self.component_id == other.component_id and self.variable_name == other.variable_name and self.time_operator == other.time_operator @@ -92,7 +92,7 @@ def __eq__(self, other: Any) -> bool: @dataclass(frozen=True) -class TermEfficient: +class Term: """ One term in a linear expression: for example the "10x" par in "10x + 5y + 5" @@ -101,7 +101,7 @@ class TermEfficient: variable_name: the name of the variable, for example "x" in "10x" """ - coefficient: ExpressionNodeEfficient + coefficient: ExpressionNode component_id: str variable_name: str structure: IndexingStructure = field( @@ -116,7 +116,7 @@ def __post_init__(self) -> None: def __eq__(self, other: object) -> bool: return ( - isinstance(other, TermEfficient) + isinstance(other, Term) and expressions_equal(self.coefficient, other.coefficient) and self.component_id == other.component_id and self.variable_name == other.variable_name @@ -214,19 +214,19 @@ def sum( self, shift: Union[ int, - "ExpressionNodeEfficient", - List["ExpressionNodeEfficient"], + "ExpressionNode", + List["ExpressionNode"], "ExpressionRange", None, ] = None, eval: Union[ int, - "ExpressionNodeEfficient", - List["ExpressionNodeEfficient"], + "ExpressionNode", + List["ExpressionNode"], "ExpressionRange", None, ] = None, - ) -> "TermEfficient": + ) -> "Term": if shift is not None and eval is not None: raise ValueError("Only shift or eval arguments should specified, not both.") @@ -251,11 +251,11 @@ def shift( self, expressions: Union[ int, - "ExpressionNodeEfficient", - List["ExpressionNodeEfficient"], + "ExpressionNode", + List["ExpressionNode"], "ExpressionRange", ], - ) -> "TermEfficient": + ) -> "Term": """ Shorthand for shift on a single time step @@ -283,11 +283,11 @@ def eval( self, expressions: Union[ int, - "ExpressionNodeEfficient", - List["ExpressionNodeEfficient"], + "ExpressionNode", + List["ExpressionNode"], "ExpressionRange", ], - ) -> "TermEfficient": + ) -> "Term": """ Shorthand for eval on a single time step @@ -311,13 +311,13 @@ def eval( else: return self.sum(eval=expressions) - def expec(self) -> "TermEfficient": + def expec(self) -> "Term": # TODO: Do we need checks, in case a scenario operator is already specified ? return dataclasses.replace(self, scenario_aggregator=Expectation()) -def generate_key(term: TermEfficient) -> TermKeyEfficient: - return TermKeyEfficient( +def generate_key(term: Term) -> TermKey: + return TermKey( term.component_id, term.variable_name, term.time_operator, @@ -344,7 +344,7 @@ class PortFieldKey: @dataclass(frozen=True) class PortFieldTerm: - coefficient: ExpressionNodeEfficient + coefficient: ExpressionNode port_name: str field_name: str aggregator: Optional[PortAggregator] = None @@ -361,7 +361,7 @@ def sum_connections(self) -> "PortFieldTerm": return dataclasses.replace(self, aggregator=PortSum()) -T_val = TypeVar("T_val", bound=Union[TermEfficient, PortFieldTerm]) +T_val = TypeVar("T_val", bound=Union[Term, PortFieldTerm]) def _get_neutral_term(term: T_val, neutral: float) -> T_val: @@ -370,12 +370,11 @@ def _get_neutral_term(term: T_val, neutral: float) -> T_val: @overload def _merge_dicts( - lhs: Dict[TermKeyEfficient, TermEfficient], - rhs: Dict[TermKeyEfficient, TermEfficient], - merge_func: Callable[[TermEfficient, TermEfficient], TermEfficient], + lhs: Dict[TermKey, Term], + rhs: Dict[TermKey, Term], + merge_func: Callable[[Term, Term], Term], neutral: float, -) -> Dict[TermKeyEfficient, TermEfficient]: - ... +) -> Dict[TermKey, Term]: ... @overload @@ -384,8 +383,7 @@ def _merge_dicts( rhs: Dict[PortFieldId, PortFieldTerm], merge_func: Callable[[PortFieldTerm, PortFieldTerm], PortFieldTerm], neutral: float, -) -> Dict[PortFieldId, PortFieldTerm]: - ... +) -> Dict[PortFieldId, PortFieldTerm]: ... def _merge_dicts(lhs, rhs, merge_func, neutral): @@ -399,7 +397,7 @@ def _merge_dicts(lhs, rhs, merge_func, neutral): def _merge_is_possible(lhs: T_val, rhs: T_val) -> None: - if isinstance(lhs, TermEfficient) and isinstance(rhs, TermEfficient): + if isinstance(lhs, Term) and isinstance(rhs, Term): _merge_term_is_possible(lhs, rhs) elif isinstance(lhs, PortFieldTerm) and isinstance(rhs, PortFieldTerm): _merge_port_terms_is_possible(lhs, rhs) @@ -407,7 +405,7 @@ def _merge_is_possible(lhs: T_val, rhs: T_val) -> None: raise TypeError("Cannot merge terms of different types") -def _merge_term_is_possible(lhs: TermEfficient, rhs: TermEfficient) -> None: +def _merge_term_is_possible(lhs: Term, rhs: Term) -> None: if lhs.component_id != rhs.component_id or lhs.variable_name != rhs.variable_name: raise ValueError("Cannot merge terms for different variables") if ( @@ -437,7 +435,7 @@ def _substract_terms(lhs: T_val, rhs: T_val) -> T_val: return dataclasses.replace(lhs, coefficient=lhs.coefficient - rhs.coefficient) -class LinearExpressionEfficient: +class LinearExpression: """ Represents a linear expression with respect to variable names, for example 10x + 5y + 2. @@ -450,21 +448,19 @@ class LinearExpressionEfficient: Examples: Operators may be used for construction: - >>> LinearExpression([], 10) + LinearExpression([TermEfficient(10, "x")], 0) - LinearExpression([TermEfficient(10, "x")], 10) + >>> LinearExpression([], 10) + LinearExpression([Term + LinearExpression([Term """ - terms: Dict[TermKeyEfficient, TermEfficient] - constant: ExpressionNodeEfficient + terms: Dict[TermKey, Term] + constant: ExpressionNode port_field_terms: Dict[PortFieldId, PortFieldTerm] # TODO: We need to check that terms.key is indeed a TermKey and change the tests that this will break def __init__( self, - terms: Optional[ - Union[Dict[TermKeyEfficient, TermEfficient], List[TermEfficient]] - ] = None, - constant: Optional[Union[float, ExpressionNodeEfficient]] = None, + terms: Optional[Union[Dict[TermKey, Term], List[Term]]] = None, + constant: Optional[Union[float, ExpressionNode]] = None, port_field_terms: Optional[ Union[Dict[PortFieldId, PortFieldTerm], List[PortFieldTerm]] ] = None, @@ -477,8 +473,8 @@ def __init__( self.terms = {} if terms is not None: # Allows to give two different syntax in the constructor: - # - List[TermEfficient] is natural - # - Dict[str, TermEfficient] is useful when constructing a linear expression from the terms of another expression + # - List[Term] is natural + # - Dict[str, Term] is useful when constructing a linear expression from the terms of another expression if isinstance(terms, dict): for term_key, term in terms.items(): if not term.is_zero(): @@ -489,7 +485,7 @@ def __init__( self.terms[generate_key(term)] = term else: raise TypeError( - f"Terms must be either of type Dict[TermKeyEfficient, Term] or List[Term], whereas {terms} is of type {type(terms)}" + f"Terms must be either of type Dict[TermKey, Term] or List[Term], whereas {terms} is of type {type(terms)}" ) self.port_field_terms = {} @@ -560,8 +556,8 @@ def __eq__(self, rhs: Any) -> "StandaloneConstraint": # type: ignore ) def __iadd__( - self, rhs: Union["LinearExpressionEfficient", int, float] - ) -> "LinearExpressionEfficient": + self, rhs: Union["LinearExpression", int, float] + ) -> "LinearExpression": rhs = wrap_in_linear_expr(rhs) self.constant += rhs.constant @@ -576,20 +572,18 @@ def __iadd__( self.remove_zeros_from_terms() return self - def __add__( - self, rhs: Union["LinearExpressionEfficient", int, float] - ) -> "LinearExpressionEfficient": - result = LinearExpressionEfficient() + def __add__(self, rhs: Union["LinearExpression", int, float]) -> "LinearExpression": + result = LinearExpression() result += self result += rhs return result - def __radd__(self, rhs: int) -> "LinearExpressionEfficient": + def __radd__(self, rhs: int) -> "LinearExpression": return self.__add__(rhs) def __isub__( - self, rhs: Union["LinearExpressionEfficient", int, float] - ) -> "LinearExpressionEfficient": + self, rhs: Union["LinearExpression", int, float] + ) -> "LinearExpression": rhs = wrap_in_linear_expr(rhs) self.constant -= rhs.constant @@ -604,25 +598,23 @@ def __isub__( self.remove_zeros_from_terms() return self - def __sub__( - self, rhs: Union["LinearExpressionEfficient", int, float] - ) -> "LinearExpressionEfficient": - result = LinearExpressionEfficient() + def __sub__(self, rhs: Union["LinearExpression", int, float]) -> "LinearExpression": + result = LinearExpression() result += self result -= rhs return result - def __rsub__(self, rhs: int) -> "LinearExpressionEfficient": + def __rsub__(self, rhs: int) -> "LinearExpression": return -self + rhs - def __neg__(self) -> "LinearExpressionEfficient": - result = LinearExpressionEfficient() + def __neg__(self) -> "LinearExpression": + result = LinearExpression() result -= self return result def __imul__( - self, rhs: Union["LinearExpressionEfficient", int, float] - ) -> "LinearExpressionEfficient": + self, rhs: Union["LinearExpression", int, float] + ) -> "LinearExpression": rhs = wrap_in_linear_expr(rhs) if not (self.is_constant() or rhs.is_constant()): @@ -635,7 +627,7 @@ def __imul__( left_expr = rhs const_expr = self if is_zero(const_expr.constant): - return LinearExpressionEfficient() + return LinearExpression() elif is_one(const_expr.constant): _copy_expression(left_expr, self) else: @@ -652,20 +644,18 @@ def __imul__( _copy_expression(left_expr, self) return self - def __mul__( - self, rhs: Union["LinearExpressionEfficient", int, float] - ) -> "LinearExpressionEfficient": - result = LinearExpressionEfficient() + def __mul__(self, rhs: Union["LinearExpression", int, float]) -> "LinearExpression": + result = LinearExpression() result += self result *= rhs return result - def __rmul__(self, rhs: int) -> "LinearExpressionEfficient": + def __rmul__(self, rhs: int) -> "LinearExpression": return self.__mul__(rhs) def __itruediv__( - self, rhs: Union["LinearExpressionEfficient", int, float] - ) -> "LinearExpressionEfficient": + self, rhs: Union["LinearExpression", int, float] + ) -> "LinearExpression": rhs = wrap_in_linear_expr(rhs) if not rhs.is_constant(): @@ -688,15 +678,15 @@ def __itruediv__( return self def __truediv__( - self, rhs: Union["LinearExpressionEfficient", int, float] - ) -> "LinearExpressionEfficient": - result = LinearExpressionEfficient() + self, rhs: Union["LinearExpression", int, float] + ) -> "LinearExpression": + result = LinearExpression() result += self result /= rhs return result - def __rtruediv__(self, rhs: Union[int, float]) -> "LinearExpressionEfficient": + def __rtruediv__(self, rhs: Union[int, float]) -> "LinearExpression": return self.__truediv__(rhs) def remove_zeros_from_terms(self) -> None: @@ -741,19 +731,19 @@ def sum( self, shift: Union[ int, - "ExpressionNodeEfficient", - List["ExpressionNodeEfficient"], + "ExpressionNode", + List["ExpressionNode"], "ExpressionRange", None, ] = None, eval: Union[ int, - "ExpressionNodeEfficient", - List["ExpressionNodeEfficient"], + "ExpressionNode", + List["ExpressionNode"], "ExpressionRange", None, ] = None, - ) -> "LinearExpressionEfficient": + ) -> "LinearExpression": """ Examples: >>> x.sum(shift=[1, 2, 4]) represents x[t+1] + x[t+2] + x[t+4] @@ -802,9 +792,7 @@ def sum( stay_roll=False, ) - return LinearExpressionEfficient( - self._apply_operator(sum_args), result_constant - ) + return LinearExpression(self._apply_operator(sum_args), result_constant) def _apply_operator( self, @@ -812,13 +800,13 @@ def _apply_operator( str, Union[ int, - "ExpressionNodeEfficient", - List["ExpressionNodeEfficient"], + "ExpressionNode", + List["ExpressionNode"], "ExpressionRange", None, ], ], - ) -> Dict[TermKeyEfficient, TermEfficient]: + ) -> Dict[TermKey, Term]: result_terms = {} for term in self.terms.values(): term_with_operator = term.sum(**sum_args) @@ -830,11 +818,11 @@ def shift( self, expressions: Union[ int, - "ExpressionNodeEfficient", - List["ExpressionNodeEfficient"], + "ExpressionNode", + List["ExpressionNode"], "ExpressionRange", ], - ) -> "LinearExpressionEfficient": + ) -> "LinearExpression": """ Shorthand for shift on a single time step @@ -856,11 +844,11 @@ def eval( self, expressions: Union[ int, - "ExpressionNodeEfficient", - List["ExpressionNodeEfficient"], + "ExpressionNode", + List["ExpressionNode"], "ExpressionRange", ], - ) -> "LinearExpressionEfficient": + ) -> "LinearExpression": """ Shorthand for eval on a single time step @@ -878,7 +866,7 @@ def eval( else: return self.sum(eval=expressions) - def expec(self) -> "LinearExpressionEfficient": + def expec(self) -> "LinearExpression": """ Expectation of linear expression. As the operator is linear, it distributes over all terms and the constant """ @@ -891,10 +879,10 @@ def expec(self) -> "LinearExpressionEfficient": result_constant = ScenarioOperatorNode( self.constant, ScenarioOperatorName.EXPECTATION ) - result_expr = LinearExpressionEfficient(result_terms, result_constant) + result_expr = LinearExpression(result_terms, result_constant) return result_expr - def sum_connections(self) -> "LinearExpressionEfficient": + def sum_connections(self) -> "LinearExpression": if not self.is_zero(): raise ValueError( "sum_connections only after an expression created with port_field" @@ -902,14 +890,14 @@ def sum_connections(self) -> "LinearExpressionEfficient": port_field_terms = {} for port_field_key, port_field_value in self.port_field_terms.items(): port_field_terms[port_field_key] = port_field_value.sum_connections() - return LinearExpressionEfficient(port_field_terms=port_field_terms) + return LinearExpression(port_field_terms=port_field_terms) def resolve_port( self, component_id: str, - ports_expressions: Dict[PortFieldKey, List["LinearExpressionEfficient"]], - ) -> "LinearExpressionEfficient": - port_expr = LinearExpressionEfficient() + ports_expressions: Dict[PortFieldKey, List["LinearExpression"]], + ) -> "LinearExpression": + port_expr = LinearExpression() for port_term in self.port_field_terms.values(): expressions = ports_expressions.get( PortFieldKey( @@ -930,10 +918,10 @@ def resolve_port( port_expr += sum_expressions( [port_term.coefficient * expression for expression in expressions] ) - self_without_ports = LinearExpressionEfficient(self.terms, self.constant) + self_without_ports = LinearExpression(self.terms, self.constant) return self_without_ports + port_expr - def add_component_context(self, component_id: str) -> "LinearExpressionEfficient": + def add_component_context(self, component_id: str) -> "LinearExpression": result_terms = {} for term in self.terms.values(): # Some terms may involve variable from other component if they arise from previous port resolution @@ -954,9 +942,7 @@ def add_component_context(self, component_id: str) -> "LinearExpressionEfficient ) result_terms[generate_key(result_term)] = result_term result_constant = add_component_context(component_id, self.constant) - return LinearExpressionEfficient( - result_terms, result_constant, self.port_field_terms - ) + return LinearExpression(result_terms, result_constant, self.port_field_terms) def _add_component_context_to_expression_range( @@ -982,25 +968,23 @@ def _add_component_context_to_instances_index( _add_component_context_to_expression_range(component_id, expressions) ) if isinstance(expressions, list): - expressions_list = cast(List[ExpressionNodeEfficient], expressions) + expressions_list = cast(List[ExpressionNode], expressions) copy = [add_component_context(component_id, e) for e in expressions_list] return InstancesTimeIndex(copy) raise ValueError("Unexpected type in instances index") -def linear_expressions_equal( - lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient -) -> bool: +def linear_expressions_equal(lhs: LinearExpression, rhs: LinearExpression) -> bool: return ( - isinstance(lhs, LinearExpressionEfficient) - and isinstance(rhs, LinearExpressionEfficient) + isinstance(lhs, LinearExpression) + and isinstance(rhs, LinearExpression) and expressions_equal(lhs.constant, rhs.constant) and lhs.terms == rhs.terms ) def linear_expressions_equal_if_present( - lhs: Optional[LinearExpressionEfficient], rhs: Optional[LinearExpressionEfficient] + lhs: Optional[LinearExpression], rhs: Optional[LinearExpression] ) -> bool: if lhs is None and rhs is None: return True @@ -1012,8 +996,8 @@ def linear_expressions_equal_if_present( # TODO: Is this function useful ? Could we just rely on the sum operator overloading ? Only the case with an empty list may make the function useful def sum_expressions( - expressions: Sequence[LinearExpressionEfficient], -) -> Union[LinearExpressionEfficient, Literal[0]]: + expressions: Sequence[LinearExpression], +) -> Union[LinearExpression, Literal[0]]: if len(expressions) == 0: return wrap_in_linear_expr(literal(0)) else: @@ -1026,9 +1010,9 @@ class StandaloneConstraint: A standalone constraint, with rigid initialization. """ - expression: LinearExpressionEfficient - lower_bound: LinearExpressionEfficient - upper_bound: LinearExpressionEfficient + expression: LinearExpression + lower_bound: LinearExpression + upper_bound: LinearExpression def __post_init__( self, @@ -1051,47 +1035,41 @@ def __eq__(self, other: Any) -> bool: ) -def wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient: - if isinstance(obj, LinearExpressionEfficient): +def wrap_in_linear_expr(obj: Any) -> LinearExpression: + if isinstance(obj, LinearExpression): return obj elif isinstance(obj, float) or isinstance(obj, int): - return LinearExpressionEfficient([], LiteralNode(float(obj))) - elif isinstance(obj, ExpressionNodeEfficient): - return LinearExpressionEfficient([], obj) + return LinearExpression([], LiteralNode(float(obj))) + elif isinstance(obj, ExpressionNode): + return LinearExpression([], obj) raise TypeError(f"Unable to wrap {obj} into a linear expression") -def wrap_in_linear_expr_if_present(obj: Any) -> Union[None, LinearExpressionEfficient]: +def wrap_in_linear_expr_if_present(obj: Any) -> Union[None, LinearExpression]: if obj is None: return None else: return wrap_in_linear_expr(obj) -def _copy_expression( - src: LinearExpressionEfficient, dst: LinearExpressionEfficient -) -> None: +def _copy_expression(src: LinearExpression, dst: LinearExpression) -> None: dst.terms = src.terms dst.constant = src.constant # TODO : Define shortcuts for "x", is_one etc .... -def var(name: str) -> LinearExpressionEfficient: +def var(name: str) -> LinearExpression: # TODO: At term build time, no information on the variable structure is known, we use a default time, scenario varying, maybe discard structure as term attribute ? - return LinearExpressionEfficient( - [ - TermEfficient( - coefficient=LiteralNode(1), component_id="", variable_name=name - ) - ], + return LinearExpression( + [Term(coefficient=LiteralNode(1), component_id="", variable_name=name)], LiteralNode(0), ) -def comp_var(component_id: str, name: str) -> LinearExpressionEfficient: - return LinearExpressionEfficient( +def comp_var(component_id: str, name: str) -> LinearExpression: + return LinearExpression( [ - TermEfficient( + Term( coefficient=LiteralNode(1), component_id=component_id, variable_name=name, @@ -1101,11 +1079,11 @@ def comp_var(component_id: str, name: str) -> LinearExpressionEfficient: ) -def port_field(port_name: str, field_name: str) -> LinearExpressionEfficient: - return LinearExpressionEfficient( +def port_field(port_name: str, field_name: str) -> LinearExpression: + return LinearExpression( port_field_terms=[PortFieldTerm(literal(1), port_name, field_name)] ) -def is_linear(expr: LinearExpressionEfficient) -> bool: +def is_linear(expr: LinearExpression) -> bool: return True diff --git a/src/andromede/expression/parsing/parse_expression.py b/src/andromede/expression/parsing/parse_expression.py index fcc37459..e535b810 100644 --- a/src/andromede/expression/parsing/parse_expression.py +++ b/src/andromede/expression/parsing/parse_expression.py @@ -24,7 +24,7 @@ param, ) from andromede.expression.linear_expression import ( - LinearExpressionEfficient, + LinearExpression, port_field, var, wrap_in_linear_expr, @@ -58,23 +58,19 @@ class ExpressionNodeBuilderVisitor(ExprVisitor): identifiers: ModelIdentifiers - def visitFullexpr( - self, ctx: ExprParser.FullexprContext - ) -> LinearExpressionEfficient: + def visitFullexpr(self, ctx: ExprParser.FullexprContext) -> LinearExpression: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#number. - def visitNumber(self, ctx: ExprParser.NumberContext) -> LinearExpressionEfficient: + def visitNumber(self, ctx: ExprParser.NumberContext) -> LinearExpression: return literal(float(ctx.NUMBER().getText())) # type: ignore # Visit a parse tree produced by ExprParser#identifier. - def visitIdentifier( - self, ctx: ExprParser.IdentifierContext - ) -> LinearExpressionEfficient: + def visitIdentifier(self, ctx: ExprParser.IdentifierContext) -> LinearExpression: return self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore # Visit a parse tree produced by ExprParser#division. - def visitMuldiv(self, ctx: ExprParser.MuldivContext) -> LinearExpressionEfficient: + def visitMuldiv(self, ctx: ExprParser.MuldivContext) -> LinearExpression: left = ctx.expr(0).accept(self) # type: ignore right = ctx.expr(1).accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -85,7 +81,7 @@ def visitMuldiv(self, ctx: ExprParser.MuldivContext) -> LinearExpressionEfficien raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#subtraction. - def visitAddsub(self, ctx: ExprParser.AddsubContext) -> LinearExpressionEfficient: + def visitAddsub(self, ctx: ExprParser.AddsubContext) -> LinearExpression: left = ctx.expr(0).accept(self) # type: ignore right = ctx.expr(1).accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -96,24 +92,20 @@ def visitAddsub(self, ctx: ExprParser.AddsubContext) -> LinearExpressionEfficien raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#negation. - def visitNegation( - self, ctx: ExprParser.NegationContext - ) -> LinearExpressionEfficient: + def visitNegation(self, ctx: ExprParser.NegationContext) -> LinearExpression: return -ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#expression. - def visitExpression( - self, ctx: ExprParser.ExpressionContext - ) -> LinearExpressionEfficient: + def visitExpression(self, ctx: ExprParser.ExpressionContext) -> LinearExpression: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#unsignedAtom. def visitUnsignedAtom( self, ctx: ExprParser.UnsignedAtomContext - ) -> LinearExpressionEfficient: + ) -> LinearExpression: return ctx.atom().accept(self) # type: ignore - def _convert_identifier(self, identifier: str) -> LinearExpressionEfficient: + def _convert_identifier(self, identifier: str) -> LinearExpression: if self.identifiers.is_variable(identifier): return var(identifier) elif self.identifiers.is_parameter(identifier): @@ -121,48 +113,38 @@ def _convert_identifier(self, identifier: str) -> LinearExpressionEfficient: raise ValueError(f"{identifier} is not a valid variable or parameter name.") # Visit a parse tree produced by ExprParser#portField. - def visitPortField( - self, ctx: ExprParser.PortFieldContext - ) -> LinearExpressionEfficient: + def visitPortField(self, ctx: ExprParser.PortFieldContext) -> LinearExpression: return port_field( port_name=ctx.IDENTIFIER(0).getText(), # type: ignore field_name=ctx.IDENTIFIER(1).getText(), # type: ignore ) # Visit a parse tree produced by ExprParser#comparison. - def visitComparison( - self, ctx: ExprParser.ComparisonContext - ) -> LinearExpressionEfficient: + def visitComparison(self, ctx: ExprParser.ComparisonContext) -> LinearExpression: op = ctx.COMPARISON().getText() # type: ignore exp1 = ctx.expr(0).accept(self) # type: ignore exp2 = ctx.expr(1).accept(self) # type: ignore comp = { - "=": LinearExpressionEfficient.__eq__, - "<=": LinearExpressionEfficient.__le__, - ">=": LinearExpressionEfficient.__ge__, + "=": LinearExpression.__eq__, + "<=": LinearExpression.__le__, + ">=": LinearExpression.__ge__, }[op] return comp(exp1, exp2) # Visit a parse tree produced by ExprParser#timeShift. - def visitTimeIndex( - self, ctx: ExprParser.TimeIndexContext - ) -> LinearExpressionEfficient: + def visitTimeIndex(self, ctx: ExprParser.TimeIndexContext) -> LinearExpression: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore time_shifts = [e.accept(self) for e in ctx.expr()] # type: ignore return shifted_expr.eval(time_shifts) # Visit a parse tree produced by ExprParser#rangeTimeShift. - def visitTimeRange( - self, ctx: ExprParser.TimeRangeContext - ) -> LinearExpressionEfficient: + def visitTimeRange(self, ctx: ExprParser.TimeRangeContext) -> LinearExpression: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore expressions = [e.accept(self) for e in ctx.expr()] # type: ignore # TODO: Is there a visitSum somewhere that is not needed ? Are the correct symbol parsed (sum(...) ?) ? return shifted_expr.sum(eval=ExpressionRange(expressions[0], expressions[1])) - def visitTimeShift( - self, ctx: ExprParser.TimeShiftContext - ) -> LinearExpressionEfficient: + def visitTimeShift(self, ctx: ExprParser.TimeShiftContext) -> LinearExpression: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore time_shifts = [s.accept(self) for s in ctx.shift()] # type: ignore # specifics for x[t] ... @@ -172,34 +154,30 @@ def visitTimeShift( def visitTimeShiftRange( self, ctx: ExprParser.TimeShiftRangeContext - ) -> LinearExpressionEfficient: + ) -> LinearExpression: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore shift1 = ctx.shift1.accept(self) # type: ignore shift2 = ctx.shift2.accept(self) # type: ignore return shifted_expr.sum(shift=ExpressionRange(shift1, shift2)) # Visit a parse tree produced by ExprParser#function. - def visitFunction( - self, ctx: ExprParser.FunctionContext - ) -> LinearExpressionEfficient: + def visitFunction(self, ctx: ExprParser.FunctionContext) -> LinearExpression: function_name: str = ctx.IDENTIFIER().getText() # type: ignore - operand: LinearExpressionEfficient = ctx.expr().accept(self) # type: ignore + operand: LinearExpression = ctx.expr().accept(self) # type: ignore fn = _FUNCTIONS.get(function_name, None) if fn is None: raise ValueError(f"Encountered invalid function name {function_name}") return fn(operand) # Visit a parse tree produced by ExprParser#shift. - def visitShift(self, ctx: ExprParser.ShiftContext) -> LinearExpressionEfficient: + def visitShift(self, ctx: ExprParser.ShiftContext) -> LinearExpression: if ctx.shift_expr() is None: # type: ignore return literal(0) shift = ctx.shift_expr().accept(self) # type: ignore return shift # Visit a parse tree produced by ExprParser#shiftAddsub. - def visitShiftAddsub( - self, ctx: ExprParser.ShiftAddsubContext - ) -> LinearExpressionEfficient: + def visitShiftAddsub(self, ctx: ExprParser.ShiftAddsubContext) -> LinearExpression: left = ctx.shift_expr().accept(self) # type: ignore right = ctx.right_expr().accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -210,9 +188,7 @@ def visitShiftAddsub( raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#shiftMuldiv. - def visitShiftMuldiv( - self, ctx: ExprParser.ShiftMuldivContext - ) -> LinearExpressionEfficient: + def visitShiftMuldiv(self, ctx: ExprParser.ShiftMuldivContext) -> LinearExpression: left = ctx.shift_expr().accept(self) # type: ignore right = ctx.right_expr().accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -225,16 +201,14 @@ def visitShiftMuldiv( # Visit a parse tree produced by ExprParser#signedExpression. def visitSignedExpression( self, ctx: ExprParser.SignedExpressionContext - ) -> LinearExpressionEfficient: + ) -> LinearExpression: if ctx.op.text == "-": # type: ignore return -ctx.expr().accept(self) # type: ignore else: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#signedAtom. - def visitSignedAtom( - self, ctx: ExprParser.SignedAtomContext - ) -> LinearExpressionEfficient: + def visitSignedAtom(self, ctx: ExprParser.SignedAtomContext) -> LinearExpression: if ctx.op.text == "-": # type: ignore return -ctx.atom().accept(self) # type: ignore else: @@ -243,13 +217,11 @@ def visitSignedAtom( # Visit a parse tree produced by ExprParser#rightExpression. def visitRightExpression( self, ctx: ExprParser.RightExpressionContext - ) -> LinearExpressionEfficient: + ) -> LinearExpression: return ctx.expr().accept(self) # type: ignore # Visit a parse tree produced by ExprParser#rightMuldiv. - def visitRightMuldiv( - self, ctx: ExprParser.RightMuldivContext - ) -> LinearExpressionEfficient: + def visitRightMuldiv(self, ctx: ExprParser.RightMuldivContext) -> LinearExpression: left = ctx.right_expr(0).accept(self) # type: ignore right = ctx.right_expr(1).accept(self) # type: ignore op = ctx.op.text # type: ignore @@ -260,16 +232,14 @@ def visitRightMuldiv( raise ValueError(f"Invalid operator {op}") # Visit a parse tree produced by ExprParser#rightAtom. - def visitRightAtom( - self, ctx: ExprParser.RightAtomContext - ) -> LinearExpressionEfficient: + def visitRightAtom(self, ctx: ExprParser.RightAtomContext) -> LinearExpression: return ctx.atom().accept(self) # type: ignore _FUNCTIONS = { - "sum": LinearExpressionEfficient.sum, - "sum_connections": LinearExpressionEfficient.sum_connections, - "expec": LinearExpressionEfficient.expec, + "sum": LinearExpression.sum, + "sum_connections": LinearExpression.sum_connections, + "expec": LinearExpression.expec, } @@ -279,7 +249,7 @@ class AntaresParseException(Exception): def parse_expression( expression: str, identifiers: ModelIdentifiers -) -> LinearExpressionEfficient: +) -> LinearExpression: """ Parses a string expression to create the corresponding AST representation. """ diff --git a/src/andromede/expression/print.py b/src/andromede/expression/print.py index 7fb5a5de..6b1e1c84 100644 --- a/src/andromede/expression/print.py +++ b/src/andromede/expression/print.py @@ -19,7 +19,7 @@ ComparisonNode, ComponentParameterNode, DivisionNode, - ExpressionNodeEfficient, + ExpressionNode, LiteralNode, MultiplicationNode, NegationNode, @@ -103,5 +103,5 @@ def port_field_aggregator(self, node: PortFieldAggregatorNode) -> str: return f"({visit(node.operand, self)}.{node.aggregator})" -def print_expr(expression: ExpressionNodeEfficient) -> str: +def print_expr(expression: ExpressionNode) -> str: return visit(expression, PrinterVisitor()) diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index 67f56470..55e93edb 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -21,7 +21,7 @@ ComparisonNode, ComponentParameterNode, DivisionNode, - ExpressionNodeEfficient, + ExpressionNode, LiteralNode, MultiplicationNode, NegationNode, @@ -47,63 +47,49 @@ class ExpressionVisitor(ABC, Generic[T]): """ @abstractmethod - def literal(self, node: LiteralNode) -> T: - ... + def literal(self, node: LiteralNode) -> T: ... @abstractmethod - def negation(self, node: NegationNode) -> T: - ... + def negation(self, node: NegationNode) -> T: ... @abstractmethod - def addition(self, node: AdditionNode) -> T: - ... + def addition(self, node: AdditionNode) -> T: ... @abstractmethod - def substraction(self, node: SubstractionNode) -> T: - ... + def substraction(self, node: SubstractionNode) -> T: ... @abstractmethod - def multiplication(self, node: MultiplicationNode) -> T: - ... + def multiplication(self, node: MultiplicationNode) -> T: ... @abstractmethod - def division(self, node: DivisionNode) -> T: - ... + def division(self, node: DivisionNode) -> T: ... @abstractmethod - def comparison(self, node: ComparisonNode) -> T: - ... + def comparison(self, node: ComparisonNode) -> T: ... @abstractmethod - def parameter(self, node: ParameterNode) -> T: - ... + def parameter(self, node: ParameterNode) -> T: ... @abstractmethod - def comp_parameter(self, node: ComponentParameterNode) -> T: - ... + def comp_parameter(self, node: ComponentParameterNode) -> T: ... @abstractmethod - def time_operator(self, node: TimeOperatorNode) -> T: - ... + def time_operator(self, node: TimeOperatorNode) -> T: ... @abstractmethod - def time_aggregator(self, node: TimeAggregatorNode) -> T: - ... + def time_aggregator(self, node: TimeAggregatorNode) -> T: ... @abstractmethod - def scenario_operator(self, node: ScenarioOperatorNode) -> T: - ... + def scenario_operator(self, node: ScenarioOperatorNode) -> T: ... @abstractmethod - def port_field(self, node: PortFieldNode) -> T: - ... + def port_field(self, node: PortFieldNode) -> T: ... @abstractmethod - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T: - ... + def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T: ... -def visit(root: ExpressionNodeEfficient, visitor: ExpressionVisitor[T]) -> T: +def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T: """ Utility method to dispatch calls to the right method of a visitor. """ diff --git a/src/andromede/libs/standard.py b/src/andromede/libs/standard.py index 5e6c8c20..073d74d2 100644 --- a/src/andromede/libs/standard.py +++ b/src/andromede/libs/standard.py @@ -16,11 +16,7 @@ from andromede.expression.expression import ExpressionRange, literal, param from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression import ( - port_field, - var, - wrap_in_linear_expr, -) +from andromede.expression.linear_expression import port_field, var, wrap_in_linear_expr from andromede.model.constraint import Constraint from andromede.model.model import ModelPort, PortFieldDefinition, PortFieldId, model from andromede.model.parameter import float_parameter, int_parameter @@ -266,7 +262,7 @@ ) <= var("nb_on"), ), - # TODO : Improve API so that we are not forced to use sum() on one shifted element for ExpressionNodeEfficient + # TODO : Improve API so that we are not forced to use sum() on one shifted element for ExpressionNode Constraint( "Min down time", var("nb_stop").sum( diff --git a/src/andromede/model/common.py b/src/andromede/model/common.py index 180db628..ffd0ffe5 100644 --- a/src/andromede/model/common.py +++ b/src/andromede/model/common.py @@ -16,12 +16,10 @@ from enum import Enum from typing import Union -from andromede.expression.expression import ExpressionNodeEfficient -from andromede.expression.linear_expression import LinearExpressionEfficient +from andromede.expression.expression import ExpressionNode +from andromede.expression.linear_expression import LinearExpression -ValueOrExprNodeOrLinearExpr = Union[ - int, float, ExpressionNodeEfficient, LinearExpressionEfficient -] +ValueOrExprNodeOrLinearExpr = Union[int, float, ExpressionNode, LinearExpression] class ValueType(Enum): diff --git a/src/andromede/model/constraint.py b/src/andromede/model/constraint.py index 02a03897..a4cd09a8 100644 --- a/src/andromede/model/constraint.py +++ b/src/andromede/model/constraint.py @@ -12,9 +12,9 @@ from dataclasses import InitVar, dataclass, field from typing import Any, Union -from andromede.expression.expression import ExpressionNodeEfficient, literal +from andromede.expression.expression import ExpressionNode, literal from andromede.expression.linear_expression import ( - LinearExpressionEfficient, + LinearExpression, StandaloneConstraint, linear_expressions_equal, wrap_in_linear_expr, @@ -33,22 +33,20 @@ class Constraint: name: str # Used only for mypy type checking, we could have done the same by using only the attribute expression expression_init: InitVar[ - Union[ExpressionNodeEfficient, LinearExpressionEfficient, StandaloneConstraint] + Union[ExpressionNode, LinearExpression, StandaloneConstraint] ] - expression: LinearExpressionEfficient = field(init=False) - lower_bound: LinearExpressionEfficient = field( + expression: LinearExpression = field(init=False) + lower_bound: LinearExpression = field( default=wrap_in_linear_expr(literal(-float("inf"))) ) - upper_bound: LinearExpressionEfficient = field( + upper_bound: LinearExpression = field( default=wrap_in_linear_expr(literal(float("inf"))) ) context: ProblemContext = field(default=ProblemContext.OPERATIONAL) def __post_init__( self, - expression_init: Union[ - ExpressionNodeEfficient, LinearExpressionEfficient, StandaloneConstraint - ], + expression_init: Union[ExpressionNode, LinearExpression, StandaloneConstraint], ) -> None: self.lower_bound = wrap_in_linear_expr(self.lower_bound) self.upper_bound = wrap_in_linear_expr(self.upper_bound) diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index d19eeb84..1b0c19d0 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -25,7 +25,7 @@ ComparisonNode, ComponentParameterNode, DivisionNode, - ExpressionNodeEfficient, + ExpressionNode, LiteralNode, MultiplicationNode, NegationNode, @@ -40,7 +40,7 @@ from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.linear_expression import ( - LinearExpressionEfficient, + LinearExpression, is_linear, wrap_in_linear_expr, wrap_in_linear_expr_if_present, @@ -80,7 +80,7 @@ def get_component_variable_structure( def _is_objective_contribution_valid( - model: "Model", objective_contribution: LinearExpressionEfficient + model: "Model", objective_contribution: LinearExpression ) -> bool: if not is_linear(objective_contribution): raise ValueError("Objective contribution must be a linear expression.") @@ -123,11 +123,11 @@ class PortFieldDefinition: port_field: PortFieldId # Used only for type checking... - definition_init: InitVar[Union[ExpressionNodeEfficient, LinearExpressionEfficient]] - definition: LinearExpressionEfficient = field(init=False) + definition_init: InitVar[Union[ExpressionNode, LinearExpression]] + definition: LinearExpression = field(init=False) def __post_init__( - self, definition_init: Union[ExpressionNodeEfficient, LinearExpressionEfficient] + self, definition_init: Union[ExpressionNode, LinearExpression] ) -> None: object.__setattr__(self, "definition", wrap_in_linear_expr(definition_init)) _validate_port_field_expression(self) @@ -136,7 +136,7 @@ def __post_init__( def port_field_def( port_name: str, field_name: str, - definition: Union[ExpressionNodeEfficient, LinearExpressionEfficient], + definition: Union[ExpressionNode, LinearExpression], ) -> PortFieldDefinition: return PortFieldDefinition(PortFieldId(port_name, field_name), definition) @@ -154,8 +154,8 @@ class Model: inter_block_dyn: bool = False parameters: Dict[str, Parameter] = field(default_factory=dict) variables: Dict[str, Variable] = field(default_factory=dict) - objective_operational_contribution: Optional[LinearExpressionEfficient] = None - objective_investment_contribution: Optional[LinearExpressionEfficient] = None + objective_operational_contribution: Optional[LinearExpression] = None + objective_investment_contribution: Optional[LinearExpression] = None ports: Dict[str, ModelPort] = field(default_factory=dict) # key = port name port_fields_definitions: Dict[PortFieldId, PortFieldDefinition] = field( default_factory=dict @@ -307,7 +307,7 @@ def _validate_port_field_expression(definition: PortFieldDefinition) -> None: def _check_port_field_expression_type(definition: PortFieldDefinition) -> None: - if not isinstance(definition.definition, LinearExpressionEfficient): + if not isinstance(definition.definition, LinearExpression): raise TypeError( f"Port field definition should be a LinearExpression, not a {type(definition.definition)}" ) diff --git a/src/andromede/model/resolve_library.py b/src/andromede/model/resolve_library.py index fc08feaa..22820e2e 100644 --- a/src/andromede/model/resolve_library.py +++ b/src/andromede/model/resolve_library.py @@ -12,10 +12,10 @@ from typing import Dict, List, Optional, TypedDict, Union # from andromede.expression import ExpressionNode -from andromede.expression.expression import ExpressionNodeEfficient +from andromede.expression.expression import ExpressionNode from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.linear_expression import ( - LinearExpressionEfficient, + LinearExpression, StandaloneConstraint, wrap_in_linear_expr_if_present, ) @@ -129,7 +129,7 @@ def _to_parameter(param: InputParameter) -> Parameter: def _to_expression_if_present( expr: Optional[str], identifiers: ModelIdentifiers -) -> Optional[LinearExpressionEfficient]: +) -> Optional[LinearExpression]: if not expr: return None return parse_expression(expr, identifiers) @@ -155,11 +155,9 @@ def _to_variable(var: InputVariable, identifiers: ModelIdentifiers) -> Variable: # Used only for mypy class ConstraintKwargs(TypedDict, total=False): name: str - expression_init: Union[ - ExpressionNodeEfficient, LinearExpressionEfficient, StandaloneConstraint - ] - lower_bound: LinearExpressionEfficient - upper_bound: LinearExpressionEfficient + expression_init: Union[ExpressionNode, LinearExpression, StandaloneConstraint] + lower_bound: LinearExpression + upper_bound: LinearExpression def _to_constraint( diff --git a/src/andromede/model/variable.py b/src/andromede/model/variable.py index 6900cf88..a1d94447 100644 --- a/src/andromede/model/variable.py +++ b/src/andromede/model/variable.py @@ -16,7 +16,7 @@ from andromede.expression.expression import literal from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.linear_expression import ( - LinearExpressionEfficient, + LinearExpression, linear_expressions_equal_if_present, wrap_in_linear_expr, wrap_in_linear_expr_if_present, @@ -36,8 +36,8 @@ class Variable: name: str data_type: ValueType - lower_bound: Optional[LinearExpressionEfficient] - upper_bound: Optional[LinearExpressionEfficient] + lower_bound: Optional[LinearExpression] + upper_bound: Optional[LinearExpression] structure: IndexingStructure context: ProblemContext diff --git a/src/andromede/simulation/linear_expression_resolver.py b/src/andromede/simulation/linear_expression_resolver.py index 5e15d7e4..95c8c895 100644 --- a/src/andromede/simulation/linear_expression_resolver.py +++ b/src/andromede/simulation/linear_expression_resolver.py @@ -20,14 +20,7 @@ resolve_coefficient, ) from andromede.expression.indexing_structure import RowIndex -from andromede.expression.linear_expression import ( - LinearExpressionEfficient, - TermEfficient, -) -from .resolved_linear_expression import ( - ResolvedLinearExpression, - ResolvedTerm, -) +from andromede.expression.linear_expression import LinearExpression, Term from andromede.expression.scenario_operator import Expectation from andromede.expression.time_operator import TimeShift from andromede.expression.value_provider import ( @@ -35,7 +28,9 @@ TimeScenarioIndices, ValueProvider, ) + from .optimization_context import OptimizationContext +from .resolved_linear_expression import ResolvedLinearExpression, ResolvedTerm @dataclass @@ -44,7 +39,7 @@ class LinearExpressionResolver: value_provider: ValueProvider def resolve( - self, expression: LinearExpressionEfficient, row_id: RowIndex + self, expression: LinearExpression, row_id: RowIndex ) -> ResolvedLinearExpression: resolved_terms = [] for term in expression.terms.values(): @@ -73,53 +68,49 @@ def resolve( return ResolvedLinearExpression(resolved_terms, resolved_constant) def resolve_constant_expr( - self, expression: LinearExpressionEfficient, row_id: RowIndex + self, expression: LinearExpression, row_id: RowIndex ) -> float: if not expression.is_constant(): raise ValueError(f"{str(self)} is not a constant expression") return resolve_coefficient(expression.constant, self.value_provider, row_id) def resolve_variables( - self, term: TermEfficient, row_id: RowIndex + self, term: Term, row_id: RowIndex ) -> Dict[TimeScenarioIndex, lp.Variable]: solver_vars = {} operator_ts_ids = self._row_id_to_term_time_scenario_id(term, row_id) for time in operator_ts_ids.time_indices: for scenario in operator_ts_ids.scenario_indices: - solver_vars[ - TimeScenarioIndex(time, scenario) - ] = self.context.get_component_variable( - time, - scenario, - term.component_id, - term.variable_name, - # At term build time, no information on the variable structure is known, we use it now - self.context.network.get_component(term.component_id) - .model.variables[term.variable_name] - .structure, + solver_vars[TimeScenarioIndex(time, scenario)] = ( + self.context.get_component_variable( + time, + scenario, + term.component_id, + term.variable_name, + # At term build time, no information on the variable structure is known, we use it now + self.context.network.get_component(term.component_id) + .model.variables[term.variable_name] + .structure, + ) ) return solver_vars def _row_id_to_term_time_scenario_id( - self, term: TermEfficient, row_id: RowIndex + self, term: Term, row_id: RowIndex ) -> TimeScenarioIndices: operator_time_ids = self._compute_operator_time_ids(term, row_id) operator_scenario_ids = self._compute_operator_scenario_ids(term, row_id) return TimeScenarioIndices(operator_time_ids, operator_scenario_ids) - def _compute_operator_scenario_ids( - self, term: TermEfficient, row_id: RowIndex - ) -> List[int]: + def _compute_operator_scenario_ids(self, term: Term, row_id: RowIndex) -> List[int]: if term.scenario_aggregator: operator_scenario_ids = list(range(self.context.scenarios)) else: operator_scenario_ids = [row_id.scenario] return operator_scenario_ids - def _compute_operator_time_ids( - self, term: TermEfficient, row_id: RowIndex - ) -> List[int]: + def _compute_operator_time_ids(self, term: Term, row_id: RowIndex) -> List[int]: if not term.time_operator and not term.time_aggregator: operator_time_ids = [row_id.time] elif term.time_operator: diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 28e347c3..755116ff 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -22,10 +22,7 @@ from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure -from andromede.expression.linear_expression import ( - LinearExpressionEfficient, - RowIndex, -) +from andromede.expression.linear_expression import LinearExpression, RowIndex from andromede.model.common import ValueType from andromede.model.constraint import Constraint from andromede.model.model import PortFieldId @@ -66,10 +63,10 @@ def _compute_indexing_structure( def _instantiate_model_expression( - model_expression: LinearExpressionEfficient, + model_expression: LinearExpression, component_id: str, optimization_context: OptimizationContext, -) -> LinearExpressionEfficient: +) -> LinearExpression: """ Performs common operations that are necessary on model expressions before their actual use: 1. add component ID for variables and parameters of THIS component @@ -122,7 +119,7 @@ def _create_objective( solver: lp.Solver, opt_context: OptimizationContext, component: Component, - objective_contribution: LinearExpressionEfficient, + objective_contribution: LinearExpression, ) -> None: instantiated_expr = _instantiate_model_expression( objective_contribution, component.id, opt_context diff --git a/src/andromede/simulation/optimization_context.py b/src/andromede/simulation/optimization_context.py index 530cd013..71735c41 100644 --- a/src/andromede/simulation/optimization_context.py +++ b/src/andromede/simulation/optimization_context.py @@ -21,16 +21,17 @@ from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.linear_expression import ( - LinearExpressionEfficient, + LinearExpression, PortFieldId, PortFieldKey, ) from andromede.expression.value_provider import TimeScenarioIndex, TimeScenarioIndices -from .time_block import TimeBlock from andromede.study.data import DataBase from andromede.study.network import Component, Network from andromede.utils import get_or_add +from .time_block import TimeBlock + @dataclass(eq=True, frozen=True) class TimestepComponentVariableKey: @@ -92,7 +93,7 @@ def __init__( self._component_variables: Dict[TimestepComponentVariableKey, lp.Variable] = {} self._solver_variables: Dict[lp.Variable, SolverVariableInfo] = {} self._connection_fields_expressions: Dict[ - PortFieldKey, List[LinearExpressionEfficient] + PortFieldKey, List[LinearExpression] ] = {} @property @@ -109,7 +110,7 @@ def block_length(self) -> int: @property def connection_fields_expressions( self, - ) -> Dict[PortFieldKey, List[LinearExpressionEfficient]]: + ) -> Dict[PortFieldKey, List[LinearExpression]]: return self._connection_fields_expressions # TODO: Need to think about data processing when creating blocks with varying or inequal time steps length (aggregation, sum ?, mean of data ?) @@ -185,7 +186,7 @@ def register_connection_fields_expressions( component_id: str, port_name: str, field_name: str, - expression: LinearExpressionEfficient, + expression: LinearExpression, ) -> None: key = PortFieldKey(component_id, PortFieldId(port_name, field_name)) get_or_add(self._connection_fields_expressions, key, lambda: []).append( @@ -241,14 +242,14 @@ def get_component_parameter_value( ) for block_timestep in time_scenarios_indices.time_indices: for scenario in time_scenarios_indices.scenario_indices: - result[ - TimeScenarioIndex(block_timestep, scenario) - ] = _get_parameter_value( - context, - _get_data_time_key(block_timestep, param_index), - _get_data_scenario_key(scenario, param_index), - component_id, - name, + result[TimeScenarioIndex(block_timestep, scenario)] = ( + _get_parameter_value( + context, + _get_data_time_key(block_timestep, param_index), + _get_data_scenario_key(scenario, param_index), + component_id, + name, + ) ) return result diff --git a/src/andromede/simulation/strategy.py b/src/andromede/simulation/strategy.py index 6a5326c7..89e753fa 100644 --- a/src/andromede/simulation/strategy.py +++ b/src/andromede/simulation/strategy.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from typing import Generator, Optional -from andromede.expression.linear_expression import LinearExpressionEfficient +from andromede.expression.linear_expression import LinearExpression from andromede.model import Constraint, Model, ProblemContext, Variable @@ -37,14 +37,12 @@ def get_constraints(self, model: Model) -> Generator[Constraint, None, None]: yield constraint @abstractmethod - def _keep_from_context(self, context: ProblemContext) -> bool: - ... + def _keep_from_context(self, context: ProblemContext) -> bool: ... @abstractmethod def get_objectives( self, model: Model - ) -> Generator[Optional[LinearExpressionEfficient], None, None]: - ... + ) -> Generator[Optional[LinearExpression], None, None]: ... class MergedProblemStrategy(ModelSelectionStrategy): @@ -53,7 +51,7 @@ def _keep_from_context(self, context: ProblemContext) -> bool: def get_objectives( self, model: Model - ) -> Generator[Optional[LinearExpressionEfficient], None, None]: + ) -> Generator[Optional[LinearExpression], None, None]: yield model.objective_operational_contribution yield model.objective_investment_contribution @@ -66,7 +64,7 @@ def _keep_from_context(self, context: ProblemContext) -> bool: def get_objectives( self, model: Model - ) -> Generator[Optional[LinearExpressionEfficient], None, None]: + ) -> Generator[Optional[LinearExpression], None, None]: yield model.objective_investment_contribution @@ -78,5 +76,5 @@ def _keep_from_context(self, context: ProblemContext) -> bool: def get_objectives( self, model: Model - ) -> Generator[Optional[LinearExpressionEfficient], None, None]: + ) -> Generator[Optional[LinearExpression], None, None]: yield model.objective_operational_contribution diff --git a/tests/unittests/expressions/parsing/test_expression_parsing.py b/tests/unittests/expressions/parsing/test_expression_parsing.py index 1aede958..b43c3144 100644 --- a/tests/unittests/expressions/parsing/test_expression_parsing.py +++ b/tests/unittests/expressions/parsing/test_expression_parsing.py @@ -15,13 +15,13 @@ from andromede.expression.equality import expressions_equal from andromede.expression.expression import ( - ExpressionNodeEfficient, + ExpressionNode, ExpressionRange, literal, param, ) from andromede.expression.linear_expression import ( - LinearExpressionEfficient, + LinearExpression, StandaloneConstraint, linear_expressions_equal, port_field, @@ -163,18 +163,16 @@ def test_parsing_visitor( variables: Set[str], parameters: Set[str], expression_str: str, - expected: Union[ - ExpressionNodeEfficient, LinearExpressionEfficient, StandaloneConstraint - ], + expected: Union[ExpressionNode, LinearExpression, StandaloneConstraint], ) -> None: identifiers = ModelIdentifiers(variables, parameters) expr = parse_expression(expression_str, identifiers) print() print(f"Expected: \n {str(expected)}") print(f"Parsed: \n {str(expr)}") - if isinstance(expected, ExpressionNodeEfficient): + if isinstance(expected, ExpressionNode): assert expressions_equal(expr, expected) - elif isinstance(expected, LinearExpressionEfficient): + elif isinstance(expected, LinearExpression): assert linear_expressions_equal(expr, expected) elif isinstance(expected, StandaloneConstraint): assert expected == expr diff --git a/tests/unittests/expressions/test_equality.py b/tests/unittests/expressions/test_equality.py index d654bdf1..0d3a74c8 100644 --- a/tests/unittests/expressions/test_equality.py +++ b/tests/unittests/expressions/test_equality.py @@ -16,7 +16,7 @@ from andromede.expression.copy import copy_expression from andromede.expression.equality import expressions_equal from andromede.expression.expression import ( - ExpressionNodeEfficient, + ExpressionNode, InstancesTimeIndex, TimeAggregatorName, TimeAggregatorNode, @@ -28,7 +28,7 @@ ) -def shifted_param() -> ExpressionNodeEfficient: +def shifted_param() -> ExpressionNode: return TimeOperatorNode( param("q"), TimeOperatorName.SHIFT, InstancesTimeIndex(expression_range(0, 2)) ) @@ -71,7 +71,7 @@ def shifted_param() -> ExpressionNodeEfficient: param("q").expec(), ], ) -def test_equals(expr: ExpressionNodeEfficient) -> None: +def test_equals(expr: ExpressionNode) -> None: copy = copy_expression(expr) assert expressions_equal(expr, copy) @@ -133,7 +133,7 @@ def test_equals(expr: ExpressionNodeEfficient) -> None: (param("q").expec(), param("y").expec()), ], ) -def test_not_equals(lhs: ExpressionNodeEfficient, rhs: ExpressionNodeEfficient) -> None: +def test_not_equals(lhs: ExpressionNode, rhs: ExpressionNode) -> None: assert not expressions_equal(lhs, rhs) diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions_efficient.py index 4a690d4d..803917ec 100644 --- a/tests/unittests/expressions/test_expressions_efficient.py +++ b/tests/unittests/expressions/test_expressions_efficient.py @@ -19,7 +19,7 @@ from andromede.expression.equality import expressions_equal from andromede.expression.expression import ( ComponentParameterNode, - ExpressionNodeEfficient, + ExpressionNode, ExpressionRange, InstancesTimeIndex, LiteralNode, @@ -35,10 +35,10 @@ from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure, RowIndex from andromede.expression.linear_expression import ( - LinearExpressionEfficient, + LinearExpression, StandaloneConstraint, - TermEfficient, - TermKeyEfficient, + Term, + TermKey, comp_var, linear_expressions_equal, sum_expressions, @@ -104,12 +104,8 @@ def scenarios() -> int: # TODO: Redundant with add tests in test_linear_expressions_efficient ? def test_comp_parameter() -> None: - expr1 = LinearExpressionEfficient([], 1) + LinearExpressionEfficient( - [TermEfficient(1, "comp1", "x")] - ) - expr2 = expr1 / LinearExpressionEfficient( - constant=ComponentParameterNode("comp1", "p") - ) + expr1 = LinearExpression([], 1) + LinearExpression([Term(1, "comp1", "x")]) + expr2 = expr1 / LinearExpression(constant=ComponentParameterNode("comp1", "p")) assert str(expr2) == "(1.0 / comp1.p)x + (1.0 / comp1.p)" context = ComponentEvaluationContext( @@ -121,10 +117,8 @@ def test_comp_parameter() -> None: # TODO: Find a better name def test_ast() -> None: - expr1 = LinearExpressionEfficient([], 1) + LinearExpressionEfficient( - [TermEfficient(1, "", "x")] - ) - expr2 = expr1 / LinearExpressionEfficient(constant=ParameterNode("p")) + expr1 = LinearExpression([], 1) + LinearExpression([Term(1, "", "x")]) + expr2 = expr1 / LinearExpression(constant=ParameterNode("p")) assert str(expr2) == "(1.0 / p)x + (1.0 / p)" @@ -135,7 +129,7 @@ def test_ast() -> None: def test_operators() -> None: x = var("x") p = param("p") - expr: LinearExpressionEfficient = (5 * x + 3) / p - 2 + expr: LinearExpression = (5 * x + 3) / p - 2 assert str(expr) == "(5.0 / p)x + ((3.0 / p) - 2.0)" @@ -182,19 +176,19 @@ def test_degree_computation_should_take_into_account_simplifications() -> None: # assert expr.resolve_parameters(TestParamProvider()) == (5 * x + 3) / 2 -# TODO: Write tests on ExpressionEfficientNodes for tree simplification, do the same for multiplication, substraction, etc +# TODO: Write tests on ExpressionNodes for tree simplification, do the same for multiplication, substraction, etc @pytest.mark.parametrize( "e1, e2, expected", [ ( var("x"), -var("x"), - LinearExpressionEfficient(), + LinearExpression(), ), ( param("p"), -param("p"), - LinearExpressionEfficient(), + LinearExpression(), ), ( var("x"), @@ -250,9 +244,9 @@ def test_degree_computation_should_take_into_account_simplifications() -> None: ], ) def test_addition( - e1: LinearExpressionEfficient, - e2: LinearExpressionEfficient, - expected: LinearExpressionEfficient, + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, ) -> None: assert linear_expressions_equal( wrap_in_linear_expr(e1) + wrap_in_linear_expr(e2), wrap_in_linear_expr(expected) @@ -305,12 +299,12 @@ def test_addition( ( var("x"), var("x"), - LinearExpressionEfficient(), + LinearExpression(), ), ( param("p"), param("p"), - LinearExpressionEfficient(), + LinearExpression(), ), ( literal(4) * param("p"), @@ -326,9 +320,9 @@ def test_addition( ], ) def test_substraction( - e1: LinearExpressionEfficient, - e2: LinearExpressionEfficient, - expected: LinearExpressionEfficient, + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, ) -> None: assert linear_expressions_equal( wrap_in_linear_expr(e1) - wrap_in_linear_expr(e2), wrap_in_linear_expr(expected) @@ -340,24 +334,24 @@ def test_substraction( [ ( (5 * comp_var("c", "x") + 3) / 2, - LinearExpressionEfficient([TermEfficient(2.5, "c", "x")], 1.5), + LinearExpression([Term(2.5, "c", "x")], 1.5), ), ( param("p") * comp_var("c", "x"), - LinearExpressionEfficient( - [TermEfficient(ParameterNode("p"), "c", "x")], + LinearExpression( + [Term(ParameterNode("p"), "c", "x")], ), ), ( param("p") * comp_var("c", "x"), - LinearExpressionEfficient( - [TermEfficient(ParameterNode("p"), "c", "x")], + LinearExpression( + [Term(ParameterNode("p"), "c", "x")], ), ), ], ) def test_linear_expression_equality( - lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient + lhs: LinearExpression, rhs: LinearExpression ) -> None: assert linear_expressions_equal(lhs, rhs) @@ -404,7 +398,7 @@ def test_comparison() -> None: ( (var("x") + var("y") + literal(1)).shift(1), { - TermKeyEfficient( + TermKey( "", "x", TimeShift(InstancesTimeIndex(1)), @@ -412,7 +406,7 @@ def test_comparison() -> None: stay_roll=True ), # The internal representation of shift(1) is sum(shift=1) scenario_aggregator=None, - ): TermEfficient( + ): Term( LiteralNode(1), "", "x", @@ -421,7 +415,7 @@ def test_comparison() -> None: ), time_aggregator=TimeSum(stay_roll=True), ), - TermKeyEfficient( + TermKey( "", "y", TimeShift( @@ -429,7 +423,7 @@ def test_comparison() -> None: ), time_aggregator=TimeSum(stay_roll=True), scenario_aggregator=None, - ): TermEfficient( + ): Term( LiteralNode(1), "", "y", @@ -448,7 +442,7 @@ def test_comparison() -> None: ( (var("x") + var("y") + literal(1)).eval(1), { - TermKeyEfficient( + TermKey( "", "x", TimeEvaluation(InstancesTimeIndex(1)), @@ -456,7 +450,7 @@ def test_comparison() -> None: stay_roll=True ), # The internal representation of eval(1) is sum(eval=1) scenario_aggregator=None, - ): TermEfficient( + ): Term( LiteralNode(1), "", "x", @@ -465,7 +459,7 @@ def test_comparison() -> None: ), time_aggregator=TimeSum(stay_roll=True), ), - TermKeyEfficient( + TermKey( "", "y", TimeEvaluation( @@ -473,7 +467,7 @@ def test_comparison() -> None: ), time_aggregator=TimeSum(stay_roll=True), scenario_aggregator=None, - ): TermEfficient( + ): Term( LiteralNode(1), "", "y", @@ -492,26 +486,26 @@ def test_comparison() -> None: ( (var("x") + var("y") + literal(1)).sum(), { - TermKeyEfficient( + TermKey( "", "x", time_operator=None, time_aggregator=TimeSum(stay_roll=False), scenario_aggregator=None, - ): TermEfficient( + ): Term( LiteralNode(1), # Sum is not distributed to coeff "", "x", time_operator=None, time_aggregator=TimeSum(stay_roll=False), ), - TermKeyEfficient( + TermKey( "", "y", time_operator=None, time_aggregator=TimeSum(stay_roll=False), scenario_aggregator=None, - ): TermEfficient( + ): Term( LiteralNode(1), # Sum is not distributed to coeff "", "y", @@ -526,9 +520,9 @@ def test_comparison() -> None: ], ) def test_operators_are_correctly_distributed_over_terms( - expr: LinearExpressionEfficient, - expec_terms: Dict[TermKeyEfficient, TermEfficient], - expec_constant: ExpressionNodeEfficient, + expr: LinearExpression, + expec_terms: Dict[TermKey, Term], + expec_constant: ExpressionNode, ) -> None: assert expr.terms == expec_terms assert expressions_equal(expr.constant, expec_constant) @@ -617,7 +611,7 @@ def test_eval_on_time_step_list_raises_value_error() -> None: ], ) def test_compute_indexation( - linear_expr: LinearExpressionEfficient, expected_indexation: IndexingStructure + linear_expr: LinearExpression, expected_indexation: IndexingStructure ) -> None: provider = StructureProvider() assert linear_expr.compute_indexation(provider) == expected_indexation @@ -681,7 +675,7 @@ def get_variable_structure(self, name: str) -> IndexingStructure: ], ) def test_sum_expressions( - sum_expr: LinearExpressionEfficient, expected: LinearExpressionEfficient + sum_expr: LinearExpression, expected: LinearExpression ) -> None: assert linear_expressions_equal(sum_expr, wrap_in_linear_expr(expected)) @@ -696,5 +690,5 @@ def test_sum_expressions( (var("x") + literal(4), False), ], ) -def test_is_unbound(expr: LinearExpressionEfficient, unbound: bool) -> None: +def test_is_unbound(expr: LinearExpression, unbound: bool) -> None: assert wrap_in_linear_expr(expr).is_unbound() == unbound diff --git a/tests/unittests/expressions/test_linear_expressions_efficient.py b/tests/unittests/expressions/test_linear_expressions_efficient.py index eba57c14..d8c4d79f 100644 --- a/tests/unittests/expressions/test_linear_expressions_efficient.py +++ b/tests/unittests/expressions/test_linear_expressions_efficient.py @@ -14,16 +14,12 @@ import pytest -from andromede.expression.expression import ( - TimeAggregatorNode, - expression_range, - param, -) +from andromede.expression.expression import TimeAggregatorNode, expression_range, param from andromede.expression.linear_expression import ( - LinearExpressionEfficient, + LinearExpression, PortFieldId, PortFieldTerm, - TermEfficient, + Term, _copy_expression, linear_expressions_equal, var, @@ -46,7 +42,7 @@ def test_affine_expression_printing_should_reflect_required_formatting( coeff: float, var_name: str, constant: float, expec_str: str ) -> None: - expr = LinearExpressionEfficient([TermEfficient(coeff, "c", var_name)], constant) + expr = LinearExpression([Term(coeff, "c", var_name)], constant) assert str(expr) == expec_str @@ -64,8 +60,8 @@ def test_affine_expression_printing_should_reflect_required_formatting( var("x").expec(), ], ) -def test_linear_expressions_equal(expr: LinearExpressionEfficient) -> None: - copy = LinearExpressionEfficient() +def test_linear_expressions_equal(expr: LinearExpression) -> None: + copy = LinearExpression() _copy_expression(expr, copy) assert linear_expressions_equal(expr, copy) @@ -74,43 +70,41 @@ def test_linear_expressions_equal(expr: LinearExpressionEfficient) -> None: "lhs, rhs", [ ( - LinearExpressionEfficient([], 1) + LinearExpressionEfficient([], 3), - LinearExpressionEfficient([], 4), + LinearExpression([], 1) + LinearExpression([], 3), + LinearExpression([], 4), ), ( - LinearExpressionEfficient([], 4) / LinearExpressionEfficient([], 2), - LinearExpressionEfficient([], 2), + LinearExpression([], 4) / LinearExpression([], 2), + LinearExpression([], 2), ), ( - LinearExpressionEfficient([], 4) * LinearExpressionEfficient([], 2), - LinearExpressionEfficient([], 8), + LinearExpression([], 4) * LinearExpression([], 2), + LinearExpression([], 8), ), ( - LinearExpressionEfficient([], 4) - LinearExpressionEfficient([], 2), - LinearExpressionEfficient([], 2), + LinearExpression([], 4) - LinearExpression([], 2), + LinearExpression([], 2), ), ], ) -def test_constant_expressions( - lhs: LinearExpressionEfficient, rhs: LinearExpressionEfficient -) -> None: +def test_constant_expressions(lhs: LinearExpression, rhs: LinearExpression) -> None: assert linear_expressions_equal(lhs, rhs) @pytest.mark.parametrize( "terms_dict, constant, exp_terms, exp_constant", [ - ({"x": TermEfficient(0, "c", "x")}, 1, {}, 1), - ({"x": TermEfficient(1, "c", "x")}, 1, {"x": TermEfficient(1, "c", "x")}, 1), + ({"x": Term(0, "c", "x")}, 1, {}, 1), + ({"x": Term(1, "c", "x")}, 1, {"x": Term(1, "c", "x")}, 1), ], ) def test_instantiate_linear_expression_from_dict( - terms_dict: Dict[str, TermEfficient], + terms_dict: Dict[str, Term], constant: float, - exp_terms: Dict[str, TermEfficient], + exp_terms: Dict[str, Term], exp_constant: float, ) -> None: - expr = LinearExpressionEfficient(terms_dict, constant) + expr = LinearExpression(terms_dict, constant) assert expr.terms == exp_terms assert expr.constant == exp_constant @@ -118,20 +112,20 @@ def test_instantiate_linear_expression_from_dict( @pytest.mark.parametrize( "expr, expected", [ - (LinearExpressionEfficient(), True), - (LinearExpressionEfficient([]), True), - (LinearExpressionEfficient([], 0, {}), True), - (LinearExpressionEfficient([TermEfficient(1, "c", "x")], 0, {}), False), - (LinearExpressionEfficient([], 1, {}), False), + (LinearExpression(), True), + (LinearExpression([]), True), + (LinearExpression([], 0, {}), True), + (LinearExpression([Term(1, "c", "x")], 0, {}), False), + (LinearExpression([], 1, {}), False), ( - LinearExpressionEfficient( + LinearExpression( [], 1, {PortFieldId("p", "f"): PortFieldTerm(1, "p", "f")} ), False, ), ], ) -def test_is_zero(expr: LinearExpressionEfficient, expected: bool) -> None: +def test_is_zero(expr: LinearExpression, expected: bool) -> None: assert expr.is_zero() == expected @@ -139,55 +133,49 @@ def test_is_zero(expr: LinearExpressionEfficient, expected: bool) -> None: "e1, e2, expected", [ ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 1), - LinearExpressionEfficient([TermEfficient(5, "c", "x")], 2), - LinearExpressionEfficient([TermEfficient(15, "c", "x")], 3), + LinearExpression([Term(10, "c", "x")], 1), + LinearExpression([Term(5, "c", "x")], 2), + LinearExpression([Term(15, "c", "x")], 3), ), ( - LinearExpressionEfficient([TermEfficient(10, "c1", "x")], 1), - LinearExpressionEfficient([TermEfficient(5, "c2", "x")], 2), - LinearExpressionEfficient( - [TermEfficient(10, "c1", "x"), TermEfficient(5, "c2", "x")], 3 - ), + LinearExpression([Term(10, "c1", "x")], 1), + LinearExpression([Term(5, "c2", "x")], 2), + LinearExpression([Term(10, "c1", "x"), Term(5, "c2", "x")], 3), ), ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 0), - LinearExpressionEfficient([TermEfficient(5, "c", "y")], 0), - LinearExpressionEfficient( - [TermEfficient(10, "c", "x"), TermEfficient(5, "c", "y")], 0 - ), + LinearExpression([Term(10, "c", "x")], 0), + LinearExpression([Term(5, "c", "y")], 0), + LinearExpression([Term(10, "c", "x"), Term(5, "c", "y")], 0), ), ( - LinearExpressionEfficient(), - LinearExpressionEfficient([TermEfficient(10, "c", "x", TimeShift(-1))]), - LinearExpressionEfficient([TermEfficient(10, "c", "x", TimeShift(-1))]), + LinearExpression(), + LinearExpression([Term(10, "c", "x", TimeShift(-1))]), + LinearExpression([Term(10, "c", "x", TimeShift(-1))]), ), ( - LinearExpressionEfficient(), - LinearExpressionEfficient( - [TermEfficient(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] + LinearExpression(), + LinearExpression( + [Term(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] ), - LinearExpressionEfficient( - [TermEfficient(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] + LinearExpression( + [Term(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] ), ), ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")]), - LinearExpressionEfficient( - [TermEfficient(10, "c", "x", time_operator=TimeShift(-1))] - ), - LinearExpressionEfficient( + LinearExpression([Term(10, "c", "x")]), + LinearExpression([Term(10, "c", "x", time_operator=TimeShift(-1))]), + LinearExpression( [ - TermEfficient(10, "c", "x"), - TermEfficient(10, "c", "x", time_operator=TimeShift(-1)), + Term(10, "c", "x"), + Term(10, "c", "x", time_operator=TimeShift(-1)), ] ), ), ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")]), - LinearExpressionEfficient( + LinearExpression([Term(10, "c", "x")]), + LinearExpression( [ - TermEfficient( + Term( 10, "c", "x", @@ -196,10 +184,10 @@ def test_is_zero(expr: LinearExpressionEfficient, expected: bool) -> None: ) ] ), - LinearExpressionEfficient( + LinearExpression( [ - TermEfficient(10, "c", "x"), - TermEfficient( + Term(10, "c", "x"), + Term( 10, "c", "x", @@ -212,9 +200,9 @@ def test_is_zero(expr: LinearExpressionEfficient, expected: bool) -> None: ], ) def test_addition( - e1: LinearExpressionEfficient, - e2: LinearExpressionEfficient, - expected: LinearExpressionEfficient, + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, ) -> None: assert linear_expressions_equal(e1 + e2, expected) @@ -222,8 +210,8 @@ def test_addition( def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_from_terms() -> ( None ): - e1 = LinearExpressionEfficient([TermEfficient(10, "c", "x")], 1) - e2 = LinearExpressionEfficient([TermEfficient(10, "c", "x")], 2) + e1 = LinearExpression([Term(10, "c", "x")], 1) + e2 = LinearExpression([Term(10, "c", "x")], 2) e3 = e2 - e1 assert e3.terms == {} @@ -232,24 +220,24 @@ def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_fr "e1, e2, expected", [ ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 3), - LinearExpressionEfficient([], 2), - LinearExpressionEfficient([TermEfficient(20, "c", "x")], 6), + LinearExpression([Term(10, "c", "x")], 3), + LinearExpression([], 2), + LinearExpression([Term(20, "c", "x")], 6), ), ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 3), - LinearExpressionEfficient([], 1), - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 3), + LinearExpression([Term(10, "c", "x")], 3), + LinearExpression([], 1), + LinearExpression([Term(10, "c", "x")], 3), ), ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 3), - LinearExpressionEfficient(), - LinearExpressionEfficient(), + LinearExpression([Term(10, "c", "x")], 3), + LinearExpression(), + LinearExpression(), ), ( - LinearExpressionEfficient( + LinearExpression( [ - TermEfficient( + Term( 10, "c", "x", @@ -259,10 +247,10 @@ def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_fr ], 3, ), - LinearExpressionEfficient([], 2), - LinearExpressionEfficient( + LinearExpression([], 2), + LinearExpression( [ - TermEfficient( + Term( 20, "c", "x", @@ -276,17 +264,17 @@ def test_operation_that_leads_to_term_with_zero_coefficient_should_be_removed_fr ], ) def test_multiplication( - e1: LinearExpressionEfficient, - e2: LinearExpressionEfficient, - expected: LinearExpressionEfficient, + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, ) -> None: assert linear_expressions_equal(e1 * e2, expected) assert linear_expressions_equal(e2 * e1, expected) def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> None: - e1 = LinearExpressionEfficient([TermEfficient(10, "c", "x")], 0) - e2 = LinearExpressionEfficient([TermEfficient(5, "c", "x")], 0) + e1 = LinearExpression([Term(10, "c", "x")], 0) + e2 = LinearExpression([Term(5, "c", "x")], 0) with pytest.raises(ValueError) as exc: _ = e1 * e2 assert str(exc.value) == "Cannot multiply two non constant expression" @@ -296,13 +284,13 @@ def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> "e1, expected", [ ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 5), - LinearExpressionEfficient([TermEfficient(-10, "c", "x")], -5), + LinearExpression([Term(10, "c", "x")], 5), + LinearExpression([Term(-10, "c", "x")], -5), ), ( - LinearExpressionEfficient( + LinearExpression( [ - TermEfficient( + Term( 10, "c", "x", @@ -313,9 +301,9 @@ def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> ], 5, ), - LinearExpressionEfficient( + LinearExpression( [ - TermEfficient( + Term( -10, "c", "x", @@ -329,9 +317,7 @@ def test_multiplication_of_two_non_constant_terms_should_raise_value_error() -> ), ], ) -def test_negation( - e1: LinearExpressionEfficient, expected: LinearExpressionEfficient -) -> None: +def test_negation(e1: LinearExpression, expected: LinearExpression) -> None: assert linear_expressions_equal(-e1, expected) @@ -339,59 +325,49 @@ def test_negation( "e1, e2, expected", [ ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 1), - LinearExpressionEfficient([TermEfficient(5, "c", "x")], 2), - LinearExpressionEfficient([TermEfficient(5, "c", "x")], -1), + LinearExpression([Term(10, "c", "x")], 1), + LinearExpression([Term(5, "c", "x")], 2), + LinearExpression([Term(5, "c", "x")], -1), ), ( - LinearExpressionEfficient([TermEfficient(10, "c1", "x")], 1), - LinearExpressionEfficient([TermEfficient(5, "c2", "x")], 2), - LinearExpressionEfficient( - [TermEfficient(10, "c1", "x"), TermEfficient(-5, "c2", "x")], -1 - ), + LinearExpression([Term(10, "c1", "x")], 1), + LinearExpression([Term(5, "c2", "x")], 2), + LinearExpression([Term(10, "c1", "x"), Term(-5, "c2", "x")], -1), ), ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 0), - LinearExpressionEfficient([TermEfficient(5, "c", "y")], 0), - LinearExpressionEfficient( - [TermEfficient(10, "c", "x"), TermEfficient(-5, "c", "y")], 0 - ), + LinearExpression([Term(10, "c", "x")], 0), + LinearExpression([Term(5, "c", "y")], 0), + LinearExpression([Term(10, "c", "x"), Term(-5, "c", "y")], 0), ), ( - LinearExpressionEfficient(), - LinearExpressionEfficient( - [TermEfficient(10, "c", "x", time_operator=TimeShift(-1))] - ), - LinearExpressionEfficient( - [TermEfficient(-10, "c", "x", time_operator=TimeShift(-1))] - ), + LinearExpression(), + LinearExpression([Term(10, "c", "x", time_operator=TimeShift(-1))]), + LinearExpression([Term(-10, "c", "x", time_operator=TimeShift(-1))]), ), ( - LinearExpressionEfficient(), - LinearExpressionEfficient( - [TermEfficient(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] + LinearExpression(), + LinearExpression( + [Term(10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] ), - LinearExpressionEfficient( - [TermEfficient(-10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] + LinearExpression( + [Term(-10, "c", "x", time_aggregator=TimeSum(stay_roll=True))] ), ), ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")]), - LinearExpressionEfficient( - [TermEfficient(10, "c", "x", time_operator=TimeShift(-1))] - ), - LinearExpressionEfficient( + LinearExpression([Term(10, "c", "x")]), + LinearExpression([Term(10, "c", "x", time_operator=TimeShift(-1))]), + LinearExpression( [ - TermEfficient(10, "c", "x"), - TermEfficient(-10, "c", "x", time_operator=TimeShift(-1)), + Term(10, "c", "x"), + Term(-10, "c", "x", time_operator=TimeShift(-1)), ] ), ), ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")]), - LinearExpressionEfficient( + LinearExpression([Term(10, "c", "x")]), + LinearExpression( [ - TermEfficient( + Term( 10, "c", "x", @@ -401,10 +377,10 @@ def test_negation( ) ] ), - LinearExpressionEfficient( + LinearExpression( [ - TermEfficient(10, "c", "x"), - TermEfficient( + Term(10, "c", "x"), + Term( -10, "c", "x", @@ -418,9 +394,9 @@ def test_negation( ], ) def test_substraction( - e1: LinearExpressionEfficient, - e2: LinearExpressionEfficient, - expected: LinearExpressionEfficient, + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, ) -> None: assert linear_expressions_equal(e1 - e2, expected) @@ -429,19 +405,19 @@ def test_substraction( "e1, e2, expected", [ ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 15), - LinearExpressionEfficient([], 5), - LinearExpressionEfficient([TermEfficient(2, "c", "x")], 3), + LinearExpression([Term(10, "c", "x")], 15), + LinearExpression([], 5), + LinearExpression([Term(2, "c", "x")], 3), ), ( - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 15), - LinearExpressionEfficient([], 1), - LinearExpressionEfficient([TermEfficient(10, "c", "x")], 15), + LinearExpression([Term(10, "c", "x")], 15), + LinearExpression([], 1), + LinearExpression([Term(10, "c", "x")], 15), ), ( - LinearExpressionEfficient( + LinearExpression( [ - TermEfficient( + Term( 10, "c", "x", @@ -452,10 +428,10 @@ def test_substraction( ], 15, ), - LinearExpressionEfficient([], 5), - LinearExpressionEfficient( + LinearExpression([], 5), + LinearExpression( [ - TermEfficient( + Term( 2, "c", "x", @@ -470,24 +446,24 @@ def test_substraction( ], ) def test_division( - e1: LinearExpressionEfficient, - e2: LinearExpressionEfficient, - expected: LinearExpressionEfficient, + e1: LinearExpression, + e2: LinearExpression, + expected: LinearExpression, ) -> None: assert linear_expressions_equal(e1 / e2, expected) def test_division_by_zero_sould_raise_zero_division_error() -> None: - e1 = LinearExpressionEfficient([TermEfficient(10, "c", "x")], 15) - e2 = LinearExpressionEfficient() + e1 = LinearExpression([Term(10, "c", "x")], 15) + e2 = LinearExpression() with pytest.raises(ZeroDivisionError) as exc: _ = e1 / e2 assert str(exc.value) == "Cannot divide expression by zero" def test_division_by_non_constant_expr_sould_raise_value_error() -> None: - e1 = LinearExpressionEfficient([TermEfficient(10, "c", "x")], 15) - e2 = LinearExpressionEfficient() + e1 = LinearExpression([Term(10, "c", "x")], 15) + e2 = LinearExpression() with pytest.raises(ValueError) as exc: _ = e2 / e1 assert str(exc.value) == "Cannot divide by a non constant expression" @@ -496,9 +472,9 @@ def test_division_by_non_constant_expr_sould_raise_value_error() -> None: def test_imul_preserve_identity() -> None: # technical test to check the behaviour of reassigning "self" in imul operator: # it did not preserve identity, which could lead to weird behaviour - e1 = LinearExpressionEfficient([], 15) + e1 = LinearExpression([], 15) e2 = e1 - e1 *= LinearExpressionEfficient([], 2) - assert linear_expressions_equal(e1, LinearExpressionEfficient([], 30)) + e1 *= LinearExpression([], 2) + assert linear_expressions_equal(e1, LinearExpression([], 30)) assert linear_expressions_equal(e2, e1) assert e2 is e1 diff --git a/tests/unittests/expressions/test_port_resolver.py b/tests/unittests/expressions/test_port_resolver.py index 4008fba9..ed30fcd9 100644 --- a/tests/unittests/expressions/test_port_resolver.py +++ b/tests/unittests/expressions/test_port_resolver.py @@ -16,7 +16,7 @@ from andromede.expression.equality import expressions_equal from andromede.expression.linear_expression import ( - LinearExpressionEfficient, + LinearExpression, PortFieldId, PortFieldKey, linear_expressions_equal, @@ -32,13 +32,13 @@ (port_field("port", "field") - 2, var("flow") - 2), (port_field("port", "field") * 2, 2 * var("flow")), (port_field("port", "field") / 2, var("flow") / 2), - (port_field("port", "field") * 0, LinearExpressionEfficient()), + (port_field("port", "field") * 0, LinearExpression()), ], ) def test_port_field_resolution( - port_expr: LinearExpressionEfficient, expected: LinearExpressionEfficient + port_expr: LinearExpression, expected: LinearExpression ) -> None: - ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]] = {} + ports_expressions: Dict[PortFieldKey, List[LinearExpression]] = {} key = PortFieldKey("com_id", PortFieldId(field_name="field", port_name="port")) expression = var("flow") @@ -55,7 +55,7 @@ def test_port_field_resolution( def test_port_field_resolution_sum() -> None: - ports_expressions: Dict[PortFieldKey, List[LinearExpressionEfficient]] = {} + ports_expressions: Dict[PortFieldKey, List[LinearExpression]] = {} key = PortFieldKey("com_id", PortFieldId(field_name="field", port_name="port")) diff --git a/tests/unittests/expressions/test_resolve_coefficients.py b/tests/unittests/expressions/test_resolve_coefficients.py index 08daa4ab..f5619f0c 100644 --- a/tests/unittests/expressions/test_resolve_coefficients.py +++ b/tests/unittests/expressions/test_resolve_coefficients.py @@ -20,7 +20,7 @@ from andromede.expression.expression import ( Comparator, ComparisonNode, - ExpressionNodeEfficient, + ExpressionNode, ExpressionRange, InstancesTimeIndex, LiteralNode, @@ -150,7 +150,7 @@ def provider() -> CustomValueProvider: ], ) def test_resolve_coefficient_raises_value_error_on_port_field_node( - port_node: ExpressionNodeEfficient, provider: CustomValueProvider + port_node: ExpressionNode, provider: CustomValueProvider ) -> None: with pytest.raises( ValueError, match="Port fields must be resolved before evaluating parameters" @@ -186,7 +186,7 @@ def test_resolve_coefficient_raises_value_error_on_comparison_node( ], ) def test_resolve_coefficient_raises_value_error_on_expressions_that_are_not_aggregated_on_a_single_time_and_scenario( - expr: ExpressionNodeEfficient, provider: CustomValueProvider + expr: ExpressionNode, provider: CustomValueProvider ) -> None: with pytest.raises( ValueError, match="Evaluation of expression cannot be reduced to a float value" @@ -203,7 +203,7 @@ def test_resolve_coefficient_raises_value_error_on_expressions_that_are_not_aggr ], ) def test_resolve_coefficient_on_expression_with_shift_but_without_sum_raises_value_error( - expr: ExpressionNodeEfficient, + expr: ExpressionNode, provider: CustomValueProvider, ) -> None: with pytest.raises( @@ -235,7 +235,7 @@ def test_resolve_coefficient_on_expression_with_shift_but_without_sum_raises_val ], ) def test_resolve_coefficient_with_no_time_varying_parameter_in_time_operator_argument_raises_value_error( - expr: ExpressionNodeEfficient, + expr: ExpressionNode, ) -> None: class TimeVaryingParameterValueProvider(CustomValueProvider): def parameter_is_constant_over_time(self, name: str) -> bool: @@ -263,7 +263,7 @@ def parameter_is_constant_over_time(self, name: str) -> bool: ], ) def test_resolve_coefficient_on_elementary_operations( - expr: ExpressionNodeEfficient, + expr: ExpressionNode, row_id: RowIndex, expected: float, provider: CustomValueProvider, @@ -290,7 +290,7 @@ def test_resolve_coefficient_on_elementary_operations( ], ) def test_resolve_coefficient_on_time_shift_and_sum( - expr: ExpressionNodeEfficient, + expr: ExpressionNode, row_id: RowIndex, expected: float, provider: CustomValueProvider, @@ -307,7 +307,7 @@ def test_resolve_coefficient_on_time_shift_and_sum( ], ) def test_resolve_coefficient_on_expectation( - expr: ExpressionNodeEfficient, + expr: ExpressionNode, row_id: RowIndex, expected: float, provider: CustomValueProvider, @@ -326,7 +326,7 @@ def test_resolve_coefficient_on_expectation( ], ) def test_resolve_coefficient_on_sum_and_expectation( - expr: ExpressionNodeEfficient, + expr: ExpressionNode, row_id: RowIndex, expected: float, provider: CustomValueProvider, diff --git a/tests/unittests/expressions/test_term_efficient.py b/tests/unittests/expressions/test_term_efficient.py index 45bb66e4..539ba994 100644 --- a/tests/unittests/expressions/test_term_efficient.py +++ b/tests/unittests/expressions/test_term_efficient.py @@ -13,7 +13,7 @@ import pytest from andromede.expression.expression import LiteralNode -from andromede.expression.linear_expression import TermEfficient +from andromede.expression.linear_expression import Term from andromede.expression.scenario_operator import Expectation, Variance from andromede.expression.time_operator import TimeShift, TimeSum @@ -21,14 +21,14 @@ @pytest.mark.parametrize( "term, expected", [ - (TermEfficient(1, "c", "x"), "+x"), - (TermEfficient(-1, "c", "x"), "-x"), - (TermEfficient(2.50, "c", "x"), "2.5x"), - (TermEfficient(-3, "c", "x"), "-3.0x"), - (TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), "-3.0x.shift(-1)"), - (TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), "-3.0x.sum(True)"), + (Term(1, "c", "x"), "+x"), + (Term(-1, "c", "x"), "-x"), + (Term(2.50, "c", "x"), "2.5x"), + (Term(-3, "c", "x"), "-3.0x"), + (Term(-3, "c", "x", time_operator=TimeShift(-1)), "-3.0x.shift(-1)"), + (Term(-3, "c", "x", time_aggregator=TimeSum(True)), "-3.0x.sum(True)"), ( - TermEfficient( + Term( -3, "c", "x", @@ -38,11 +38,11 @@ "-3.0x.shift([2, 3]).sum(False)", ), ( - TermEfficient(-3, "c", "x", scenario_aggregator=Expectation()), + Term(-3, "c", "x", scenario_aggregator=Expectation()), "-3.0x.expec()", ), ( - TermEfficient( + Term( -3, "c", "x", @@ -53,52 +53,52 @@ ), ], ) -def test_printing_term(term: TermEfficient, expected: str) -> None: +def test_printing_term(term: Term, expected: str) -> None: assert str(term) == expected @pytest.mark.parametrize( "lhs, rhs, expected", [ - (TermEfficient(1, "c", "x"), TermEfficient(1, "c", "x"), True), - (TermEfficient(1, "c", "x"), TermEfficient(2, "c", "x"), False), + (Term(1, "c", "x"), Term(1, "c", "x"), True), + (Term(1, "c", "x"), Term(2, "c", "x"), False), ( - TermEfficient(LiteralNode(1), "c", "x"), - TermEfficient(LiteralNode(2), "c", "x"), + Term(LiteralNode(1), "c", "x"), + Term(LiteralNode(2), "c", "x"), False, ), - (TermEfficient(-1, "c", "x"), TermEfficient(-1, "", "x"), False), - (TermEfficient(2.50, "c", "x"), TermEfficient(2.50, "c", ""), False), - (TermEfficient(-3, "c", "x"), TermEfficient(-3, "c", "y"), False), + (Term(-1, "c", "x"), Term(-1, "", "x"), False), + (Term(2.50, "c", "x"), Term(2.50, "c", ""), False), + (Term(-3, "c", "x"), Term(-3, "c", "y"), False), ( - TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), - TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), + Term(-3, "c", "x", time_operator=TimeShift(-1)), + Term(-3, "c", "x", time_operator=TimeShift(-1)), True, ), ( - TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), - TermEfficient(-3, "c", "x"), + Term(-3, "c", "x", time_operator=TimeShift(-1)), + Term(-3, "c", "x"), False, ), ( - TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), - TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), + Term(-3, "c", "x", time_aggregator=TimeSum(True)), + Term(-3, "c", "x", time_aggregator=TimeSum(True)), True, ), ( - TermEfficient(-3, "c", "x", time_aggregator=TimeSum(True)), - TermEfficient(-3, "c", "x", time_operator=TimeShift(-1)), + Term(-3, "c", "x", time_aggregator=TimeSum(True)), + Term(-3, "c", "x", time_operator=TimeShift(-1)), False, ), ( - TermEfficient( + Term( -3, "c", "x", time_operator=TimeShift([2, 3]), time_aggregator=TimeSum(False), ), - TermEfficient( + Term( -3, "c", "x", @@ -108,24 +108,24 @@ def test_printing_term(term: TermEfficient, expected: str) -> None: False, ), ( - TermEfficient(-3, "c", "x", scenario_aggregator=Expectation()), - TermEfficient(-3, "c", "x", scenario_aggregator=Expectation()), + Term(-3, "c", "x", scenario_aggregator=Expectation()), + Term(-3, "c", "x", scenario_aggregator=Expectation()), True, ), ( - TermEfficient(-3, "c", "x", scenario_aggregator=Expectation()), - TermEfficient(-3, "c", "x", scenario_aggregator=Variance()), + Term(-3, "c", "x", scenario_aggregator=Expectation()), + Term(-3, "c", "x", scenario_aggregator=Variance()), False, ), ( - TermEfficient( + Term( -3, "c", "x", time_aggregator=TimeSum(True), scenario_aggregator=Expectation(), ), - TermEfficient( + Term( -3, "c", "x", @@ -136,5 +136,5 @@ def test_printing_term(term: TermEfficient, expected: str) -> None: ), ], ) -def test_term_equality(lhs: TermEfficient, rhs: TermEfficient, expected: bool) -> None: +def test_term_equality(lhs: Term, rhs: Term, expected: bool) -> None: assert (lhs == rhs) == expected diff --git a/tests/unittests/test_model.py b/tests/unittests/test_model.py index 6d61caae..bf5cae24 100644 --- a/tests/unittests/test_model.py +++ b/tests/unittests/test_model.py @@ -17,7 +17,7 @@ from andromede.expression.expression import ExpressionRange, comp_param, param from andromede.expression.linear_expression import ( - LinearExpressionEfficient, + LinearExpression, comp_var, linear_expressions_equal, literal, @@ -126,13 +126,13 @@ ) def test_constraint_instantiation( name: str, - expression: LinearExpressionEfficient, - lb: Optional[LinearExpressionEfficient], - ub: Optional[LinearExpressionEfficient], + expression: LinearExpression, + lb: Optional[LinearExpression], + ub: Optional[LinearExpression], exp_name: str, - exp_expr: LinearExpressionEfficient, - exp_lb: LinearExpressionEfficient, - exp_ub: LinearExpressionEfficient, + exp_expr: LinearExpression, + exp_lb: LinearExpression, + exp_ub: LinearExpression, ) -> None: if lb is None and ub is None: constraint = Constraint(name, expression) @@ -254,7 +254,7 @@ def test_instantiating_a_model_with_non_linear_scenario_operator_in_the_objectiv ], ) def test_invalid_port_field_definition_should_raise( - expression: LinearExpressionEfficient, error_type: Type, error_msg: str + expression: LinearExpression, error_type: Type, error_msg: str ) -> None: with pytest.raises(error_type, match=re.escape(error_msg)): port_field_def(port_name="p", field_name="f", definition=expression) From f50f57f1163e37ccdff7d8dae9467aee149a5c53 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 27 Aug 2024 16:20:00 +0200 Subject: [PATCH 47/51] Rename test files --- .../{test_performance_efficient.py => test_performance.py} | 0 .../{test_expressions_efficient.py => test_expressions.py} | 0 ...linear_expressions_efficient.py => test_linear_expressions.py} | 0 .../expressions/{test_term_efficient.py => test_term.py} | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename tests/functional/{test_performance_efficient.py => test_performance.py} (100%) rename tests/unittests/expressions/{test_expressions_efficient.py => test_expressions.py} (100%) rename tests/unittests/expressions/{test_linear_expressions_efficient.py => test_linear_expressions.py} (100%) rename tests/unittests/expressions/{test_term_efficient.py => test_term.py} (100%) diff --git a/tests/functional/test_performance_efficient.py b/tests/functional/test_performance.py similarity index 100% rename from tests/functional/test_performance_efficient.py rename to tests/functional/test_performance.py diff --git a/tests/unittests/expressions/test_expressions_efficient.py b/tests/unittests/expressions/test_expressions.py similarity index 100% rename from tests/unittests/expressions/test_expressions_efficient.py rename to tests/unittests/expressions/test_expressions.py diff --git a/tests/unittests/expressions/test_linear_expressions_efficient.py b/tests/unittests/expressions/test_linear_expressions.py similarity index 100% rename from tests/unittests/expressions/test_linear_expressions_efficient.py rename to tests/unittests/expressions/test_linear_expressions.py diff --git a/tests/unittests/expressions/test_term_efficient.py b/tests/unittests/expressions/test_term.py similarity index 100% rename from tests/unittests/expressions/test_term_efficient.py rename to tests/unittests/expressions/test_term.py From abd1eeda031b3ca7e731a8171ba772a02f4454eb Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 27 Aug 2024 16:35:07 +0200 Subject: [PATCH 48/51] Remove useless comments --- src/andromede/expression/expression.py | 5 ++--- src/andromede/expression/linear_expression.py | 3 +-- src/andromede/model/model.py | 2 -- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/andromede/expression/expression.py b/src/andromede/expression/expression.py index bbcb2aef..d2fee655 100644 --- a/src/andromede/expression/expression.py +++ b/src/andromede/expression/expression.py @@ -183,7 +183,7 @@ def _add_node(lhs: ExpressionNode, rhs: ExpressionNode) -> ExpressionNode: return rhs if is_zero(rhs): return lhs - # TODO: How can we use the equality visitor here (simple import -> circular import), copy code here ? + # TODO: How can we use the equality visitor here (simple import -> circular import) -> equality visitor code is placed at the bottom of this file... if expressions_equal(lhs, -rhs): return LiteralNode(0) if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): @@ -237,7 +237,6 @@ def _substract_node(lhs: ExpressionNode, rhs: ExpressionNode) -> ExpressionNode: return -rhs if is_zero(rhs): return lhs - # TODO: How can we use the equality visitor here (simple import -> circular import), copy code here ? if expressions_equal(lhs, rhs): return LiteralNode(0) if isinstance(lhs, LiteralNode) and isinstance(rhs, LiteralNode): @@ -657,7 +656,7 @@ def negation(self, left: NegationNode, right: NegationNode) -> bool: return self.visit(left.operand, right.operand) def addition(self, left: AdditionNode, right: AdditionNode) -> bool: - # TODO: Commutativty ??? Cannot detect that a+b == b+a + # TODO: Commutativty ??? Cannot detect that a+b == b+a ? Do we want to do this ? return self._visit_operands(left, right) def substraction(self, left: SubstractionNode, right: SubstractionNode) -> bool: diff --git a/src/andromede/expression/linear_expression.py b/src/andromede/expression/linear_expression.py index c86774c5..2c3e0556 100644 --- a/src/andromede/expression/linear_expression.py +++ b/src/andromede/expression/linear_expression.py @@ -456,7 +456,6 @@ class LinearExpression: constant: ExpressionNode port_field_terms: Dict[PortFieldId, PortFieldTerm] - # TODO: We need to check that terms.key is indeed a TermKey and change the tests that this will break def __init__( self, terms: Optional[Union[Dict[TermKey, Term], List[Term]]] = None, @@ -698,6 +697,7 @@ def remove_zeros_from_terms(self) -> None: if is_zero(port_term.coefficient): del self.port_field_terms[port_term_key] + # Function used only in tests... def evaluate(self, context: ValueProvider, time_scenario_index: RowIndex) -> float: return sum( [ @@ -1057,7 +1057,6 @@ def _copy_expression(src: LinearExpression, dst: LinearExpression) -> None: dst.constant = src.constant -# TODO : Define shortcuts for "x", is_one etc .... def var(name: str) -> LinearExpression: # TODO: At term build time, no information on the variable structure is known, we use a default time, scenario varying, maybe discard structure as term attribute ? return LinearExpression( diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index 1b0c19d0..93011db6 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -53,7 +53,6 @@ from andromede.model.variable import Variable -# TODO: Introduce bool_variable ? def _make_structure_provider(model: "Model") -> IndexingStructureProvider: class Provider(IndexingStructureProvider): def get_parameter_structure(self, name: str) -> IndexingStructure: @@ -92,7 +91,6 @@ def _is_objective_contribution_valid( if objective_structure != IndexingStructure(time=False, scenario=False): raise ValueError("Objective contribution should be a real-valued expression.") - # TODO: We should also check that the number of instances is equal to 1, but this would require a linearization here, do not want to do that for now... return True From f6a03c6a606fc621eb8a1ce231c486792ab403b1 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 27 Aug 2024 16:45:23 +0200 Subject: [PATCH 49/51] Comment and reformatting --- src/andromede/expression/linear_expression.py | 6 ++- src/andromede/expression/visitor.py | 42 +++++++++++------ .../simulation/linear_expression_resolver.py | 22 ++++----- .../simulation/optimization_context.py | 16 +++---- src/andromede/simulation/strategy.py | 6 ++- .../unittests/expressions/test_expressions.py | 45 ++----------------- .../expressions/test_linear_expressions.py | 2 +- 7 files changed, 59 insertions(+), 80 deletions(-) diff --git a/src/andromede/expression/linear_expression.py b/src/andromede/expression/linear_expression.py index 2c3e0556..5050888f 100644 --- a/src/andromede/expression/linear_expression.py +++ b/src/andromede/expression/linear_expression.py @@ -374,7 +374,8 @@ def _merge_dicts( rhs: Dict[TermKey, Term], merge_func: Callable[[Term, Term], Term], neutral: float, -) -> Dict[TermKey, Term]: ... +) -> Dict[TermKey, Term]: + ... @overload @@ -383,7 +384,8 @@ def _merge_dicts( rhs: Dict[PortFieldId, PortFieldTerm], merge_func: Callable[[PortFieldTerm, PortFieldTerm], PortFieldTerm], neutral: float, -) -> Dict[PortFieldId, PortFieldTerm]: ... +) -> Dict[PortFieldId, PortFieldTerm]: + ... def _merge_dicts(lhs, rhs, merge_func, neutral): diff --git a/src/andromede/expression/visitor.py b/src/andromede/expression/visitor.py index 55e93edb..37d9f507 100644 --- a/src/andromede/expression/visitor.py +++ b/src/andromede/expression/visitor.py @@ -47,46 +47,60 @@ class ExpressionVisitor(ABC, Generic[T]): """ @abstractmethod - def literal(self, node: LiteralNode) -> T: ... + def literal(self, node: LiteralNode) -> T: + ... @abstractmethod - def negation(self, node: NegationNode) -> T: ... + def negation(self, node: NegationNode) -> T: + ... @abstractmethod - def addition(self, node: AdditionNode) -> T: ... + def addition(self, node: AdditionNode) -> T: + ... @abstractmethod - def substraction(self, node: SubstractionNode) -> T: ... + def substraction(self, node: SubstractionNode) -> T: + ... @abstractmethod - def multiplication(self, node: MultiplicationNode) -> T: ... + def multiplication(self, node: MultiplicationNode) -> T: + ... @abstractmethod - def division(self, node: DivisionNode) -> T: ... + def division(self, node: DivisionNode) -> T: + ... @abstractmethod - def comparison(self, node: ComparisonNode) -> T: ... + def comparison(self, node: ComparisonNode) -> T: + ... @abstractmethod - def parameter(self, node: ParameterNode) -> T: ... + def parameter(self, node: ParameterNode) -> T: + ... @abstractmethod - def comp_parameter(self, node: ComponentParameterNode) -> T: ... + def comp_parameter(self, node: ComponentParameterNode) -> T: + ... @abstractmethod - def time_operator(self, node: TimeOperatorNode) -> T: ... + def time_operator(self, node: TimeOperatorNode) -> T: + ... @abstractmethod - def time_aggregator(self, node: TimeAggregatorNode) -> T: ... + def time_aggregator(self, node: TimeAggregatorNode) -> T: + ... @abstractmethod - def scenario_operator(self, node: ScenarioOperatorNode) -> T: ... + def scenario_operator(self, node: ScenarioOperatorNode) -> T: + ... @abstractmethod - def port_field(self, node: PortFieldNode) -> T: ... + def port_field(self, node: PortFieldNode) -> T: + ... @abstractmethod - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T: ... + def port_field_aggregator(self, node: PortFieldAggregatorNode) -> T: + ... def visit(root: ExpressionNode, visitor: ExpressionVisitor[T]) -> T: diff --git a/src/andromede/simulation/linear_expression_resolver.py b/src/andromede/simulation/linear_expression_resolver.py index 95c8c895..7bfbf2ca 100644 --- a/src/andromede/simulation/linear_expression_resolver.py +++ b/src/andromede/simulation/linear_expression_resolver.py @@ -81,17 +81,17 @@ def resolve_variables( operator_ts_ids = self._row_id_to_term_time_scenario_id(term, row_id) for time in operator_ts_ids.time_indices: for scenario in operator_ts_ids.scenario_indices: - solver_vars[TimeScenarioIndex(time, scenario)] = ( - self.context.get_component_variable( - time, - scenario, - term.component_id, - term.variable_name, - # At term build time, no information on the variable structure is known, we use it now - self.context.network.get_component(term.component_id) - .model.variables[term.variable_name] - .structure, - ) + solver_vars[ + TimeScenarioIndex(time, scenario) + ] = self.context.get_component_variable( + time, + scenario, + term.component_id, + term.variable_name, + # At term build time, no information on the variable structure is known, we use it now + self.context.network.get_component(term.component_id) + .model.variables[term.variable_name] + .structure, ) return solver_vars diff --git a/src/andromede/simulation/optimization_context.py b/src/andromede/simulation/optimization_context.py index 71735c41..789cd456 100644 --- a/src/andromede/simulation/optimization_context.py +++ b/src/andromede/simulation/optimization_context.py @@ -242,14 +242,14 @@ def get_component_parameter_value( ) for block_timestep in time_scenarios_indices.time_indices: for scenario in time_scenarios_indices.scenario_indices: - result[TimeScenarioIndex(block_timestep, scenario)] = ( - _get_parameter_value( - context, - _get_data_time_key(block_timestep, param_index), - _get_data_scenario_key(scenario, param_index), - component_id, - name, - ) + result[ + TimeScenarioIndex(block_timestep, scenario) + ] = _get_parameter_value( + context, + _get_data_time_key(block_timestep, param_index), + _get_data_scenario_key(scenario, param_index), + component_id, + name, ) return result diff --git a/src/andromede/simulation/strategy.py b/src/andromede/simulation/strategy.py index 89e753fa..288cced2 100644 --- a/src/andromede/simulation/strategy.py +++ b/src/andromede/simulation/strategy.py @@ -37,12 +37,14 @@ def get_constraints(self, model: Model) -> Generator[Constraint, None, None]: yield constraint @abstractmethod - def _keep_from_context(self, context: ProblemContext) -> bool: ... + def _keep_from_context(self, context: ProblemContext) -> bool: + ... @abstractmethod def get_objectives( self, model: Model - ) -> Generator[Optional[LinearExpression], None, None]: ... + ) -> Generator[Optional[LinearExpression], None, None]: + ... class MergedProblemStrategy(ModelSelectionStrategy): diff --git a/tests/unittests/expressions/test_expressions.py b/tests/unittests/expressions/test_expressions.py index 803917ec..699dd1a0 100644 --- a/tests/unittests/expressions/test_expressions.py +++ b/tests/unittests/expressions/test_expressions.py @@ -10,6 +10,8 @@ # # This file is part of the Antares project. +# This test file should contain tests on ExpressionNode objects rather than on LinearExpression objects which are already tested in test_linear_expressions.py ... + import re from dataclasses import dataclass, field from typing import Dict @@ -102,7 +104,7 @@ def scenarios() -> int: raise NotImplementedError() -# TODO: Redundant with add tests in test_linear_expressions_efficient ? +# TODO: Redundant with tests in test_linear_expressions_efficient ? def test_comp_parameter() -> None: expr1 = LinearExpression([], 1) + LinearExpression([Term(1, "comp1", "x")]) expr2 = expr1 / LinearExpression(constant=ComponentParameterNode("comp1", "p")) @@ -139,18 +141,6 @@ def test_operators() -> None: assert -expr.evaluate(context, RowIndex(0, 0)) == pytest.approx(-2.5, 1e-16) -# def test_degree() -> None: -# x = var("x") -# p = param("p") -# expr = (5 * x + 3) / p - -# assert expr.compute_degree() == 1 - -# # TODO: Should this be allowed ? If so, how should we represent is ? -# expr = x * expr -# assert expr.compute_degree() == 2 - - def test_degree_computation_should_take_into_account_simplifications() -> None: x = var("x") expr = x - x @@ -161,21 +151,6 @@ def test_degree_computation_should_take_into_account_simplifications() -> None: assert expr.is_zero() -# def test_parameters_resolution() -> None: -# class TestParamProvider(ParameterValueProvider): -# def get_component_parameter_value(self, component_id: str, name: str) -> float: -# raise NotImplementedError() - -# def get_parameter_value(self, name: str) -> float: -# return 2 - -# x = var("x") -# p = param("p") -# expr = (5 * x + 3) / p -# # TODO: We do not want this in the API, but rather expr.get(t, w) -# assert expr.resolve_parameters(TestParamProvider()) == (5 * x + 3) / 2 - - # TODO: Write tests on ExpressionNodes for tree simplification, do the same for multiplication, substraction, etc @pytest.mark.parametrize( "e1, e2, expected", @@ -356,20 +331,6 @@ def test_linear_expression_equality( assert linear_expressions_equal(lhs, rhs) -# TODO: What is the equivalent of this test ? -# def test_linearization_of_non_linear_expressions_should_raise_value_error() -> None: -# x = var("x") -# expr = x.variance() - -# provider = StructureProvider() -# with pytest.raises(ValueError) as exc: -# linearize_expression(expr, provider) -# assert ( -# str(exc.value) -# == "Cannot linearize expression with a non-linear operator: Variance" -# ) - - def test_standalone_constraint() -> None: cst = StandaloneConstraint( var("x"), wrap_in_linear_expr(literal(0)), wrap_in_linear_expr(literal(10)) diff --git a/tests/unittests/expressions/test_linear_expressions.py b/tests/unittests/expressions/test_linear_expressions.py index d8c4d79f..b7e8771c 100644 --- a/tests/unittests/expressions/test_linear_expressions.py +++ b/tests/unittests/expressions/test_linear_expressions.py @@ -14,7 +14,7 @@ import pytest -from andromede.expression.expression import TimeAggregatorNode, expression_range, param +from andromede.expression.expression import expression_range, param from andromede.expression.linear_expression import ( LinearExpression, PortFieldId, From 9431fe759720fcb940462f7ba94ba7817cb317c2 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Tue, 27 Aug 2024 18:39:21 +0200 Subject: [PATCH 50/51] WIP for parsing --- grammar/Expr.g4 | 11 +- .../expression/parsing/antlr/Expr.interp | 8 +- .../expression/parsing/antlr/Expr.tokens | 8 +- .../expression/parsing/antlr/ExprLexer.interp | 7 +- .../expression/parsing/antlr/ExprLexer.py | 1044 +------- .../expression/parsing/antlr/ExprLexer.tokens | 8 +- .../expression/parsing/antlr/ExprParser.py | 2321 +++++------------ .../expression/parsing/antlr/ExprVisitor.py | 89 +- .../expression/parsing/parse_expression.py | 39 +- 9 files changed, 769 insertions(+), 2766 deletions(-) diff --git a/grammar/Expr.g4 b/grammar/Expr.g4 index 072bf52e..56f0babe 100644 --- a/grammar/Expr.g4 +++ b/grammar/Expr.g4 @@ -26,10 +26,9 @@ expr | expr op=('+' | '-') expr # addsub | expr COMPARISON expr # comparison | IDENTIFIER '(' expr ')' # function - | IDENTIFIER '[' shift (',' shift)* ']' # timeShift - | IDENTIFIER '[' expr (',' expr )* ']' # timeIndex - | IDENTIFIER '[' shift1=shift '..' shift2=shift ']' # timeShiftRange - | IDENTIFIER '[' expr '..' expr ']' # timeRange + | IDENTIFIER '[' shift ']' # timeShift + | IDENTIFIER '[' expr ']' # timeIndex + | TIME_SUM '(' (expr | shift | timeShiftRange | timeRange) ',' IDENTIFIER ')' #timeSum ; atom @@ -37,6 +36,9 @@ atom | IDENTIFIER # identifier ; +timeShiftRange: shift1=shift '..' shift2=shift; +timeRange: expr1=expr '..' expr2=expr; + // a shift is required to be either "t" or "t + ..." or "t - ..." // Note: simply defining it as "shift: TIME ('+' | '-') expr" won't work // because the minus sign will not have the expected precedence: @@ -74,5 +76,6 @@ NUMBER : DIGIT+ ('.' DIGIT+)?; TIME : 't'; IDENTIFIER : CHAR CHAR_OR_DIGIT*; COMPARISON : ( '=' | '>=' | '<=' ); +TIME_SUM : 'sum'; WS: (' ' | '\t' | '\r'| '\n') -> skip; diff --git a/src/andromede/expression/parsing/antlr/Expr.interp b/src/andromede/expression/parsing/antlr/Expr.interp index bf05ae28..18ea1ac5 100644 --- a/src/andromede/expression/parsing/antlr/Expr.interp +++ b/src/andromede/expression/parsing/antlr/Expr.interp @@ -8,13 +8,14 @@ null '*' '+' '[' -',' ']' +',' '..' null 't' null null +'sum' null token symbolic names: @@ -34,16 +35,19 @@ NUMBER TIME IDENTIFIER COMPARISON +TIME_SUM WS rule names: fullexpr expr atom +timeShiftRange +timeRange shift shift_expr right_expr atn: -[4, 1, 16, 131, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 37, 8, 1, 10, 1, 12, 1, 40, 9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 49, 8, 1, 10, 1, 12, 1, 52, 9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 70, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 81, 8, 1, 10, 1, 12, 1, 84, 9, 1, 1, 2, 1, 2, 3, 2, 88, 8, 2, 1, 3, 1, 3, 3, 3, 92, 8, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 3, 4, 102, 8, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 5, 4, 110, 8, 4, 10, 4, 12, 4, 113, 9, 4, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 1, 5, 3, 5, 121, 8, 5, 1, 5, 1, 5, 1, 5, 5, 5, 126, 8, 5, 10, 5, 12, 5, 129, 9, 5, 1, 5, 0, 3, 2, 8, 10, 6, 0, 2, 4, 6, 8, 10, 0, 2, 1, 0, 5, 6, 2, 0, 2, 2, 7, 7, 144, 0, 12, 1, 0, 0, 0, 2, 69, 1, 0, 0, 0, 4, 87, 1, 0, 0, 0, 6, 89, 1, 0, 0, 0, 8, 101, 1, 0, 0, 0, 10, 120, 1, 0, 0, 0, 12, 13, 3, 2, 1, 0, 13, 14, 5, 0, 0, 1, 14, 1, 1, 0, 0, 0, 15, 16, 6, 1, -1, 0, 16, 70, 3, 4, 2, 0, 17, 18, 5, 14, 0, 0, 18, 19, 5, 1, 0, 0, 19, 70, 5, 14, 0, 0, 20, 21, 5, 2, 0, 0, 21, 70, 3, 2, 1, 10, 22, 23, 5, 3, 0, 0, 23, 24, 3, 2, 1, 0, 24, 25, 5, 4, 0, 0, 25, 70, 1, 0, 0, 0, 26, 27, 5, 14, 0, 0, 27, 28, 5, 3, 0, 0, 28, 29, 3, 2, 1, 0, 29, 30, 5, 4, 0, 0, 30, 70, 1, 0, 0, 0, 31, 32, 5, 14, 0, 0, 32, 33, 5, 8, 0, 0, 33, 38, 3, 6, 3, 0, 34, 35, 5, 9, 0, 0, 35, 37, 3, 6, 3, 0, 36, 34, 1, 0, 0, 0, 37, 40, 1, 0, 0, 0, 38, 36, 1, 0, 0, 0, 38, 39, 1, 0, 0, 0, 39, 41, 1, 0, 0, 0, 40, 38, 1, 0, 0, 0, 41, 42, 5, 10, 0, 0, 42, 70, 1, 0, 0, 0, 43, 44, 5, 14, 0, 0, 44, 45, 5, 8, 0, 0, 45, 50, 3, 2, 1, 0, 46, 47, 5, 9, 0, 0, 47, 49, 3, 2, 1, 0, 48, 46, 1, 0, 0, 0, 49, 52, 1, 0, 0, 0, 50, 48, 1, 0, 0, 0, 50, 51, 1, 0, 0, 0, 51, 53, 1, 0, 0, 0, 52, 50, 1, 0, 0, 0, 53, 54, 5, 10, 0, 0, 54, 70, 1, 0, 0, 0, 55, 56, 5, 14, 0, 0, 56, 57, 5, 8, 0, 0, 57, 58, 3, 6, 3, 0, 58, 59, 5, 11, 0, 0, 59, 60, 3, 6, 3, 0, 60, 61, 5, 10, 0, 0, 61, 70, 1, 0, 0, 0, 62, 63, 5, 14, 0, 0, 63, 64, 5, 8, 0, 0, 64, 65, 3, 2, 1, 0, 65, 66, 5, 11, 0, 0, 66, 67, 3, 2, 1, 0, 67, 68, 5, 10, 0, 0, 68, 70, 1, 0, 0, 0, 69, 15, 1, 0, 0, 0, 69, 17, 1, 0, 0, 0, 69, 20, 1, 0, 0, 0, 69, 22, 1, 0, 0, 0, 69, 26, 1, 0, 0, 0, 69, 31, 1, 0, 0, 0, 69, 43, 1, 0, 0, 0, 69, 55, 1, 0, 0, 0, 69, 62, 1, 0, 0, 0, 70, 82, 1, 0, 0, 0, 71, 72, 10, 8, 0, 0, 72, 73, 7, 0, 0, 0, 73, 81, 3, 2, 1, 9, 74, 75, 10, 7, 0, 0, 75, 76, 7, 1, 0, 0, 76, 81, 3, 2, 1, 8, 77, 78, 10, 6, 0, 0, 78, 79, 5, 15, 0, 0, 79, 81, 3, 2, 1, 7, 80, 71, 1, 0, 0, 0, 80, 74, 1, 0, 0, 0, 80, 77, 1, 0, 0, 0, 81, 84, 1, 0, 0, 0, 82, 80, 1, 0, 0, 0, 82, 83, 1, 0, 0, 0, 83, 3, 1, 0, 0, 0, 84, 82, 1, 0, 0, 0, 85, 88, 5, 12, 0, 0, 86, 88, 5, 14, 0, 0, 87, 85, 1, 0, 0, 0, 87, 86, 1, 0, 0, 0, 88, 5, 1, 0, 0, 0, 89, 91, 5, 13, 0, 0, 90, 92, 3, 8, 4, 0, 91, 90, 1, 0, 0, 0, 91, 92, 1, 0, 0, 0, 92, 7, 1, 0, 0, 0, 93, 94, 6, 4, -1, 0, 94, 95, 7, 1, 0, 0, 95, 102, 3, 4, 2, 0, 96, 97, 7, 1, 0, 0, 97, 98, 5, 3, 0, 0, 98, 99, 3, 2, 1, 0, 99, 100, 5, 4, 0, 0, 100, 102, 1, 0, 0, 0, 101, 93, 1, 0, 0, 0, 101, 96, 1, 0, 0, 0, 102, 111, 1, 0, 0, 0, 103, 104, 10, 4, 0, 0, 104, 105, 7, 0, 0, 0, 105, 110, 3, 10, 5, 0, 106, 107, 10, 3, 0, 0, 107, 108, 7, 1, 0, 0, 108, 110, 3, 10, 5, 0, 109, 103, 1, 0, 0, 0, 109, 106, 1, 0, 0, 0, 110, 113, 1, 0, 0, 0, 111, 109, 1, 0, 0, 0, 111, 112, 1, 0, 0, 0, 112, 9, 1, 0, 0, 0, 113, 111, 1, 0, 0, 0, 114, 115, 6, 5, -1, 0, 115, 116, 5, 3, 0, 0, 116, 117, 3, 2, 1, 0, 117, 118, 5, 4, 0, 0, 118, 121, 1, 0, 0, 0, 119, 121, 3, 4, 2, 0, 120, 114, 1, 0, 0, 0, 120, 119, 1, 0, 0, 0, 121, 127, 1, 0, 0, 0, 122, 123, 10, 3, 0, 0, 123, 124, 7, 0, 0, 0, 124, 126, 3, 10, 5, 4, 125, 122, 1, 0, 0, 0, 126, 129, 1, 0, 0, 0, 127, 125, 1, 0, 0, 0, 127, 128, 1, 0, 0, 0, 128, 11, 1, 0, 0, 0, 129, 127, 1, 0, 0, 0, 12, 38, 50, 69, 80, 82, 87, 91, 101, 109, 111, 120, 127] \ No newline at end of file +[4, 1, 17, 127, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 52, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 58, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 69, 8, 1, 10, 1, 12, 1, 72, 9, 1, 1, 2, 1, 2, 3, 2, 76, 8, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 5, 1, 5, 3, 5, 88, 8, 5, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 3, 6, 98, 8, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 5, 6, 106, 8, 6, 10, 6, 12, 6, 109, 9, 6, 1, 7, 1, 7, 1, 7, 1, 7, 1, 7, 1, 7, 3, 7, 117, 8, 7, 1, 7, 1, 7, 1, 7, 5, 7, 122, 8, 7, 10, 7, 12, 7, 125, 9, 7, 1, 7, 0, 3, 2, 12, 14, 8, 0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 1, 0, 5, 6, 2, 0, 2, 2, 7, 7, 138, 0, 16, 1, 0, 0, 0, 2, 57, 1, 0, 0, 0, 4, 75, 1, 0, 0, 0, 6, 77, 1, 0, 0, 0, 8, 81, 1, 0, 0, 0, 10, 85, 1, 0, 0, 0, 12, 97, 1, 0, 0, 0, 14, 116, 1, 0, 0, 0, 16, 17, 3, 2, 1, 0, 17, 18, 5, 0, 0, 1, 18, 1, 1, 0, 0, 0, 19, 20, 6, 1, -1, 0, 20, 58, 3, 4, 2, 0, 21, 22, 5, 14, 0, 0, 22, 23, 5, 1, 0, 0, 23, 58, 5, 14, 0, 0, 24, 25, 5, 2, 0, 0, 25, 58, 3, 2, 1, 9, 26, 27, 5, 3, 0, 0, 27, 28, 3, 2, 1, 0, 28, 29, 5, 4, 0, 0, 29, 58, 1, 0, 0, 0, 30, 31, 5, 14, 0, 0, 31, 32, 5, 3, 0, 0, 32, 33, 3, 2, 1, 0, 33, 34, 5, 4, 0, 0, 34, 58, 1, 0, 0, 0, 35, 36, 5, 14, 0, 0, 36, 37, 5, 8, 0, 0, 37, 38, 3, 10, 5, 0, 38, 39, 5, 9, 0, 0, 39, 58, 1, 0, 0, 0, 40, 41, 5, 14, 0, 0, 41, 42, 5, 8, 0, 0, 42, 43, 3, 2, 1, 0, 43, 44, 5, 9, 0, 0, 44, 58, 1, 0, 0, 0, 45, 46, 5, 16, 0, 0, 46, 51, 5, 3, 0, 0, 47, 52, 3, 2, 1, 0, 48, 52, 3, 10, 5, 0, 49, 52, 3, 6, 3, 0, 50, 52, 3, 8, 4, 0, 51, 47, 1, 0, 0, 0, 51, 48, 1, 0, 0, 0, 51, 49, 1, 0, 0, 0, 51, 50, 1, 0, 0, 0, 52, 53, 1, 0, 0, 0, 53, 54, 5, 10, 0, 0, 54, 55, 5, 14, 0, 0, 55, 56, 5, 4, 0, 0, 56, 58, 1, 0, 0, 0, 57, 19, 1, 0, 0, 0, 57, 21, 1, 0, 0, 0, 57, 24, 1, 0, 0, 0, 57, 26, 1, 0, 0, 0, 57, 30, 1, 0, 0, 0, 57, 35, 1, 0, 0, 0, 57, 40, 1, 0, 0, 0, 57, 45, 1, 0, 0, 0, 58, 70, 1, 0, 0, 0, 59, 60, 10, 7, 0, 0, 60, 61, 7, 0, 0, 0, 61, 69, 3, 2, 1, 8, 62, 63, 10, 6, 0, 0, 63, 64, 7, 1, 0, 0, 64, 69, 3, 2, 1, 7, 65, 66, 10, 5, 0, 0, 66, 67, 5, 15, 0, 0, 67, 69, 3, 2, 1, 6, 68, 59, 1, 0, 0, 0, 68, 62, 1, 0, 0, 0, 68, 65, 1, 0, 0, 0, 69, 72, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 70, 71, 1, 0, 0, 0, 71, 3, 1, 0, 0, 0, 72, 70, 1, 0, 0, 0, 73, 76, 5, 12, 0, 0, 74, 76, 5, 14, 0, 0, 75, 73, 1, 0, 0, 0, 75, 74, 1, 0, 0, 0, 76, 5, 1, 0, 0, 0, 77, 78, 3, 10, 5, 0, 78, 79, 5, 11, 0, 0, 79, 80, 3, 10, 5, 0, 80, 7, 1, 0, 0, 0, 81, 82, 3, 2, 1, 0, 82, 83, 5, 11, 0, 0, 83, 84, 3, 2, 1, 0, 84, 9, 1, 0, 0, 0, 85, 87, 5, 13, 0, 0, 86, 88, 3, 12, 6, 0, 87, 86, 1, 0, 0, 0, 87, 88, 1, 0, 0, 0, 88, 11, 1, 0, 0, 0, 89, 90, 6, 6, -1, 0, 90, 91, 7, 1, 0, 0, 91, 98, 3, 4, 2, 0, 92, 93, 7, 1, 0, 0, 93, 94, 5, 3, 0, 0, 94, 95, 3, 2, 1, 0, 95, 96, 5, 4, 0, 0, 96, 98, 1, 0, 0, 0, 97, 89, 1, 0, 0, 0, 97, 92, 1, 0, 0, 0, 98, 107, 1, 0, 0, 0, 99, 100, 10, 4, 0, 0, 100, 101, 7, 0, 0, 0, 101, 106, 3, 14, 7, 0, 102, 103, 10, 3, 0, 0, 103, 104, 7, 1, 0, 0, 104, 106, 3, 14, 7, 0, 105, 99, 1, 0, 0, 0, 105, 102, 1, 0, 0, 0, 106, 109, 1, 0, 0, 0, 107, 105, 1, 0, 0, 0, 107, 108, 1, 0, 0, 0, 108, 13, 1, 0, 0, 0, 109, 107, 1, 0, 0, 0, 110, 111, 6, 7, -1, 0, 111, 112, 5, 3, 0, 0, 112, 113, 3, 2, 1, 0, 113, 114, 5, 4, 0, 0, 114, 117, 1, 0, 0, 0, 115, 117, 3, 4, 2, 0, 116, 110, 1, 0, 0, 0, 116, 115, 1, 0, 0, 0, 117, 123, 1, 0, 0, 0, 118, 119, 10, 3, 0, 0, 119, 120, 7, 0, 0, 0, 120, 122, 3, 14, 7, 4, 121, 118, 1, 0, 0, 0, 122, 125, 1, 0, 0, 0, 123, 121, 1, 0, 0, 0, 123, 124, 1, 0, 0, 0, 124, 15, 1, 0, 0, 0, 125, 123, 1, 0, 0, 0, 11, 51, 57, 68, 70, 75, 87, 97, 105, 107, 116, 123] \ No newline at end of file diff --git a/src/andromede/expression/parsing/antlr/Expr.tokens b/src/andromede/expression/parsing/antlr/Expr.tokens index 9401c83a..1c48b7d6 100644 --- a/src/andromede/expression/parsing/antlr/Expr.tokens +++ b/src/andromede/expression/parsing/antlr/Expr.tokens @@ -13,7 +13,8 @@ NUMBER=12 TIME=13 IDENTIFIER=14 COMPARISON=15 -WS=16 +TIME_SUM=16 +WS=17 '.'=1 '-'=2 '('=3 @@ -22,7 +23,8 @@ WS=16 '*'=6 '+'=7 '['=8 -','=9 -']'=10 +']'=9 +','=10 '..'=11 't'=13 +'sum'=16 diff --git a/src/andromede/expression/parsing/antlr/ExprLexer.interp b/src/andromede/expression/parsing/antlr/ExprLexer.interp index 2e85e1b7..c29ef905 100644 --- a/src/andromede/expression/parsing/antlr/ExprLexer.interp +++ b/src/andromede/expression/parsing/antlr/ExprLexer.interp @@ -8,13 +8,14 @@ null '*' '+' '[' -',' ']' +',' '..' null 't' null null +'sum' null token symbolic names: @@ -34,6 +35,7 @@ NUMBER TIME IDENTIFIER COMPARISON +TIME_SUM WS rule names: @@ -55,6 +57,7 @@ NUMBER TIME IDENTIFIER COMPARISON +TIME_SUM WS channel names: @@ -65,4 +68,4 @@ mode names: DEFAULT_MODE atn: -[4, 0, 16, 103, 6, -1, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2, 10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 2, 15, 7, 15, 2, 16, 7, 16, 2, 17, 7, 17, 2, 18, 7, 18, 1, 0, 1, 0, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 5, 1, 5, 1, 6, 1, 6, 1, 7, 1, 7, 1, 8, 1, 8, 1, 9, 1, 9, 1, 10, 1, 10, 1, 10, 1, 11, 1, 11, 1, 12, 1, 12, 1, 13, 1, 13, 3, 13, 69, 8, 13, 1, 14, 4, 14, 72, 8, 14, 11, 14, 12, 14, 73, 1, 14, 1, 14, 4, 14, 78, 8, 14, 11, 14, 12, 14, 79, 3, 14, 82, 8, 14, 1, 15, 1, 15, 1, 16, 1, 16, 5, 16, 88, 8, 16, 10, 16, 12, 16, 91, 9, 16, 1, 17, 1, 17, 1, 17, 1, 17, 1, 17, 3, 17, 98, 8, 17, 1, 18, 1, 18, 1, 18, 1, 18, 0, 0, 19, 1, 1, 3, 2, 5, 3, 7, 4, 9, 5, 11, 6, 13, 7, 15, 8, 17, 9, 19, 10, 21, 11, 23, 0, 25, 0, 27, 0, 29, 12, 31, 13, 33, 14, 35, 15, 37, 16, 1, 0, 3, 1, 0, 48, 57, 3, 0, 65, 90, 95, 95, 97, 122, 3, 0, 9, 10, 13, 13, 32, 32, 106, 0, 1, 1, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0, 5, 1, 0, 0, 0, 0, 7, 1, 0, 0, 0, 0, 9, 1, 0, 0, 0, 0, 11, 1, 0, 0, 0, 0, 13, 1, 0, 0, 0, 0, 15, 1, 0, 0, 0, 0, 17, 1, 0, 0, 0, 0, 19, 1, 0, 0, 0, 0, 21, 1, 0, 0, 0, 0, 29, 1, 0, 0, 0, 0, 31, 1, 0, 0, 0, 0, 33, 1, 0, 0, 0, 0, 35, 1, 0, 0, 0, 0, 37, 1, 0, 0, 0, 1, 39, 1, 0, 0, 0, 3, 41, 1, 0, 0, 0, 5, 43, 1, 0, 0, 0, 7, 45, 1, 0, 0, 0, 9, 47, 1, 0, 0, 0, 11, 49, 1, 0, 0, 0, 13, 51, 1, 0, 0, 0, 15, 53, 1, 0, 0, 0, 17, 55, 1, 0, 0, 0, 19, 57, 1, 0, 0, 0, 21, 59, 1, 0, 0, 0, 23, 62, 1, 0, 0, 0, 25, 64, 1, 0, 0, 0, 27, 68, 1, 0, 0, 0, 29, 71, 1, 0, 0, 0, 31, 83, 1, 0, 0, 0, 33, 85, 1, 0, 0, 0, 35, 97, 1, 0, 0, 0, 37, 99, 1, 0, 0, 0, 39, 40, 5, 46, 0, 0, 40, 2, 1, 0, 0, 0, 41, 42, 5, 45, 0, 0, 42, 4, 1, 0, 0, 0, 43, 44, 5, 40, 0, 0, 44, 6, 1, 0, 0, 0, 45, 46, 5, 41, 0, 0, 46, 8, 1, 0, 0, 0, 47, 48, 5, 47, 0, 0, 48, 10, 1, 0, 0, 0, 49, 50, 5, 42, 0, 0, 50, 12, 1, 0, 0, 0, 51, 52, 5, 43, 0, 0, 52, 14, 1, 0, 0, 0, 53, 54, 5, 91, 0, 0, 54, 16, 1, 0, 0, 0, 55, 56, 5, 44, 0, 0, 56, 18, 1, 0, 0, 0, 57, 58, 5, 93, 0, 0, 58, 20, 1, 0, 0, 0, 59, 60, 5, 46, 0, 0, 60, 61, 5, 46, 0, 0, 61, 22, 1, 0, 0, 0, 62, 63, 7, 0, 0, 0, 63, 24, 1, 0, 0, 0, 64, 65, 7, 1, 0, 0, 65, 26, 1, 0, 0, 0, 66, 69, 3, 25, 12, 0, 67, 69, 3, 23, 11, 0, 68, 66, 1, 0, 0, 0, 68, 67, 1, 0, 0, 0, 69, 28, 1, 0, 0, 0, 70, 72, 3, 23, 11, 0, 71, 70, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 71, 1, 0, 0, 0, 73, 74, 1, 0, 0, 0, 74, 81, 1, 0, 0, 0, 75, 77, 5, 46, 0, 0, 76, 78, 3, 23, 11, 0, 77, 76, 1, 0, 0, 0, 78, 79, 1, 0, 0, 0, 79, 77, 1, 0, 0, 0, 79, 80, 1, 0, 0, 0, 80, 82, 1, 0, 0, 0, 81, 75, 1, 0, 0, 0, 81, 82, 1, 0, 0, 0, 82, 30, 1, 0, 0, 0, 83, 84, 5, 116, 0, 0, 84, 32, 1, 0, 0, 0, 85, 89, 3, 25, 12, 0, 86, 88, 3, 27, 13, 0, 87, 86, 1, 0, 0, 0, 88, 91, 1, 0, 0, 0, 89, 87, 1, 0, 0, 0, 89, 90, 1, 0, 0, 0, 90, 34, 1, 0, 0, 0, 91, 89, 1, 0, 0, 0, 92, 98, 5, 61, 0, 0, 93, 94, 5, 62, 0, 0, 94, 98, 5, 61, 0, 0, 95, 96, 5, 60, 0, 0, 96, 98, 5, 61, 0, 0, 97, 92, 1, 0, 0, 0, 97, 93, 1, 0, 0, 0, 97, 95, 1, 0, 0, 0, 98, 36, 1, 0, 0, 0, 99, 100, 7, 2, 0, 0, 100, 101, 1, 0, 0, 0, 101, 102, 6, 18, 0, 0, 102, 38, 1, 0, 0, 0, 7, 0, 68, 73, 79, 81, 89, 97, 1, 6, 0, 0] \ No newline at end of file +[4, 0, 17, 109, 6, -1, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 2, 8, 7, 8, 2, 9, 7, 9, 2, 10, 7, 10, 2, 11, 7, 11, 2, 12, 7, 12, 2, 13, 7, 13, 2, 14, 7, 14, 2, 15, 7, 15, 2, 16, 7, 16, 2, 17, 7, 17, 2, 18, 7, 18, 2, 19, 7, 19, 1, 0, 1, 0, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 5, 1, 5, 1, 6, 1, 6, 1, 7, 1, 7, 1, 8, 1, 8, 1, 9, 1, 9, 1, 10, 1, 10, 1, 10, 1, 11, 1, 11, 1, 12, 1, 12, 1, 13, 1, 13, 3, 13, 71, 8, 13, 1, 14, 4, 14, 74, 8, 14, 11, 14, 12, 14, 75, 1, 14, 1, 14, 4, 14, 80, 8, 14, 11, 14, 12, 14, 81, 3, 14, 84, 8, 14, 1, 15, 1, 15, 1, 16, 1, 16, 5, 16, 90, 8, 16, 10, 16, 12, 16, 93, 9, 16, 1, 17, 1, 17, 1, 17, 1, 17, 1, 17, 3, 17, 100, 8, 17, 1, 18, 1, 18, 1, 18, 1, 18, 1, 19, 1, 19, 1, 19, 1, 19, 0, 0, 20, 1, 1, 3, 2, 5, 3, 7, 4, 9, 5, 11, 6, 13, 7, 15, 8, 17, 9, 19, 10, 21, 11, 23, 0, 25, 0, 27, 0, 29, 12, 31, 13, 33, 14, 35, 15, 37, 16, 39, 17, 1, 0, 3, 1, 0, 48, 57, 3, 0, 65, 90, 95, 95, 97, 122, 3, 0, 9, 10, 13, 13, 32, 32, 112, 0, 1, 1, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0, 5, 1, 0, 0, 0, 0, 7, 1, 0, 0, 0, 0, 9, 1, 0, 0, 0, 0, 11, 1, 0, 0, 0, 0, 13, 1, 0, 0, 0, 0, 15, 1, 0, 0, 0, 0, 17, 1, 0, 0, 0, 0, 19, 1, 0, 0, 0, 0, 21, 1, 0, 0, 0, 0, 29, 1, 0, 0, 0, 0, 31, 1, 0, 0, 0, 0, 33, 1, 0, 0, 0, 0, 35, 1, 0, 0, 0, 0, 37, 1, 0, 0, 0, 0, 39, 1, 0, 0, 0, 1, 41, 1, 0, 0, 0, 3, 43, 1, 0, 0, 0, 5, 45, 1, 0, 0, 0, 7, 47, 1, 0, 0, 0, 9, 49, 1, 0, 0, 0, 11, 51, 1, 0, 0, 0, 13, 53, 1, 0, 0, 0, 15, 55, 1, 0, 0, 0, 17, 57, 1, 0, 0, 0, 19, 59, 1, 0, 0, 0, 21, 61, 1, 0, 0, 0, 23, 64, 1, 0, 0, 0, 25, 66, 1, 0, 0, 0, 27, 70, 1, 0, 0, 0, 29, 73, 1, 0, 0, 0, 31, 85, 1, 0, 0, 0, 33, 87, 1, 0, 0, 0, 35, 99, 1, 0, 0, 0, 37, 101, 1, 0, 0, 0, 39, 105, 1, 0, 0, 0, 41, 42, 5, 46, 0, 0, 42, 2, 1, 0, 0, 0, 43, 44, 5, 45, 0, 0, 44, 4, 1, 0, 0, 0, 45, 46, 5, 40, 0, 0, 46, 6, 1, 0, 0, 0, 47, 48, 5, 41, 0, 0, 48, 8, 1, 0, 0, 0, 49, 50, 5, 47, 0, 0, 50, 10, 1, 0, 0, 0, 51, 52, 5, 42, 0, 0, 52, 12, 1, 0, 0, 0, 53, 54, 5, 43, 0, 0, 54, 14, 1, 0, 0, 0, 55, 56, 5, 91, 0, 0, 56, 16, 1, 0, 0, 0, 57, 58, 5, 93, 0, 0, 58, 18, 1, 0, 0, 0, 59, 60, 5, 44, 0, 0, 60, 20, 1, 0, 0, 0, 61, 62, 5, 46, 0, 0, 62, 63, 5, 46, 0, 0, 63, 22, 1, 0, 0, 0, 64, 65, 7, 0, 0, 0, 65, 24, 1, 0, 0, 0, 66, 67, 7, 1, 0, 0, 67, 26, 1, 0, 0, 0, 68, 71, 3, 25, 12, 0, 69, 71, 3, 23, 11, 0, 70, 68, 1, 0, 0, 0, 70, 69, 1, 0, 0, 0, 71, 28, 1, 0, 0, 0, 72, 74, 3, 23, 11, 0, 73, 72, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 73, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 83, 1, 0, 0, 0, 77, 79, 5, 46, 0, 0, 78, 80, 3, 23, 11, 0, 79, 78, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 79, 1, 0, 0, 0, 81, 82, 1, 0, 0, 0, 82, 84, 1, 0, 0, 0, 83, 77, 1, 0, 0, 0, 83, 84, 1, 0, 0, 0, 84, 30, 1, 0, 0, 0, 85, 86, 5, 116, 0, 0, 86, 32, 1, 0, 0, 0, 87, 91, 3, 25, 12, 0, 88, 90, 3, 27, 13, 0, 89, 88, 1, 0, 0, 0, 90, 93, 1, 0, 0, 0, 91, 89, 1, 0, 0, 0, 91, 92, 1, 0, 0, 0, 92, 34, 1, 0, 0, 0, 93, 91, 1, 0, 0, 0, 94, 100, 5, 61, 0, 0, 95, 96, 5, 62, 0, 0, 96, 100, 5, 61, 0, 0, 97, 98, 5, 60, 0, 0, 98, 100, 5, 61, 0, 0, 99, 94, 1, 0, 0, 0, 99, 95, 1, 0, 0, 0, 99, 97, 1, 0, 0, 0, 100, 36, 1, 0, 0, 0, 101, 102, 5, 115, 0, 0, 102, 103, 5, 117, 0, 0, 103, 104, 5, 109, 0, 0, 104, 38, 1, 0, 0, 0, 105, 106, 7, 2, 0, 0, 106, 107, 1, 0, 0, 0, 107, 108, 6, 19, 0, 0, 108, 40, 1, 0, 0, 0, 7, 0, 70, 75, 81, 83, 91, 99, 1, 6, 0, 0] \ No newline at end of file diff --git a/src/andromede/expression/parsing/antlr/ExprLexer.py b/src/andromede/expression/parsing/antlr/ExprLexer.py index 1ad7f368..aa15abd0 100644 --- a/src/andromede/expression/parsing/antlr/ExprLexer.py +++ b/src/andromede/expression/parsing/antlr/ExprLexer.py @@ -1,9 +1,7 @@ -# Generated from Expr.g4 by ANTLR 4.13.1 -import sys -from io import StringIO - +# Generated from Expr.g4 by ANTLR 4.13.2 from antlr4 import * - +from io import StringIO +import sys if sys.version_info[1] > 5: from typing import TextIO else: @@ -12,945 +10,50 @@ def serializedATN(): return [ - 4, - 0, - 16, - 103, - 6, - -1, - 2, - 0, - 7, - 0, - 2, - 1, - 7, - 1, - 2, - 2, - 7, - 2, - 2, - 3, - 7, - 3, - 2, - 4, - 7, - 4, - 2, - 5, - 7, - 5, - 2, - 6, - 7, - 6, - 2, - 7, - 7, - 7, - 2, - 8, - 7, - 8, - 2, - 9, - 7, - 9, - 2, - 10, - 7, - 10, - 2, - 11, - 7, - 11, - 2, - 12, - 7, - 12, - 2, - 13, - 7, - 13, - 2, - 14, - 7, - 14, - 2, - 15, - 7, - 15, - 2, - 16, - 7, - 16, - 2, - 17, - 7, - 17, - 2, - 18, - 7, - 18, - 1, - 0, - 1, - 0, - 1, - 1, - 1, - 1, - 1, - 2, - 1, - 2, - 1, - 3, - 1, - 3, - 1, - 4, - 1, - 4, - 1, - 5, - 1, - 5, - 1, - 6, - 1, - 6, - 1, - 7, - 1, - 7, - 1, - 8, - 1, - 8, - 1, - 9, - 1, - 9, - 1, - 10, - 1, - 10, - 1, - 10, - 1, - 11, - 1, - 11, - 1, - 12, - 1, - 12, - 1, - 13, - 1, - 13, - 3, - 13, - 69, - 8, - 13, - 1, - 14, - 4, - 14, - 72, - 8, - 14, - 11, - 14, - 12, - 14, - 73, - 1, - 14, - 1, - 14, - 4, - 14, - 78, - 8, - 14, - 11, - 14, - 12, - 14, - 79, - 3, - 14, - 82, - 8, - 14, - 1, - 15, - 1, - 15, - 1, - 16, - 1, - 16, - 5, - 16, - 88, - 8, - 16, - 10, - 16, - 12, - 16, - 91, - 9, - 16, - 1, - 17, - 1, - 17, - 1, - 17, - 1, - 17, - 1, - 17, - 3, - 17, - 98, - 8, - 17, - 1, - 18, - 1, - 18, - 1, - 18, - 1, - 18, - 0, - 0, - 19, - 1, - 1, - 3, - 2, - 5, - 3, - 7, - 4, - 9, - 5, - 11, - 6, - 13, - 7, - 15, - 8, - 17, - 9, - 19, - 10, - 21, - 11, - 23, - 0, - 25, - 0, - 27, - 0, - 29, - 12, - 31, - 13, - 33, - 14, - 35, - 15, - 37, - 16, - 1, - 0, - 3, - 1, - 0, - 48, - 57, - 3, - 0, - 65, - 90, - 95, - 95, - 97, - 122, - 3, - 0, - 9, - 10, - 13, - 13, - 32, - 32, - 106, - 0, - 1, - 1, - 0, - 0, - 0, - 0, - 3, - 1, - 0, - 0, - 0, - 0, - 5, - 1, - 0, - 0, - 0, - 0, - 7, - 1, - 0, - 0, - 0, - 0, - 9, - 1, - 0, - 0, - 0, - 0, - 11, - 1, - 0, - 0, - 0, - 0, - 13, - 1, - 0, - 0, - 0, - 0, - 15, - 1, - 0, - 0, - 0, - 0, - 17, - 1, - 0, - 0, - 0, - 0, - 19, - 1, - 0, - 0, - 0, - 0, - 21, - 1, - 0, - 0, - 0, - 0, - 29, - 1, - 0, - 0, - 0, - 0, - 31, - 1, - 0, - 0, - 0, - 0, - 33, - 1, - 0, - 0, - 0, - 0, - 35, - 1, - 0, - 0, - 0, - 0, - 37, - 1, - 0, - 0, - 0, - 1, - 39, - 1, - 0, - 0, - 0, - 3, - 41, - 1, - 0, - 0, - 0, - 5, - 43, - 1, - 0, - 0, - 0, - 7, - 45, - 1, - 0, - 0, - 0, - 9, - 47, - 1, - 0, - 0, - 0, - 11, - 49, - 1, - 0, - 0, - 0, - 13, - 51, - 1, - 0, - 0, - 0, - 15, - 53, - 1, - 0, - 0, - 0, - 17, - 55, - 1, - 0, - 0, - 0, - 19, - 57, - 1, - 0, - 0, - 0, - 21, - 59, - 1, - 0, - 0, - 0, - 23, - 62, - 1, - 0, - 0, - 0, - 25, - 64, - 1, - 0, - 0, - 0, - 27, - 68, - 1, - 0, - 0, - 0, - 29, - 71, - 1, - 0, - 0, - 0, - 31, - 83, - 1, - 0, - 0, - 0, - 33, - 85, - 1, - 0, - 0, - 0, - 35, - 97, - 1, - 0, - 0, - 0, - 37, - 99, - 1, - 0, - 0, - 0, - 39, - 40, - 5, - 46, - 0, - 0, - 40, - 2, - 1, - 0, - 0, - 0, - 41, - 42, - 5, - 45, - 0, - 0, - 42, - 4, - 1, - 0, - 0, - 0, - 43, - 44, - 5, - 40, - 0, - 0, - 44, - 6, - 1, - 0, - 0, - 0, - 45, - 46, - 5, - 41, - 0, - 0, - 46, - 8, - 1, - 0, - 0, - 0, - 47, - 48, - 5, - 47, - 0, - 0, - 48, - 10, - 1, - 0, - 0, - 0, - 49, - 50, - 5, - 42, - 0, - 0, - 50, - 12, - 1, - 0, - 0, - 0, - 51, - 52, - 5, - 43, - 0, - 0, - 52, - 14, - 1, - 0, - 0, - 0, - 53, - 54, - 5, - 91, - 0, - 0, - 54, - 16, - 1, - 0, - 0, - 0, - 55, - 56, - 5, - 44, - 0, - 0, - 56, - 18, - 1, - 0, - 0, - 0, - 57, - 58, - 5, - 93, - 0, - 0, - 58, - 20, - 1, - 0, - 0, - 0, - 59, - 60, - 5, - 46, - 0, - 0, - 60, - 61, - 5, - 46, - 0, - 0, - 61, - 22, - 1, - 0, - 0, - 0, - 62, - 63, - 7, - 0, - 0, - 0, - 63, - 24, - 1, - 0, - 0, - 0, - 64, - 65, - 7, - 1, - 0, - 0, - 65, - 26, - 1, - 0, - 0, - 0, - 66, - 69, - 3, - 25, - 12, - 0, - 67, - 69, - 3, - 23, - 11, - 0, - 68, - 66, - 1, - 0, - 0, - 0, - 68, - 67, - 1, - 0, - 0, - 0, - 69, - 28, - 1, - 0, - 0, - 0, - 70, - 72, - 3, - 23, - 11, - 0, - 71, - 70, - 1, - 0, - 0, - 0, - 72, - 73, - 1, - 0, - 0, - 0, - 73, - 71, - 1, - 0, - 0, - 0, - 73, - 74, - 1, - 0, - 0, - 0, - 74, - 81, - 1, - 0, - 0, - 0, - 75, - 77, - 5, - 46, - 0, - 0, - 76, - 78, - 3, - 23, - 11, - 0, - 77, - 76, - 1, - 0, - 0, - 0, - 78, - 79, - 1, - 0, - 0, - 0, - 79, - 77, - 1, - 0, - 0, - 0, - 79, - 80, - 1, - 0, - 0, - 0, - 80, - 82, - 1, - 0, - 0, - 0, - 81, - 75, - 1, - 0, - 0, - 0, - 81, - 82, - 1, - 0, - 0, - 0, - 82, - 30, - 1, - 0, - 0, - 0, - 83, - 84, - 5, - 116, - 0, - 0, - 84, - 32, - 1, - 0, - 0, - 0, - 85, - 89, - 3, - 25, - 12, - 0, - 86, - 88, - 3, - 27, - 13, - 0, - 87, - 86, - 1, - 0, - 0, - 0, - 88, - 91, - 1, - 0, - 0, - 0, - 89, - 87, - 1, - 0, - 0, - 0, - 89, - 90, - 1, - 0, - 0, - 0, - 90, - 34, - 1, - 0, - 0, - 0, - 91, - 89, - 1, - 0, - 0, - 0, - 92, - 98, - 5, - 61, - 0, - 0, - 93, - 94, - 5, - 62, - 0, - 0, - 94, - 98, - 5, - 61, - 0, - 0, - 95, - 96, - 5, - 60, - 0, - 0, - 96, - 98, - 5, - 61, - 0, - 0, - 97, - 92, - 1, - 0, - 0, - 0, - 97, - 93, - 1, - 0, - 0, - 0, - 97, - 95, - 1, - 0, - 0, - 0, - 98, - 36, - 1, - 0, - 0, - 0, - 99, - 100, - 7, - 2, - 0, - 0, - 100, - 101, - 1, - 0, - 0, - 0, - 101, - 102, - 6, - 18, - 0, - 0, - 102, - 38, - 1, - 0, - 0, - 0, - 7, - 0, - 68, - 73, - 79, - 81, - 89, - 97, - 1, - 6, - 0, - 0, + 4,0,17,109,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5, + 2,6,7,6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2, + 13,7,13,2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7, + 19,1,0,1,0,1,1,1,1,1,2,1,2,1,3,1,3,1,4,1,4,1,5,1,5,1,6,1,6,1,7,1, + 7,1,8,1,8,1,9,1,9,1,10,1,10,1,10,1,11,1,11,1,12,1,12,1,13,1,13,3, + 13,71,8,13,1,14,4,14,74,8,14,11,14,12,14,75,1,14,1,14,4,14,80,8, + 14,11,14,12,14,81,3,14,84,8,14,1,15,1,15,1,16,1,16,5,16,90,8,16, + 10,16,12,16,93,9,16,1,17,1,17,1,17,1,17,1,17,3,17,100,8,17,1,18, + 1,18,1,18,1,18,1,19,1,19,1,19,1,19,0,0,20,1,1,3,2,5,3,7,4,9,5,11, + 6,13,7,15,8,17,9,19,10,21,11,23,0,25,0,27,0,29,12,31,13,33,14,35, + 15,37,16,39,17,1,0,3,1,0,48,57,3,0,65,90,95,95,97,122,3,0,9,10,13, + 13,32,32,112,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0,0,7,1,0,0,0,0,9, + 1,0,0,0,0,11,1,0,0,0,0,13,1,0,0,0,0,15,1,0,0,0,0,17,1,0,0,0,0,19, + 1,0,0,0,0,21,1,0,0,0,0,29,1,0,0,0,0,31,1,0,0,0,0,33,1,0,0,0,0,35, + 1,0,0,0,0,37,1,0,0,0,0,39,1,0,0,0,1,41,1,0,0,0,3,43,1,0,0,0,5,45, + 1,0,0,0,7,47,1,0,0,0,9,49,1,0,0,0,11,51,1,0,0,0,13,53,1,0,0,0,15, + 55,1,0,0,0,17,57,1,0,0,0,19,59,1,0,0,0,21,61,1,0,0,0,23,64,1,0,0, + 0,25,66,1,0,0,0,27,70,1,0,0,0,29,73,1,0,0,0,31,85,1,0,0,0,33,87, + 1,0,0,0,35,99,1,0,0,0,37,101,1,0,0,0,39,105,1,0,0,0,41,42,5,46,0, + 0,42,2,1,0,0,0,43,44,5,45,0,0,44,4,1,0,0,0,45,46,5,40,0,0,46,6,1, + 0,0,0,47,48,5,41,0,0,48,8,1,0,0,0,49,50,5,47,0,0,50,10,1,0,0,0,51, + 52,5,42,0,0,52,12,1,0,0,0,53,54,5,43,0,0,54,14,1,0,0,0,55,56,5,91, + 0,0,56,16,1,0,0,0,57,58,5,93,0,0,58,18,1,0,0,0,59,60,5,44,0,0,60, + 20,1,0,0,0,61,62,5,46,0,0,62,63,5,46,0,0,63,22,1,0,0,0,64,65,7,0, + 0,0,65,24,1,0,0,0,66,67,7,1,0,0,67,26,1,0,0,0,68,71,3,25,12,0,69, + 71,3,23,11,0,70,68,1,0,0,0,70,69,1,0,0,0,71,28,1,0,0,0,72,74,3,23, + 11,0,73,72,1,0,0,0,74,75,1,0,0,0,75,73,1,0,0,0,75,76,1,0,0,0,76, + 83,1,0,0,0,77,79,5,46,0,0,78,80,3,23,11,0,79,78,1,0,0,0,80,81,1, + 0,0,0,81,79,1,0,0,0,81,82,1,0,0,0,82,84,1,0,0,0,83,77,1,0,0,0,83, + 84,1,0,0,0,84,30,1,0,0,0,85,86,5,116,0,0,86,32,1,0,0,0,87,91,3,25, + 12,0,88,90,3,27,13,0,89,88,1,0,0,0,90,93,1,0,0,0,91,89,1,0,0,0,91, + 92,1,0,0,0,92,34,1,0,0,0,93,91,1,0,0,0,94,100,5,61,0,0,95,96,5,62, + 0,0,96,100,5,61,0,0,97,98,5,60,0,0,98,100,5,61,0,0,99,94,1,0,0,0, + 99,95,1,0,0,0,99,97,1,0,0,0,100,36,1,0,0,0,101,102,5,115,0,0,102, + 103,5,117,0,0,103,104,5,109,0,0,104,38,1,0,0,0,105,106,7,2,0,0,106, + 107,1,0,0,0,107,108,6,19,0,0,108,40,1,0,0,0,7,0,70,75,81,83,91,99, + 1,6,0,0 ] - class ExprLexer(Lexer): + atn = ATNDeserializer().deserialize(serializedATN()) - decisionsToDFA = [DFA(ds, i) for i, ds in enumerate(atn.decisionToState)] + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] T__0 = 1 T__1 = 2 @@ -967,59 +70,32 @@ class ExprLexer(Lexer): TIME = 13 IDENTIFIER = 14 COMPARISON = 15 - WS = 16 + TIME_SUM = 16 + WS = 17 - channelNames = ["DEFAULT_TOKEN_CHANNEL", "HIDDEN"] + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] - modeNames = ["DEFAULT_MODE"] + modeNames = [ "DEFAULT_MODE" ] - literalNames = [ - "", - "'.'", - "'-'", - "'('", - "')'", - "'/'", - "'*'", - "'+'", - "'['", - "','", - "']'", - "'..'", - "'t'", - ] + literalNames = [ "", + "'.'", "'-'", "'('", "')'", "'/'", "'*'", "'+'", "'['", "']'", + "','", "'..'", "'t'", "'sum'" ] - symbolicNames = ["", "NUMBER", "TIME", "IDENTIFIER", "COMPARISON", "WS"] + symbolicNames = [ "", + "NUMBER", "TIME", "IDENTIFIER", "COMPARISON", "TIME_SUM", "WS" ] - ruleNames = [ - "T__0", - "T__1", - "T__2", - "T__3", - "T__4", - "T__5", - "T__6", - "T__7", - "T__8", - "T__9", - "T__10", - "DIGIT", - "CHAR", - "CHAR_OR_DIGIT", - "NUMBER", - "TIME", - "IDENTIFIER", - "COMPARISON", - "WS", - ] + ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", + "T__7", "T__8", "T__9", "T__10", "DIGIT", "CHAR", "CHAR_OR_DIGIT", + "NUMBER", "TIME", "IDENTIFIER", "COMPARISON", "TIME_SUM", + "WS" ] grammarFileName = "Expr.g4" - def __init__(self, input=None, output: TextIO = sys.stdout): + def __init__(self, input=None, output:TextIO = sys.stdout): super().__init__(input, output) - self.checkVersion("4.13.1") - self._interp = LexerATNSimulator( - self, self.atn, self.decisionsToDFA, PredictionContextCache() - ) + self.checkVersion("4.13.2") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) self._actions = None self._predicates = None + + diff --git a/src/andromede/expression/parsing/antlr/ExprLexer.tokens b/src/andromede/expression/parsing/antlr/ExprLexer.tokens index 9401c83a..1c48b7d6 100644 --- a/src/andromede/expression/parsing/antlr/ExprLexer.tokens +++ b/src/andromede/expression/parsing/antlr/ExprLexer.tokens @@ -13,7 +13,8 @@ NUMBER=12 TIME=13 IDENTIFIER=14 COMPARISON=15 -WS=16 +TIME_SUM=16 +WS=17 '.'=1 '-'=2 '('=3 @@ -22,7 +23,8 @@ WS=16 '*'=6 '+'=7 '['=8 -','=9 -']'=10 +']'=9 +','=10 '..'=11 't'=13 +'sum'=16 diff --git a/src/andromede/expression/parsing/antlr/ExprParser.py b/src/andromede/expression/parsing/antlr/ExprParser.py index 8f312fe9..5ce55309 100644 --- a/src/andromede/expression/parsing/antlr/ExprParser.py +++ b/src/andromede/expression/parsing/antlr/ExprParser.py @@ -1,1296 +1,129 @@ -# Generated from Expr.g4 by ANTLR 4.13.1 +# Generated from Expr.g4 by ANTLR 4.13.2 # encoding: utf-8 -import sys -from io import StringIO - from antlr4 import * - +from io import StringIO +import sys if sys.version_info[1] > 5: - from typing import TextIO + from typing import TextIO else: - from typing.io import TextIO - + from typing.io import TextIO def serializedATN(): return [ - 4, - 1, - 16, - 131, - 2, - 0, - 7, - 0, - 2, - 1, - 7, - 1, - 2, - 2, - 7, - 2, - 2, - 3, - 7, - 3, - 2, - 4, - 7, - 4, - 2, - 5, - 7, - 5, - 1, - 0, - 1, - 0, - 1, - 0, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 5, - 1, - 37, - 8, - 1, - 10, - 1, - 12, - 1, - 40, - 9, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 5, - 1, - 49, - 8, - 1, - 10, - 1, - 12, - 1, - 52, - 9, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 3, - 1, - 70, - 8, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 5, - 1, - 81, - 8, - 1, - 10, - 1, - 12, - 1, - 84, - 9, - 1, - 1, - 2, - 1, - 2, - 3, - 2, - 88, - 8, - 2, - 1, - 3, - 1, - 3, - 3, - 3, - 92, - 8, - 3, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 3, - 4, - 102, - 8, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 1, - 4, - 5, - 4, - 110, - 8, - 4, - 10, - 4, - 12, - 4, - 113, - 9, - 4, - 1, - 5, - 1, - 5, - 1, - 5, - 1, - 5, - 1, - 5, - 1, - 5, - 3, - 5, - 121, - 8, - 5, - 1, - 5, - 1, - 5, - 1, - 5, - 5, - 5, - 126, - 8, - 5, - 10, - 5, - 12, - 5, - 129, - 9, - 5, - 1, - 5, - 0, - 3, - 2, - 8, - 10, - 6, - 0, - 2, - 4, - 6, - 8, - 10, - 0, - 2, - 1, - 0, - 5, - 6, - 2, - 0, - 2, - 2, - 7, - 7, - 144, - 0, - 12, - 1, - 0, - 0, - 0, - 2, - 69, - 1, - 0, - 0, - 0, - 4, - 87, - 1, - 0, - 0, - 0, - 6, - 89, - 1, - 0, - 0, - 0, - 8, - 101, - 1, - 0, - 0, - 0, - 10, - 120, - 1, - 0, - 0, - 0, - 12, - 13, - 3, - 2, - 1, - 0, - 13, - 14, - 5, - 0, - 0, - 1, - 14, - 1, - 1, - 0, - 0, - 0, - 15, - 16, - 6, - 1, - -1, - 0, - 16, - 70, - 3, - 4, - 2, - 0, - 17, - 18, - 5, - 14, - 0, - 0, - 18, - 19, - 5, - 1, - 0, - 0, - 19, - 70, - 5, - 14, - 0, - 0, - 20, - 21, - 5, - 2, - 0, - 0, - 21, - 70, - 3, - 2, - 1, - 10, - 22, - 23, - 5, - 3, - 0, - 0, - 23, - 24, - 3, - 2, - 1, - 0, - 24, - 25, - 5, - 4, - 0, - 0, - 25, - 70, - 1, - 0, - 0, - 0, - 26, - 27, - 5, - 14, - 0, - 0, - 27, - 28, - 5, - 3, - 0, - 0, - 28, - 29, - 3, - 2, - 1, - 0, - 29, - 30, - 5, - 4, - 0, - 0, - 30, - 70, - 1, - 0, - 0, - 0, - 31, - 32, - 5, - 14, - 0, - 0, - 32, - 33, - 5, - 8, - 0, - 0, - 33, - 38, - 3, - 6, - 3, - 0, - 34, - 35, - 5, - 9, - 0, - 0, - 35, - 37, - 3, - 6, - 3, - 0, - 36, - 34, - 1, - 0, - 0, - 0, - 37, - 40, - 1, - 0, - 0, - 0, - 38, - 36, - 1, - 0, - 0, - 0, - 38, - 39, - 1, - 0, - 0, - 0, - 39, - 41, - 1, - 0, - 0, - 0, - 40, - 38, - 1, - 0, - 0, - 0, - 41, - 42, - 5, - 10, - 0, - 0, - 42, - 70, - 1, - 0, - 0, - 0, - 43, - 44, - 5, - 14, - 0, - 0, - 44, - 45, - 5, - 8, - 0, - 0, - 45, - 50, - 3, - 2, - 1, - 0, - 46, - 47, - 5, - 9, - 0, - 0, - 47, - 49, - 3, - 2, - 1, - 0, - 48, - 46, - 1, - 0, - 0, - 0, - 49, - 52, - 1, - 0, - 0, - 0, - 50, - 48, - 1, - 0, - 0, - 0, - 50, - 51, - 1, - 0, - 0, - 0, - 51, - 53, - 1, - 0, - 0, - 0, - 52, - 50, - 1, - 0, - 0, - 0, - 53, - 54, - 5, - 10, - 0, - 0, - 54, - 70, - 1, - 0, - 0, - 0, - 55, - 56, - 5, - 14, - 0, - 0, - 56, - 57, - 5, - 8, - 0, - 0, - 57, - 58, - 3, - 6, - 3, - 0, - 58, - 59, - 5, - 11, - 0, - 0, - 59, - 60, - 3, - 6, - 3, - 0, - 60, - 61, - 5, - 10, - 0, - 0, - 61, - 70, - 1, - 0, - 0, - 0, - 62, - 63, - 5, - 14, - 0, - 0, - 63, - 64, - 5, - 8, - 0, - 0, - 64, - 65, - 3, - 2, - 1, - 0, - 65, - 66, - 5, - 11, - 0, - 0, - 66, - 67, - 3, - 2, - 1, - 0, - 67, - 68, - 5, - 10, - 0, - 0, - 68, - 70, - 1, - 0, - 0, - 0, - 69, - 15, - 1, - 0, - 0, - 0, - 69, - 17, - 1, - 0, - 0, - 0, - 69, - 20, - 1, - 0, - 0, - 0, - 69, - 22, - 1, - 0, - 0, - 0, - 69, - 26, - 1, - 0, - 0, - 0, - 69, - 31, - 1, - 0, - 0, - 0, - 69, - 43, - 1, - 0, - 0, - 0, - 69, - 55, - 1, - 0, - 0, - 0, - 69, - 62, - 1, - 0, - 0, - 0, - 70, - 82, - 1, - 0, - 0, - 0, - 71, - 72, - 10, - 8, - 0, - 0, - 72, - 73, - 7, - 0, - 0, - 0, - 73, - 81, - 3, - 2, - 1, - 9, - 74, - 75, - 10, - 7, - 0, - 0, - 75, - 76, - 7, - 1, - 0, - 0, - 76, - 81, - 3, - 2, - 1, - 8, - 77, - 78, - 10, - 6, - 0, - 0, - 78, - 79, - 5, - 15, - 0, - 0, - 79, - 81, - 3, - 2, - 1, - 7, - 80, - 71, - 1, - 0, - 0, - 0, - 80, - 74, - 1, - 0, - 0, - 0, - 80, - 77, - 1, - 0, - 0, - 0, - 81, - 84, - 1, - 0, - 0, - 0, - 82, - 80, - 1, - 0, - 0, - 0, - 82, - 83, - 1, - 0, - 0, - 0, - 83, - 3, - 1, - 0, - 0, - 0, - 84, - 82, - 1, - 0, - 0, - 0, - 85, - 88, - 5, - 12, - 0, - 0, - 86, - 88, - 5, - 14, - 0, - 0, - 87, - 85, - 1, - 0, - 0, - 0, - 87, - 86, - 1, - 0, - 0, - 0, - 88, - 5, - 1, - 0, - 0, - 0, - 89, - 91, - 5, - 13, - 0, - 0, - 90, - 92, - 3, - 8, - 4, - 0, - 91, - 90, - 1, - 0, - 0, - 0, - 91, - 92, - 1, - 0, - 0, - 0, - 92, - 7, - 1, - 0, - 0, - 0, - 93, - 94, - 6, - 4, - -1, - 0, - 94, - 95, - 7, - 1, - 0, - 0, - 95, - 102, - 3, - 4, - 2, - 0, - 96, - 97, - 7, - 1, - 0, - 0, - 97, - 98, - 5, - 3, - 0, - 0, - 98, - 99, - 3, - 2, - 1, - 0, - 99, - 100, - 5, - 4, - 0, - 0, - 100, - 102, - 1, - 0, - 0, - 0, - 101, - 93, - 1, - 0, - 0, - 0, - 101, - 96, - 1, - 0, - 0, - 0, - 102, - 111, - 1, - 0, - 0, - 0, - 103, - 104, - 10, - 4, - 0, - 0, - 104, - 105, - 7, - 0, - 0, - 0, - 105, - 110, - 3, - 10, - 5, - 0, - 106, - 107, - 10, - 3, - 0, - 0, - 107, - 108, - 7, - 1, - 0, - 0, - 108, - 110, - 3, - 10, - 5, - 0, - 109, - 103, - 1, - 0, - 0, - 0, - 109, - 106, - 1, - 0, - 0, - 0, - 110, - 113, - 1, - 0, - 0, - 0, - 111, - 109, - 1, - 0, - 0, - 0, - 111, - 112, - 1, - 0, - 0, - 0, - 112, - 9, - 1, - 0, - 0, - 0, - 113, - 111, - 1, - 0, - 0, - 0, - 114, - 115, - 6, - 5, - -1, - 0, - 115, - 116, - 5, - 3, - 0, - 0, - 116, - 117, - 3, - 2, - 1, - 0, - 117, - 118, - 5, - 4, - 0, - 0, - 118, - 121, - 1, - 0, - 0, - 0, - 119, - 121, - 3, - 4, - 2, - 0, - 120, - 114, - 1, - 0, - 0, - 0, - 120, - 119, - 1, - 0, - 0, - 0, - 121, - 127, - 1, - 0, - 0, - 0, - 122, - 123, - 10, - 3, - 0, - 0, - 123, - 124, - 7, - 0, - 0, - 0, - 124, - 126, - 3, - 10, - 5, - 4, - 125, - 122, - 1, - 0, - 0, - 0, - 126, - 129, - 1, - 0, - 0, - 0, - 127, - 125, - 1, - 0, - 0, - 0, - 127, - 128, - 1, - 0, - 0, - 0, - 128, - 11, - 1, - 0, - 0, - 0, - 129, - 127, - 1, - 0, - 0, - 0, - 12, - 38, - 50, - 69, - 80, - 82, - 87, - 91, - 101, - 109, - 111, - 120, - 127, + 4,1,17,127,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, + 6,2,7,7,7,1,0,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,3,1,52,8,1,1,1,1,1,1,1,1,1,3,1,58,8,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,5,1,69,8,1,10,1,12,1,72,9,1,1,2, + 1,2,3,2,76,8,2,1,3,1,3,1,3,1,3,1,4,1,4,1,4,1,4,1,5,1,5,3,5,88,8, + 5,1,6,1,6,1,6,1,6,1,6,1,6,1,6,1,6,3,6,98,8,6,1,6,1,6,1,6,1,6,1,6, + 1,6,5,6,106,8,6,10,6,12,6,109,9,6,1,7,1,7,1,7,1,7,1,7,1,7,3,7,117, + 8,7,1,7,1,7,1,7,5,7,122,8,7,10,7,12,7,125,9,7,1,7,0,3,2,12,14,8, + 0,2,4,6,8,10,12,14,0,2,1,0,5,6,2,0,2,2,7,7,138,0,16,1,0,0,0,2,57, + 1,0,0,0,4,75,1,0,0,0,6,77,1,0,0,0,8,81,1,0,0,0,10,85,1,0,0,0,12, + 97,1,0,0,0,14,116,1,0,0,0,16,17,3,2,1,0,17,18,5,0,0,1,18,1,1,0,0, + 0,19,20,6,1,-1,0,20,58,3,4,2,0,21,22,5,14,0,0,22,23,5,1,0,0,23,58, + 5,14,0,0,24,25,5,2,0,0,25,58,3,2,1,9,26,27,5,3,0,0,27,28,3,2,1,0, + 28,29,5,4,0,0,29,58,1,0,0,0,30,31,5,14,0,0,31,32,5,3,0,0,32,33,3, + 2,1,0,33,34,5,4,0,0,34,58,1,0,0,0,35,36,5,14,0,0,36,37,5,8,0,0,37, + 38,3,10,5,0,38,39,5,9,0,0,39,58,1,0,0,0,40,41,5,14,0,0,41,42,5,8, + 0,0,42,43,3,2,1,0,43,44,5,9,0,0,44,58,1,0,0,0,45,46,5,16,0,0,46, + 51,5,3,0,0,47,52,3,2,1,0,48,52,3,10,5,0,49,52,3,6,3,0,50,52,3,8, + 4,0,51,47,1,0,0,0,51,48,1,0,0,0,51,49,1,0,0,0,51,50,1,0,0,0,52,53, + 1,0,0,0,53,54,5,10,0,0,54,55,5,14,0,0,55,56,5,4,0,0,56,58,1,0,0, + 0,57,19,1,0,0,0,57,21,1,0,0,0,57,24,1,0,0,0,57,26,1,0,0,0,57,30, + 1,0,0,0,57,35,1,0,0,0,57,40,1,0,0,0,57,45,1,0,0,0,58,70,1,0,0,0, + 59,60,10,7,0,0,60,61,7,0,0,0,61,69,3,2,1,8,62,63,10,6,0,0,63,64, + 7,1,0,0,64,69,3,2,1,7,65,66,10,5,0,0,66,67,5,15,0,0,67,69,3,2,1, + 6,68,59,1,0,0,0,68,62,1,0,0,0,68,65,1,0,0,0,69,72,1,0,0,0,70,68, + 1,0,0,0,70,71,1,0,0,0,71,3,1,0,0,0,72,70,1,0,0,0,73,76,5,12,0,0, + 74,76,5,14,0,0,75,73,1,0,0,0,75,74,1,0,0,0,76,5,1,0,0,0,77,78,3, + 10,5,0,78,79,5,11,0,0,79,80,3,10,5,0,80,7,1,0,0,0,81,82,3,2,1,0, + 82,83,5,11,0,0,83,84,3,2,1,0,84,9,1,0,0,0,85,87,5,13,0,0,86,88,3, + 12,6,0,87,86,1,0,0,0,87,88,1,0,0,0,88,11,1,0,0,0,89,90,6,6,-1,0, + 90,91,7,1,0,0,91,98,3,4,2,0,92,93,7,1,0,0,93,94,5,3,0,0,94,95,3, + 2,1,0,95,96,5,4,0,0,96,98,1,0,0,0,97,89,1,0,0,0,97,92,1,0,0,0,98, + 107,1,0,0,0,99,100,10,4,0,0,100,101,7,0,0,0,101,106,3,14,7,0,102, + 103,10,3,0,0,103,104,7,1,0,0,104,106,3,14,7,0,105,99,1,0,0,0,105, + 102,1,0,0,0,106,109,1,0,0,0,107,105,1,0,0,0,107,108,1,0,0,0,108, + 13,1,0,0,0,109,107,1,0,0,0,110,111,6,7,-1,0,111,112,5,3,0,0,112, + 113,3,2,1,0,113,114,5,4,0,0,114,117,1,0,0,0,115,117,3,4,2,0,116, + 110,1,0,0,0,116,115,1,0,0,0,117,123,1,0,0,0,118,119,10,3,0,0,119, + 120,7,0,0,0,120,122,3,14,7,4,121,118,1,0,0,0,122,125,1,0,0,0,123, + 121,1,0,0,0,123,124,1,0,0,0,124,15,1,0,0,0,125,123,1,0,0,0,11,51, + 57,68,70,75,87,97,105,107,116,123 ] +class ExprParser ( Parser ): -class ExprParser(Parser): grammarFileName = "Expr.g4" atn = ATNDeserializer().deserialize(serializedATN()) - decisionsToDFA = [DFA(ds, i) for i, ds in enumerate(atn.decisionToState)] + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] sharedContextCache = PredictionContextCache() - literalNames = [ - "", - "'.'", - "'-'", - "'('", - "')'", - "'/'", - "'*'", - "'+'", - "'['", - "','", - "']'", - "'..'", - "", - "'t'", - ] + literalNames = [ "", "'.'", "'-'", "'('", "')'", "'/'", "'*'", + "'+'", "'['", "']'", "','", "'..'", "", "'t'", + "", "", "'sum'" ] - symbolicNames = [ - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "NUMBER", - "TIME", - "IDENTIFIER", - "COMPARISON", - "WS", - ] + symbolicNames = [ "", "", "", "", + "", "", "", "", + "", "", "", "", + "NUMBER", "TIME", "IDENTIFIER", "COMPARISON", "TIME_SUM", + "WS" ] RULE_fullexpr = 0 RULE_expr = 1 RULE_atom = 2 - RULE_shift = 3 - RULE_shift_expr = 4 - RULE_right_expr = 5 + RULE_timeShiftRange = 3 + RULE_timeRange = 4 + RULE_shift = 5 + RULE_shift_expr = 6 + RULE_right_expr = 7 - ruleNames = ["fullexpr", "expr", "atom", "shift", "shift_expr", "right_expr"] + ruleNames = [ "fullexpr", "expr", "atom", "timeShiftRange", "timeRange", + "shift", "shift_expr", "right_expr" ] EOF = Token.EOF - T__0 = 1 - T__1 = 2 - T__2 = 3 - T__3 = 4 - T__4 = 5 - T__5 = 6 - T__6 = 7 - T__7 = 8 - T__8 = 9 - T__9 = 10 - T__10 = 11 - NUMBER = 12 - TIME = 13 - IDENTIFIER = 14 - COMPARISON = 15 - WS = 16 - - def __init__(self, input: TokenStream, output: TextIO = sys.stdout): + T__0=1 + T__1=2 + T__2=3 + T__3=4 + T__4=5 + T__5=6 + T__6=7 + T__7=8 + T__8=9 + T__9=10 + T__10=11 + NUMBER=12 + TIME=13 + IDENTIFIER=14 + COMPARISON=15 + TIME_SUM=16 + WS=17 + + def __init__(self, input:TokenStream, output:TextIO = sys.stdout): super().__init__(input, output) - self.checkVersion("4.13.1") - self._interp = ParserATNSimulator( - self, self.atn, self.decisionsToDFA, self.sharedContextCache - ) + self.checkVersion("4.13.2") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) self._predicates = None + + + class FullexprContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) + def EOF(self): return self.getToken(ExprParser.EOF, 0) @@ -1298,20 +131,24 @@ def EOF(self): def getRuleIndex(self): return ExprParser.RULE_fullexpr - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitFullexpr"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitFullexpr" ): return visitor.visitFullexpr(self) else: return visitor.visitChildren(self) + + + def fullexpr(self): + localctx = ExprParser.FullexprContext(self, self._ctx, self.state) self.enterRule(localctx, 0, self.RULE_fullexpr) try: self.enterOuterAlt(localctx, 1) - self.state = 12 + self.state = 16 self.expr(0) - self.state = 13 + self.state = 17 self.match(ExprParser.EOF) except RecognitionException as re: localctx.exception = re @@ -1321,278 +158,264 @@ def fullexpr(self): self.exitRule() return localctx + class ExprContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser + def getRuleIndex(self): return ExprParser.RULE_expr - def copyFrom(self, ctx: ParserRuleContext): + + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) + + class TimeSumContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def TIME_SUM(self): + return self.getToken(ExprParser.TIME_SUM, 0) + def IDENTIFIER(self): + return self.getToken(ExprParser.IDENTIFIER, 0) + def expr(self): + return self.getTypedRuleContext(ExprParser.ExprContext,0) + + def shift(self): + return self.getTypedRuleContext(ExprParser.ShiftContext,0) + + def timeShiftRange(self): + return self.getTypedRuleContext(ExprParser.TimeShiftRangeContext,0) + + def timeRange(self): + return self.getTypedRuleContext(ExprParser.TimeRangeContext,0) + + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTimeSum" ): + return visitor.visitTimeSum(self) + else: + return visitor.visitChildren(self) + + class NegationContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitNegation"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitNegation" ): return visitor.visitNegation(self) else: return visitor.visitChildren(self) + class UnsignedAtomContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def atom(self): - return self.getTypedRuleContext(ExprParser.AtomContext, 0) + return self.getTypedRuleContext(ExprParser.AtomContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitUnsignedAtom"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitUnsignedAtom" ): return visitor.visitUnsignedAtom(self) else: return visitor.visitChildren(self) + class ExpressionContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitExpression"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitExpression" ): return visitor.visitExpression(self) else: return visitor.visitChildren(self) + class TimeIndexContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def IDENTIFIER(self): return self.getToken(ExprParser.IDENTIFIER, 0) + def expr(self): + return self.getTypedRuleContext(ExprParser.ExprContext,0) - def expr(self, i: int = None): - if i is None: - return self.getTypedRuleContexts(ExprParser.ExprContext) - else: - return self.getTypedRuleContext(ExprParser.ExprContext, i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitTimeIndex"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTimeIndex" ): return visitor.visitTimeIndex(self) else: return visitor.visitChildren(self) + class ComparisonContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) - def expr(self, i: int = None): + def expr(self, i:int=None): if i is None: return self.getTypedRuleContexts(ExprParser.ExprContext) else: - return self.getTypedRuleContext(ExprParser.ExprContext, i) + return self.getTypedRuleContext(ExprParser.ExprContext,i) def COMPARISON(self): return self.getToken(ExprParser.COMPARISON, 0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitComparison"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitComparison" ): return visitor.visitComparison(self) else: return visitor.visitChildren(self) + class TimeShiftContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def IDENTIFIER(self): return self.getToken(ExprParser.IDENTIFIER, 0) + def shift(self): + return self.getTypedRuleContext(ExprParser.ShiftContext,0) - def shift(self, i: int = None): - if i is None: - return self.getTypedRuleContexts(ExprParser.ShiftContext) - else: - return self.getTypedRuleContext(ExprParser.ShiftContext, i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitTimeShift"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTimeShift" ): return visitor.visitTimeShift(self) else: return visitor.visitChildren(self) + class FunctionContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) def IDENTIFIER(self): return self.getToken(ExprParser.IDENTIFIER, 0) - def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitFunction"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitFunction" ): return visitor.visitFunction(self) else: return visitor.visitChildren(self) + class AddsubContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) - def expr(self, i: int = None): + def expr(self, i:int=None): if i is None: return self.getTypedRuleContexts(ExprParser.ExprContext) else: - return self.getTypedRuleContext(ExprParser.ExprContext, i) + return self.getTypedRuleContext(ExprParser.ExprContext,i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitAddsub"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitAddsub" ): return visitor.visitAddsub(self) else: return visitor.visitChildren(self) - class TimeShiftRangeContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext - super().__init__(parser) - self.shift1 = None # ShiftContext - self.shift2 = None # ShiftContext - self.copyFrom(ctx) - - def IDENTIFIER(self): - return self.getToken(ExprParser.IDENTIFIER, 0) - - def shift(self, i: int = None): - if i is None: - return self.getTypedRuleContexts(ExprParser.ShiftContext) - else: - return self.getTypedRuleContext(ExprParser.ShiftContext, i) - - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitTimeShiftRange"): - return visitor.visitTimeShiftRange(self) - else: - return visitor.visitChildren(self) class PortFieldContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) self.copyFrom(ctx) - def IDENTIFIER(self, i: int = None): + def IDENTIFIER(self, i:int=None): if i is None: return self.getTokens(ExprParser.IDENTIFIER) else: return self.getToken(ExprParser.IDENTIFIER, i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitPortField"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitPortField" ): return visitor.visitPortField(self) else: return visitor.visitChildren(self) + class MuldivContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.ExprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) - def expr(self, i: int = None): + def expr(self, i:int=None): if i is None: return self.getTypedRuleContexts(ExprParser.ExprContext) else: - return self.getTypedRuleContext(ExprParser.ExprContext, i) + return self.getTypedRuleContext(ExprParser.ExprContext,i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitMuldiv"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitMuldiv" ): return visitor.visitMuldiv(self) else: return visitor.visitChildren(self) - class TimeRangeContext(ExprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.ExprContext - super().__init__(parser) - self.copyFrom(ctx) - - def IDENTIFIER(self): - return self.getToken(ExprParser.IDENTIFIER, 0) - - def expr(self, i: int = None): - if i is None: - return self.getTypedRuleContexts(ExprParser.ExprContext) - else: - return self.getTypedRuleContext(ExprParser.ExprContext, i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitTimeRange"): - return visitor.visitTimeRange(self) - else: - return visitor.visitChildren(self) - def expr(self, _p: int = 0): + def expr(self, _p:int=0): _parentctx = self._ctx _parentState = self.state localctx = ExprParser.ExprContext(self, self._ctx, _parentState) _prevctx = localctx _startState = 2 self.enterRecursionRule(localctx, 2, self.RULE_expr, _p) - self._la = 0 # Token type + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 69 + self.state = 57 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 2, self._ctx) + la_ = self._interp.adaptivePredict(self._input,1,self._ctx) if la_ == 1: localctx = ExprParser.UnsignedAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 16 + self.state = 20 self.atom() pass @@ -1600,11 +423,11 @@ def expr(self, _p: int = 0): localctx = ExprParser.PortFieldContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 17 + self.state = 21 self.match(ExprParser.IDENTIFIER) - self.state = 18 + self.state = 22 self.match(ExprParser.T__0) - self.state = 19 + self.state = 23 self.match(ExprParser.IDENTIFIER) pass @@ -1612,21 +435,21 @@ def expr(self, _p: int = 0): localctx = ExprParser.NegationContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 20 + self.state = 24 self.match(ExprParser.T__1) - self.state = 21 - self.expr(10) + self.state = 25 + self.expr(9) pass elif la_ == 4: localctx = ExprParser.ExpressionContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 22 + self.state = 26 self.match(ExprParser.T__2) - self.state = 23 + self.state = 27 self.expr(0) - self.state = 24 + self.state = 28 self.match(ExprParser.T__3) pass @@ -1634,13 +457,13 @@ def expr(self, _p: int = 0): localctx = ExprParser.FunctionContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 26 + self.state = 30 self.match(ExprParser.IDENTIFIER) - self.state = 27 + self.state = 31 self.match(ExprParser.T__2) - self.state = 28 + self.state = 32 self.expr(0) - self.state = 29 + self.state = 33 self.match(ExprParser.T__3) pass @@ -1648,177 +471,138 @@ def expr(self, _p: int = 0): localctx = ExprParser.TimeShiftContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 31 + self.state = 35 self.match(ExprParser.IDENTIFIER) - self.state = 32 + self.state = 36 self.match(ExprParser.T__7) - self.state = 33 + self.state = 37 self.shift() self.state = 38 - self._errHandler.sync(self) - _la = self._input.LA(1) - while _la == 9: - self.state = 34 - self.match(ExprParser.T__8) - self.state = 35 - self.shift() - self.state = 40 - self._errHandler.sync(self) - _la = self._input.LA(1) - - self.state = 41 - self.match(ExprParser.T__9) + self.match(ExprParser.T__8) pass elif la_ == 7: localctx = ExprParser.TimeIndexContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 43 + self.state = 40 self.match(ExprParser.IDENTIFIER) - self.state = 44 + self.state = 41 self.match(ExprParser.T__7) - self.state = 45 + self.state = 42 self.expr(0) - self.state = 50 + self.state = 43 + self.match(ExprParser.T__8) + pass + + elif la_ == 8: + localctx = ExprParser.TimeSumContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 45 + self.match(ExprParser.TIME_SUM) + self.state = 46 + self.match(ExprParser.T__2) + self.state = 51 self._errHandler.sync(self) - _la = self._input.LA(1) - while _la == 9: - self.state = 46 - self.match(ExprParser.T__8) + la_ = self._interp.adaptivePredict(self._input,0,self._ctx) + if la_ == 1: self.state = 47 self.expr(0) - self.state = 52 - self._errHandler.sync(self) - _la = self._input.LA(1) + pass + + elif la_ == 2: + self.state = 48 + self.shift() + pass + + elif la_ == 3: + self.state = 49 + self.timeShiftRange() + pass + + elif la_ == 4: + self.state = 50 + self.timeRange() + pass + self.state = 53 self.match(ExprParser.T__9) - pass - - elif la_ == 8: - localctx = ExprParser.TimeShiftRangeContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 55 + self.state = 54 self.match(ExprParser.IDENTIFIER) - self.state = 56 - self.match(ExprParser.T__7) - self.state = 57 - localctx.shift1 = self.shift() - self.state = 58 - self.match(ExprParser.T__10) - self.state = 59 - localctx.shift2 = self.shift() - self.state = 60 - self.match(ExprParser.T__9) + self.state = 55 + self.match(ExprParser.T__3) pass - elif la_ == 9: - localctx = ExprParser.TimeRangeContext(self, localctx) - self._ctx = localctx - _prevctx = localctx - self.state = 62 - self.match(ExprParser.IDENTIFIER) - self.state = 63 - self.match(ExprParser.T__7) - self.state = 64 - self.expr(0) - self.state = 65 - self.match(ExprParser.T__10) - self.state = 66 - self.expr(0) - self.state = 67 - self.match(ExprParser.T__9) - pass self._ctx.stop = self._input.LT(-1) - self.state = 82 + self.state = 70 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 4, self._ctx) - while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: - if _alt == 1: + _alt = self._interp.adaptivePredict(self._input,3,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 80 + self.state = 68 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 3, self._ctx) + la_ = self._interp.adaptivePredict(self._input,2,self._ctx) if la_ == 1: - localctx = ExprParser.MuldivContext( - self, ExprParser.ExprContext(self, _parentctx, _parentState) - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_expr - ) - self.state = 71 - if not self.precpred(self._ctx, 8): + localctx = ExprParser.MuldivContext(self, ExprParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 59 + if not self.precpred(self._ctx, 7): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 8)" - ) - self.state = 72 + raise FailedPredicateException(self, "self.precpred(self._ctx, 7)") + self.state = 60 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 5 or _la == 6): + if not(_la==5 or _la==6): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 73 - self.expr(9) + self.state = 61 + self.expr(8) pass elif la_ == 2: - localctx = ExprParser.AddsubContext( - self, ExprParser.ExprContext(self, _parentctx, _parentState) - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_expr - ) - self.state = 74 - if not self.precpred(self._ctx, 7): + localctx = ExprParser.AddsubContext(self, ExprParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 62 + if not self.precpred(self._ctx, 6): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 7)" - ) - self.state = 75 + raise FailedPredicateException(self, "self.precpred(self._ctx, 6)") + self.state = 63 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 2 or _la == 7): + if not(_la==2 or _la==7): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 76 - self.expr(8) + self.state = 64 + self.expr(7) pass elif la_ == 3: - localctx = ExprParser.ComparisonContext( - self, ExprParser.ExprContext(self, _parentctx, _parentState) - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_expr - ) - self.state = 77 - if not self.precpred(self._ctx, 6): + localctx = ExprParser.ComparisonContext(self, ExprParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 65 + if not self.precpred(self._ctx, 5): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 6)" - ) - self.state = 78 + raise FailedPredicateException(self, "self.precpred(self._ctx, 5)") + self.state = 66 self.match(ExprParser.COMPARISON) - self.state = 79 - self.expr(7) + self.state = 67 + self.expr(6) pass - self.state = 84 + + self.state = 72 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 4, self._ctx) + _alt = self._interp.adaptivePredict(self._input,3,self._ctx) except RecognitionException as re: localctx.exception = re @@ -1828,70 +612,75 @@ def expr(self, _p: int = 0): self.unrollRecursionContexts(_parentctx) return localctx + class AtomContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser + def getRuleIndex(self): return ExprParser.RULE_atom - def copyFrom(self, ctx: ParserRuleContext): + + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) + + class NumberContext(AtomContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.AtomContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.AtomContext super().__init__(parser) self.copyFrom(ctx) def NUMBER(self): return self.getToken(ExprParser.NUMBER, 0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitNumber"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitNumber" ): return visitor.visitNumber(self) else: return visitor.visitChildren(self) + class IdentifierContext(AtomContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.AtomContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.AtomContext super().__init__(parser) self.copyFrom(ctx) def IDENTIFIER(self): return self.getToken(ExprParser.IDENTIFIER, 0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitIdentifier"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitIdentifier" ): return visitor.visitIdentifier(self) else: return visitor.visitChildren(self) + + def atom(self): + localctx = ExprParser.AtomContext(self, self._ctx, self.state) self.enterRule(localctx, 4, self.RULE_atom) try: - self.state = 87 + self.state = 75 self._errHandler.sync(self) token = self._input.LA(1) if token in [12]: localctx = ExprParser.NumberContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 85 + self.state = 73 self.match(ExprParser.NUMBER) pass elif token in [14]: localctx = ExprParser.IdentifierContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 86 + self.state = 74 self.match(ExprParser.IDENTIFIER) pass else: @@ -1905,12 +694,109 @@ def atom(self): self.exitRule() return localctx + + class TimeShiftRangeContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.shift1 = None # ShiftContext + self.shift2 = None # ShiftContext + + def shift(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(ExprParser.ShiftContext) + else: + return self.getTypedRuleContext(ExprParser.ShiftContext,i) + + + def getRuleIndex(self): + return ExprParser.RULE_timeShiftRange + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTimeShiftRange" ): + return visitor.visitTimeShiftRange(self) + else: + return visitor.visitChildren(self) + + + + + def timeShiftRange(self): + + localctx = ExprParser.TimeShiftRangeContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_timeShiftRange) + try: + self.enterOuterAlt(localctx, 1) + self.state = 77 + localctx.shift1 = self.shift() + self.state = 78 + self.match(ExprParser.T__10) + self.state = 79 + localctx.shift2 = self.shift() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class TimeRangeContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.expr1 = None # ExprContext + self.expr2 = None # ExprContext + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(ExprParser.ExprContext) + else: + return self.getTypedRuleContext(ExprParser.ExprContext,i) + + + def getRuleIndex(self): + return ExprParser.RULE_timeRange + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitTimeRange" ): + return visitor.visitTimeRange(self) + else: + return visitor.visitChildren(self) + + + + + def timeRange(self): + + localctx = ExprParser.TimeRangeContext(self, self._ctx, self.state) + self.enterRule(localctx, 8, self.RULE_timeRange) + try: + self.enterOuterAlt(localctx, 1) + self.state = 81 + localctx.expr1 = self.expr(0) + self.state = 82 + self.match(ExprParser.T__10) + self.state = 83 + localctx.expr2 = self.expr(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + class ShiftContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser @@ -1918,32 +804,38 @@ def TIME(self): return self.getToken(ExprParser.TIME, 0) def shift_expr(self): - return self.getTypedRuleContext(ExprParser.Shift_exprContext, 0) + return self.getTypedRuleContext(ExprParser.Shift_exprContext,0) + def getRuleIndex(self): return ExprParser.RULE_shift - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitShift"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitShift" ): return visitor.visitShift(self) else: return visitor.visitChildren(self) + + + def shift(self): + localctx = ExprParser.ShiftContext(self, self._ctx, self.state) - self.enterRule(localctx, 6, self.RULE_shift) - self._la = 0 # Token type + self.enterRule(localctx, 10, self.RULE_shift) + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 89 + self.state = 85 self.match(ExprParser.TIME) - self.state = 91 + self.state = 87 self._errHandler.sync(self) _la = self._input.LA(1) - if _la == 2 or _la == 7: - self.state = 90 + if _la==2 or _la==7: + self.state = 86 self.shift_expr(0) + except RecognitionException as re: localctx.exception = re self._errHandler.reportError(self, re) @@ -1952,122 +844,129 @@ def shift(self): self.exitRule() return localctx + class Shift_exprContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser + def getRuleIndex(self): return ExprParser.RULE_shift_expr - def copyFrom(self, ctx: ParserRuleContext): + + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) + class SignedAtomContext(Shift_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Shift_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Shift_exprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) def atom(self): - return self.getTypedRuleContext(ExprParser.AtomContext, 0) + return self.getTypedRuleContext(ExprParser.AtomContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitSignedAtom"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitSignedAtom" ): return visitor.visitSignedAtom(self) else: return visitor.visitChildren(self) + class SignedExpressionContext(Shift_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Shift_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Shift_exprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitSignedExpression"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitSignedExpression" ): return visitor.visitSignedExpression(self) else: return visitor.visitChildren(self) + class ShiftMuldivContext(Shift_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Shift_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Shift_exprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) def shift_expr(self): - return self.getTypedRuleContext(ExprParser.Shift_exprContext, 0) + return self.getTypedRuleContext(ExprParser.Shift_exprContext,0) def right_expr(self): - return self.getTypedRuleContext(ExprParser.Right_exprContext, 0) + return self.getTypedRuleContext(ExprParser.Right_exprContext,0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitShiftMuldiv"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitShiftMuldiv" ): return visitor.visitShiftMuldiv(self) else: return visitor.visitChildren(self) + class ShiftAddsubContext(Shift_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Shift_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Shift_exprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) def shift_expr(self): - return self.getTypedRuleContext(ExprParser.Shift_exprContext, 0) + return self.getTypedRuleContext(ExprParser.Shift_exprContext,0) def right_expr(self): - return self.getTypedRuleContext(ExprParser.Right_exprContext, 0) + return self.getTypedRuleContext(ExprParser.Right_exprContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitShiftAddsub"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitShiftAddsub" ): return visitor.visitShiftAddsub(self) else: return visitor.visitChildren(self) - def shift_expr(self, _p: int = 0): + + + def shift_expr(self, _p:int=0): _parentctx = self._ctx _parentState = self.state localctx = ExprParser.Shift_exprContext(self, self._ctx, _parentState) _prevctx = localctx - _startState = 8 - self.enterRecursionRule(localctx, 8, self.RULE_shift_expr, _p) - self._la = 0 # Token type + _startState = 12 + self.enterRecursionRule(localctx, 12, self.RULE_shift_expr, _p) + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 101 + self.state = 97 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 7, self._ctx) + la_ = self._interp.adaptivePredict(self._input,6,self._ctx) if la_ == 1: localctx = ExprParser.SignedAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 94 + self.state = 90 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 2 or _la == 7): + if not(_la==2 or _la==7): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 95 + self.state = 91 self.atom() pass @@ -2075,95 +974,77 @@ def shift_expr(self, _p: int = 0): localctx = ExprParser.SignedExpressionContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 96 + self.state = 92 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 2 or _la == 7): + if not(_la==2 or _la==7): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 97 + self.state = 93 self.match(ExprParser.T__2) - self.state = 98 + self.state = 94 self.expr(0) - self.state = 99 + self.state = 95 self.match(ExprParser.T__3) pass + self._ctx.stop = self._input.LT(-1) - self.state = 111 + self.state = 107 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 9, self._ctx) - while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: - if _alt == 1: + _alt = self._interp.adaptivePredict(self._input,8,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 109 + self.state = 105 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input, 8, self._ctx) + la_ = self._interp.adaptivePredict(self._input,7,self._ctx) if la_ == 1: - localctx = ExprParser.ShiftMuldivContext( - self, - ExprParser.Shift_exprContext( - self, _parentctx, _parentState - ), - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_shift_expr - ) - self.state = 103 + localctx = ExprParser.ShiftMuldivContext(self, ExprParser.Shift_exprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_shift_expr) + self.state = 99 if not self.precpred(self._ctx, 4): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 4)" - ) - self.state = 104 + raise FailedPredicateException(self, "self.precpred(self._ctx, 4)") + self.state = 100 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 5 or _la == 6): + if not(_la==5 or _la==6): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 105 + self.state = 101 self.right_expr(0) pass elif la_ == 2: - localctx = ExprParser.ShiftAddsubContext( - self, - ExprParser.Shift_exprContext( - self, _parentctx, _parentState - ), - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_shift_expr - ) - self.state = 106 + localctx = ExprParser.ShiftAddsubContext(self, ExprParser.Shift_exprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_shift_expr) + self.state = 102 if not self.precpred(self._ctx, 3): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 3)" - ) - self.state = 107 + raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") + self.state = 103 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 2 or _la == 7): + if not(_la==2 or _la==7): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 108 + self.state = 104 self.right_expr(0) pass - self.state = 113 + + self.state = 109 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 9, self._ctx) + _alt = self._interp.adaptivePredict(self._input,8,self._ctx) except RecognitionException as re: localctx.exception = re @@ -2173,84 +1054,90 @@ def shift_expr(self, _p: int = 0): self.unrollRecursionContexts(_parentctx) return localctx + class Right_exprContext(ParserRuleContext): - __slots__ = "parser" + __slots__ = 'parser' - def __init__( - self, parser, parent: ParserRuleContext = None, invokingState: int = -1 - ): + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): super().__init__(parent, invokingState) self.parser = parser + def getRuleIndex(self): return ExprParser.RULE_right_expr - def copyFrom(self, ctx: ParserRuleContext): + + def copyFrom(self, ctx:ParserRuleContext): super().copyFrom(ctx) + class RightExpressionContext(Right_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Right_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Right_exprContext super().__init__(parser) self.copyFrom(ctx) def expr(self): - return self.getTypedRuleContext(ExprParser.ExprContext, 0) + return self.getTypedRuleContext(ExprParser.ExprContext,0) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitRightExpression"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitRightExpression" ): return visitor.visitRightExpression(self) else: return visitor.visitChildren(self) + class RightMuldivContext(Right_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Right_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Right_exprContext super().__init__(parser) - self.op = None # Token + self.op = None # Token self.copyFrom(ctx) - def right_expr(self, i: int = None): + def right_expr(self, i:int=None): if i is None: return self.getTypedRuleContexts(ExprParser.Right_exprContext) else: - return self.getTypedRuleContext(ExprParser.Right_exprContext, i) + return self.getTypedRuleContext(ExprParser.Right_exprContext,i) - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitRightMuldiv"): + + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitRightMuldiv" ): return visitor.visitRightMuldiv(self) else: return visitor.visitChildren(self) + class RightAtomContext(Right_exprContext): - def __init__( - self, parser, ctx: ParserRuleContext - ): # actually a ExprParser.Right_exprContext + + def __init__(self, parser, ctx:ParserRuleContext): # actually a ExprParser.Right_exprContext super().__init__(parser) self.copyFrom(ctx) def atom(self): - return self.getTypedRuleContext(ExprParser.AtomContext, 0) + return self.getTypedRuleContext(ExprParser.AtomContext,0) + - def accept(self, visitor: ParseTreeVisitor): - if hasattr(visitor, "visitRightAtom"): + def accept(self, visitor:ParseTreeVisitor): + if hasattr( visitor, "visitRightAtom" ): return visitor.visitRightAtom(self) else: return visitor.visitChildren(self) - def right_expr(self, _p: int = 0): + + + def right_expr(self, _p:int=0): _parentctx = self._ctx _parentState = self.state localctx = ExprParser.Right_exprContext(self, self._ctx, _parentState) _prevctx = localctx - _startState = 10 - self.enterRecursionRule(localctx, 10, self.RULE_right_expr, _p) - self._la = 0 # Token type + _startState = 14 + self.enterRecursionRule(localctx, 14, self.RULE_right_expr, _p) + self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 120 + self.state = 116 self._errHandler.sync(self) token = self._input.LA(1) if token in [3]: @@ -2258,59 +1145,51 @@ def right_expr(self, _p: int = 0): self._ctx = localctx _prevctx = localctx - self.state = 115 + self.state = 111 self.match(ExprParser.T__2) - self.state = 116 + self.state = 112 self.expr(0) - self.state = 117 + self.state = 113 self.match(ExprParser.T__3) pass elif token in [12, 14]: localctx = ExprParser.RightAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 119 + self.state = 115 self.atom() pass else: raise NoViableAltException(self) self._ctx.stop = self._input.LT(-1) - self.state = 127 + self.state = 123 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 11, self._ctx) - while _alt != 2 and _alt != ATN.INVALID_ALT_NUMBER: - if _alt == 1: + _alt = self._interp.adaptivePredict(self._input,10,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - localctx = ExprParser.RightMuldivContext( - self, - ExprParser.Right_exprContext(self, _parentctx, _parentState), - ) - self.pushNewRecursionContext( - localctx, _startState, self.RULE_right_expr - ) - self.state = 122 + localctx = ExprParser.RightMuldivContext(self, ExprParser.Right_exprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_right_expr) + self.state = 118 if not self.precpred(self._ctx, 3): from antlr4.error.Errors import FailedPredicateException - - raise FailedPredicateException( - self, "self.precpred(self._ctx, 3)" - ) - self.state = 123 + raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") + self.state = 119 localctx.op = self._input.LT(1) _la = self._input.LA(1) - if not (_la == 5 or _la == 6): + if not(_la==5 or _la==6): localctx.op = self._errHandler.recoverInline(self) else: self._errHandler.reportMatch(self) self.consume() - self.state = 124 - self.right_expr(4) - self.state = 129 + self.state = 120 + self.right_expr(4) + self.state = 125 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input, 11, self._ctx) + _alt = self._interp.adaptivePredict(self._input,10,self._ctx) except RecognitionException as re: localctx.exception = re @@ -2320,35 +1199,47 @@ def right_expr(self, _p: int = 0): self.unrollRecursionContexts(_parentctx) return localctx - def sempred(self, localctx: RuleContext, ruleIndex: int, predIndex: int): + + + def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): if self._predicates == None: self._predicates = dict() self._predicates[1] = self.expr_sempred - self._predicates[4] = self.shift_expr_sempred - self._predicates[5] = self.right_expr_sempred + self._predicates[6] = self.shift_expr_sempred + self._predicates[7] = self.right_expr_sempred pred = self._predicates.get(ruleIndex, None) if pred is None: raise Exception("No predicate with index:" + str(ruleIndex)) else: return pred(localctx, predIndex) - def expr_sempred(self, localctx: ExprContext, predIndex: int): - if predIndex == 0: - return self.precpred(self._ctx, 8) + def expr_sempred(self, localctx:ExprContext, predIndex:int): + if predIndex == 0: + return self.precpred(self._ctx, 7) + + + if predIndex == 1: + return self.precpred(self._ctx, 6) + + + if predIndex == 2: + return self.precpred(self._ctx, 5) + + + def shift_expr_sempred(self, localctx:Shift_exprContext, predIndex:int): + if predIndex == 3: + return self.precpred(self._ctx, 4) + + + if predIndex == 4: + return self.precpred(self._ctx, 3) + - if predIndex == 1: - return self.precpred(self._ctx, 7) + def right_expr_sempred(self, localctx:Right_exprContext, predIndex:int): + if predIndex == 5: + return self.precpred(self._ctx, 3) + - if predIndex == 2: - return self.precpred(self._ctx, 6) - def shift_expr_sempred(self, localctx: Shift_exprContext, predIndex: int): - if predIndex == 3: - return self.precpred(self._ctx, 4) - if predIndex == 4: - return self.precpred(self._ctx, 3) - def right_expr_sempred(self, localctx: Right_exprContext, predIndex: int): - if predIndex == 5: - return self.precpred(self._ctx, 3) diff --git a/src/andromede/expression/parsing/antlr/ExprVisitor.py b/src/andromede/expression/parsing/antlr/ExprVisitor.py index 0e924349..ee405aa2 100644 --- a/src/andromede/expression/parsing/antlr/ExprVisitor.py +++ b/src/andromede/expression/parsing/antlr/ExprVisitor.py @@ -1,6 +1,5 @@ -# Generated from Expr.g4 by ANTLR 4.13.1 +# Generated from Expr.g4 by ANTLR 4.13.2 from antlr4 import * - if "." in __name__: from .ExprParser import ExprParser else: @@ -8,99 +7,127 @@ # This class defines a complete generic visitor for a parse tree produced by ExprParser. - class ExprVisitor(ParseTreeVisitor): + # Visit a parse tree produced by ExprParser#fullexpr. - def visitFullexpr(self, ctx: ExprParser.FullexprContext): + def visitFullexpr(self, ctx:ExprParser.FullexprContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by ExprParser#timeSum. + def visitTimeSum(self, ctx:ExprParser.TimeSumContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#negation. - def visitNegation(self, ctx: ExprParser.NegationContext): + def visitNegation(self, ctx:ExprParser.NegationContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#unsignedAtom. - def visitUnsignedAtom(self, ctx: ExprParser.UnsignedAtomContext): + def visitUnsignedAtom(self, ctx:ExprParser.UnsignedAtomContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#expression. - def visitExpression(self, ctx: ExprParser.ExpressionContext): + def visitExpression(self, ctx:ExprParser.ExpressionContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#timeIndex. - def visitTimeIndex(self, ctx: ExprParser.TimeIndexContext): + def visitTimeIndex(self, ctx:ExprParser.TimeIndexContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#comparison. - def visitComparison(self, ctx: ExprParser.ComparisonContext): + def visitComparison(self, ctx:ExprParser.ComparisonContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#timeShift. - def visitTimeShift(self, ctx: ExprParser.TimeShiftContext): + def visitTimeShift(self, ctx:ExprParser.TimeShiftContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#function. - def visitFunction(self, ctx: ExprParser.FunctionContext): + def visitFunction(self, ctx:ExprParser.FunctionContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#addsub. - def visitAddsub(self, ctx: ExprParser.AddsubContext): + def visitAddsub(self, ctx:ExprParser.AddsubContext): return self.visitChildren(ctx) - # Visit a parse tree produced by ExprParser#timeShiftRange. - def visitTimeShiftRange(self, ctx: ExprParser.TimeShiftRangeContext): - return self.visitChildren(ctx) # Visit a parse tree produced by ExprParser#portField. - def visitPortField(self, ctx: ExprParser.PortFieldContext): + def visitPortField(self, ctx:ExprParser.PortFieldContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#muldiv. - def visitMuldiv(self, ctx: ExprParser.MuldivContext): + def visitMuldiv(self, ctx:ExprParser.MuldivContext): return self.visitChildren(ctx) - # Visit a parse tree produced by ExprParser#timeRange. - def visitTimeRange(self, ctx: ExprParser.TimeRangeContext): - return self.visitChildren(ctx) # Visit a parse tree produced by ExprParser#number. - def visitNumber(self, ctx: ExprParser.NumberContext): + def visitNumber(self, ctx:ExprParser.NumberContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#identifier. - def visitIdentifier(self, ctx: ExprParser.IdentifierContext): + def visitIdentifier(self, ctx:ExprParser.IdentifierContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by ExprParser#timeShiftRange. + def visitTimeShiftRange(self, ctx:ExprParser.TimeShiftRangeContext): + return self.visitChildren(ctx) + + + # Visit a parse tree produced by ExprParser#timeRange. + def visitTimeRange(self, ctx:ExprParser.TimeRangeContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#shift. - def visitShift(self, ctx: ExprParser.ShiftContext): + def visitShift(self, ctx:ExprParser.ShiftContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#signedAtom. - def visitSignedAtom(self, ctx: ExprParser.SignedAtomContext): + def visitSignedAtom(self, ctx:ExprParser.SignedAtomContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#signedExpression. - def visitSignedExpression(self, ctx: ExprParser.SignedExpressionContext): + def visitSignedExpression(self, ctx:ExprParser.SignedExpressionContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#shiftMuldiv. - def visitShiftMuldiv(self, ctx: ExprParser.ShiftMuldivContext): + def visitShiftMuldiv(self, ctx:ExprParser.ShiftMuldivContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#shiftAddsub. - def visitShiftAddsub(self, ctx: ExprParser.ShiftAddsubContext): + def visitShiftAddsub(self, ctx:ExprParser.ShiftAddsubContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#rightExpression. - def visitRightExpression(self, ctx: ExprParser.RightExpressionContext): + def visitRightExpression(self, ctx:ExprParser.RightExpressionContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#rightMuldiv. - def visitRightMuldiv(self, ctx: ExprParser.RightMuldivContext): + def visitRightMuldiv(self, ctx:ExprParser.RightMuldivContext): return self.visitChildren(ctx) + # Visit a parse tree produced by ExprParser#rightAtom. - def visitRightAtom(self, ctx: ExprParser.RightAtomContext): + def visitRightAtom(self, ctx:ExprParser.RightAtomContext): return self.visitChildren(ctx) -del ExprParser + +del ExprParser \ No newline at end of file diff --git a/src/andromede/expression/parsing/parse_expression.py b/src/andromede/expression/parsing/parse_expression.py index e535b810..961c0e51 100644 --- a/src/andromede/expression/parsing/parse_expression.py +++ b/src/andromede/expression/parsing/parse_expression.py @@ -134,31 +134,26 @@ def visitComparison(self, ctx: ExprParser.ComparisonContext) -> LinearExpression # Visit a parse tree produced by ExprParser#timeShift. def visitTimeIndex(self, ctx: ExprParser.TimeIndexContext) -> LinearExpression: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore - time_shifts = [e.accept(self) for e in ctx.expr()] # type: ignore - return shifted_expr.eval(time_shifts) - - # Visit a parse tree produced by ExprParser#rangeTimeShift. - def visitTimeRange(self, ctx: ExprParser.TimeRangeContext) -> LinearExpression: - shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore - expressions = [e.accept(self) for e in ctx.expr()] # type: ignore - # TODO: Is there a visitSum somewhere that is not needed ? Are the correct symbol parsed (sum(...) ?) ? - return shifted_expr.sum(eval=ExpressionRange(expressions[0], expressions[1])) + time_shift = ctx.expr().accept(self) # type: ignore + return shifted_expr.eval(time_shift) def visitTimeShift(self, ctx: ExprParser.TimeShiftContext) -> LinearExpression: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore - time_shifts = [s.accept(self) for s in ctx.shift()] # type: ignore - # specifics for x[t] ... - if len(time_shifts) == 1 and expressions_equal(time_shifts[0], literal(0)): - return shifted_expr - return shifted_expr.sum(shift=time_shifts) - - def visitTimeShiftRange( - self, ctx: ExprParser.TimeShiftRangeContext - ) -> LinearExpression: - shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore - shift1 = ctx.shift1.accept(self) # type: ignore - shift2 = ctx.shift2.accept(self) # type: ignore - return shifted_expr.sum(shift=ExpressionRange(shift1, shift2)) + time_shift = ctx.shift().accept(self) # type: ignore + return shifted_expr.sum(shift=time_shift) + + # Visit a parse tree produced by ExprParser#timeSum. + def visitTimeSum(self, ctx:ExprParser.TimeSumContext): + return self.visitChildren(ctx) + + # Visit a parse tree produced by ExprParser#timeShiftRange. + def visitTimeShiftRange(self, ctx:ExprParser.TimeShiftRangeContext): + return ExpressionRange(ctx.shift, ctx.shift2) + + + # Visit a parse tree produced by ExprParser#timeRange. + def visitTimeRange(self, ctx:ExprParser.TimeRangeContext): + return self.visitChildren(ctx) # Visit a parse tree produced by ExprParser#function. def visitFunction(self, ctx: ExprParser.FunctionContext) -> LinearExpression: From 02f8a398979ff89512c389185347e0154e1b5500 Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Thu, 5 Sep 2024 10:34:02 +0200 Subject: [PATCH 51/51] WIP for yaml parsing --- grammar/Expr.g4 | 2 +- .../expression/parsing/antlr/Expr.interp | 2 +- .../expression/parsing/antlr/ExprParser.py | 251 +++++++++--------- .../expression/parsing/parse_expression.py | 28 +- .../parsing/test_expression_parsing.py | 7 +- 5 files changed, 159 insertions(+), 131 deletions(-) diff --git a/grammar/Expr.g4 b/grammar/Expr.g4 index 56f0babe..c49a5a02 100644 --- a/grammar/Expr.g4 +++ b/grammar/Expr.g4 @@ -28,7 +28,7 @@ expr | IDENTIFIER '(' expr ')' # function | IDENTIFIER '[' shift ']' # timeShift | IDENTIFIER '[' expr ']' # timeIndex - | TIME_SUM '(' (expr | shift | timeShiftRange | timeRange) ',' IDENTIFIER ')' #timeSum + | TIME_SUM '(' ((expr | shift | timeShiftRange | timeRange) ',')? IDENTIFIER ')' #timeSum ; atom diff --git a/src/andromede/expression/parsing/antlr/Expr.interp b/src/andromede/expression/parsing/antlr/Expr.interp index 18ea1ac5..436305a4 100644 --- a/src/andromede/expression/parsing/antlr/Expr.interp +++ b/src/andromede/expression/parsing/antlr/Expr.interp @@ -50,4 +50,4 @@ right_expr atn: -[4, 1, 17, 127, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 52, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 58, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 69, 8, 1, 10, 1, 12, 1, 72, 9, 1, 1, 2, 1, 2, 3, 2, 76, 8, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 5, 1, 5, 3, 5, 88, 8, 5, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 3, 6, 98, 8, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 5, 6, 106, 8, 6, 10, 6, 12, 6, 109, 9, 6, 1, 7, 1, 7, 1, 7, 1, 7, 1, 7, 1, 7, 3, 7, 117, 8, 7, 1, 7, 1, 7, 1, 7, 5, 7, 122, 8, 7, 10, 7, 12, 7, 125, 9, 7, 1, 7, 0, 3, 2, 12, 14, 8, 0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 1, 0, 5, 6, 2, 0, 2, 2, 7, 7, 138, 0, 16, 1, 0, 0, 0, 2, 57, 1, 0, 0, 0, 4, 75, 1, 0, 0, 0, 6, 77, 1, 0, 0, 0, 8, 81, 1, 0, 0, 0, 10, 85, 1, 0, 0, 0, 12, 97, 1, 0, 0, 0, 14, 116, 1, 0, 0, 0, 16, 17, 3, 2, 1, 0, 17, 18, 5, 0, 0, 1, 18, 1, 1, 0, 0, 0, 19, 20, 6, 1, -1, 0, 20, 58, 3, 4, 2, 0, 21, 22, 5, 14, 0, 0, 22, 23, 5, 1, 0, 0, 23, 58, 5, 14, 0, 0, 24, 25, 5, 2, 0, 0, 25, 58, 3, 2, 1, 9, 26, 27, 5, 3, 0, 0, 27, 28, 3, 2, 1, 0, 28, 29, 5, 4, 0, 0, 29, 58, 1, 0, 0, 0, 30, 31, 5, 14, 0, 0, 31, 32, 5, 3, 0, 0, 32, 33, 3, 2, 1, 0, 33, 34, 5, 4, 0, 0, 34, 58, 1, 0, 0, 0, 35, 36, 5, 14, 0, 0, 36, 37, 5, 8, 0, 0, 37, 38, 3, 10, 5, 0, 38, 39, 5, 9, 0, 0, 39, 58, 1, 0, 0, 0, 40, 41, 5, 14, 0, 0, 41, 42, 5, 8, 0, 0, 42, 43, 3, 2, 1, 0, 43, 44, 5, 9, 0, 0, 44, 58, 1, 0, 0, 0, 45, 46, 5, 16, 0, 0, 46, 51, 5, 3, 0, 0, 47, 52, 3, 2, 1, 0, 48, 52, 3, 10, 5, 0, 49, 52, 3, 6, 3, 0, 50, 52, 3, 8, 4, 0, 51, 47, 1, 0, 0, 0, 51, 48, 1, 0, 0, 0, 51, 49, 1, 0, 0, 0, 51, 50, 1, 0, 0, 0, 52, 53, 1, 0, 0, 0, 53, 54, 5, 10, 0, 0, 54, 55, 5, 14, 0, 0, 55, 56, 5, 4, 0, 0, 56, 58, 1, 0, 0, 0, 57, 19, 1, 0, 0, 0, 57, 21, 1, 0, 0, 0, 57, 24, 1, 0, 0, 0, 57, 26, 1, 0, 0, 0, 57, 30, 1, 0, 0, 0, 57, 35, 1, 0, 0, 0, 57, 40, 1, 0, 0, 0, 57, 45, 1, 0, 0, 0, 58, 70, 1, 0, 0, 0, 59, 60, 10, 7, 0, 0, 60, 61, 7, 0, 0, 0, 61, 69, 3, 2, 1, 8, 62, 63, 10, 6, 0, 0, 63, 64, 7, 1, 0, 0, 64, 69, 3, 2, 1, 7, 65, 66, 10, 5, 0, 0, 66, 67, 5, 15, 0, 0, 67, 69, 3, 2, 1, 6, 68, 59, 1, 0, 0, 0, 68, 62, 1, 0, 0, 0, 68, 65, 1, 0, 0, 0, 69, 72, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 70, 71, 1, 0, 0, 0, 71, 3, 1, 0, 0, 0, 72, 70, 1, 0, 0, 0, 73, 76, 5, 12, 0, 0, 74, 76, 5, 14, 0, 0, 75, 73, 1, 0, 0, 0, 75, 74, 1, 0, 0, 0, 76, 5, 1, 0, 0, 0, 77, 78, 3, 10, 5, 0, 78, 79, 5, 11, 0, 0, 79, 80, 3, 10, 5, 0, 80, 7, 1, 0, 0, 0, 81, 82, 3, 2, 1, 0, 82, 83, 5, 11, 0, 0, 83, 84, 3, 2, 1, 0, 84, 9, 1, 0, 0, 0, 85, 87, 5, 13, 0, 0, 86, 88, 3, 12, 6, 0, 87, 86, 1, 0, 0, 0, 87, 88, 1, 0, 0, 0, 88, 11, 1, 0, 0, 0, 89, 90, 6, 6, -1, 0, 90, 91, 7, 1, 0, 0, 91, 98, 3, 4, 2, 0, 92, 93, 7, 1, 0, 0, 93, 94, 5, 3, 0, 0, 94, 95, 3, 2, 1, 0, 95, 96, 5, 4, 0, 0, 96, 98, 1, 0, 0, 0, 97, 89, 1, 0, 0, 0, 97, 92, 1, 0, 0, 0, 98, 107, 1, 0, 0, 0, 99, 100, 10, 4, 0, 0, 100, 101, 7, 0, 0, 0, 101, 106, 3, 14, 7, 0, 102, 103, 10, 3, 0, 0, 103, 104, 7, 1, 0, 0, 104, 106, 3, 14, 7, 0, 105, 99, 1, 0, 0, 0, 105, 102, 1, 0, 0, 0, 106, 109, 1, 0, 0, 0, 107, 105, 1, 0, 0, 0, 107, 108, 1, 0, 0, 0, 108, 13, 1, 0, 0, 0, 109, 107, 1, 0, 0, 0, 110, 111, 6, 7, -1, 0, 111, 112, 5, 3, 0, 0, 112, 113, 3, 2, 1, 0, 113, 114, 5, 4, 0, 0, 114, 117, 1, 0, 0, 0, 115, 117, 3, 4, 2, 0, 116, 110, 1, 0, 0, 0, 116, 115, 1, 0, 0, 0, 117, 123, 1, 0, 0, 0, 118, 119, 10, 3, 0, 0, 119, 120, 7, 0, 0, 0, 120, 122, 3, 14, 7, 4, 121, 118, 1, 0, 0, 0, 122, 125, 1, 0, 0, 0, 123, 121, 1, 0, 0, 0, 123, 124, 1, 0, 0, 0, 124, 15, 1, 0, 0, 0, 125, 123, 1, 0, 0, 0, 11, 51, 57, 68, 70, 75, 87, 97, 105, 107, 116, 123] \ No newline at end of file +[4, 1, 17, 129, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 2, 6, 7, 6, 2, 7, 7, 7, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 52, 8, 1, 1, 1, 1, 1, 3, 1, 56, 8, 1, 1, 1, 1, 1, 3, 1, 60, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 71, 8, 1, 10, 1, 12, 1, 74, 9, 1, 1, 2, 1, 2, 3, 2, 78, 8, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 5, 1, 5, 3, 5, 90, 8, 5, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 3, 6, 100, 8, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 1, 6, 5, 6, 108, 8, 6, 10, 6, 12, 6, 111, 9, 6, 1, 7, 1, 7, 1, 7, 1, 7, 1, 7, 1, 7, 3, 7, 119, 8, 7, 1, 7, 1, 7, 1, 7, 5, 7, 124, 8, 7, 10, 7, 12, 7, 127, 9, 7, 1, 7, 0, 3, 2, 12, 14, 8, 0, 2, 4, 6, 8, 10, 12, 14, 0, 2, 1, 0, 5, 6, 2, 0, 2, 2, 7, 7, 141, 0, 16, 1, 0, 0, 0, 2, 59, 1, 0, 0, 0, 4, 77, 1, 0, 0, 0, 6, 79, 1, 0, 0, 0, 8, 83, 1, 0, 0, 0, 10, 87, 1, 0, 0, 0, 12, 99, 1, 0, 0, 0, 14, 118, 1, 0, 0, 0, 16, 17, 3, 2, 1, 0, 17, 18, 5, 0, 0, 1, 18, 1, 1, 0, 0, 0, 19, 20, 6, 1, -1, 0, 20, 60, 3, 4, 2, 0, 21, 22, 5, 14, 0, 0, 22, 23, 5, 1, 0, 0, 23, 60, 5, 14, 0, 0, 24, 25, 5, 2, 0, 0, 25, 60, 3, 2, 1, 9, 26, 27, 5, 3, 0, 0, 27, 28, 3, 2, 1, 0, 28, 29, 5, 4, 0, 0, 29, 60, 1, 0, 0, 0, 30, 31, 5, 14, 0, 0, 31, 32, 5, 3, 0, 0, 32, 33, 3, 2, 1, 0, 33, 34, 5, 4, 0, 0, 34, 60, 1, 0, 0, 0, 35, 36, 5, 14, 0, 0, 36, 37, 5, 8, 0, 0, 37, 38, 3, 10, 5, 0, 38, 39, 5, 9, 0, 0, 39, 60, 1, 0, 0, 0, 40, 41, 5, 14, 0, 0, 41, 42, 5, 8, 0, 0, 42, 43, 3, 2, 1, 0, 43, 44, 5, 9, 0, 0, 44, 60, 1, 0, 0, 0, 45, 46, 5, 16, 0, 0, 46, 55, 5, 3, 0, 0, 47, 52, 3, 2, 1, 0, 48, 52, 3, 10, 5, 0, 49, 52, 3, 6, 3, 0, 50, 52, 3, 8, 4, 0, 51, 47, 1, 0, 0, 0, 51, 48, 1, 0, 0, 0, 51, 49, 1, 0, 0, 0, 51, 50, 1, 0, 0, 0, 52, 53, 1, 0, 0, 0, 53, 54, 5, 10, 0, 0, 54, 56, 1, 0, 0, 0, 55, 51, 1, 0, 0, 0, 55, 56, 1, 0, 0, 0, 56, 57, 1, 0, 0, 0, 57, 58, 5, 14, 0, 0, 58, 60, 5, 4, 0, 0, 59, 19, 1, 0, 0, 0, 59, 21, 1, 0, 0, 0, 59, 24, 1, 0, 0, 0, 59, 26, 1, 0, 0, 0, 59, 30, 1, 0, 0, 0, 59, 35, 1, 0, 0, 0, 59, 40, 1, 0, 0, 0, 59, 45, 1, 0, 0, 0, 60, 72, 1, 0, 0, 0, 61, 62, 10, 7, 0, 0, 62, 63, 7, 0, 0, 0, 63, 71, 3, 2, 1, 8, 64, 65, 10, 6, 0, 0, 65, 66, 7, 1, 0, 0, 66, 71, 3, 2, 1, 7, 67, 68, 10, 5, 0, 0, 68, 69, 5, 15, 0, 0, 69, 71, 3, 2, 1, 6, 70, 61, 1, 0, 0, 0, 70, 64, 1, 0, 0, 0, 70, 67, 1, 0, 0, 0, 71, 74, 1, 0, 0, 0, 72, 70, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 3, 1, 0, 0, 0, 74, 72, 1, 0, 0, 0, 75, 78, 5, 12, 0, 0, 76, 78, 5, 14, 0, 0, 77, 75, 1, 0, 0, 0, 77, 76, 1, 0, 0, 0, 78, 5, 1, 0, 0, 0, 79, 80, 3, 10, 5, 0, 80, 81, 5, 11, 0, 0, 81, 82, 3, 10, 5, 0, 82, 7, 1, 0, 0, 0, 83, 84, 3, 2, 1, 0, 84, 85, 5, 11, 0, 0, 85, 86, 3, 2, 1, 0, 86, 9, 1, 0, 0, 0, 87, 89, 5, 13, 0, 0, 88, 90, 3, 12, 6, 0, 89, 88, 1, 0, 0, 0, 89, 90, 1, 0, 0, 0, 90, 11, 1, 0, 0, 0, 91, 92, 6, 6, -1, 0, 92, 93, 7, 1, 0, 0, 93, 100, 3, 4, 2, 0, 94, 95, 7, 1, 0, 0, 95, 96, 5, 3, 0, 0, 96, 97, 3, 2, 1, 0, 97, 98, 5, 4, 0, 0, 98, 100, 1, 0, 0, 0, 99, 91, 1, 0, 0, 0, 99, 94, 1, 0, 0, 0, 100, 109, 1, 0, 0, 0, 101, 102, 10, 4, 0, 0, 102, 103, 7, 0, 0, 0, 103, 108, 3, 14, 7, 0, 104, 105, 10, 3, 0, 0, 105, 106, 7, 1, 0, 0, 106, 108, 3, 14, 7, 0, 107, 101, 1, 0, 0, 0, 107, 104, 1, 0, 0, 0, 108, 111, 1, 0, 0, 0, 109, 107, 1, 0, 0, 0, 109, 110, 1, 0, 0, 0, 110, 13, 1, 0, 0, 0, 111, 109, 1, 0, 0, 0, 112, 113, 6, 7, -1, 0, 113, 114, 5, 3, 0, 0, 114, 115, 3, 2, 1, 0, 115, 116, 5, 4, 0, 0, 116, 119, 1, 0, 0, 0, 117, 119, 3, 4, 2, 0, 118, 112, 1, 0, 0, 0, 118, 117, 1, 0, 0, 0, 119, 125, 1, 0, 0, 0, 120, 121, 10, 3, 0, 0, 121, 122, 7, 0, 0, 0, 122, 124, 3, 14, 7, 4, 123, 120, 1, 0, 0, 0, 124, 127, 1, 0, 0, 0, 125, 123, 1, 0, 0, 0, 125, 126, 1, 0, 0, 0, 126, 15, 1, 0, 0, 0, 127, 125, 1, 0, 0, 0, 12, 51, 55, 59, 70, 72, 77, 89, 99, 107, 109, 118, 125] \ No newline at end of file diff --git a/src/andromede/expression/parsing/antlr/ExprParser.py b/src/andromede/expression/parsing/antlr/ExprParser.py index 5ce55309..8fef9021 100644 --- a/src/andromede/expression/parsing/antlr/ExprParser.py +++ b/src/andromede/expression/parsing/antlr/ExprParser.py @@ -10,48 +10,49 @@ def serializedATN(): return [ - 4,1,17,127,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, + 4,1,17,129,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, 6,2,7,7,7,1,0,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, - 1,1,1,1,1,1,1,1,1,1,1,3,1,52,8,1,1,1,1,1,1,1,1,1,3,1,58,8,1,1,1, - 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,5,1,69,8,1,10,1,12,1,72,9,1,1,2, - 1,2,3,2,76,8,2,1,3,1,3,1,3,1,3,1,4,1,4,1,4,1,4,1,5,1,5,3,5,88,8, - 5,1,6,1,6,1,6,1,6,1,6,1,6,1,6,1,6,3,6,98,8,6,1,6,1,6,1,6,1,6,1,6, - 1,6,5,6,106,8,6,10,6,12,6,109,9,6,1,7,1,7,1,7,1,7,1,7,1,7,3,7,117, - 8,7,1,7,1,7,1,7,5,7,122,8,7,10,7,12,7,125,9,7,1,7,0,3,2,12,14,8, - 0,2,4,6,8,10,12,14,0,2,1,0,5,6,2,0,2,2,7,7,138,0,16,1,0,0,0,2,57, - 1,0,0,0,4,75,1,0,0,0,6,77,1,0,0,0,8,81,1,0,0,0,10,85,1,0,0,0,12, - 97,1,0,0,0,14,116,1,0,0,0,16,17,3,2,1,0,17,18,5,0,0,1,18,1,1,0,0, - 0,19,20,6,1,-1,0,20,58,3,4,2,0,21,22,5,14,0,0,22,23,5,1,0,0,23,58, - 5,14,0,0,24,25,5,2,0,0,25,58,3,2,1,9,26,27,5,3,0,0,27,28,3,2,1,0, - 28,29,5,4,0,0,29,58,1,0,0,0,30,31,5,14,0,0,31,32,5,3,0,0,32,33,3, - 2,1,0,33,34,5,4,0,0,34,58,1,0,0,0,35,36,5,14,0,0,36,37,5,8,0,0,37, - 38,3,10,5,0,38,39,5,9,0,0,39,58,1,0,0,0,40,41,5,14,0,0,41,42,5,8, - 0,0,42,43,3,2,1,0,43,44,5,9,0,0,44,58,1,0,0,0,45,46,5,16,0,0,46, - 51,5,3,0,0,47,52,3,2,1,0,48,52,3,10,5,0,49,52,3,6,3,0,50,52,3,8, - 4,0,51,47,1,0,0,0,51,48,1,0,0,0,51,49,1,0,0,0,51,50,1,0,0,0,52,53, - 1,0,0,0,53,54,5,10,0,0,54,55,5,14,0,0,55,56,5,4,0,0,56,58,1,0,0, - 0,57,19,1,0,0,0,57,21,1,0,0,0,57,24,1,0,0,0,57,26,1,0,0,0,57,30, - 1,0,0,0,57,35,1,0,0,0,57,40,1,0,0,0,57,45,1,0,0,0,58,70,1,0,0,0, - 59,60,10,7,0,0,60,61,7,0,0,0,61,69,3,2,1,8,62,63,10,6,0,0,63,64, - 7,1,0,0,64,69,3,2,1,7,65,66,10,5,0,0,66,67,5,15,0,0,67,69,3,2,1, - 6,68,59,1,0,0,0,68,62,1,0,0,0,68,65,1,0,0,0,69,72,1,0,0,0,70,68, - 1,0,0,0,70,71,1,0,0,0,71,3,1,0,0,0,72,70,1,0,0,0,73,76,5,12,0,0, - 74,76,5,14,0,0,75,73,1,0,0,0,75,74,1,0,0,0,76,5,1,0,0,0,77,78,3, - 10,5,0,78,79,5,11,0,0,79,80,3,10,5,0,80,7,1,0,0,0,81,82,3,2,1,0, - 82,83,5,11,0,0,83,84,3,2,1,0,84,9,1,0,0,0,85,87,5,13,0,0,86,88,3, - 12,6,0,87,86,1,0,0,0,87,88,1,0,0,0,88,11,1,0,0,0,89,90,6,6,-1,0, - 90,91,7,1,0,0,91,98,3,4,2,0,92,93,7,1,0,0,93,94,5,3,0,0,94,95,3, - 2,1,0,95,96,5,4,0,0,96,98,1,0,0,0,97,89,1,0,0,0,97,92,1,0,0,0,98, - 107,1,0,0,0,99,100,10,4,0,0,100,101,7,0,0,0,101,106,3,14,7,0,102, - 103,10,3,0,0,103,104,7,1,0,0,104,106,3,14,7,0,105,99,1,0,0,0,105, - 102,1,0,0,0,106,109,1,0,0,0,107,105,1,0,0,0,107,108,1,0,0,0,108, - 13,1,0,0,0,109,107,1,0,0,0,110,111,6,7,-1,0,111,112,5,3,0,0,112, - 113,3,2,1,0,113,114,5,4,0,0,114,117,1,0,0,0,115,117,3,4,2,0,116, - 110,1,0,0,0,116,115,1,0,0,0,117,123,1,0,0,0,118,119,10,3,0,0,119, - 120,7,0,0,0,120,122,3,14,7,4,121,118,1,0,0,0,122,125,1,0,0,0,123, - 121,1,0,0,0,123,124,1,0,0,0,124,15,1,0,0,0,125,123,1,0,0,0,11,51, - 57,68,70,75,87,97,105,107,116,123 + 1,1,1,1,1,1,1,1,1,1,1,3,1,52,8,1,1,1,1,1,3,1,56,8,1,1,1,1,1,3,1, + 60,8,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,5,1,71,8,1,10,1,12,1, + 74,9,1,1,2,1,2,3,2,78,8,2,1,3,1,3,1,3,1,3,1,4,1,4,1,4,1,4,1,5,1, + 5,3,5,90,8,5,1,6,1,6,1,6,1,6,1,6,1,6,1,6,1,6,3,6,100,8,6,1,6,1,6, + 1,6,1,6,1,6,1,6,5,6,108,8,6,10,6,12,6,111,9,6,1,7,1,7,1,7,1,7,1, + 7,1,7,3,7,119,8,7,1,7,1,7,1,7,5,7,124,8,7,10,7,12,7,127,9,7,1,7, + 0,3,2,12,14,8,0,2,4,6,8,10,12,14,0,2,1,0,5,6,2,0,2,2,7,7,141,0,16, + 1,0,0,0,2,59,1,0,0,0,4,77,1,0,0,0,6,79,1,0,0,0,8,83,1,0,0,0,10,87, + 1,0,0,0,12,99,1,0,0,0,14,118,1,0,0,0,16,17,3,2,1,0,17,18,5,0,0,1, + 18,1,1,0,0,0,19,20,6,1,-1,0,20,60,3,4,2,0,21,22,5,14,0,0,22,23,5, + 1,0,0,23,60,5,14,0,0,24,25,5,2,0,0,25,60,3,2,1,9,26,27,5,3,0,0,27, + 28,3,2,1,0,28,29,5,4,0,0,29,60,1,0,0,0,30,31,5,14,0,0,31,32,5,3, + 0,0,32,33,3,2,1,0,33,34,5,4,0,0,34,60,1,0,0,0,35,36,5,14,0,0,36, + 37,5,8,0,0,37,38,3,10,5,0,38,39,5,9,0,0,39,60,1,0,0,0,40,41,5,14, + 0,0,41,42,5,8,0,0,42,43,3,2,1,0,43,44,5,9,0,0,44,60,1,0,0,0,45,46, + 5,16,0,0,46,55,5,3,0,0,47,52,3,2,1,0,48,52,3,10,5,0,49,52,3,6,3, + 0,50,52,3,8,4,0,51,47,1,0,0,0,51,48,1,0,0,0,51,49,1,0,0,0,51,50, + 1,0,0,0,52,53,1,0,0,0,53,54,5,10,0,0,54,56,1,0,0,0,55,51,1,0,0,0, + 55,56,1,0,0,0,56,57,1,0,0,0,57,58,5,14,0,0,58,60,5,4,0,0,59,19,1, + 0,0,0,59,21,1,0,0,0,59,24,1,0,0,0,59,26,1,0,0,0,59,30,1,0,0,0,59, + 35,1,0,0,0,59,40,1,0,0,0,59,45,1,0,0,0,60,72,1,0,0,0,61,62,10,7, + 0,0,62,63,7,0,0,0,63,71,3,2,1,8,64,65,10,6,0,0,65,66,7,1,0,0,66, + 71,3,2,1,7,67,68,10,5,0,0,68,69,5,15,0,0,69,71,3,2,1,6,70,61,1,0, + 0,0,70,64,1,0,0,0,70,67,1,0,0,0,71,74,1,0,0,0,72,70,1,0,0,0,72,73, + 1,0,0,0,73,3,1,0,0,0,74,72,1,0,0,0,75,78,5,12,0,0,76,78,5,14,0,0, + 77,75,1,0,0,0,77,76,1,0,0,0,78,5,1,0,0,0,79,80,3,10,5,0,80,81,5, + 11,0,0,81,82,3,10,5,0,82,7,1,0,0,0,83,84,3,2,1,0,84,85,5,11,0,0, + 85,86,3,2,1,0,86,9,1,0,0,0,87,89,5,13,0,0,88,90,3,12,6,0,89,88,1, + 0,0,0,89,90,1,0,0,0,90,11,1,0,0,0,91,92,6,6,-1,0,92,93,7,1,0,0,93, + 100,3,4,2,0,94,95,7,1,0,0,95,96,5,3,0,0,96,97,3,2,1,0,97,98,5,4, + 0,0,98,100,1,0,0,0,99,91,1,0,0,0,99,94,1,0,0,0,100,109,1,0,0,0,101, + 102,10,4,0,0,102,103,7,0,0,0,103,108,3,14,7,0,104,105,10,3,0,0,105, + 106,7,1,0,0,106,108,3,14,7,0,107,101,1,0,0,0,107,104,1,0,0,0,108, + 111,1,0,0,0,109,107,1,0,0,0,109,110,1,0,0,0,110,13,1,0,0,0,111,109, + 1,0,0,0,112,113,6,7,-1,0,113,114,5,3,0,0,114,115,3,2,1,0,115,116, + 5,4,0,0,116,119,1,0,0,0,117,119,3,4,2,0,118,112,1,0,0,0,118,117, + 1,0,0,0,119,125,1,0,0,0,120,121,10,3,0,0,121,122,7,0,0,0,122,124, + 3,14,7,4,123,120,1,0,0,0,124,127,1,0,0,0,125,123,1,0,0,0,125,126, + 1,0,0,0,126,15,1,0,0,0,127,125,1,0,0,0,12,51,55,59,70,72,77,89,99, + 107,109,118,125 ] class ExprParser ( Parser ): @@ -407,9 +408,9 @@ def expr(self, _p:int=0): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 57 + self.state = 59 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,1,self._ctx) + la_ = self._interp.adaptivePredict(self._input,2,self._ctx) if la_ == 1: localctx = ExprParser.UnsignedAtomContext(self, localctx) self._ctx = localctx @@ -503,59 +504,65 @@ def expr(self, _p:int=0): self.match(ExprParser.TIME_SUM) self.state = 46 self.match(ExprParser.T__2) - self.state = 51 + self.state = 55 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,0,self._ctx) + la_ = self._interp.adaptivePredict(self._input,1,self._ctx) if la_ == 1: - self.state = 47 - self.expr(0) - pass + self.state = 51 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,0,self._ctx) + if la_ == 1: + self.state = 47 + self.expr(0) + pass + + elif la_ == 2: + self.state = 48 + self.shift() + pass + + elif la_ == 3: + self.state = 49 + self.timeShiftRange() + pass - elif la_ == 2: - self.state = 48 - self.shift() - pass + elif la_ == 4: + self.state = 50 + self.timeRange() + pass - elif la_ == 3: - self.state = 49 - self.timeShiftRange() - pass - elif la_ == 4: - self.state = 50 - self.timeRange() - pass + self.state = 53 + self.match(ExprParser.T__9) - self.state = 53 - self.match(ExprParser.T__9) - self.state = 54 + self.state = 57 self.match(ExprParser.IDENTIFIER) - self.state = 55 + self.state = 58 self.match(ExprParser.T__3) pass self._ctx.stop = self._input.LT(-1) - self.state = 70 + self.state = 72 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,3,self._ctx) + _alt = self._interp.adaptivePredict(self._input,4,self._ctx) while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: if _alt==1: if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 68 + self.state = 70 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,2,self._ctx) + la_ = self._interp.adaptivePredict(self._input,3,self._ctx) if la_ == 1: localctx = ExprParser.MuldivContext(self, ExprParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 59 + self.state = 61 if not self.precpred(self._ctx, 7): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 7)") - self.state = 60 + self.state = 62 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==5 or _la==6): @@ -563,18 +570,18 @@ def expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 61 + self.state = 63 self.expr(8) pass elif la_ == 2: localctx = ExprParser.AddsubContext(self, ExprParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 62 + self.state = 64 if not self.precpred(self._ctx, 6): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 6)") - self.state = 63 + self.state = 65 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==2 or _la==7): @@ -582,27 +589,27 @@ def expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 64 + self.state = 66 self.expr(7) pass elif la_ == 3: localctx = ExprParser.ComparisonContext(self, ExprParser.ExprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) - self.state = 65 + self.state = 67 if not self.precpred(self._ctx, 5): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 5)") - self.state = 66 + self.state = 68 self.match(ExprParser.COMPARISON) - self.state = 67 + self.state = 69 self.expr(6) pass - self.state = 72 + self.state = 74 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,3,self._ctx) + _alt = self._interp.adaptivePredict(self._input,4,self._ctx) except RecognitionException as re: localctx.exception = re @@ -668,19 +675,19 @@ def atom(self): localctx = ExprParser.AtomContext(self, self._ctx, self.state) self.enterRule(localctx, 4, self.RULE_atom) try: - self.state = 75 + self.state = 77 self._errHandler.sync(self) token = self._input.LA(1) if token in [12]: localctx = ExprParser.NumberContext(self, localctx) self.enterOuterAlt(localctx, 1) - self.state = 73 + self.state = 75 self.match(ExprParser.NUMBER) pass elif token in [14]: localctx = ExprParser.IdentifierContext(self, localctx) self.enterOuterAlt(localctx, 2) - self.state = 74 + self.state = 76 self.match(ExprParser.IDENTIFIER) pass else: @@ -729,11 +736,11 @@ def timeShiftRange(self): self.enterRule(localctx, 6, self.RULE_timeShiftRange) try: self.enterOuterAlt(localctx, 1) - self.state = 77 + self.state = 79 localctx.shift1 = self.shift() - self.state = 78 + self.state = 80 self.match(ExprParser.T__10) - self.state = 79 + self.state = 81 localctx.shift2 = self.shift() except RecognitionException as re: localctx.exception = re @@ -778,11 +785,11 @@ def timeRange(self): self.enterRule(localctx, 8, self.RULE_timeRange) try: self.enterOuterAlt(localctx, 1) - self.state = 81 + self.state = 83 localctx.expr1 = self.expr(0) - self.state = 82 + self.state = 84 self.match(ExprParser.T__10) - self.state = 83 + self.state = 85 localctx.expr2 = self.expr(0) except RecognitionException as re: localctx.exception = re @@ -826,13 +833,13 @@ def shift(self): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 85 - self.match(ExprParser.TIME) self.state = 87 + self.match(ExprParser.TIME) + self.state = 89 self._errHandler.sync(self) _la = self._input.LA(1) if _la==2 or _la==7: - self.state = 86 + self.state = 88 self.shift_expr(0) @@ -950,15 +957,15 @@ def shift_expr(self, _p:int=0): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 97 + self.state = 99 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,6,self._ctx) + la_ = self._interp.adaptivePredict(self._input,7,self._ctx) if la_ == 1: localctx = ExprParser.SignedAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 90 + self.state = 92 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==2 or _la==7): @@ -966,7 +973,7 @@ def shift_expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 91 + self.state = 93 self.atom() pass @@ -974,7 +981,7 @@ def shift_expr(self, _p:int=0): localctx = ExprParser.SignedExpressionContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 92 + self.state = 94 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==2 or _la==7): @@ -982,35 +989,35 @@ def shift_expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 93 + self.state = 95 self.match(ExprParser.T__2) - self.state = 94 + self.state = 96 self.expr(0) - self.state = 95 + self.state = 97 self.match(ExprParser.T__3) pass self._ctx.stop = self._input.LT(-1) - self.state = 107 + self.state = 109 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,8,self._ctx) + _alt = self._interp.adaptivePredict(self._input,9,self._ctx) while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: if _alt==1: if self._parseListeners is not None: self.triggerExitRuleEvent() _prevctx = localctx - self.state = 105 + self.state = 107 self._errHandler.sync(self) - la_ = self._interp.adaptivePredict(self._input,7,self._ctx) + la_ = self._interp.adaptivePredict(self._input,8,self._ctx) if la_ == 1: localctx = ExprParser.ShiftMuldivContext(self, ExprParser.Shift_exprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_shift_expr) - self.state = 99 + self.state = 101 if not self.precpred(self._ctx, 4): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 4)") - self.state = 100 + self.state = 102 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==5 or _la==6): @@ -1018,18 +1025,18 @@ def shift_expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 101 + self.state = 103 self.right_expr(0) pass elif la_ == 2: localctx = ExprParser.ShiftAddsubContext(self, ExprParser.Shift_exprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_shift_expr) - self.state = 102 + self.state = 104 if not self.precpred(self._ctx, 3): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") - self.state = 103 + self.state = 105 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==2 or _la==7): @@ -1037,14 +1044,14 @@ def shift_expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 104 + self.state = 106 self.right_expr(0) pass - self.state = 109 + self.state = 111 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,8,self._ctx) + _alt = self._interp.adaptivePredict(self._input,9,self._ctx) except RecognitionException as re: localctx.exception = re @@ -1137,7 +1144,7 @@ def right_expr(self, _p:int=0): self._la = 0 # Token type try: self.enterOuterAlt(localctx, 1) - self.state = 116 + self.state = 118 self._errHandler.sync(self) token = self._input.LA(1) if token in [3]: @@ -1145,27 +1152,27 @@ def right_expr(self, _p:int=0): self._ctx = localctx _prevctx = localctx - self.state = 111 + self.state = 113 self.match(ExprParser.T__2) - self.state = 112 + self.state = 114 self.expr(0) - self.state = 113 + self.state = 115 self.match(ExprParser.T__3) pass elif token in [12, 14]: localctx = ExprParser.RightAtomContext(self, localctx) self._ctx = localctx _prevctx = localctx - self.state = 115 + self.state = 117 self.atom() pass else: raise NoViableAltException(self) self._ctx.stop = self._input.LT(-1) - self.state = 123 + self.state = 125 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,10,self._ctx) + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: if _alt==1: if self._parseListeners is not None: @@ -1173,11 +1180,11 @@ def right_expr(self, _p:int=0): _prevctx = localctx localctx = ExprParser.RightMuldivContext(self, ExprParser.Right_exprContext(self, _parentctx, _parentState)) self.pushNewRecursionContext(localctx, _startState, self.RULE_right_expr) - self.state = 118 + self.state = 120 if not self.precpred(self._ctx, 3): from antlr4.error.Errors import FailedPredicateException raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") - self.state = 119 + self.state = 121 localctx.op = self._input.LT(1) _la = self._input.LA(1) if not(_la==5 or _la==6): @@ -1185,11 +1192,11 @@ def right_expr(self, _p:int=0): else: self._errHandler.reportMatch(self) self.consume() - self.state = 120 + self.state = 122 self.right_expr(4) - self.state = 125 + self.state = 127 self._errHandler.sync(self) - _alt = self._interp.adaptivePredict(self._input,10,self._ctx) + _alt = self._interp.adaptivePredict(self._input,11,self._ctx) except RecognitionException as re: localctx.exception = re diff --git a/src/andromede/expression/parsing/parse_expression.py b/src/andromede/expression/parsing/parse_expression.py index 961c0e51..f4ed9142 100644 --- a/src/andromede/expression/parsing/parse_expression.py +++ b/src/andromede/expression/parsing/parse_expression.py @@ -140,20 +140,38 @@ def visitTimeIndex(self, ctx: ExprParser.TimeIndexContext) -> LinearExpression: def visitTimeShift(self, ctx: ExprParser.TimeShiftContext) -> LinearExpression: shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore time_shift = ctx.shift().accept(self) # type: ignore - return shifted_expr.sum(shift=time_shift) + # specifics for x[t] ... + if expressions_equal(time_shift, literal(0)): + # TODO: Should expression simplification be handled only in linear expression building rather than here in the parsing ? + return shifted_expr + return shifted_expr.shift(time_shift) # Visit a parse tree produced by ExprParser#timeSum. - def visitTimeSum(self, ctx:ExprParser.TimeSumContext): - return self.visitChildren(ctx) + def visitTimeSum(self, ctx:ExprParser.TimeSumContext) -> LinearExpression: + summed_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) + expr = ctx.expr() + shift = ctx.shift() + time_range = ctx.timeRange() + time_shift_range = ctx.timeShiftRange() + if expr is not None: + return summed_expr.sum(eval=expr.accept(self)) + elif shift is not None: + return summed_expr.sum(shift=shift.accept(self)) + elif time_range is not None: + return summed_expr.sum(eval=time_range.accept(self)) + elif time_shift_range is not None: + return summed_expr.sum(shift=time_shift_range.accept(self)) + else: + return summed_expr.sum() # Visit a parse tree produced by ExprParser#timeShiftRange. def visitTimeShiftRange(self, ctx:ExprParser.TimeShiftRangeContext): - return ExpressionRange(ctx.shift, ctx.shift2) + return ExpressionRange(ctx.shift1, ctx.shift2) # Visit a parse tree produced by ExprParser#timeRange. def visitTimeRange(self, ctx:ExprParser.TimeRangeContext): - return self.visitChildren(ctx) + return ExpressionRange(ctx.expr1, ctx.expr2) # Visit a parse tree produced by ExprParser#function. def visitFunction(self, ctx: ExprParser.FunctionContext) -> LinearExpression: diff --git a/tests/unittests/expressions/parsing/test_expression_parsing.py b/tests/unittests/expressions/parsing/test_expression_parsing.py index b43c3144..81ab09f5 100644 --- a/tests/unittests/expressions/parsing/test_expression_parsing.py +++ b/tests/unittests/expressions/parsing/test_expression_parsing.py @@ -58,7 +58,7 @@ ( {"x"}, {}, - "x[-1..5]", + "sum(-1..5, x)", var("x").sum(eval=ExpressionRange(-literal(1), literal(5))), ), ({"x"}, {}, "x[1]", var("x").eval(1)), @@ -66,7 +66,7 @@ ( {"x"}, {}, - "x[t-1, t+4]", # TODO: Should raise ValueError: shift always with sum + "sum((t-1, t+4), x)", # TODO: Should raise ValueError: shift always with sum var("x").sum(shift=[-literal(1), literal(4)]), ), ( @@ -186,6 +186,9 @@ def test_parsing_visitor( "x[t+1-t]", "x[2*t]", "x[t 4]", + "x[t..4]", + "x[t+1..t+4]", + "x[1..4]", ], ) def test_parse_cancellation_should_throw(expression_str: str) -> None: