Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions app/context/physical_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,9 @@ def quantity_match(unused_inputs):
ans = parameters["reserved_expressions"]["answer"]["standard"]["value"].simplify()
res = parameters["reserved_expressions"]["response"]["standard"]["value"].simplify()
if (ans is not None and ans.is_constant()) and (res is not None and res.is_constant()):
if parsing_params.get('rtol', 0) > 0 and (ans != 0):
value_match = bool(abs(float((ans-res)/ans)) < parsing_params['rtol'])
elif parsing_params.get('atol', 0) > 0 or (ans == 0):
value_match = bool(abs(float(ans-res)) < parsing_params['atol'])
atol = float(parsing_params.get('atol', 0))
rtol = float(parsing_params.get('rtol', 0))
value_match = bool(abs(float(ans - res)) <= atol + rtol * abs(float(ans)))

substitutions = [(key, expr["standard"]["unit"]) for (key, expr) in reserved_expressions]
unit_match = is_equal(lhs, rhs, substitutions)
Expand Down
32 changes: 13 additions & 19 deletions app/context/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ def check_equality(criterion, parameters_dict, local_substitutions=[]):

# TODO: Make numerical comparison its own context
if result is False:
error_below_rtol = None
error_below_atol = None
if parameters_dict.get("numerical", False) or float(parameters_dict.get("rtol", 0)) > 0 or float(parameters_dict.get("atol", 0)) > 0:
# REMARK: 'pi' should be a reserved symbol but it is sometimes not treated as one, possibly because of input symbols.
# The two lines below this comments fixes the issue but a more robust solution should be found for cases where there
Expand All @@ -158,26 +156,22 @@ def replace_pi(expr):
# Separates LHS and RHS, parses and evaluates them
res = N(replace_pi(lhs_expr))
ans = N(replace_pi(rhs_expr))
if float(parameters_dict.get("atol", 0)) > 0:
try:
absolute_error = abs(float(ans-res))
error_below_atol = bool(absolute_error < float(parameters_dict["atol"]))
except TypeError:
error_below_atol = None
else:
error_below_atol = True
if float(parameters_dict.get("rtol", 0)) > 0:

atol = float(parameters_dict.get("atol", 0))
rtol = float(parameters_dict.get("rtol", 0))

try:
# Dividing by ans cancels symbols (e.g. rho*L^3 / rho*L^3 = 1)
# Equivalent to numpy.isclose: |ans-res| <= atol + rtol*|ans|
ratio = float((ans - res) / ans)
try:
relative_error = abs(float((ans-res)/ans))
error_below_rtol = bool(relative_error < float(parameters_dict["rtol"]))
tolerance = rtol + atol / abs(float(ans))
except TypeError:
error_below_rtol = None
else:
error_below_rtol = True
if error_below_atol is None or error_below_rtol is None:
# Defaulting to relative tolerance if symbol cancellation fails to preserve the previous implementation
tolerance = rtol
result = bool(abs(ratio) <= tolerance)
except (TypeError, ZeroDivisionError):
result = False
elif error_below_atol is True and error_below_rtol is True:
result = True

return result

Expand Down
Loading