You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: traincheck/config/config.py
+7Lines changed: 7 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -95,6 +95,13 @@
95
95
TYPE_ERR_THRESHOLD=3
96
96
RECURSION_ERR_THRESHOLD=5
97
97
98
+
INSTRUMENTATION_POLICY= {
99
+
"interval": 1,
100
+
"warm_up": 1, # default to 1 to ensure the first step is always instrumented: before warm-up is depleted, we do instrumentation with interval=1, after warm-up is depleted, we do instrumentation with the specified interval
Copy file name to clipboardExpand all lines: traincheck/instrumentor/proxy_wrapper/proxy.py
+4Lines changed: 4 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -8,6 +8,7 @@
8
8
9
9
importtorch
10
10
11
+
importtraincheck.config.configasconfig
11
12
importtraincheck.instrumentor.proxy_wrapper.proxy_configasproxy_config# HACK: cannot directly import config variables as then they would be local variables
# We need to insert the import at the top of the file ideally,
183
+
# but inserting inside the loop works if we deal with python scoping (imports are valid statements).
184
+
# Actually proper way is to add import at module level.
185
+
# But `visit_Module` is not here.
186
+
# For simplicity, let's just use fully qualified name or inject import in the loop (a bit inefficient but works).
187
+
# Better: Inject `import traincheck.instrumentor.control` at top of loop or use `traincheck.instrumentor.control.start_step()` with import logic handled elsewhere?
188
+
# The `InsertTracerVisitor` modifies the module. We can add an import to the module body if we had access.
189
+
# `visit_Import` adds imports.
190
+
# Let's assume `traincheck` is importable.
191
+
192
+
# Helper to create `traincheck.instrumentor.control.start_step()` call
193
+
# And ensure import is present.
194
+
# Actually `InsertTracerVisitor` is used on the whole file.
195
+
# Let's just blindly insert the call logic and rely on the fact that we can insert an import at the top of the loop
196
+
# or just assume the user code can handle it if we inject the import statement right before the call.
197
+
198
+
# Let's inject:
199
+
# from traincheck.instrumentor.control import start_step
200
+
# start_step()
201
+
202
+
node.body.insert(0, call_stmt)
203
+
node.body.insert(0, import_stmt)
204
+
returnnode
205
+
90
206
91
207
definstrument_library(
92
208
source: str,
@@ -95,6 +211,8 @@ def instrument_library(
95
211
use_full_instr: bool,
96
212
funcs_to_instr: list[str] |None,
97
213
API_dump_stack_trace: bool,
214
+
sampling_interval: int,
215
+
warm_up_steps: int,
98
216
) ->str:
99
217
"""
100
218
Instruments the given source code and returns the instrumented source code.
@@ -116,6 +234,8 @@ def instrument_library(
116
234
use_full_instr,
117
235
funcs_to_instr,
118
236
API_dump_stack_trace,
237
+
sampling_interval,
238
+
warm_up_steps,
119
239
)
120
240
root=visitor.visit(root)
121
241
source=ast.unparse(root)
@@ -811,6 +931,8 @@ def instrument_file(
811
931
instr_descriptors: bool,
812
932
no_auto_var_instr: bool,
813
933
use_torch_compile: bool,
934
+
sampling_interval: int=1,
935
+
warm_up_steps: int=0,
814
936
) ->str:
815
937
"""
816
938
Instruments the given file and returns the instrumented source code.
# Meta var update for step is now handled by traincheck.instrumentor.control.start_step
415
+
# which is injected into training loops.
416
+
# However, for backward compatibility or if injection fails, we might want to keep basic step counting?
417
+
# User specifically asked to move logic. If we keep it here, we might double count if both run.
418
+
# But injection is "An easy way out".
419
+
# Let's check META_VARS.
420
+
pass
421
+
416
422
returnfunction_wrapper(
417
423
original_function,
418
424
original_function_name,
@@ -554,6 +560,8 @@ def __init__(
554
560
use_full_instr: bool,
555
561
funcs_to_instr: Optional[list[str]] =None,
556
562
API_dump_stack_trace: bool=False,
563
+
sampling_interval: int=1,
564
+
warm_up_steps: int=0,
557
565
):
558
566
"""
559
567
Instruments the specified target with additional tracing functionality.
@@ -576,12 +584,24 @@ def __init__(
576
584
and the functions in this list will be instrumented with dump enabled. NOTE: If this list is provided, use_full_str must be set to False. WRAP_WITHOUT_DUMP will be ignored.
577
585
API_dump_stack_trace (bool):
578
586
Whether to dump the stack trace of the function call. Enabling this will add the stack trace to the trace log.
587
+
sampling_interval (int):
588
+
The interval for sampling-based instrumentation. Every Nth step will be instrumented. Defaults to 1.
589
+
warm_up_steps (int):
590
+
The number of initial steps to always instrument. Defaults to 0.
579
591
580
592
Indirectly, at initialization, the instrumentor will also load the instr_opts.json file if it exists.
581
593
This file is automatically generated by the `collect_trace` script when `--invariants` is provided.
582
594
The user should not need to interact with this file directly.
0 commit comments