-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimulate.py
More file actions
65 lines (55 loc) · 2.17 KB
/
simulate.py
File metadata and controls
65 lines (55 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from compiler.target_gen.memory.memory_manager import MemoryManager
from compiler.scheduler.normal_scheduler import NormalScheduler
from compiler.scheduler.wu_imm_scheduler import WUImmScheduler
from simulator.memory import Memory
import torch.nn as nn
from model.resnet import resnet18_cifar
from model.lenet import *
from model.alexnet import AlexNet
from compiler.utils.unique_class_name import unique_class_name
from compiler.graph.replace_tool import Finder,ReplaceTool
from compiler.config import Config,CodeGen
from converter import Converter
from backends.sparse_train.target_code.instruction_gen import InstructionGenerator
from compiler.target.dataflow import Dataflow
import numpy as np
from executer.executer import Executer
def run():
# torch_net = TestNet()
torch_net = resnet18_cifar()
# torch_net = AlexNet()
torch_net.eval()
total_params = sum(p.numel() for p in torch_net.parameters())
in_shape=[4,3,32,32]
print("in_shape:",in_shape)
converter = Converter(torch_net,in_shape=in_shape)
converter.convert()
net = converter.net
replace_tool = ReplaceTool(net=net,config_path="./backends/sparse_train/replace.yaml")
# replace_tool.replace_all()
scheduler = NormalScheduler()
# scheduler = WUImmScheduler()
scheduler.schedule(net)
# print(net)
print(net.count())
net.reduce_tensor()
print(net.count())
net.set_tensor_index()
MemoryManager().tensor_memory_layout2(net)
input = torch.randn(in_shape)
input.requires_grad=True
executer = Executer(net)
output = executer.execute(input,to="BEdge_0").tensors.get_data("output_grad")
torch_output = torch_net(input)
torch_output = torch.sum(torch_output)
torch_output.backward()
torch_output = input.grad
if output.shape==torch_output.shape:
print(torch.max(torch.abs(output-torch_output)))
print(torch.max(torch.abs(output-torch_output))<0.01)
else:
print(f"Shape is not equal! output.shape={output.shape}, torch_output.shape={torch_output.shape}")
if __name__=="__main__":
run()