Skip to content

Commit 5c5e61c

Browse files
wok1909YWHyuk
authored andcommitted
[Implement] Hook and GuardImpl for extension device
1 parent b7a275e commit 5c5e61c

8 files changed

Lines changed: 221 additions & 10 deletions
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#include "ExtensionDeviceGuardImpl.h"
2+
#include <c10/core/impl/DeviceGuardImplRegistry.h>
3+
4+
namespace c10::extension_device::impl {
5+
6+
C10_REGISTER_GUARD_IMPL(extension_device, ExtensionDeviceGuardImpl);
7+
8+
} // namespace c10::extension_device::impl
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#pragma once
2+
3+
#include <c10/core/DeviceGuard.h>
4+
#include <c10/core/impl/DeviceGuardImplInterface.h>
5+
#include <c10/core/Stream.h>
6+
#include <c10/core/Event.h>
7+
#include <c10/core/DeviceType.h>
8+
#include <c10/util/Optional.h>
9+
10+
namespace c10::extension_device::impl {
11+
12+
struct ExtensionDeviceGuardImpl final : public c10::impl::DeviceGuardImplInterface {
13+
static constexpr DeviceType static_type = DeviceType::PrivateUse1; // ✅ your backend type
14+
15+
ExtensionDeviceGuardImpl() = default;
16+
17+
explicit ExtensionDeviceGuardImpl(DeviceType t) {
18+
TORCH_CHECK(
19+
t == static_type,
20+
"ExtensionDeviceGuardImpl initialized with non-extension_device DeviceType: ",
21+
t);
22+
}
23+
24+
// --------------------------------------------------------------------------
25+
// 기본적인 device guard (CPU처럼 동작)
26+
// --------------------------------------------------------------------------
27+
DeviceType type() const override {
28+
return static_type;
29+
}
30+
31+
Device exchangeDevice(Device d) const override {
32+
TORCH_CHECK(d.type() == static_type, "Expected extension_device but got ", d);
33+
return d; // nothing to exchange, CPU-like
34+
}
35+
36+
Device getDevice() const override {
37+
return Device(static_type, 0);
38+
}
39+
40+
void setDevice(Device d) const override {
41+
TORCH_CHECK(d.type() == static_type, "Expected extension_device but got ", d);
42+
}
43+
44+
void uncheckedSetDevice(Device d) const noexcept override {}
45+
46+
DeviceIndex deviceCount() const noexcept override {
47+
return 1; // pretend single device
48+
}
49+
50+
// --------------------------------------------------------------------------
51+
// Stream handling (동기식이므로 기본 stream만 사용)
52+
// --------------------------------------------------------------------------
53+
Stream getStream(Device d) const override {
54+
return Stream(Stream::DEFAULT, d);
55+
}
56+
57+
Stream getNewStream(Device d, int priority = 0) const override {
58+
return Stream(Stream::DEFAULT, d);
59+
}
60+
61+
Stream getStreamFromGlobalPool(Device d, bool = false) const override {
62+
return Stream(Stream::DEFAULT, d);
63+
}
64+
65+
Stream exchangeStream(Stream s) const override {
66+
return s;
67+
}
68+
69+
bool queryStream(const Stream& stream) const override {
70+
(void)stream;
71+
return true;
72+
}
73+
74+
void synchronizeStream(const Stream& stream) const override {
75+
(void)stream;
76+
}
77+
78+
void synchronizeDevice(DeviceIndex device_index) const override {
79+
(void)device_index;
80+
}
81+
82+
// --------------------------------------------------------------------------
83+
// Event handling (전부 no-op)
84+
// --------------------------------------------------------------------------
85+
void destroyEvent(void* event, const DeviceIndex device_index) const noexcept override {
86+
(void)event;
87+
(void)device_index;
88+
}
89+
90+
void record(void** event, const Stream& stream, const DeviceIndex device_index, const EventFlag flag) const override {
91+
(void)event;
92+
(void)stream;
93+
(void)device_index;
94+
(void)flag;
95+
}
96+
97+
void block(void* event, const Stream& stream) const override {
98+
(void)event;
99+
(void)stream;
100+
}
101+
102+
bool queryEvent(void* event) const override {
103+
(void)event;
104+
return true;
105+
}
106+
107+
void synchronizeEvent(void* event) const override {
108+
(void)event;
109+
}
110+
111+
double elapsedTime(void* start_event, void* end_event, const DeviceIndex device_index) const override {
112+
(void)start_event;
113+
(void)end_event;
114+
(void)device_index;
115+
return 0.0;
116+
}
117+
118+
// --------------------------------------------------------------------------
119+
// Misc (allocator integration)
120+
// --------------------------------------------------------------------------
121+
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) const override {
122+
(void)data_ptr;
123+
(void)stream;
124+
}
125+
};
126+
127+
} // namespace c10::extension_device::impl
Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,12 @@ static inline at::MemoryFormat fix_memory_format(c10::optional<at::MemoryFormat>
5555
return mf;
5656
}
5757

58+
#include "ExtensionDeviceGuardImpl.h"
59+
5860
static uint64_t op_counter = 0;
5961
static uint64_t last_saved_value = 0;
6062

61-
// register guard
62-
namespace at {
63-
namespace detail {
64-
65-
C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
66-
67-
}} // namespace at::detail
63+
C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::extension_device::impl::ExtensionDeviceGuardImpl);
6864

