1010
1111#include < runtime/OpenRegFunctions.h>
1212#include < amp/OpenRegAmp.h>
13+ #include < include/openreg.h>
14+ #include < functional>
15+ #include < memory>
16+ #include < thread>
1317
1418static PyObject* _initExtension (PyObject* self, PyObject* noargs) {
1519 HANDLE_TH_ERRORS
@@ -135,6 +139,274 @@ PyObject* _getAmpSupportedDtype(PyObject* self, PyObject* noargs) {
135139 END_HANDLE_TH_ERRORS
136140}
137141
142+ // Stream functions
143+ PyObject* _streamCreate (PyObject* self, PyObject* noargs) {
144+ HANDLE_TH_ERRORS
145+ torch::utils::device_lazy_init (at::kPrivateUse1 );
146+ orStream_t stream = nullptr ;
147+ orError_t err = orStreamCreate (&stream);
148+ std::cerr << " [DEBUG] Stream created: " << stream << std::endl;
149+ if (err != orSuccess) {
150+ TORCH_CHECK (false , " Failed to create stream" );
151+ }
152+ return THPUtils_packInt64 (reinterpret_cast <int64_t >(stream));
153+ END_HANDLE_TH_ERRORS
154+ }
155+
156+ PyObject* _streamCreateWithPriority (PyObject* self, PyObject* args) {
157+ HANDLE_TH_ERRORS
158+ TORCH_CHECK (PyTuple_Size (args) == 2 , " stream_create_with_priority expects 2 arguments" );
159+ PyObject* flags_obj = PyTuple_GetItem (args, 0 );
160+ PyObject* priority_obj = PyTuple_GetItem (args, 1 );
161+ TORCH_CHECK (THPUtils_checkLong (flags_obj), " flags must be an int" );
162+ TORCH_CHECK (THPUtils_checkLong (priority_obj), " priority must be an int" );
163+ unsigned int flags = static_cast <unsigned int >(THPUtils_unpackLong (flags_obj));
164+ int priority = static_cast <int >(THPUtils_unpackLong (priority_obj));
165+
166+ torch::utils::device_lazy_init (at::kPrivateUse1 );
167+ orStream_t stream = nullptr ;
168+ orError_t err = orStreamCreateWithPriority (&stream, flags, priority);
169+ if (err != orSuccess) {
170+ TORCH_CHECK (false , " Failed to create stream with priority" );
171+ }
172+ return THPUtils_packInt64 (reinterpret_cast <int64_t >(stream));
173+ END_HANDLE_TH_ERRORS
174+ }
175+
176+ PyObject* _streamDestroy (PyObject* self, PyObject* arg) {
177+ HANDLE_TH_ERRORS
178+ TORCH_CHECK (THPUtils_checkLong (arg), " stream_destroy expects an int" );
179+ orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (arg));
180+ orError_t err = orStreamDestroy (stream);
181+ if (err != orSuccess) {
182+ TORCH_CHECK (false , " Failed to destroy stream" );
183+ }
184+ Py_RETURN_NONE;
185+ END_HANDLE_TH_ERRORS
186+ }
187+
188+ PyObject* _streamSynchronize (PyObject* self, PyObject* arg) {
189+ HANDLE_TH_ERRORS
190+ TORCH_CHECK (THPUtils_checkLong (arg), " stream_synchronize expects an int" );
191+ orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (arg));
192+
193+ orError_t err;
194+ Py_BEGIN_ALLOW_THREADS
195+ err = orStreamSynchronize (stream);
196+ Py_END_ALLOW_THREADS
197+
198+ if (err != orSuccess) {
199+ TORCH_CHECK (false , " Failed to synchronize stream" );
200+ }
201+ Py_RETURN_NONE;
202+ END_HANDLE_TH_ERRORS
203+ }
204+
205+ PyObject* _streamQuery (PyObject* self, PyObject* arg) {
206+ HANDLE_TH_ERRORS
207+ TORCH_CHECK (THPUtils_checkLong (arg), " stream_query expects an int" );
208+ orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (arg));
209+ orError_t err = orStreamQuery (stream);
210+ if (err == orSuccess) {
211+ Py_RETURN_TRUE;
212+ } else {
213+ Py_RETURN_FALSE;
214+ }
215+ END_HANDLE_TH_ERRORS
216+ }
217+
218+ PyObject* _streamGetPriority (PyObject* self, PyObject* arg) {
219+ HANDLE_TH_ERRORS
220+ TORCH_CHECK (THPUtils_checkLong (arg), " stream_get_priority expects an int" );
221+ orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (arg));
222+ int priority = 0 ;
223+ orError_t err = orStreamGetPriority (stream, &priority);
224+ if (err != orSuccess) {
225+ TORCH_CHECK (false , " Failed to get stream priority" );
226+ }
227+ return THPUtils_packInt32 (priority);
228+ END_HANDLE_TH_ERRORS
229+ }
230+
231+ PyObject* _streamWaitEvent (PyObject* self, PyObject* args) {
232+ HANDLE_TH_ERRORS
233+ TORCH_CHECK (PyTuple_Size (args) == 2 , " stream_wait_event expects 2 arguments" );
234+ PyObject* stream_obj = PyTuple_GetItem (args, 0 );
235+ PyObject* event_obj = PyTuple_GetItem (args, 1 );
236+ TORCH_CHECK (THPUtils_checkLong (stream_obj), " stream must be an int" );
237+ TORCH_CHECK (THPUtils_checkLong (event_obj), " event must be an int" );
238+ orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (stream_obj));
239+ orEvent_t event = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (event_obj));
240+ orError_t err = orStreamWaitEvent (stream, event, 0 );
241+ if (err != orSuccess) {
242+ TORCH_CHECK (false , " Failed to wait for event" );
243+ }
244+ Py_RETURN_NONE;
245+ END_HANDLE_TH_ERRORS
246+ }
247+
248+ // Event functions
249+ PyObject* _eventCreate (PyObject* self, PyObject* noargs) {
250+ HANDLE_TH_ERRORS
251+ torch::utils::device_lazy_init (at::kPrivateUse1 );
252+ orEvent_t event = nullptr ;
253+ orError_t err = orEventCreate (&event);
254+ if (err != orSuccess) {
255+ TORCH_CHECK (false , " Failed to create event" );
256+ }
257+ return THPUtils_packInt64 (reinterpret_cast <int64_t >(event));
258+ END_HANDLE_TH_ERRORS
259+ }
260+
261+ PyObject* _eventCreateWithFlags (PyObject* self, PyObject* arg) {
262+ HANDLE_TH_ERRORS
263+ TORCH_CHECK (THPUtils_checkLong (arg), " event_create_with_flags expects an int" );
264+ unsigned int flags = static_cast <unsigned int >(THPUtils_unpackLong (arg));
265+
266+ torch::utils::device_lazy_init (at::kPrivateUse1 );
267+ orEvent_t event = nullptr ;
268+ orError_t err = orEventCreateWithFlags (&event, flags);
269+ if (err != orSuccess) {
270+ TORCH_CHECK (false , " Failed to create event with flags" );
271+ }
272+ return THPUtils_packInt64 (reinterpret_cast <int64_t >(event));
273+ END_HANDLE_TH_ERRORS
274+ }
275+
276+ PyObject* _eventDestroy (PyObject* self, PyObject* arg) {
277+ HANDLE_TH_ERRORS
278+ TORCH_CHECK (THPUtils_checkLong (arg), " event_destroy expects an int" );
279+ orEvent_t event = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (arg));
280+ orError_t err = orEventDestroy (event);
281+ if (err != orSuccess) {
282+ TORCH_CHECK (false , " Failed to destroy event" );
283+ }
284+ Py_RETURN_NONE;
285+ END_HANDLE_TH_ERRORS
286+ }
287+
288+ PyObject* _eventRecord (PyObject* self, PyObject* args) {
289+ HANDLE_TH_ERRORS
290+ TORCH_CHECK (PyTuple_Size (args) == 2 , " event_record expects 2 arguments" );
291+ PyObject* event_obj = PyTuple_GetItem (args, 0 );
292+ PyObject* stream_obj = PyTuple_GetItem (args, 1 );
293+ TORCH_CHECK (THPUtils_checkLong (event_obj), " event must be an int" );
294+ TORCH_CHECK (THPUtils_checkLong (stream_obj), " stream must be an int" );
295+ orEvent_t event = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (event_obj));
296+ orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (stream_obj));
297+ orError_t err = orEventRecord (event, stream);
298+ if (err != orSuccess) {
299+ TORCH_CHECK (false , " Failed to record event" );
300+ }
301+ Py_RETURN_NONE;
302+ END_HANDLE_TH_ERRORS
303+ }
304+
305+ PyObject* _eventSynchronize (PyObject* self, PyObject* arg) {
306+ HANDLE_TH_ERRORS
307+ TORCH_CHECK (THPUtils_checkLong (arg), " event_synchronize expects an int" );
308+ orEvent_t event = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (arg));
309+
310+ orError_t err;
311+ Py_BEGIN_ALLOW_THREADS
312+ err = orEventSynchronize (event);
313+ Py_END_ALLOW_THREADS
314+
315+ if (err != orSuccess) {
316+ TORCH_CHECK (false , " Failed to synchronize event" );
317+ }
318+ Py_RETURN_NONE;
319+ END_HANDLE_TH_ERRORS
320+ }
321+
322+ PyObject* _eventQuery (PyObject* self, PyObject* arg) {
323+ HANDLE_TH_ERRORS
324+ TORCH_CHECK (THPUtils_checkLong (arg), " event_query expects an int" );
325+ orEvent_t event = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (arg));
326+ orError_t err = orEventQuery (event);
327+ if (err == orSuccess) {
328+ Py_RETURN_TRUE;
329+ } else {
330+ Py_RETURN_FALSE;
331+ }
332+ END_HANDLE_TH_ERRORS
333+ }
334+
335+ PyObject* _eventElapsedTime (PyObject* self, PyObject* args) {
336+ HANDLE_TH_ERRORS
337+ TORCH_CHECK (PyTuple_Size (args) == 2 , " event_elapsed_time expects 2 arguments" );
338+ PyObject* start_obj = PyTuple_GetItem (args, 0 );
339+ PyObject* end_obj = PyTuple_GetItem (args, 1 );
340+ TORCH_CHECK (THPUtils_checkLong (start_obj), " start event must be an int" );
341+ TORCH_CHECK (THPUtils_checkLong (end_obj), " end event must be an int" );
342+ orEvent_t start = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (start_obj));
343+ orEvent_t end = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (end_obj));
344+ float ms = 0 .0f ;
345+ orError_t err = orEventElapsedTime (&ms, start, end);
346+ if (err != orSuccess) {
347+ TORCH_CHECK (false , " Failed to get elapsed time" );
348+ }
349+ return PyFloat_FromDouble (static_cast <double >(ms));
350+ END_HANDLE_TH_ERRORS
351+ }
352+
353+ PyObject* _deviceSynchronize (PyObject* self, PyObject* noargs) {
354+ HANDLE_TH_ERRORS
355+ torch::utils::device_lazy_init (at::kPrivateUse1 );
356+
357+ orError_t err;
358+ Py_BEGIN_ALLOW_THREADS
359+ err = orDeviceSynchronize ();
360+ Py_END_ALLOW_THREADS
361+
362+ if (err != orSuccess) {
363+ TORCH_CHECK (false , " Failed to synchronize device" );
364+ }
365+ Py_RETURN_NONE;
366+ END_HANDLE_TH_ERRORS
367+ }
368+
369+ PyObject* _addTaskToStream (PyObject* self, PyObject* args) {
370+ HANDLE_TH_ERRORS
371+ TORCH_CHECK (PyTuple_Size (args) == 2 , " add_task_to_stream expects 2 arguments" );
372+ PyObject* stream_obj = PyTuple_GetItem (args, 0 );
373+ PyObject* callable_obj = PyTuple_GetItem (args, 1 );
374+
375+ TORCH_CHECK (THPUtils_checkLong (stream_obj), " stream must be an int" );
376+ TORCH_CHECK (PyCallable_Check (callable_obj), " task must be callable" );
377+
378+ orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (stream_obj));
379+
380+ Py_INCREF (callable_obj);
381+ auto py_callable = std::shared_ptr<PyObject>(callable_obj, [](PyObject* obj) {
382+ PyGILState_STATE gstate = PyGILState_Ensure ();
383+ Py_DECREF (obj);
384+ PyGILState_Release (gstate);
385+ });
386+
387+ auto task = [py_callable]() {
388+ PyGILState_STATE gstate = PyGILState_Ensure ();
389+ try {
390+ PyObject* result = PyObject_CallObject (py_callable.get (), nullptr );
391+ if (result == nullptr ) {
392+ PyErr_Print ();
393+ PyErr_Clear ();
394+ } else {
395+ Py_DECREF (result);
396+ }
397+ } catch (...) {
398+ }
399+
400+ PyGILState_Release (gstate);
401+ };
402+ orError_t err = openreg::addTaskToStream (stream, task);
403+ if (err != orSuccess) {
404+ TORCH_CHECK (false , " Failed to add task to stream" );
405+ }
406+ Py_RETURN_NONE;
407+ END_HANDLE_TH_ERRORS
408+ }
409+
138410static PyMethodDef methods[] = {
139411 {" _init" , _initExtension, METH_NOARGS, nullptr },
140412 {" _get_default_generator" , _getDefaultGenerator, METH_O, nullptr },
@@ -147,6 +419,26 @@ static PyMethodDef methods[] = {
147419 {" get_autocast_dtype" , _getAutocastDtype, METH_NOARGS, nullptr },
148420 {" set_autocast_dtype" , _setAutocastDtype, METH_O, nullptr },
149421 {" get_amp_supported_dtype" , _getAmpSupportedDtype, METH_NOARGS, nullptr },
422+ // Stream functions
423+ {" _stream_create" , _streamCreate, METH_NOARGS, nullptr },
424+ {" _stream_create_with_priority" , _streamCreateWithPriority, METH_VARARGS, nullptr },
425+ {" _stream_destroy" , _streamDestroy, METH_O, nullptr },
426+ {" _stream_synchronize" , _streamSynchronize, METH_O, nullptr },
427+ {" _stream_query" , _streamQuery, METH_O, nullptr },
428+ {" _stream_get_priority" , _streamGetPriority, METH_O, nullptr },
429+ {" _stream_wait_event" , _streamWaitEvent, METH_VARARGS, nullptr },
430+ // Event functions
431+ {" _event_create" , _eventCreate, METH_NOARGS, nullptr },
432+ {" _event_create_with_flags" , _eventCreateWithFlags, METH_O, nullptr },
433+ {" _event_destroy" , _eventDestroy, METH_O, nullptr },
434+ {" _event_record" , _eventRecord, METH_VARARGS, nullptr },
435+ {" _event_synchronize" , _eventSynchronize, METH_O, nullptr },
436+ {" _event_query" , _eventQuery, METH_O, nullptr },
437+ {" _event_elapsed_time" , _eventElapsedTime, METH_VARARGS, nullptr },
438+ // Device functions
439+ {" _device_synchronize" , _deviceSynchronize, METH_NOARGS, nullptr },
440+ // Stream task functions
441+ {" _add_task_to_stream" , _addTaskToStream, METH_VARARGS, nullptr },
150442 {nullptr , nullptr , 0 , nullptr }};
151443
152444/*
0 commit comments