Skip to content

Commit 07787ab

Browse files
committed
add: support sampling and warmup instrumentation policies
1 parent 53d90ca commit 07787ab

8 files changed

Lines changed: 191 additions & 7 deletions

File tree

traincheck/collect_trace.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,20 @@ def main():
380380
help="Indicate wthether use torch.compile to speed the model, necessary to realize compatibility",
381381
)
382382

383+
## instrumentation policy configs
384+
parser.add_argument(
385+
"--sampling-interval",
386+
type=int,
387+
default=None,
388+
help="Interval of steps to instrument (e.g., 10 for every 10th step).",
389+
)
390+
parser.add_argument(
391+
"--warm-up-steps",
392+
type=int,
393+
default=0,
394+
help="Number of initial steps to always instrument.",
395+
)
396+
383397
args = parser.parse_args()
384398

385399
# read the configuration file
@@ -508,6 +522,8 @@ def main():
508522
instr_descriptors=args.instr_descriptors,
509523
no_auto_var_instr=args.no_auto_var_instr,
510524
use_torch_compile=args.use_torch_compile,
525+
sampling_interval=args.sampling_interval,
526+
warm_up_steps=args.warm_up_steps,
511527
)
512528

513529
if args.copy_all_files:

traincheck/config/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@
9595
TYPE_ERR_THRESHOLD = 3
9696
RECURSION_ERR_THRESHOLD = 5
9797

98+
INSTRUMENTATION_POLICY = {
99+
"interval": 1,
100+
"warm_up": 1, # default to 1 to ensure the first step is always instrumented: before warm-up is depleted, we do instrumentation with interval=1, after warm-up is depleted, we do instrumentation with the specified interval
101+
}
102+
103+
DISABLE_WRAPPER = False
104+
98105

99106
class InstrOpt:
100107
def __init__(

traincheck/developer/annotations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import traincheck.config.config as config
12
import traincheck.instrumentor.tracer as tracer
23
from traincheck.config.config import ALL_STAGE_NAMES
34
from traincheck.instrumentor import META_VARS
@@ -16,8 +17,13 @@ def annotate_stage(stage_name: str):
1617
stage_name in ALL_STAGE_NAMES
1718
), f"Invalid stage name: {stage_name}, valid ones are {ALL_STAGE_NAMES}"
1819

20+
old_stage = META_VARS.get("stage", None)
1921
META_VARS["stage"] = stage_name
2022

23+
# We always reset the wrapper when stage changes, and let the policy decide later if we should skip
24+
if old_stage != stage_name:
25+
config.DISABLE_WRAPPER = False
26+
2127

2228
def annotate_answer_start_token_ids(
2329
answer_start_token_id: int, include_start_token: bool = False

traincheck/instrumentor/proxy_wrapper/proxy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010

11+
import traincheck.config.config as config
1112
import traincheck.instrumentor.proxy_wrapper.proxy_config as proxy_config # HACK: cannot directly import config variables as then they would be local variables
1213
import traincheck.instrumentor.proxy_wrapper.proxy_methods as proxy_methods
1314
from traincheck.config.config import should_disable_proxy_dumping
@@ -158,6 +159,9 @@ def __deepcopy__(self, memo):
158159
return new_copy
159160

160161
def dump_trace(self, phase, dump_loc):
162+
if config.DISABLE_WRAPPER:
163+
return
164+
161165
obj = self._obj
162166
var_name = self.__dict__["var_name"]
163167
assert var_name is not None # '' is allowed as a var_name (root object)

traincheck/instrumentor/proxy_wrapper/proxy_observer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22
import typing
33

4+
import traincheck.config.config as config
45
from traincheck.config.config import should_disable_proxy_dumping
56
from traincheck.instrumentor.proxy_wrapper.subclass import ProxyParameter
67
from traincheck.utils import typename
@@ -21,6 +22,8 @@ def observe_proxy_var(
2122
phase,
2223
observe_api_name: str,
2324
):
25+
if config.DISABLE_WRAPPER:
26+
return
2427

2528
# update the proxy object's timestamp
2629
var.update_timestamp()

traincheck/instrumentor/proxy_wrapper/subclass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from torch import nn
88

9+
import traincheck.config.config as config
910
from traincheck.config.config import should_disable_proxy_dumping
1011
from traincheck.instrumentor.dumper import dump_trace_VAR
1112
from traincheck.instrumentor.proxy_wrapper.dumper import dump_attributes, get_meta_vars
@@ -178,6 +179,9 @@ def register_object(self):
178179
)
179180