6965
// basic dummy add function
7066
at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
File renamed without changes.
File renamed without changes.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include "extension_hooks.h"
2+
3+
bool ExtensionPU1Hooks::isBuilt() const { return true; }
4+
bool ExtensionPU1Hooks::isAvailable() const { return true; }
5+
6+
const at::Generator& ExtensionPU1Hooks::getDefaultGenerator(c10::DeviceIndex idx) const {
7+
if (idx < 0) idx = 0;
8+
static std::vector<at::Generator> gens;
9+
static std::mutex m;
10+
std::lock_guard<std::mutex> g(m);
11+
if (gens.size() <= (size_t)idx) gens.resize((size_t)idx + 1);
12+
if (!gens[idx].defined()) gens[idx] = at::GetGeneratorForPrivateuse1(idx);
13+
return gens[idx]; // 영속 객체 참조 반환
14+
}
15+
16+
at::Generator ExtensionPU1Hooks::getNewGenerator(c10::DeviceIndex idx) const {
17+
if (idx < 0) idx = 0;
18+
return at::GetGeneratorForPrivateuse1(idx);
19+
}
20+
21+
at::Device ExtensionPU1Hooks::getDeviceFromPtr(void* data) const {
22+
return at::Device(at::kPrivateUse1, 0); // MVP: 단일 디바이스 가정
23+
}
24+
25+
bool ExtensionPU1Hooks::isPinnedPtr(const void* data) const {
26+
return false;
27+
}
28+
29+
at::Allocator* ExtensionPU1Hooks::getPinnedMemoryAllocator() const {
30+
return at::getHostAllocator(at::kPrivateUse1);
31+
}
32+
33+
bool ExtensionPU1Hooks::hasPrimaryContext(c10::DeviceIndex device_index) const { return true; }
34+
35+
void ExtensionPU1Hooks::resizePrivateUse1Bytes(const c10::Storage&, size_t) const {
36+
TORCH_CHECK(false, "resizePrivateUse1Bytes not implemented");
37+
}
38+
39+
// REGISTER_EXTENSION_HOOKS(ExtensionPU1Hooks);
40+
41+
namespace {
42+
struct AutoRegistrar {
43+
AutoRegistrar() {
44+
at::RegisterPrivateUse1HooksInterface(new ExtensionPU1Hooks());
45+
}
46+
};
47+
static AutoRegistrar _auto_registrar;
48+
}

PyTorchSimDevice/extension_hooks.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
3+
#include <ATen/core/CachingHostAllocator.h>
4+
#include <ATen/detail/PrivateUse1HooksInterface.h>
5+
6+
#include <ATen/core/Generator.h>
7+
#include <c10/core/Allocator.h>
8+
#include <c10/core/Device.h>
9+
#include <c10/core/Storage.h>
10+
#include <c10/util/Exception.h>
11+
12+
struct ExtensionPU1Hooks final : public at::PrivateUse1HooksInterface {
13+
ExtensionPU1Hooks() {}
14+
bool isBuilt() const;
15+
bool isAvailable() const;
16+
17+
const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override;
18+
19+
at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override;
20+
21+
at::Device getDeviceFromPtr(void* data) const override;
22+
23+
bool isPinnedPtr(const void* data) const override;
24+
25+
at::Allocator* getPinnedMemoryAllocator() const override;
26+
27+
bool hasPrimaryContext(c10::DeviceIndex device_index) const override;
28+
29+
void resizePrivateUse1Bytes(const c10::Storage& /*storage*/, size_t /*newsize*/) const override;
30+
};

Scheduler/scheduler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from PyTorchSimFrontend.extension_codecache import hash_prefix
99
from Simulator.simulator import TOGSimulator
1010
from PyTorchSimFrontend import extension_config
11-
from PyTorchSimFrontend.extension_device_interface import ExtensionDeviceInterface
11+
from PyTorchSimDevice.extension_device_interface import ExtensionDeviceInterface
1212

1313
from torch._dynamo.device_interface import register_interface_for_device
1414

@@ -173,14 +173,16 @@ def setup_device(cls):
173173
return cls.NPU_MODULE
174174
source_file_path = os.path.dirname(os.path.abspath(__file__))
175175
source_file = os.path.join(
176-
source_file_path, f"{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimFrontend/extension_device.cpp"
176+
source_file_path, f"{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimDevice/extension_device.cpp"
177177
)
178+
hook_file = os.path.join(source_file_path, f"{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimDevice/extension_hooks.cpp")
178179

179180
import torch.utils.cpp_extension
180181
module = torch.utils.cpp_extension.load(
181182
name="npu",
182183
sources=[
183184
str(source_file),
185+
str(hook_file),
184186
],
185187
extra_cflags=["-g"],
186188
verbose=True,
@@ -205,7 +207,7 @@ def setup_device(cls):
205207
lambda scheduling: MLIRScheduling(scheduling),
206208
ExtensionWrapperCodegen
207209
)
208-
import PyTorchSimFrontend.extension_device_op_overrides
210+
import PyTorchSimDevice.extension_device_op_overrides
209211

210212
assert(
211213
get_wrapper_codegen_for_device("npu")

0 commit comments

Comments
 (0)