-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcuda_utils.cpp
More file actions
173 lines (156 loc) · 4.89 KB
/
cuda_utils.cpp
File metadata and controls
173 lines (156 loc) · 4.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
/*
Copyright [2024] [Yao Yao]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "cuda_utils.h"
#include "CudaEventPool.h"
CudaEvent makeCudaEvent(unsigned flags)
{
cudaEvent_t ev = nullptr;
cudaCheck(cudaEventCreateWithFlags(&ev, flags));
return CudaEvent{ev};
}
CudaStream makeCudaStream(unsigned flags)
{
cudaStream_t stream = nullptr;
cudaCheck(cudaStreamCreateWithFlags(&stream, flags));
return CudaStream{stream};
}
CudaStream makeCudaStreamWithPriority(int priority, unsigned flags)
{
cudaStream_t stream = nullptr;
cudaCheck(cudaStreamCreateWithPriority(&stream, flags, priority));
return CudaStream{stream};
}
CudaGraph makeCudaGraph() {
cudaGraph_t g = nullptr;
cudaCheck(cudaGraphCreate(&g, 0));
return CudaGraph{g};
}
CudaGraphExec instantiateCudaGraph(cudaGraph_t graph) {
cudaGraphExec_t exec = nullptr;
cudaCheck(cudaGraphInstantiate(&exec, graph, nullptr, nullptr, 0));
return CudaGraphExec{exec};
}
void connectStreams(cudaStream_t first, cudaStream_t second)
{
const cudapp::PooledCudaEvent event = cudapp::createPooledCudaEvent();
cudaCheck(cudaEventRecord(event.get(), first));
cudaCheck(cudaStreamWaitEvent(second, event.get(), 0));
}
void connectStreams(cudaStream_t first, cudaStream_t second, cudaEvent_t event, std::mutex* pMutex)
{
std_optional<std::lock_guard<std::mutex>> lock;
if (pMutex != nullptr)
{
lock.emplace(*pMutex);
}
cudaCheck(cudaEventRecord(event, first));
cudaCheck(cudaStreamWaitEvent(second, event, 0));
}
ICudaMultiEvent::~ICudaMultiEvent() = default;
template <bool isPooled>
class CudaMultiEvent : public ICudaMultiEvent
{
public:
using Event = std::conditional_t<isPooled, cudapp::PooledCudaEvent, CudaEvent>;
void clear() override {
std::lock_guard lk{mMutex};
mEvents.clear();
}
void recordEvent(cudaStream_t stream) override;
// This stream will wait until all
void streamWaitEvent(cudaStream_t stream) const override;
void sync() const override;
void scrub() override;
bool query() override;
bool empty() const override {
std::lock_guard lk{mMutex};
return mEvents.empty();
}
private:
mutable std::mutex mMutex;
std::vector<Event> mEvents;
};
template <>
void CudaMultiEvent<false>::recordEvent(cudaStream_t stream){
CudaEvent event = makeCudaEvent();
cudaCheck(cudaEventRecord(event.get(), stream));
std::lock_guard lk{mMutex};
mEvents.emplace_back(std::move(event));
}
template <>
void CudaMultiEvent<true>::recordEvent(cudaStream_t stream){
cudapp::PooledCudaEvent event = cudapp::createPooledCudaEvent();
cudaCheck(cudaEventRecord(event.get(), stream));
std::lock_guard lk{mMutex};
mEvents.emplace_back(std::move(event));
}
// This stream will wait until all
template <bool isPooled>
void CudaMultiEvent<isPooled>::streamWaitEvent(cudaStream_t stream) const{
std::lock_guard lk{mMutex};
for (const auto& ev : mEvents){
cudaCheck(cudaStreamWaitEvent(stream, ev.get(), 0));
}
}
template <bool isPooled>
void CudaMultiEvent<isPooled>::sync() const {
std::lock_guard lk{mMutex};
for (const auto& ev : mEvents){
cudaCheck(cudaEventSynchronize(ev.get()));
}
}
template <bool isPooled>
void CudaMultiEvent<isPooled>::scrub() {
std::lock_guard lk{mMutex};
const auto iterLast = std::remove_if(mEvents.begin(), mEvents.end(), [](const Event& u){
const cudaError_t state = cudaEventQuery(u.get());
switch (state)
{
case cudaSuccess:
cudaCheck(cudaEventSynchronize(u.get()));
return true;
case cudaErrorNotReady: return false;
default: cudaCheck(state);
}
throw std::logic_error("You should never reach here");
});
mEvents.erase(iterLast, mEvents.end());
}
template <bool isPooled>
bool CudaMultiEvent<isPooled>::query() {
scrub();
std::lock_guard lk{mMutex};
return mEvents.empty();
}
std::unique_ptr<ICudaMultiEvent> createCudaMultiEvent(bool isPooled)
{
if (isPooled) {
return std::make_unique<CudaMultiEvent<true>>();
}
else {
return std::make_unique<CudaMultiEvent<false>>();
}
}
namespace cudapp
{
void streamSync(cudaStream_t stream)
{
#if 1
cudaCheck(cudaStreamSynchronize(stream));
#else
const auto ev = createPooledCudaEvent();
cudaCheck(cudaEventRecord(ev.get(), stream));
cudaCheck(cudaEventSynchronize(ev.get()));
#endif
}
}