Skip to content

Commit ec15d44

Browse files
committed
basic case funky-trig-substitution integral works! epic
1 parent 703968b commit ec15d44

5 files changed

Lines changed: 139 additions & 7 deletions

File tree

src/simpy/expr.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,11 @@ def evalf(self, subs: Optional[Dict[str, "Expr"]] = None) -> "Expr":
228228
def children(self) -> List["Expr"]:
229229
raise NotImplementedError(f"Cannot get children of {self.__class__.__name__}")
230230

231+
@property
232+
def childless(self) -> bool:
233+
"""Returns True if it's a basic element like a symbol or a number that doesn't have any subcomponents at all."""
234+
return len(self.children()) == 0
235+
231236
def contains(self: "Expr", var: "Symbol") -> bool:
232237
is_var = isinstance(self, Symbol) and self.name == var.name
233238
return is_var or any(e.contains(var) for e in self.children())

src/simpy/regex.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
"""Custom library for checking what shit exists. Replaces searching the repr with regex.
1+
"""Custom library for checking what exists within exprs. Replaces searching the repr with regex.
22
33
This module is currently still developmental. It does the job often but is not promised to be robust
44
outside of the cases it is currently used for. Use with caution.
5+
6+
TODO: write a regex quickstart guide. Maybe spin this off into a subfolder and make the outward facing API more
7+
intuitive. Perhaps make it similar to the regex library.
8+
+ Unit testing with comprehensively thought-out cases.
59
"""
610

711
from collections import defaultdict
@@ -395,6 +399,28 @@ def count(expr: Expr, query: Expr) -> int:
395399
return sum(count(e, query) for e in expr.children())
396400

397401

402+
def contains(expr: Expr, query: Expr) -> EqResult:
403+
"""Checks if `query` appears in `expr`. Assumes query contains Any_ objects.
404+
Exact any-matches only, no up to factor or sum.
405+
Returns a results dictionary like eq() does. with `success` and `matches` keys.
406+
"""
407+
## base cases ##
408+
eq_output = eq(expr, query)
409+
if eq_output["success"]:
410+
return eq_output
411+
if expr.childless:
412+
return {"success": False}
413+
414+
## recursive cases ##
415+
# this isn't super sophisticated for dupes but it's fine for now.
416+
for e in expr.children():
417+
eq_output_ = contains(e, query)
418+
if eq_output_["success"]:
419+
return eq_output_
420+
421+
return {"success": False}
422+
423+
398424
def contains_cls(expr: Expr, cls: Type[Expr]) -> bool:
399425
if isinstance(expr, cls):
400426
return True
@@ -412,6 +438,10 @@ def general_count(expr: Expr, condition: ExprCondition) -> int:
412438

413439

414440
def general_contains(expr: Expr, condition: ExprCondition) -> bool:
441+
"""contains with a condition function instead of any.
442+
443+
this is sorta-legacy --- i think we should use contains instead; it's cuter. any-matches are cute.
444+
"""
415445
if condition(expr):
416446
return True
417447
return any(general_contains(e, condition) for e in expr.children())

src/simpy/transforms.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from .integral_table import check_integral_table
3434
from .linalg import invert
3535
from .polynomial import Polynomial, is_polynomial, polynomial_to_expr, rid_ending_zeros, to_const_polynomial
36-
from .regex import count, general_contains, replace, replace_class, replace_factory
36+
from .regex import Any_, contains, count, general_contains, replace, replace_class, replace_factory
3737
from .simplify import pythagorean_simplification
3838
from .simplify.product_to_sum import product_to_sum_unit
3939
from .utils import ExprFn, eq_with_var, random_id
@@ -211,6 +211,7 @@ def check(self, node: Node) -> bool:
211211
class USub(Transform, ABC):
212212
"""Base class for u-substituion transforms."""
213213

214+
# u (new var) written in terms of x (old var)
214215
_u: Expr = None
215216

216217
def backward(self, node: Node) -> None:
@@ -377,8 +378,6 @@ def _get_last_heuristic_transform(node: Node, tup=(PullConstant, Additivity)):
377378
return node.transform
378379

379380

380-
# Let's just add all the transforms we've used for now.
381-
# and we will make this shit good and generalized later.
382381
class TrigUSub2(USub):
383382
"""
384383
u-sub of a trig function
@@ -410,8 +409,8 @@ def check(self, node: Node) -> bool:
410409
if super().check(node) is False:
411410
return False
412411

413-
# Since B and C essentially undo each other, we want to make sure that the last
414-
# heuristic transform wasn't C.
412+
# Since TrigUSub2 and InverseTrigUSub essentially undo each other, we want to make sure that the last
413+
# heuristic transform wasn't InverseTrigUSub.
415414

416415
t = _get_last_heuristic_transform(node)
417416
if isinstance(t, InverseTrigUSub):
@@ -966,6 +965,67 @@ def forward(self, node: Node) -> None:
966965
self._u = self._u
967966

968967

