diff --git a/app/context/physical_quantity.py b/app/context/physical_quantity.py index 28348de..480bdab 100644 --- a/app/context/physical_quantity.py +++ b/app/context/physical_quantity.py @@ -282,10 +282,9 @@ def quantity_match(unused_inputs): ans = parameters["reserved_expressions"]["answer"]["standard"]["value"].simplify() res = parameters["reserved_expressions"]["response"]["standard"]["value"].simplify() if (ans is not None and ans.is_constant()) and (res is not None and res.is_constant()): - if parsing_params.get('rtol', 0) > 0 and (ans != 0): - value_match = bool(abs(float((ans-res)/ans)) < parsing_params['rtol']) - elif parsing_params.get('atol', 0) > 0 or (ans == 0): - value_match = bool(abs(float(ans-res)) < parsing_params['atol']) + atol = float(parsing_params.get('atol', 0)) + rtol = float(parsing_params.get('rtol', 0)) + value_match = bool(abs(float(ans - res)) <= atol + rtol * abs(float(ans))) substitutions = [(key, expr["standard"]["unit"]) for (key, expr) in reserved_expressions] unit_match = is_equal(lhs, rhs, substitutions) diff --git a/app/context/symbolic.py b/app/context/symbolic.py index 960c989..4da8e9c 100644 --- a/app/context/symbolic.py +++ b/app/context/symbolic.py @@ -142,8 +142,6 @@ def check_equality(criterion, parameters_dict, local_substitutions=[]): # TODO: Make numerical comparison its own context if result is False: - error_below_rtol = None - error_below_atol = None if parameters_dict.get("numerical", False) or float(parameters_dict.get("rtol", 0)) > 0 or float(parameters_dict.get("atol", 0)) > 0: # REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols. # The two lines below this comments fixes the issue but a more robust solution should be found for cases where there @@ -158,26 +156,22 @@ def replace_pi(expr): # Separates LHS and RHS, parses and evaluates them res = N(replace_pi(lhs_expr)) ans = N(replace_pi(rhs_expr)) - if float(parameters_dict.get("atol", 0)) > 0: - try: - absolute_error = abs(float(ans-res)) - error_below_atol = bool(absolute_error < float(parameters_dict["atol"])) - except TypeError: - error_below_atol = None - else: - error_below_atol = True - if float(parameters_dict.get("rtol", 0)) > 0: + + atol = float(parameters_dict.get("atol", 0)) + rtol = float(parameters_dict.get("rtol", 0)) + + try: + # Dividing by ans cancels symbols (e.g. rho*L^3 / rho*L^3 = 1) + # Equivalent to numpy.isclose: |ans-res| <= atol + rtol*|ans| + ratio = float((ans - res) / ans) try: - relative_error = abs(float((ans-res)/ans)) - error_below_rtol = bool(relative_error < float(parameters_dict["rtol"])) + tolerance = rtol + atol / abs(float(ans)) except TypeError: - error_below_rtol = None - else: - error_below_rtol = True - if error_below_atol is None or error_below_rtol is None: + # Defaulting to relative tolerance if symbol cancellation fails to preserve the previous implementation + tolerance = rtol + result = bool(abs(ratio) <= tolerance) + except (TypeError, ZeroDivisionError): result = False - elif error_below_atol is True and error_below_rtol is True: - result = True return result