|
33 | 33 | from .integral_table import check_integral_table |
34 | 34 | from .linalg import invert |
35 | 35 | from .polynomial import Polynomial, is_polynomial, polynomial_to_expr, rid_ending_zeros, to_const_polynomial |
36 | | -from .regex import count, general_contains, replace, replace_class, replace_factory |
| 36 | +from .regex import Any_, contains, count, general_contains, replace, replace_class, replace_factory |
37 | 37 | from .simplify import pythagorean_simplification |
38 | 38 | from .simplify.product_to_sum import product_to_sum_unit |
39 | 39 | from .utils import ExprFn, eq_with_var, random_id |
@@ -211,6 +211,7 @@ def check(self, node: Node) -> bool: |
211 | 211 | class USub(Transform, ABC): |
212 | 212 | """Base class for u-substituion transforms.""" |
213 | 213 |
|
| 214 | + # u (new var) written in terms of x (old var) |
214 | 215 | _u: Expr = None |
215 | 216 |
|
216 | 217 | def backward(self, node: Node) -> None: |
@@ -377,8 +378,6 @@ def _get_last_heuristic_transform(node: Node, tup=(PullConstant, Additivity)): |
377 | 378 | return node.transform |
378 | 379 |
|
379 | 380 |
|
380 | | -# Let's just add all the transforms we've used for now. |
381 | | -# and we will make this shit good and generalized later. |
382 | 381 | class TrigUSub2(USub): |
383 | 382 | """ |
384 | 383 | u-sub of a trig function |
@@ -410,8 +409,8 @@ def check(self, node: Node) -> bool: |
410 | 409 | if super().check(node) is False: |
411 | 410 | return False |
412 | 411 |
|
413 | | - # Since B and C essentially undo each other, we want to make sure that the last |
414 | | - # heuristic transform wasn't C. |
| 412 | + # Since TrigUSub2 and InverseTrigUSub essentially undo each other, we want to make sure that the last |
| 413 | + # heuristic transform wasn't InverseTrigUSub. |
415 | 414 |
|
416 | 415 | t = _get_last_heuristic_transform(node) |
417 | 416 | if isinstance(t, InverseTrigUSub): |
@@ -966,6 +965,67 @@ def forward(self, node: Node) -> None: |
966 | 965 | self._u = self._u |
967 | 966 |
|
968 | 967 |
|
| 968 | +class TrigUSub(USub): |
| 969 | + """This is the substitution of the form x = 4*cos(theta) (TODO: write better descrip later)""" |
| 970 | + |
| 971 | + _a: Rat = None # constant in the square root |
| 972 | + _exponent: Rat = None # exponent of the square root (includes the square root, is a fraction w denom = 2) |
| 973 | + _case: int = None # 0 for sqrt(a^2 - x^2), 1 for sqrt(a^2 + x^2), 2 for sqrt(-a^2 + x^2) |
| 974 | + |
| 975 | + def check(self, node: Node): |
| 976 | + """Check if node.expr contains sqrt(a^2-x^2) or sqrt(a^2+x^2) where a is a constant.""" |
| 977 | + if super().check(node) is False: |
| 978 | + return False |
| 979 | + |
| 980 | + # check if any instance of sqrt(a^2 - x^2) or sqrt(a^2 + x^2) or sqrt(-a^2 + x^2) appears. |
| 981 | + def squared_integer_condition(expr: Expr) -> bool: |
| 982 | + return isinstance(expr, Rat) and isinstance(sqrt(expr), Rat) |
| 983 | + |
| 984 | + a_squared = Any_("squared_integer", squared_integer_condition, is_constant=True) |
| 985 | + any_square_root_exponent = Any_( |
| 986 | + "square_root_exponent", lambda expr: isinstance(expr, Rat) and expr.denominator == 2, is_constant=True |
| 987 | + ) |
| 988 | + queries = [ |
| 989 | + (a_squared - node.var**2) ** any_square_root_exponent, |
| 990 | + (node.var**2 - a_squared) ** any_square_root_exponent, |
| 991 | + (node.var**2 + a_squared) ** any_square_root_exponent, |
| 992 | + ] |
| 993 | + for i, query in enumerate(queries): |
| 994 | + out = contains(node.expr, query) |
| 995 | + if out["success"]: |
| 996 | + self._a = sqrt(out["matches"]["squared_integer"]) |
| 997 | + self._exponent = out["matches"]["square_root_exponent"] |
| 998 | + self._case = i |
| 999 | + return True |
| 1000 | + |
| 1001 | + return False |
| 1002 | + |
| 1003 | + def forward(self, node: Node) -> None: |
| 1004 | + if self._case == 0: |
| 1005 | + # in the case of sqrt(a^2 - x^2): |
| 1006 | + # x = a * sin(theta) |
| 1007 | + # dx = a * cos(theta) d(theta) |
| 1008 | + theta = generate_intermediate_var() |
| 1009 | + dx_dtheta = self._a * cos(theta) |
| 1010 | + |
| 1011 | + # so we replace x = a * sin(theta), this effectively leads to |
| 1012 | + # replacing sqrt(a^2 - x^2) = a * cos^2(theta) |
| 1013 | + theta_expr = replace( |
| 1014 | + node.expr, |
| 1015 | + (self._a**2 - node.var**2) ** self._exponent, |
| 1016 | + (self._a * cos(theta)) ** self._exponent.numerator, |
| 1017 | + ) |
| 1018 | + if theta_expr.contains(node.var): |
| 1019 | + theta_expr = replace(theta_expr, node.var, self._a * sin(theta)) |
| 1020 | + |
| 1021 | + new_integrand = theta_expr * dx_dtheta |
| 1022 | + node.add_child(Node(new_integrand, theta, self, node)) |
| 1023 | + |
| 1024 | + self._u = asin(node.var / self._a) # theta in terms of x |
| 1025 | + |
| 1026 | + return NotImplemented |
| 1027 | + |
| 1028 | + |
969 | 1029 | class CompleteTheSquare(Transform): |
970 | 1030 | """Integration via completing the square""" |
971 | 1031 |
|
@@ -1120,6 +1180,7 @@ def backward(self, node: Node) -> None: |
1120 | 1180 | RewriteTrig, |
1121 | 1181 | RewritePythagorean, |
1122 | 1182 | InverseTrigUSub, |
| 1183 | + TrigUSub, |
1123 | 1184 | CompleteTheSquare, |
1124 | 1185 | GenericUSub, |
1125 | 1186 | ] |
|
0 commit comments