180181
def dump_trace(self, phase, dump_loc):
182+
if config.DISABLE_WRAPPER:
183+
return
184+
181185
# TODO
182186
var_name = self.__dict__["var_name"]
183187
# assert var_name is not None # '' is allowed as a var_name (root object)

traincheck/instrumentor/source_file.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def __init__(
3333
use_full_instr: bool,
3434
funcs_to_instr: list[str] | None,
3535
API_dump_stack_trace: bool,
36+
sampling_interval: int,
37+
warm_up_steps: int,
3638
):
3739
super().__init__()
3840
if not modules_to_instr:
@@ -44,10 +46,27 @@ def __init__(
4446
self.use_full_instr = use_full_instr
4547
self.funcs_to_instr = funcs_to_instr
4648
self.API_dump_stack_trace = API_dump_stack_trace
49+
self.sampling_interval = sampling_interval
50+
self.warm_up_steps = warm_up_steps
51+
self.current_function = None
52+
53+
def visit_FunctionDef(self, node):
54+
old_function = self.current_function
55+
self.current_function = node.name
56+
self.generic_visit(node)
57+
self.current_function = old_function
58+
return node
59+
60+
def visit_AsyncFunctionDef(self, node):
61+
old_function = self.current_function
62+
self.current_function = node.name
63+
self.generic_visit(node)
64+
self.current_function = old_function
65+
return node
4766

4867
def get_instrument_node(self, module_name: str):
4968
return ast.parse(
50-
f"from traincheck.instrumentor.tracer import Instrumentor; Instrumentor({module_name}, scan_proxy_in_args={self.scan_proxy_in_args}, use_full_instr={self.use_full_instr}, funcs_to_instr={str(self.funcs_to_instr)}, API_dump_stack_trace={self.API_dump_stack_trace}).instrument()"
69+
f"from traincheck.instrumentor.tracer import Instrumentor; Instrumentor({module_name}, scan_proxy_in_args={self.scan_proxy_in_args}, use_full_instr={self.use_full_instr}, funcs_to_instr={str(self.funcs_to_instr)}, API_dump_stack_trace={self.API_dump_stack_trace}, sampling_interval={str(self.sampling_interval)}, warm_up_steps={str(self.warm_up_steps)}).instrument()"
5170
).body
5271

5372
def visit_Import(self, node):
@@ -65,8 +84,6 @@ def visit_Import(self, node):
6584
instrument_nodes.append(self.get_instrument_node(n.asname))
6685
else:
6786
instrument_nodes.append(self.get_instrument_node(n.name))
68-
# let's see if there are aliases, if yes, use them
69-
# if not, let's use the module name directly
7087
return [node] + instrument_nodes
7188

7289
def visit_ImportFrom(self, node):
@@ -87,6 +104,105 @@ def visit_ImportFrom(self, node):
87104
instrument_nodes.append(self.get_instrument_node(n.name))
88105
return [node] + instrument_nodes
89106

107+
def _get_loop_context(self, node):
108+
# Heuristic: Inject into loops that look like training loops.
109+
# Check for calls to .step() or .backward()
110+
has_training_signal = False
111+
for child in ast.walk(node):
112+
if isinstance(child, ast.Call):
113+
if isinstance(child.func, ast.Attribute):
114+
if child.func.attr in ["step", "backward"]:
115+
has_training_signal = True
116+
117+
if has_training_signal:
118+
return "training"
119+
120+
# If no explicit training signal, check if we are in an eval/test function
121+
if self.current_function:
122+
name_lower = self.current_function.lower()
123+
if "test" in name_lower or "eval" in name_lower or "valid" in name_lower:
124+
return "eval"
125+
126+
return None
127+
128+
def _inject_call(self, node, func_name):
129+
import_stmt = ast.ImportFrom(
130+
module="traincheck.instrumentor.control",
131+
names=[ast.alias(name=func_name, asname=None)],
132+
level=0,
133+
)
134+
call_stmt = ast.Expr(
135+
value=ast.Call(
136+
func=ast.Name(id=func_name, ctx=ast.Load()), args=[], keywords=[]
137+
)
138+
)
139+
node.body.insert(0, call_stmt)
140+
node.body.insert(0, import_stmt)
141+
return node
142+
143+
def visit_For(self, node):
144+
self.generic_visit(node)
145+
context = self._get_loop_context(node)
146+
if context == "training":
147+
return self._inject_call(node, "start_step")
148+
elif context == "eval":
149+
return self._inject_call(node, "start_eval_step")
150+
return node
151+
152+
def visit_While(self, node):
153+
self.generic_visit(node)
154+
context = self._get_loop_context(node)
155+
if context == "training":
156+
return self._inject_call(node, "start_step")
157+
elif context == "eval":
158+
return self._inject_call(node, "start_eval_step")
159+
return node
160+
161+
def _should_inject_control(self, node):
162+
# Heuristic: Inject into loops that look like training loops.
163+
# Check for calls to .step() or .backward()
164+
for child in ast.walk(node):
165+
if isinstance(child, ast.Call):
166+
if isinstance(child.func, ast.Attribute):
167+
if child.func.attr in ["step", "backward"]:
168+
return True
169+
return False
170+
171+
def _inject_start_step(self, node):
172+
import_stmt = ast.ImportFrom(
173+
module="traincheck.instrumentor.control",
174+
names=[ast.alias(name="start_step", asname=None)],
175+
level=0,
176+
)
177+
call_stmt = ast.Expr(
178+
value=ast.Call(
179+
func=ast.Name(id="start_step", ctx=ast.Load()), args=[], keywords=[]
180+
)
181+
)
182+
# We need to insert the import at the top of the file ideally,
183+
# but inserting inside the loop works if we deal with python scoping (imports are valid statements).
184+
# Actually proper way is to add import at module level.
185+
# But `visit_Module` is not here.
186+
# For simplicity, let's just use fully qualified name or inject import in the loop (a bit inefficient but works).
187+
# Better: Inject `import traincheck.instrumentor.control` at top of loop or use `traincheck.instrumentor.control.start_step()` with import logic handled elsewhere?
188+
# The `InsertTracerVisitor` modifies the module. We can add an import to the module body if we had access.
189+
# `visit_Import` adds imports.
190+
# Let's assume `traincheck` is importable.
191+
192+
# Helper to create `traincheck.instrumentor.control.start_step()` call
193+
# And ensure import is present.
194+
# Actually `InsertTracerVisitor` is used on the whole file.
195+
# Let's just blindly insert the call logic and rely on the fact that we can insert an import at the top of the loop
196+
# or just assume the user code can handle it if we inject the import statement right before the call.
197+
198+
# Let's inject:
199+
# from traincheck.instrumentor.control import start_step
200+
# start_step()
201+
202+
node.body.insert(0, call_stmt)
203+
node.body.insert(0, import_stmt)
204+
return node
205+
90206

91207
def instrument_library(
92208
source: str,
@@ -95,6 +211,8 @@ def instrument_library(
95211
use_full_instr: bool,
96212
funcs_to_instr: list[str] | None,
97213
API_dump_stack_trace: bool,
214+
sampling_interval: int,
215+
warm_up_steps: int,
98216
) -> str:
99217
"""
100218
Instruments the given source code and returns the instrumented source code.
@@ -116,6 +234,8 @@ def instrument_library(
116234
use_full_instr,
117235
funcs_to_instr,
118236
API_dump_stack_trace,
237+
sampling_interval,
238+
warm_up_steps,
119239
)
120240
root = visitor.visit(root)
121241
source = ast.unparse(root)
@@ -811,6 +931,8 @@ def instrument_file(
811931
instr_descriptors: bool,
812932
no_auto_var_instr: bool,
813933
use_torch_compile: bool,
934+
sampling_interval: int = 1,
935+
warm_up_steps: int = 0,
814936
) -> str:
815937
"""
816938
Instruments the given file and returns the instrumented source code.
@@ -827,6 +949,8 @@ def instrument_file(
827949
use_full_instr,
828950
funcs_to_instr,
829951
API_dump_stack_trace,
952+
sampling_interval,
953+
warm_up_steps,
830954
)
831955
# annotate stages
832956
instrumented_source = annotate_stage(instrumented_source)

traincheck/instrumentor/tracer.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,9 @@ def function_wrapper(
162162
TypeError: function_wrapper() got multiple values for argument 'arg_name'
163163
"""
164164

165-
global DISABLE_WRAPPER
166165
global PROCESS_ID
167166

168-
if DISABLE_WRAPPER:
167+
if config.DISABLE_WRAPPER:
169168
# TODO: all meta vars update should be done outside the function_wrapper (e.g. step increment) by applying a separate wrapper
170169
return original_function(*args, **kwargs)
171170

@@ -403,7 +402,7 @@ def wrapper(
403402
increment_step = False
404403
if original_function_name.endswith(".step"):
405404
owner = get_owner_class(original_function)
406-
if issubclass(owner, torch.optim.Optimizer):
405+
if owner and issubclass(owner, torch.optim.Optimizer):
407406
increment_step = True
408407
# determine statically whether to dump the trace
409408
if not disable_dump:
@@ -412,7 +411,14 @@ def wrapper(
412411
@functools.wraps(original_function)
413412
def wrapped(*args, **kwargs):
414413
if increment_step:
415-
META_VARS["step"] += 1
414+
# Meta var update for step is now handled by traincheck.instrumentor.control.start_step
415+
# which is injected into training loops.
416+
# However, for backward compatibility or if injection fails, we might want to keep basic step counting?
417+
# User specifically asked to move logic. If we keep it here, we might double count if both run.
418+
# But injection is "An easy way out".
419+
# Let's check META_VARS.
420+
pass
421+
416422
return function_wrapper(
417423
original_function,
418424
original_function_name,
@@ -554,6 +560,8 @@ def __init__(
554560
use_full_instr: bool,
555561
funcs_to_instr: Optional[list[str]] = None,
556562
API_dump_stack_trace: bool = False,
563+
sampling_interval: int = 1,
564+
warm_up_steps: int = 0,
557565
):
558566
"""
559567
Instruments the specified target with additional tracing functionality.
@@ -576,12 +584,24 @@ def __init__(
576584
and the functions in this list will be instrumented with dump enabled. NOTE: If this list is provided, use_full_str must be set to False. WRAP_WITHOUT_DUMP will be ignored.
577585
API_dump_stack_trace (bool):
578586
Whether to dump the stack trace of the function call. Enabling this will add the stack trace to the trace log.
587+
sampling_interval (int):
588+
The interval for sampling-based instrumentation. Every Nth step will be instrumented. Defaults to 1.
589+
warm_up_steps (int):
590+
The number of initial steps to always instrument. Defaults to 0.
579591
580592
Indirectly, at initialization, the instrumentor will also load the instr_opts.json file if it exists.
581593
This file is automatically generated by the `collect_trace` script when `--invariants` is provided.
582594
The user should not need to interact with this file directly.
583595
584596
"""
597+
if sampling_interval:
598+
if config.INSTRUMENTATION_POLICY is None:
599+
config.INSTRUMENTATION_POLICY = {}
600+
config.INSTRUMENTATION_POLICY["interval"] = sampling_interval
601+
if warm_up_steps is not None:
602+
if config.INSTRUMENTATION_POLICY is None:
603+
config.INSTRUMENTATION_POLICY = {}
604+
config.INSTRUMENTATION_POLICY["warm_up"] = warm_up_steps
585605

586606
self.instrumenting = True
587607
if isinstance(target, types.ModuleType):

0 commit comments

Comments
 (0)