@@ -3644,6 +3644,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
36443644 // Load the HSA executable.
36453645 if (Error Err = AMDImage->loadExecutable (*this ))
36463646 return std::move (Err);
3647+
3648+ // Launch the special kernel for device memory initialization
3649+ if (Error Err = launchDMInitKernel (*AMDImage))
3650+ return std::move (Err);
3651+
36473652 return AMDImage;
36483653 }
36493654
@@ -4642,13 +4647,16 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
46424647 Error preAllocateDeviceMemoryPool () {
46434648
46444649 void *DevPtr;
4650+ // Use PER_DEVICE_PREALLOC_SIZE (128KB) as heap and allocate 512MB for
4651+ // device memory
4652+ size_t PreAllocSize = hsa_utils::PER_DEVICE_PREALLOC_SIZE + DMSlabSize;
4653+
46454654 for (AMDGPUMemoryPoolTy *MemoryPool : AllMemoryPools) {
46464655 if (!MemoryPool->isGlobal ())
46474656 continue ;
46484657
46494658 if (MemoryPool->isCoarseGrained ()) {
46504659 DevPtr = nullptr ;
4651- size_t PreAllocSize = hsa_utils::PER_DEVICE_PREALLOC_SIZE;
46524660
46534661 Error Err = MemoryPool->allocate (PreAllocSize, &DevPtr);
46544662 if (Err)
@@ -4664,6 +4672,11 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
46644672 " Zero initialization of preallocated device memory pool failed" );
46654673
46664674 PreAllocatedDeviceMemoryPool = DevPtr;
4675+
4676+ DMHeapPtr = DevPtr;
4677+ DMSlabPtr =
4678+ reinterpret_cast <void *>(reinterpret_cast <uintptr_t >(DevPtr) +
4679+ hsa_utils::PER_DEVICE_PREALLOC_SIZE);
46674680 }
46684681 }
46694682 return Plugin::success ();
@@ -5070,6 +5083,13 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
50705083 // / True if in multi-device mode.
50715084 bool IsMultiDeviceEnabled = false ;
50725085
5086+ // / Arguments for device memory initialization.
5087+ void *DMHeapPtr = nullptr ;
5088+ void *DMSlabPtr = nullptr ;
5089+ bool DMInitialized = false ;
5090+ static constexpr uint32_t DMNumSlabs = 256 ;
5091+ static constexpr size_t DMSlabSize = DMNumSlabs * (2 * 1024 * 1024 ); // 512MB
5092+
50735093 // / Struct holding time in ns at a point in time for both host and device
50745094 // / This is used to compute a device-to-host offset and skew. Required for
50755095 // / OMPT function translate_time.
@@ -5167,6 +5187,70 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
51675187 return It->second ;
51685188 }
51695189
5190+ // / Launch the device memory initialization kernel.
5191+ Error launchDMInitKernel (AMDGPUDeviceImageTy &Image) {
5192+ // Already initialized, skip
5193+ if (DMInitialized)
5194+ return Plugin::success ();
5195+
5196+ if (!DMHeapPtr || !DMSlabPtr)
5197+ return Plugin::error (
5198+ ErrorCode::UNKNOWN,
5199+ " Device memory not allocated for launching DM init kernel." );
5200+
5201+ // Check if this image contains the DM init kernel
5202+ const char *KernelName = " __omp_dm_init_kernel" ;
5203+
5204+ GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler ();
5205+ if (!Handler.isSymbolInImage (*this , Image, KernelName)) {
5206+ DP (" DM init kernel is not in this image.\n " );
5207+ return Plugin::success ();
5208+ }
5209+
5210+ AMDGPUKernelTy DMInitKernel (KernelName, Plugin.getGlobalHandler ());
5211+ if (auto Err = DMInitKernel.init (*this , Image)) {
5212+ return Err;
5213+ }
5214+
5215+ DP (" Device memory initializing...\n " );
5216+
5217+ // Prepare kernel arguments
5218+ struct __attribute__ ((packed)) {
5219+ uint64_t HeapAddr;
5220+ uint64_t SlabAddr;
5221+ } Args;
5222+
5223+ Args.HeapAddr = reinterpret_cast <uint64_t >(DMHeapPtr);
5224+ Args.SlabAddr = reinterpret_cast <uint64_t >(DMSlabPtr);
5225+
5226+ KernelArgsTy KernelArgs;
5227+ KernelLaunchParamsTy LaunchParams;
5228+ LaunchParams.Data = &Args;
5229+ LaunchParams.Size = sizeof (Args);
5230+
5231+ AsyncInfoWrapperTy AsyncInfo (*this , nullptr );
5232+
5233+ uint32_t NumThreads[3 ] = {256u , 1u , 1u };
5234+ uint32_t NumBlocks[3 ] = {1u , 1u , 1u };
5235+
5236+ // Launch kernel with 256 threads and 1 block
5237+ if (auto Err = DMInitKernel.launchImpl (*this , NumThreads, NumBlocks,
5238+ KernelArgs, LaunchParams, AsyncInfo))
5239+ return Err;
5240+
5241+ // Wait for completion
5242+ Error Err = Plugin::success ();
5243+ AsyncInfo.finalize (Err);
5244+
5245+ // Mark as successfully initialized
5246+ if (!Err) {
5247+ DMInitialized = true ;
5248+ DP (" Device memory initialized successfully\n " );
5249+ }
5250+
5251+ return Err;
5252+ }
5253+
51705254public:
51715255 // / Return if it is an MI300 series device.
51725256 bool checkIfMI300Device () {
0 commit comments