Skip to content

Commit 814743d

Browse files
committed
fix:support np.array in traverse
1 parent 9fbb804 commit 814743d

3 files changed

Lines changed: 16 additions & 1 deletion

File tree

padiff/abstracts/hooks/hook.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ def info_hook(model, input, output, net_id):
139139

140140
# if this api is not processing tensors, do not create report
141141
if output is None or all([not isinstance(x, (paddle.Tensor, torch.Tensor)) for x in flatten(output)]):
142+
logger.warning_once(
143+
f"All outputs of {model.__class__.__name__} are not tensors. Skip capturing these outputs."
144+
)
142145
return None
143146

144147
# if an api under black_list_recursively, do not create report

padiff/tools/dump.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import defaultdict
1516
import json
1617
import os
1718
import sys
@@ -130,6 +131,8 @@ def dump_report_node(wrap_node, tensor_dumper):
130131

131132

132133
def dump_param_prototype(model, dump_fn, file_path):
134+
skiped_layers = defaultdict(list)
135+
133136
def dump_param_with_fn(model, fn, target_models):
134137
param_info = {
135138
"name": model.class_name,
@@ -150,7 +153,7 @@ def dump_param_with_fn(model, fn, target_models):
150153
if buffer_name not in params_found:
151154
fn(buffer_name, buffer, param_info)
152155
else:
153-
logger.debug(f"Layer {model.class_name} ({model.route}) is NOT in target_models. Skipping.")
156+
skiped_layers[model.class_name].append(model.route)
154157

155158
for name, child in model.named_children():
156159
param_info["children"].append(dump_param_with_fn(child, fn, target_models))
@@ -159,6 +162,11 @@ def dump_param_with_fn(model, fn, target_models):
159162
target_models = [layer.model for layer in model.marker.traversal_for_assign_weight()]
160163
param_info = dump_param_with_fn(model, dump_fn, target_models)
161164

165+
logger.debug_once("Params dump SKIPPED: Some layers have no available parameters(like weights).\n")
166+
for model_name, routes in skiped_layers.items():
167+
routes_str = "\n".join([f" {route}" for route in routes])
168+
logger.debug(f"Params dump SKIPPED: {model_name}.\nIncluded routes:\n{routes_str}")
169+
162170
model_info = {
163171
"model_name": model.name,
164172
"framework": model.framework,

padiff/utils/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ def traverse(structure, on_leaf, on_container=None):
144144
if isinstance(structure, (paddle.Tensor, torch.Tensor)):
145145
return on_leaf(structure)
146146

147+
# numpy
148+
if isinstance(structure, np.ndarray):
149+
return on_leaf(structure)
150+
147151
# namedtuple
148152
if hasattr(structure, "_fields"):
149153
result = type(structure)(

0 commit comments

Comments
 (0)