Skip to content

Commit 1bcfe8e

Browse files
Security hardening: replace unsafe eval() with AST-based evaluator in configure_optimizers
1 parent 669fef0 commit 1bcfe8e

3 files changed

Lines changed: 96 additions & 2 deletions

File tree

.github/workflows/automate-waiting-labels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
runs-on: ubuntu-latest
1515
steps:
1616
- name: Remove label
17-
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd
17+
uses: actions/github-script@9513afd82fbc900d13362cdd4fcdab0538f5124e
1818
with:
1919
script: |
2020
const issueNumber = context.issue.number || context.pull_request.number;

src/deepforest/main.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# entry point for deepforest model
2+
import ast
23
import importlib
34
import os
45
import warnings
@@ -887,7 +888,7 @@ def configure_optimizers(self):
887888

888889
# Assume the lambda is a function of epoch
889890
def lr_lambda(epoch):
890-
return eval(params.lr_lambda)
891+
return self._safe_eval_lr_lambda(params.lr_lambda, epoch)
891892

892893
if scheduler_type == "cosine":
893894
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
@@ -940,6 +941,69 @@ def lr_lambda(epoch):
940941
else:
941942
return optimizer
942943

944+
@staticmethod
945+
def _safe_eval_lr_lambda(expr: str, epoch: int) -> float:
946+
"""Safely evaluate arithmetic scheduler expressions against `epoch`.
947+
948+
Supported syntax is intentionally limited to numeric constants,
949+
parentheses, unary +/- and arithmetic operators (+, -, *, /, //, %, **),
950+
with `epoch` as the only allowed variable name.
951+
"""
952+
953+
allowed_binary_ops = (
954+
ast.Add,
955+
ast.Sub,
956+
ast.Mult,
957+
ast.Div,
958+
ast.FloorDiv,
959+
ast.Mod,
960+
ast.Pow,
961+
)
962+
allowed_unary_ops = (ast.UAdd, ast.USub)
963+
964+
def _eval(node):
965+
if isinstance(node, ast.Expression):
966+
return _eval(node.body)
967+
968+
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
969+
return node.value
970+
971+
if isinstance(node, ast.Name) and node.id == "epoch":
972+
return epoch
973+
974+
if isinstance(node, ast.BinOp) and isinstance(node.op, allowed_binary_ops):
975+
left = _eval(node.left)
976+
right = _eval(node.right)
977+
if isinstance(node.op, ast.Add):
978+
return left + right
979+
if isinstance(node.op, ast.Sub):
980+
return left - right
981+
if isinstance(node.op, ast.Mult):
982+
return left * right
983+
if isinstance(node.op, ast.Div):
984+
return left / right
985+
if isinstance(node.op, ast.FloorDiv):
986+
return left // right
987+
if isinstance(node.op, ast.Mod):
988+
return left % right
989+
if isinstance(node.op, ast.Pow):
990+
return left**right
991+
992+
if isinstance(node, ast.UnaryOp) and isinstance(node.op, allowed_unary_ops):
993+
operand = _eval(node.operand)
994+
if isinstance(node.op, ast.UAdd):
995+
return +operand
996+
if isinstance(node.op, ast.USub):
997+
return -operand
998+
999+
raise ValueError(f"Unsafe lr_lambda expression: {expr}")
1000+
1001+
try:
1002+
parsed = ast.parse(expr, mode="eval")
1003+
return float(_eval(parsed))
1004+
except (SyntaxError, TypeError, ValueError, ZeroDivisionError) as exc:
1005+
raise ValueError(f"Unsafe lr_lambda expression: {expr}") from exc
1006+
9431007
def evaluate(
9441008
self,
9451009
csv_file,

tests/test_main.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,36 @@ def test_custom_log_root(m, tmpdir):
11731173
version_dir = version_dirs[0]
11741174
assert version_dir.join("hparams.yaml").exists(), "hparams.yaml not found"
11751175

1176+
def test_configure_optimizers_rejects_unsafe_lr_lambda(tmp_path):
1177+
"""Regression test: malicious lr_lambda expressions must be rejected."""
1178+
annotations_file = get_data("testfile_deepforest.csv")
1179+
root_dir = os.path.dirname(get_data("testfile_deepforest.csv"))
1180+
1181+
config_args = {
1182+
"train": {
1183+
"lr": 0.01,
1184+
"scheduler": {
1185+
"type": "lambdaLR",
1186+
"params": {
1187+
"lr_lambda": "__import__('os').system('echo injected')",
1188+
},
1189+
},
1190+
"csv_file": annotations_file,
1191+
"root_dir": root_dir,
1192+
"fast_dev_run": False,
1193+
},
1194+
"validation": {
1195+
"csv_file": None,
1196+
"root_dir": root_dir,
1197+
},
1198+
"log_root": str(tmp_path),
1199+
}
1200+
1201+
m = main.deepforest(model=torch.nn.Linear(1, 1), config_args=config_args)
1202+
1203+
with pytest.raises(ValueError, match="Unsafe lr_lambda"):
1204+
m.configure_optimizers()
1205+
11761206
def test_huggingface_model_loads_correct_label_dict():
11771207
"""Regression test for #1286:
11781208
HuggingFace models should load correct label_dict from config.json.

0 commit comments

Comments
 (0)