Skip to content

Commit 849a7c8

Browse files
security: replace eval() with AST-based evaluator in configure_optimizers
Replace the unsafe eval(params.lr_lambda) call with a restricted AST-based evaluator (_safe_eval_lr_lambda) that only permits numeric constants, the epoch variable, parentheses, unary +/- and arithmetic operators (+, -, *, /, //, %, **). This prevents arbitrary code execution via malicious config overrides or CLI arguments (CWE-95) while preserving full flexibility for valid scheduler expressions. Add a regression test that verifies malicious expressions (e.g. __import__('os').system(...)) raise ValueError. Closes #1305
1 parent 4db10b2 commit 849a7c8

2 files changed

Lines changed: 420 additions & 183 deletions

File tree

src/deepforest/main.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# entry point for deepforest model
2+
import ast
23
import importlib
34
import os
45
import warnings
@@ -945,7 +946,7 @@ def configure_optimizers(self):
945946

946947
# Assume the lambda is a function of epoch
947948
def lr_lambda(epoch):
948-
return eval(params.lr_lambda)
949+
return self._safe_eval_lr_lambda(params.lr_lambda, epoch)
949950

950951
if scheduler_type == "cosine":
951952
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
@@ -998,6 +999,69 @@ def lr_lambda(epoch):
998999
else:
9991000
return optimizer
10001001

1002+
@staticmethod
1003+
def _safe_eval_lr_lambda(expr: str, epoch: int) -> float:
1004+
"""Safely evaluate arithmetic scheduler expressions against `epoch`.
1005+
1006+
Supported syntax is intentionally limited to numeric constants,
1007+
parentheses, unary +/- and arithmetic operators (+, -, *, /, //, %, **),
1008+
with `epoch` as the only allowed variable name.
1009+
"""
1010+
1011+
allowed_binary_ops = (
1012+
ast.Add,
1013+
ast.Sub,
1014+
ast.Mult,
1015+
ast.Div,
1016+
ast.FloorDiv,
1017+
ast.Mod,
1018+
ast.Pow,
1019+
)
1020+
allowed_unary_ops = (ast.UAdd, ast.USub)
1021+
1022+
def _eval(node):
1023+
if isinstance(node, ast.Expression):
1024+
return _eval(node.body)
1025+
1026+
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
1027+
return node.value
1028+
1029+
if isinstance(node, ast.Name) and node.id == "epoch":
1030+
return epoch
1031+
1032+
if isinstance(node, ast.BinOp) and isinstance(node.op, allowed_binary_ops):
1033+
left = _eval(node.left)
1034+
right = _eval(node.right)
1035+
if isinstance(node.op, ast.Add):
1036+
return left + right
1037+
if isinstance(node.op, ast.Sub):
1038+
return left - right
1039+
if isinstance(node.op, ast.Mult):
1040+
return left * right
1041+
if isinstance(node.op, ast.Div):
1042+
return left / right
1043+
if isinstance(node.op, ast.FloorDiv):
1044+
return left // right
1045+
if isinstance(node.op, ast.Mod):
1046+
return left % right
1047+
if isinstance(node.op, ast.Pow):
1048+
return left**right
1049+
1050+
if isinstance(node, ast.UnaryOp) and isinstance(node.op, allowed_unary_ops):
1051+
operand = _eval(node.operand)
1052+
if isinstance(node.op, ast.UAdd):
1053+
return +operand
1054+
if isinstance(node.op, ast.USub):
1055+
return -operand
1056+
1057+
raise ValueError(f"Unsafe lr_lambda expression: {expr}")
1058+
1059+
try:
1060+
parsed = ast.parse(expr, mode="eval")
1061+
return float(_eval(parsed))
1062+
except (SyntaxError, TypeError, ValueError, ZeroDivisionError) as exc:
1063+
raise ValueError(f"Unsafe lr_lambda expression: {expr}") from exc
1064+
10011065
def __evaluate__(
10021066
self,
10031067
csv_file,

0 commit comments

Comments
 (0)