|
1 | 1 | # entry point for deepforest model |
| 2 | +import ast |
2 | 3 | import importlib |
3 | 4 | import os |
4 | 5 | import warnings |
@@ -887,7 +888,7 @@ def configure_optimizers(self): |
887 | 888 |
|
888 | 889 | # Assume the lambda is a function of epoch |
889 | 890 | def lr_lambda(epoch): |
890 | | - return eval(params.lr_lambda) |
| 891 | + return self._safe_eval_lr_lambda(params.lr_lambda, epoch) |
891 | 892 |
|
892 | 893 | if scheduler_type == "cosine": |
893 | 894 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
@@ -940,6 +941,69 @@ def lr_lambda(epoch): |
940 | 941 | else: |
941 | 942 | return optimizer |
942 | 943 |
|
| 944 | + @staticmethod |
| 945 | + def _safe_eval_lr_lambda(expr: str, epoch: int) -> float: |
| 946 | + """Safely evaluate arithmetic scheduler expressions against `epoch`. |
| 947 | +
|
| 948 | + Supported syntax is intentionally limited to numeric constants, |
| 949 | + parentheses, unary +/- and arithmetic operators (+, -, *, /, //, %, **), |
| 950 | + with `epoch` as the only allowed variable name. |
| 951 | + """ |
| 952 | + |
| 953 | + allowed_binary_ops = ( |
| 954 | + ast.Add, |
| 955 | + ast.Sub, |
| 956 | + ast.Mult, |
| 957 | + ast.Div, |
| 958 | + ast.FloorDiv, |
| 959 | + ast.Mod, |
| 960 | + ast.Pow, |
| 961 | + ) |
| 962 | + allowed_unary_ops = (ast.UAdd, ast.USub) |
| 963 | + |
| 964 | + def _eval(node): |
| 965 | + if isinstance(node, ast.Expression): |
| 966 | + return _eval(node.body) |
| 967 | + |
| 968 | + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): |
| 969 | + return node.value |
| 970 | + |
| 971 | + if isinstance(node, ast.Name) and node.id == "epoch": |
| 972 | + return epoch |
| 973 | + |
| 974 | + if isinstance(node, ast.BinOp) and isinstance(node.op, allowed_binary_ops): |
| 975 | + left = _eval(node.left) |
| 976 | + right = _eval(node.right) |
| 977 | + if isinstance(node.op, ast.Add): |
| 978 | + return left + right |
| 979 | + if isinstance(node.op, ast.Sub): |
| 980 | + return left - right |
| 981 | + if isinstance(node.op, ast.Mult): |
| 982 | + return left * right |
| 983 | + if isinstance(node.op, ast.Div): |
| 984 | + return left / right |
| 985 | + if isinstance(node.op, ast.FloorDiv): |
| 986 | + return left // right |
| 987 | + if isinstance(node.op, ast.Mod): |
| 988 | + return left % right |
| 989 | + if isinstance(node.op, ast.Pow): |
| 990 | + return left**right |
| 991 | + |
| 992 | + if isinstance(node, ast.UnaryOp) and isinstance(node.op, allowed_unary_ops): |
| 993 | + operand = _eval(node.operand) |
| 994 | + if isinstance(node.op, ast.UAdd): |
| 995 | + return +operand |
| 996 | + if isinstance(node.op, ast.USub): |
| 997 | + return -operand |
| 998 | + |
| 999 | + raise ValueError(f"Unsafe lr_lambda expression: {expr}") |
| 1000 | + |
| 1001 | + try: |
| 1002 | + parsed = ast.parse(expr, mode="eval") |
| 1003 | + return float(_eval(parsed)) |
| 1004 | + except (SyntaxError, TypeError, ValueError, ZeroDivisionError) as exc: |
| 1005 | + raise ValueError(f"Unsafe lr_lambda expression: {expr}") from exc |
| 1006 | + |
943 | 1007 | def evaluate( |
944 | 1008 | self, |
945 | 1009 | csv_file, |
|
0 commit comments