968+
class TrigUSub(USub):
969+
"""This is the substitution of the form x = 4*cos(theta) (TODO: write better descrip later)"""
970+
971+
_a: Rat = None # constant in the square root
972+
_exponent: Rat = None # exponent of the square root (includes the square root, is a fraction w denom = 2)
973+
_case: int = None # 0 for sqrt(a^2 - x^2), 1 for sqrt(a^2 + x^2), 2 for sqrt(-a^2 + x^2)
974+
975+
def check(self, node: Node):
976+
"""Check if node.expr contains sqrt(a^2-x^2) or sqrt(a^2+x^2) where a is a constant."""
977+
if super().check(node) is False:
978+
return False
979+
980+
# check if any instance of sqrt(a^2 - x^2) or sqrt(a^2 + x^2) or sqrt(-a^2 + x^2) appears.
981+
def squared_integer_condition(expr: Expr) -> bool:
982+
return isinstance(expr, Rat) and isinstance(sqrt(expr), Rat)
983+
984+
a_squared = Any_("squared_integer", squared_integer_condition, is_constant=True)
985+
any_square_root_exponent = Any_(
986+
"square_root_exponent", lambda expr: isinstance(expr, Rat) and expr.denominator == 2, is_constant=True
987+
)
988+
queries = [
989+
(a_squared - node.var**2) ** any_square_root_exponent,
990+
(node.var**2 - a_squared) ** any_square_root_exponent,
991+
(node.var**2 + a_squared) ** any_square_root_exponent,
992+
]
993+
for i, query in enumerate(queries):
994+
out = contains(node.expr, query)
995+
if out["success"]:
996+
self._a = sqrt(out["matches"]["squared_integer"])
997+
self._exponent = out["matches"]["square_root_exponent"]
998+
self._case = i
999+
return True
1000+
1001+
return False
1002+
1003+
def forward(self, node: Node) -> None:
1004+
if self._case == 0:
1005+
# in the case of sqrt(a^2 - x^2):
1006+
# x = a * sin(theta)
1007+
# dx = a * cos(theta) d(theta)
1008+
theta = generate_intermediate_var()
1009+
dx_dtheta = self._a * cos(theta)
1010+
1011+
# so we replace x = a * sin(theta), this effectively leads to
1012+
# replacing sqrt(a^2 - x^2) = a * cos^2(theta)
1013+
theta_expr = replace(
1014+
node.expr,
1015+
(self._a**2 - node.var**2) ** self._exponent,
1016+
(self._a * cos(theta)) ** self._exponent.numerator,
1017+
)
1018+
if theta_expr.contains(node.var):
1019+
theta_expr = replace(theta_expr, node.var, self._a * sin(theta))
1020+
1021+
new_integrand = theta_expr * dx_dtheta
1022+
node.add_child(Node(new_integrand, theta, self, node))
1023+
1024+
self._u = asin(node.var / self._a) # theta in terms of x
1025+
1026+
return NotImplemented
1027+
1028+
9691029
class CompleteTheSquare(Transform):
9701030
"""Integration via completing the square"""
9711031

@@ -1120,6 +1180,7 @@ def backward(self, node: Node) -> None:
11201180
RewriteTrig,
11211181
RewritePythagorean,
11221182
InverseTrigUSub,
1183+
TrigUSub,
11231184
CompleteTheSquare,
11241185
GenericUSub,
11251186
]

tests/test_khan_academy_integrals.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,11 @@ def test_more_complicated_trig():
202202
expr = tan(x) ** 5 * sec(x) ** 4
203203
expected_ans = tan(x) ** 6 / 6 + tan(x) ** 8 / 8
204204
assert_integral(expr, expected_ans)
205+
206+
207+
def test_trigonometric_substitution():
208+
# that's this one: https://www.khanacademy.org/math/integral-calculus/ic-integration/ic-trig-substitution/e/integration-using-trigonometric-substitution
209+
expr = (4 - x**2) ** Fraction(3, 2)
210+
expected_ans = sin(4 * asin(x / 2)) / 2 + 4 * sin(2 * asin(x / 2)) + 6 * asin(x / 2) # TODO: simplify this
211+
ans = integrate(expr)
212+
assert_eq_plusc(expected_ans, ans)

tests/test_regex.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from test_utils import x, y
33

44
from simpy.expr import *
5-
from simpy.regex import Any_, any_, any_constant, eq
5+
from simpy.regex import Any_, any_, any_constant, contains, eq
66

77

88
def test_any_basic():
@@ -18,6 +18,14 @@ def test_sort_anys():
1818
assert eq(sin(x) * sec(x), sec(any_) * sin(any_))
1919

2020

21+
def test_eq_with_different_anys():
22+
any2 = Any_()
23+
expr = sin(x) + cos(y)
24+
query = sin(any_) + cos(any2)
25+
out = eq(expr, query)
26+
assert out["success"]
27+
28+
2129
@pytest.mark.parametrize(
2230
["sum", "expected"],
2331
[
@@ -118,3 +126,23 @@ def test_any_constant_fail():
118126
query = 2 * x + any_constant
119127
out = eq(expr, query)
120128
assert not out["success"]
129+
130+
131+
def test_any_constant_with_multiple_anys():
132+
expr = 2 * x + 3
133+
query = 2 * any_ + any_constant
134+
out = eq(expr, query)
135+
assert out["success"]
136+
137+
138+
def test_contains():
139+
query = log(sin(any_) ** 2 + cos(any_) ** 2)
140+
expr = (log(sin(x) ** 2 + cos(x) ** 2) + 3) ** 2
141+
assert contains(expr, query)["success"]
142+
assert contains(expr, query)["matches"] == x
143+
144+
145+
def test_contains_fail():
146+
query = log(sin(any_) ** 2 + cos(any_) ** 2)
147+
expr = (log(sin(x) ** 2 + cos(x) ** 3) + 3) ** 2 + 1
148+
assert not contains(expr, query)["success"]

0 commit comments

Comments
 (0)