Skip to content

Commit d444ea7

Browse files
committed
[Device] Use torch.device(npu)
1 parent 9d6b550 commit d444ea7

41 files changed

Lines changed: 41 additions & 124 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

experiments/gemm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,4 @@ def custom_matmul(a, b):
4848
if 'pytorchsim_functional_mode' in os.environ:
4949
del os.environ['pytorchsim_functional_mode']
5050

51-
from Scheduler.scheduler import PyTorchSimRunner
52-
module = PyTorchSimRunner.setup_device()
53-
device = module.custom_device()
5451
run_matmul(size[0], size[1], size[2], config)

scripts/ILS_experiment/test_matmul.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,5 @@ def custom_matmul(bias, a, b):
6060
args = parser.parse_args()
6161
shape = tuple(map(int, args.shape.strip('()').split(',')))
6262

63-
from Scheduler.scheduler import PyTorchSimRunner
64-
module = PyTorchSimRunner.setup_device()
65-
device = module.custom_device()
66-
test_matmul(device, *shape)
63+
device = torch.device("npu:0")
64+
test_matmul(device, *shape)

scripts/chiplet_prep.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def modify_file(dump_path, name, address_numa_stride=None, subgraph_map=None):
6464
import sys
6565
sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim'))
6666

67-
from Scheduler.scheduler import PyTorchSimRunner
68-
module = PyTorchSimRunner.setup_device()
69-
device = module.custom_device()
67+
device = torch.device("npu:0")
7068
parser = argparse.ArgumentParser(description='Process folder argument.')
7169
parser.add_argument('size', type=int, help='Folder value', default=256)
7270
args = parser.parse_args()

tests/Diffusion/test_diffusion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -637,9 +637,7 @@ def test_timesteps(
637637
args = parser.parse_args()
638638

639639
sys.path.append(os.environ.get("TORCHSIM_DIR", "/workspace/PyTorchSim"))
640-
from Scheduler.scheduler import PyTorchSimRunner
641-
module = PyTorchSimRunner.setup_device()
642-
device = module.custom_device()
640+
device = torch.device("npu:0")
643641

644642
#test_upsample2d(device)
645643
#test_groupnorm(device)

tests/Fusion/test_addmm_residual.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ def addmm_residual(a, b, c, d):
4343
import sys
4444
sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim'))
4545

46-
from Scheduler.scheduler import PyTorchSimRunner
47-
module = PyTorchSimRunner.setup_device()
48-
device = module.custom_device()
46+
device = torch.device("npu:0")
4947
test_addmm_residual(device, 32, 32, 32)
5048
test_addmm_residual(device, 128, 128, 128)
5149
test_addmm_residual(device, 512, 512, 512)

tests/Fusion/test_attention_fusion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,7 @@ def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512):
7575
import sys
7676
sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim'))
7777

78-
from Scheduler.scheduler import PyTorchSimRunner
79-
module = PyTorchSimRunner.setup_device()
80-
device = module.custom_device()
78+
device = torch.device("npu:0")
8179
test_MHA(device)
8280
# test_Attention(device, head=16, seq=512, d_k=64)
8381
# test_MHA(device, num_heads=12, embed_dim=768)

tests/Fusion/test_bmm_reduction.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ def bmm(a, b):
4242
import sys
4343
sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim'))
4444

45-
from Scheduler.scheduler import PyTorchSimRunner
46-
module = PyTorchSimRunner.setup_device()
47-
device = module.custom_device()
45+
device = torch.device("npu:0")
4846
#test_bmm_reduce(device)
4947
test_bmm_reduce(device, 12, 512)
5048
test_bmm_reduce(device, 4, 256)

tests/Fusion/test_conv_fusion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ def custom_conv_bn_relu(a, b, bias, c, d, e, f):
101101
import sys
102102
sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim'))
103103

104-
from Scheduler.scheduler import PyTorchSimRunner
105-
module = PyTorchSimRunner.setup_device()
106-
device = module.custom_device()
104+
device = torch.device("npu:0")
107105

108106
# Vanila test
109107
test_conv_residual(device, batch_size=3, in_channels=64, out_channels=64, input_size=28, kernel_size=3, stride=1, padding=1)

tests/Fusion/test_matmul_activation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def test_matmul_activation(device, batch_size=16, input_size=32, output_size=8,
7373
import sys
7474
sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim'))
7575

76-
from Scheduler.scheduler import PyTorchSimRunner
77-
module = PyTorchSimRunner.setup_device()
78-
device = module.custom_device()
76+
device = torch.device("npu:0")
7977
test_matmul_activation(device)
8078
test_matmul_activation(device, batch_size=32, input_size=32, output_size=32, activation_fn="sigmoid")
8179
test_matmul_activation(device, batch_size=42, input_size=42, output_size=42, activation_fn="sigmoid")

tests/Fusion/test_matmul_reduction.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,7 @@ def matmul_fused(a, b, c, d):
8989
import sys
9090
sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim'))
9191

92-
from Scheduler.scheduler import PyTorchSimRunner
93-
module = PyTorchSimRunner.setup_device()
94-
device = module.custom_device()
92+
device = torch.device("npu:0")
9593
test_matmul_reduce(device, 3072, 512, 768)
9694
test_matmul_var_mean(device)
9795
test_matmul_add_var_mean(device)

0 commit comments

Comments
 (0)