55 *
66 */
77
8- #include " host_ptr_manager .h"
8+ #include " runtime/command_stream/command_stream_receiver .h"
99#include " runtime/helpers/ptr_math.h"
1010#include " runtime/helpers/abort.h"
11+ #include " runtime/memory_manager/memory_manager.h"
1112
1213using namespace OCLRT ;
1314
14- std::map<const void *, FragmentStorage>::iterator OCLRT:: HostPtrManager::findElement (const void *ptr) {
15+ std::map<const void *, FragmentStorage>::iterator HostPtrManager::findElement (const void *ptr) {
1516 auto nextElement = partialAllocations.lower_bound (ptr);
1617 auto element = nextElement;
1718 if (element != partialAllocations.end ()) {
@@ -43,7 +44,7 @@ std::map<const void *, FragmentStorage>::iterator OCLRT::HostPtrManager::findEle
4344 return partialAllocations.end ();
4445}
4546
46- AllocationRequirements OCLRT:: HostPtrManager::getAllocationRequirements (const void *inputPtr, size_t size) {
47+ AllocationRequirements HostPtrManager::getAllocationRequirements (const void *inputPtr, size_t size) {
4748 AllocationRequirements requiredAllocations;
4849
4950 auto allocationCount = 0 ;
@@ -89,7 +90,7 @@ AllocationRequirements OCLRT::HostPtrManager::getAllocationRequirements(const vo
8990 return requiredAllocations;
9091}
9192
92- OsHandleStorage OCLRT:: HostPtrManager::populateAlreadyAllocatedFragments (AllocationRequirements &requirements, CheckedFragments *checkedFragments) {
93+ OsHandleStorage HostPtrManager::populateAlreadyAllocatedFragments (AllocationRequirements &requirements, CheckedFragments *checkedFragments) {
9394 OsHandleStorage handleStorage;
9495 for (unsigned int i = 0 ; i < requirements.requiredFragmentsCount ; i++) {
9596 OverlapStatus overlapStatus = OverlapStatus::FRAGMENT_NOT_CHECKED;
@@ -133,8 +134,8 @@ OsHandleStorage OCLRT::HostPtrManager::populateAlreadyAllocatedFragments(Allocat
133134 return handleStorage;
134135}
135136
136- void OCLRT:: HostPtrManager::storeFragment (FragmentStorage &fragment) {
137- std::lock_guard<std::mutex > lock (allocationsMutex);
137+ void HostPtrManager::storeFragment (FragmentStorage &fragment) {
138+ std::lock_guard<decltype (allocationsMutex) > lock (allocationsMutex);
138139 auto element = findElement (fragment.fragmentCpuPointer );
139140 if (element != partialAllocations.end ()) {
140141 element->second .refCount ++;
@@ -144,7 +145,7 @@ void OCLRT::HostPtrManager::storeFragment(FragmentStorage &fragment) {
144145 }
145146}
146147
147- void OCLRT:: HostPtrManager::storeFragment (AllocationStorageData &storageData) {
148+ void HostPtrManager::storeFragment (AllocationStorageData &storageData) {
148149 FragmentStorage fragment;
149150 fragment.fragmentCpuPointer = const_cast <void *>(storageData.cpuPtr );
150151 fragment.fragmentSize = storageData.fragmentSize ;
@@ -153,16 +154,16 @@ void OCLRT::HostPtrManager::storeFragment(AllocationStorageData &storageData) {
153154 storeFragment (fragment);
154155}
155156
156- void OCLRT:: HostPtrManager::releaseHandleStorage (OsHandleStorage &fragments) {
157+ void HostPtrManager::releaseHandleStorage (OsHandleStorage &fragments) {
157158 for (int i = 0 ; i < maxFragmentsCount; i++) {
158159 if (fragments.fragmentStorageData [i].fragmentSize || fragments.fragmentStorageData [i].cpuPtr ) {
159160 fragments.fragmentStorageData [i].freeTheFragment = releaseHostPtr (fragments.fragmentStorageData [i].cpuPtr );
160161 }
161162 }
162163}
163164
164- bool OCLRT:: HostPtrManager::releaseHostPtr (const void *ptr) {
165- std::lock_guard<std::mutex > lock (allocationsMutex);
165+ bool HostPtrManager::releaseHostPtr (const void *ptr) {
166+ std::lock_guard<decltype (allocationsMutex) > lock (allocationsMutex);
166167 bool fragmentReadyToBeReleased = false ;
167168
168169 auto element = findElement (ptr);
@@ -178,8 +179,8 @@ bool OCLRT::HostPtrManager::releaseHostPtr(const void *ptr) {
178179 return fragmentReadyToBeReleased;
179180}
180181
181- FragmentStorage *OCLRT:: HostPtrManager::getFragment (const void *inputPtr) {
182- std::lock_guard<std::mutex > lock (allocationsMutex);
182+ FragmentStorage *HostPtrManager::getFragment (const void *inputPtr) {
183+ std::lock_guard<decltype (allocationsMutex) > lock (allocationsMutex);
183184 auto element = findElement (inputPtr);
184185 if (element != partialAllocations.end ()) {
185186 return &element->second ;
@@ -188,8 +189,8 @@ FragmentStorage *OCLRT::HostPtrManager::getFragment(const void *inputPtr) {
188189}
189190
190191// for given inputs see if any allocation overlaps
191- FragmentStorage *OCLRT:: HostPtrManager::getFragmentAndCheckForOverlaps (const void *inPtr, size_t size, OverlapStatus &overlappingStatus) {
192- std::lock_guard<std::mutex > lock (allocationsMutex);
192+ FragmentStorage *HostPtrManager::getFragmentAndCheckForOverlaps (const void *inPtr, size_t size, OverlapStatus &overlappingStatus) {
193+ std::lock_guard<decltype (allocationsMutex) > lock (allocationsMutex);
193194 void *inputPtr = const_cast <void *>(inPtr);
194195 auto nextElement = partialAllocations.lower_bound (inputPtr);
195196 auto element = nextElement;
@@ -246,3 +247,64 @@ FragmentStorage *OCLRT::HostPtrManager::getFragmentAndCheckForOverlaps(const voi
246247 }
247248 return nullptr ;
248249}
250+
251+ OsHandleStorage HostPtrManager::prepareOsStorageForAllocation (MemoryManager &memoryManager, size_t size, const void *ptr) {
252+ std::lock_guard<decltype (allocationsMutex)> lock (allocationsMutex);
253+ auto requirements = HostPtrManager::getAllocationRequirements (ptr, size);
254+
255+ CheckedFragments checkedFragments;
256+ UNRECOVERABLE_IF (checkAllocationsForOverlapping (memoryManager, &requirements, &checkedFragments) == RequirementsStatus::FATAL);
257+
258+ auto osStorage = populateAlreadyAllocatedFragments (requirements, &checkedFragments);
259+ if (osStorage.fragmentCount > 0 ) {
260+ if (memoryManager.populateOsHandles (osStorage) != MemoryManager::AllocationStatus::Success) {
261+ memoryManager.cleanOsHandles (osStorage);
262+ osStorage.fragmentCount = 0 ;
263+ }
264+ }
265+ return osStorage;
266+ }
267+
268+ RequirementsStatus HostPtrManager::checkAllocationsForOverlapping (MemoryManager &memoryManager, AllocationRequirements *requirements, CheckedFragments *checkedFragments) {
269+ DEBUG_BREAK_IF (requirements == nullptr );
270+ DEBUG_BREAK_IF (checkedFragments == nullptr );
271+
272+ RequirementsStatus status = RequirementsStatus::SUCCESS;
273+ checkedFragments->count = 0 ;
274+
275+ for (unsigned int i = 0 ; i < maxFragmentsCount; i++) {
276+ checkedFragments->status [i] = OverlapStatus::FRAGMENT_NOT_CHECKED;
277+ checkedFragments->fragments [i] = nullptr ;
278+ }
279+ for (unsigned int i = 0 ; i < requirements->requiredFragmentsCount ; i++) {
280+ checkedFragments->count ++;
281+ checkedFragments->fragments [i] = getFragmentAndCheckForOverlaps (requirements->AllocationFragments [i].allocationPtr , requirements->AllocationFragments [i].allocationSize , checkedFragments->status [i]);
282+ if (checkedFragments->status [i] == OverlapStatus::FRAGMENT_OVERLAPING_AND_BIGGER_THEN_STORED_FRAGMENT) {
283+ // clean temporary allocations
284+
285+ auto commandStreamReceiver = memoryManager.getCommandStreamReceiver (0 );
286+ uint32_t taskCount = *commandStreamReceiver->getTagAddress ();
287+ memoryManager.cleanAllocationList (taskCount, TEMPORARY_ALLOCATION);
288+
289+ // check overlapping again
290+ checkedFragments->fragments [i] = getFragmentAndCheckForOverlaps (requirements->AllocationFragments [i].allocationPtr , requirements->AllocationFragments [i].allocationSize , checkedFragments->status [i]);
291+ if (checkedFragments->status [i] == OverlapStatus::FRAGMENT_OVERLAPING_AND_BIGGER_THEN_STORED_FRAGMENT) {
292+
293+ // Wait for completion
294+ while (*commandStreamReceiver->getTagAddress () < commandStreamReceiver->peekLatestSentTaskCount ())
295+ ;
296+
297+ taskCount = *commandStreamReceiver->getTagAddress ();
298+ memoryManager.cleanAllocationList (taskCount, TEMPORARY_ALLOCATION);
299+
300+ // check overlapping last time
301+ checkedFragments->fragments [i] = getFragmentAndCheckForOverlaps (requirements->AllocationFragments [i].allocationPtr , requirements->AllocationFragments [i].allocationSize , checkedFragments->status [i]);
302+ if (checkedFragments->status [i] == OverlapStatus::FRAGMENT_OVERLAPING_AND_BIGGER_THEN_STORED_FRAGMENT) {
303+ status = RequirementsStatus::FATAL;
304+ break ;
305+ }
306+ }
307+ }
308+ }
309+ return status;
310+ }
0 commit comments