Skip to content

Commit 7b653b3

Browse files
committed
[OpenReg] Add Python interface for device stream, event API
1 parent 47c563e commit 7b653b3

4 files changed

Lines changed: 439 additions & 0 deletions

File tree

PyTorchSimDevice/torch_openreg/csrc/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ file(GLOB_RECURSE SOURCE_FILES
66

77
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
88

9+
target_include_directories(${LIBRARY_NAME} PRIVATE
10+
${PROJECT_SOURCE_DIR}/third_party/openreg
11+
)
12+
913
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python_library torch_openreg)
1014

1115
if(WIN32)

PyTorchSimDevice/torch_openreg/csrc/Module.cpp

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
#include <runtime/OpenRegFunctions.h>
1212
#include <amp/OpenRegAmp.h>
13+
#include <include/openreg.h>
14+
#include <functional>
15+
#include <memory>
16+
#include <thread>
1317

1418
static PyObject* _initExtension(PyObject* self, PyObject* noargs) {
1519
HANDLE_TH_ERRORS
@@ -135,6 +139,274 @@ PyObject* _getAmpSupportedDtype(PyObject* self, PyObject* noargs) {
135139
END_HANDLE_TH_ERRORS
136140
}
137141

142+
// Stream functions
143+
PyObject* _streamCreate(PyObject* self, PyObject* noargs) {
144+
HANDLE_TH_ERRORS
145+
torch::utils::device_lazy_init(at::kPrivateUse1);
146+
orStream_t stream = nullptr;
147+
orError_t err = orStreamCreate(&stream);
148+
std::cerr << "[DEBUG] Stream created: " << stream << std::endl;
149+
if (err != orSuccess) {
150+
TORCH_CHECK(false, "Failed to create stream");
151+
}
152+
return THPUtils_packInt64(reinterpret_cast<int64_t>(stream));
153+
END_HANDLE_TH_ERRORS
154+
}
155+
156+
PyObject* _streamCreateWithPriority(PyObject* self, PyObject* args) {
157+
HANDLE_TH_ERRORS
158+
TORCH_CHECK(PyTuple_Size(args) == 2, "stream_create_with_priority expects 2 arguments");
159+
PyObject* flags_obj = PyTuple_GetItem(args, 0);
160+
PyObject* priority_obj = PyTuple_GetItem(args, 1);
161+
TORCH_CHECK(THPUtils_checkLong(flags_obj), "flags must be an int");
162+
TORCH_CHECK(THPUtils_checkLong(priority_obj), "priority must be an int");
163+
unsigned int flags = static_cast<unsigned int>(THPUtils_unpackLong(flags_obj));
164+
int priority = static_cast<int>(THPUtils_unpackLong(priority_obj));
165+
166+
torch::utils::device_lazy_init(at::kPrivateUse1);
167+
orStream_t stream = nullptr;
168+
orError_t err = orStreamCreateWithPriority(&stream, flags, priority);
169+
if (err != orSuccess) {
170+
TORCH_CHECK(false, "Failed to create stream with priority");
171+
}
172+
return THPUtils_packInt64(reinterpret_cast<int64_t>(stream));
173+
END_HANDLE_TH_ERRORS
174+
}
175+
176+
PyObject* _streamDestroy(PyObject* self, PyObject* arg) {
177+
HANDLE_TH_ERRORS
178+
TORCH_CHECK(THPUtils_checkLong(arg), "stream_destroy expects an int");
179+
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(arg));
180+
orError_t err = orStreamDestroy(stream);
181+
if (err != orSuccess) {
182+
TORCH_CHECK(false, "Failed to destroy stream");
183+
}
184+
Py_RETURN_NONE;
185+
END_HANDLE_TH_ERRORS
186+
}
187+
188+
PyObject* _streamSynchronize(PyObject* self, PyObject* arg) {
189+
HANDLE_TH_ERRORS
190+
TORCH_CHECK(THPUtils_checkLong(arg), "stream_synchronize expects an int");
191+
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(arg));
192+
193+
orError_t err;
194+
Py_BEGIN_ALLOW_THREADS
195+
err = orStreamSynchronize(stream);
196+
Py_END_ALLOW_THREADS
197+
198+
if (err != orSuccess) {
199+
TORCH_CHECK(false, "Failed to synchronize stream");
200+
}
201+
Py_RETURN_NONE;
202+
END_HANDLE_TH_ERRORS
203+
}
204+
205+
PyObject* _streamQuery(PyObject* self, PyObject* arg) {
206+
HANDLE_TH_ERRORS
207+
TORCH_CHECK(THPUtils_checkLong(arg), "stream_query expects an int");
208+
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(arg));
209+
orError_t err = orStreamQuery(stream);
210+
if (err == orSuccess) {
211+
Py_RETURN_TRUE;
212+
} else {
213+
Py_RETURN_FALSE;
214+
}
215+
END_HANDLE_TH_ERRORS
216+
}
217+
218+
PyObject* _streamGetPriority(PyObject* self, PyObject* arg) {
219+
HANDLE_TH_ERRORS
220+
TORCH_CHECK(THPUtils_checkLong(arg), "stream_get_priority expects an int");
221+
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(arg));
222+
int priority = 0;
223+
orError_t err = orStreamGetPriority(stream, &priority);
224+
if (err != orSuccess) {
225+
TORCH_CHECK(false, "Failed to get stream priority");
226+
}
227+
return THPUtils_packInt32(priority);
228+
END_HANDLE_TH_ERRORS
229+
}
230+
231+
PyObject* _streamWaitEvent(PyObject* self, PyObject* args) {
232+
HANDLE_TH_ERRORS
233+
TORCH_CHECK(PyTuple_Size(args) == 2, "stream_wait_event expects 2 arguments");
234+
PyObject* stream_obj = PyTuple_GetItem(args, 0);
235+
PyObject* event_obj = PyTuple_GetItem(args, 1);
236+
TORCH_CHECK(THPUtils_checkLong(stream_obj), "stream must be an int");
237+
TORCH_CHECK(THPUtils_checkLong(event_obj), "event must be an int");
238+
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(stream_obj));
239+
orEvent_t event = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(event_obj));
240+
orError_t err = orStreamWaitEvent(stream, event, 0);
241+
if (err != orSuccess) {
242+
TORCH_CHECK(false, "Failed to wait for event");
243+
}
244+
Py_RETURN_NONE;
245+
END_HANDLE_TH_ERRORS
246+
}
247+
248+
// Event functions
249+
PyObject* _eventCreate(PyObject* self, PyObject* noargs) {
250+
HANDLE_TH_ERRORS
251+
torch::utils::device_lazy_init(at::kPrivateUse1);
252+
orEvent_t event = nullptr;
253+
orError_t err = orEventCreate(&event);
254+
if (err != orSuccess) {
255+
TORCH_CHECK(false, "Failed to create event");
256+
}
257+
return THPUtils_packInt64(reinterpret_cast<int64_t>(event));
258+
END_HANDLE_TH_ERRORS
259+
}
260+
261+
PyObject* _eventCreateWithFlags(PyObject* self, PyObject* arg) {
262+
HANDLE_TH_ERRORS
263+
TORCH_CHECK(THPUtils_checkLong(arg), "event_create_with_flags expects an int");
264+
unsigned int flags = static_cast<unsigned int>(THPUtils_unpackLong(arg));
265+
266+
torch::utils::device_lazy_init(at::kPrivateUse1);
267+
orEvent_t event = nullptr;
268+
orError_t err = orEventCreateWithFlags(&event, flags);
269+
if (err != orSuccess) {
270+
TORCH_CHECK(false, "Failed to create event with flags");
271+
}
272+
return THPUtils_packInt64(reinterpret_cast<int64_t>(event));
273+
END_HANDLE_TH_ERRORS
274+
}
275+
276+
PyObject* _eventDestroy(PyObject* self, PyObject* arg) {
277+
HANDLE_TH_ERRORS
278+
TORCH_CHECK(THPUtils_checkLong(arg), "event_destroy expects an int");
279+
orEvent_t event = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(arg));
280+
orError_t err = orEventDestroy(event);
281+
if (err != orSuccess) {
282+
TORCH_CHECK(false, "Failed to destroy event");
283+
}
284+
Py_RETURN_NONE;
285+
END_HANDLE_TH_ERRORS
286+
}
287+
288+
PyObject* _eventRecord(PyObject* self, PyObject* args) {
289+
HANDLE_TH_ERRORS
290+
TORCH_CHECK(PyTuple_Size(args) == 2, "event_record expects 2 arguments");
291+
PyObject* event_obj = PyTuple_GetItem(args, 0);
292+
PyObject* stream_obj = PyTuple_GetItem(args, 1);
293+
TORCH_CHECK(THPUtils_checkLong(event_obj), "event must be an int");
294+
TORCH_CHECK(THPUtils_checkLong(stream_obj), "stream must be an int");
295+
orEvent_t event = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(event_obj));
296+
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(stream_obj));
297+
orError_t err = orEventRecord(event, stream);
298+
if (err != orSuccess) {
299+
TORCH_CHECK(false, "Failed to record event");
300+
}
301+
Py_RETURN_NONE;
302+
END_HANDLE_TH_ERRORS
303+
}
304+
305+
PyObject* _eventSynchronize(PyObject* self, PyObject* arg) {
306+
HANDLE_TH_ERRORS
307+
TORCH_CHECK(THPUtils_checkLong(arg), "event_synchronize expects an int");
308+
orEvent_t event = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(arg));
309+
310+
orError_t err;
311+
Py_BEGIN_ALLOW_THREADS
312+
err = orEventSynchronize(event);
313+
Py_END_ALLOW_THREADS
314+
315+
if (err != orSuccess) {
316+
TORCH_CHECK(false, "Failed to synchronize event");
317+
}
318+
Py_RETURN_NONE;
319+
END_HANDLE_TH_ERRORS
320+
}
321+
322+
PyObject* _eventQuery(PyObject* self, PyObject* arg) {
323+
HANDLE_TH_ERRORS
324+
TORCH_CHECK(THPUtils_checkLong(arg), "event_query expects an int");
325+
orEvent_t event = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(arg));
326+
orError_t err = orEventQuery(event);
327+
if (err == orSuccess) {
328+
Py_RETURN_TRUE;
329+
} else {
330+
Py_RETURN_FALSE;
331+
}
332+
END_HANDLE_TH_ERRORS
333+
}
334+
335+
PyObject* _eventElapsedTime(PyObject* self, PyObject* args) {
336+
HANDLE_TH_ERRORS
337+
TORCH_CHECK(PyTuple_Size(args) == 2, "event_elapsed_time expects 2 arguments");
338+
PyObject* start_obj = PyTuple_GetItem(args, 0);
339+
PyObject* end_obj = PyTuple_GetItem(args, 1);
340+
TORCH_CHECK(THPUtils_checkLong(start_obj), "start event must be an int");
341+
TORCH_CHECK(THPUtils_checkLong(end_obj), "end event must be an int");
342+
orEvent_t start = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(start_obj));
343+
orEvent_t end = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(end_obj));
344+
float ms = 0.0f;
345+
orError_t err = orEventElapsedTime(&ms, start, end);
346+
if (err != orSuccess) {
347+
TORCH_CHECK(false, "Failed to get elapsed time");
348+
}
349+
return PyFloat_FromDouble(static_cast<double>(ms));
350+
END_HANDLE_TH_ERRORS
351+
}
352+
353+
PyObject* _deviceSynchronize(PyObject* self, PyObject* noargs) {
354+
HANDLE_TH_ERRORS
355+
torch::utils::device_lazy_init(at::kPrivateUse1);
356+
357+
orError_t err;
358+
Py_BEGIN_ALLOW_THREADS
359+
err = orDeviceSynchronize();
360+
Py_END_ALLOW_THREADS
361+
362+
if (err != orSuccess) {
363+
TORCH_CHECK(false, "Failed to synchronize device");
364+
}
365+
Py_RETURN_NONE;
366+
END_HANDLE_TH_ERRORS
367+
}
368+
369+
PyObject* _addTaskToStream(PyObject* self, PyObject* args) {
370+
HANDLE_TH_ERRORS
371+
TORCH_CHECK(PyTuple_Size(args) == 2, "add_task_to_stream expects 2 arguments");
372+
PyObject* stream_obj = PyTuple_GetItem(args, 0);
373+
PyObject* callable_obj = PyTuple_GetItem(args, 1);
374+
375+
TORCH_CHECK(THPUtils_checkLong(stream_obj), "stream must be an int");
376+
TORCH_CHECK(PyCallable_Check(callable_obj), "task must be callable");
377+
378+
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(stream_obj));
379+
380+
Py_INCREF(callable_obj);
381+
auto py_callable = std::shared_ptr<PyObject>(callable_obj, [](PyObject* obj) {
382+
PyGILState_STATE gstate = PyGILState_Ensure();
383+
Py_DECREF(obj);
384+
PyGILState_Release(gstate);
385+
});
386+
387+
auto task = [py_callable]() {
388+
PyGILState_STATE gstate = PyGILState_Ensure();
389+
try {
390+
PyObject* result = PyObject_CallObject(py_callable.get(), nullptr);
391+
if (result == nullptr) {
392+
PyErr_Print();
393+
PyErr_Clear();
394+
} else {
395+
Py_DECREF(result);
396+
}
397+
} catch (...) {
398+
}
399+
400+
PyGILState_Release(gstate);
401+
};
402+
orError_t err = openreg::addTaskToStream(stream, task);
403+
if (err != orSuccess) {
404+
TORCH_CHECK(false, "Failed to add task to stream");
405+
}
406+
Py_RETURN_NONE;
407+
END_HANDLE_TH_ERRORS
408+
}
409+
138410
static PyMethodDef methods[] = {
139411
{"_init", _initExtension, METH_NOARGS, nullptr},
140412
{"_get_default_generator", _getDefaultGenerator, METH_O, nullptr},
@@ -147,6 +419,26 @@ static PyMethodDef methods[] = {
147419
{"get_autocast_dtype", _getAutocastDtype, METH_NOARGS, nullptr},
148420
{"set_autocast_dtype", _setAutocastDtype, METH_O, nullptr},
149421
{"get_amp_supported_dtype", _getAmpSupportedDtype, METH_NOARGS, nullptr},
422+
// Stream functions
423+
{"_stream_create", _streamCreate, METH_NOARGS, nullptr},
424+
{"_stream_create_with_priority", _streamCreateWithPriority, METH_VARARGS, nullptr},
425+
{"_stream_destroy", _streamDestroy, METH_O, nullptr},
426+
{"_stream_synchronize", _streamSynchronize, METH_O, nullptr},
427+
{"_stream_query", _streamQuery, METH_O, nullptr},
428+
{"_stream_get_priority", _streamGetPriority, METH_O, nullptr},
429+
{"_stream_wait_event", _streamWaitEvent, METH_VARARGS, nullptr},
430+
// Event functions
431+
{"_event_create", _eventCreate, METH_NOARGS, nullptr},
432+
{"_event_create_with_flags", _eventCreateWithFlags, METH_O, nullptr},
433+
{"_event_destroy", _eventDestroy, METH_O, nullptr},
434+
{"_event_record", _eventRecord, METH_VARARGS, nullptr},
435+
{"_event_synchronize", _eventSynchronize, METH_O, nullptr},
436+
{"_event_query", _eventQuery, METH_O, nullptr},
437+
{"_event_elapsed_time", _eventElapsedTime, METH_VARARGS, nullptr},
438+
// Device functions
439+
{"_device_synchronize", _deviceSynchronize, METH_NOARGS, nullptr},
440+
// Stream task functions
441+
{"_add_task_to_stream", _addTaskToStream, METH_VARARGS, nullptr},
150442
{nullptr, nullptr, 0, nullptr}};
151443

152444
/*

0 commit comments

Comments
 (0)