Skip to content

Commit ecb7468

Browse files
authored
fix:support np.array in traverse (#132)
1 parent 9fbb804 commit ecb7468

4 files changed

Lines changed: 25 additions & 1 deletion

File tree

.github/workflows/tests.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,20 @@ jobs:
2626
with:
2727
python-version: '3.10'
2828
cache: 'pip'
29+
- name: Clean system disk space
30+
run: |
31+
sudo apt-get clean
32+
sudo rm -rf /var/lib/apt/lists/*
33+
echo "Disk space after cleanup:"
34+
df -h /
2935
- name: Install dependencies
3036
run: |
3137
python -m pip install --upgrade pip
3238
pip install -e .[full]
3339
make install
40+
pip cache purge
41+
sudo apt-get clean
42+
sudo rm -rf /var/lib/apt/lists/*
3443
- name: Run tests
3544
run: make test
3645
- name: Upload coverage report

padiff/abstracts/hooks/hook.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ def info_hook(model, input, output, net_id):
139139

140140
# if this api is not processing tensors, do not create report
141141
if output is None or all([not isinstance(x, (paddle.Tensor, torch.Tensor)) for x in flatten(output)]):
142+
logger.warning_once(
143+
f"All outputs of {model.__class__.__name__} are not tensors. Skip capturing these outputs."
144+
)
142145
return None
143146

144147
# if an api under black_list_recursively, do not create report

padiff/tools/dump.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import defaultdict
1516
import json
1617
import os
1718
import sys
@@ -130,6 +131,8 @@ def dump_report_node(wrap_node, tensor_dumper):
130131

131132

132133
def dump_param_prototype(model, dump_fn, file_path):
134+
skiped_layers = defaultdict(list)
135+
133136
def dump_param_with_fn(model, fn, target_models):
134137
param_info = {
135138
"name": model.class_name,
@@ -150,7 +153,7 @@ def dump_param_with_fn(model, fn, target_models):
150153
if buffer_name not in params_found:
151154
fn(buffer_name, buffer, param_info)
152155
else:
153-
logger.debug(f"Layer {model.class_name} ({model.route}) is NOT in target_models. Skipping.")
156+
skiped_layers[model.class_name].append(model.route)
154157

155158
for name, child in model.named_children():
156159
param_info["children"].append(dump_param_with_fn(child, fn, target_models))
@@ -159,6 +162,11 @@ def dump_param_with_fn(model, fn, target_models):
159162
target_models = [layer.model for layer in model.marker.traversal_for_assign_weight()]
160163
param_info = dump_param_with_fn(model, dump_fn, target_models)
161164

165+
logger.debug_once("Params dump SKIPPED: Some layers have no available parameters(like weights).\n")
166+
for model_name, routes in skiped_layers.items():
167+
routes_str = "\n".join([f" {route}" for route in routes])
168+
logger.debug(f"Params dump SKIPPED: {model_name}.\nIncluded routes:\n{routes_str}")
169+
162170
model_info = {
163171
"model_name": model.name,
164172
"framework": model.framework,

padiff/utils/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ def traverse(structure, on_leaf, on_container=None):
144144
if isinstance(structure, (paddle.Tensor, torch.Tensor)):
145145
return on_leaf(structure)
146146

147+
# numpy
148+
if isinstance(structure, np.ndarray):
149+
return on_leaf(structure)
150+
147151
# namedtuple
148152
if hasattr(structure, "_fields"):
149153
result = type(structure)(

0 commit comments

Comments
 (0)