Skip to content

Commit 0768b01

Browse files
committed
fix applyQ for GPU in Rsvd_notruncate.cpp
1 parent 519cbf7 commit 0768b01

1 file changed

Lines changed: 8 additions & 22 deletions

File tree

src/linalg/Rsvd_notruncate.cpp

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,44 +64,30 @@ namespace cytnx {
6464
cytnx::linalg_internal::lii.Gesvd_ii[in.dtype()](
6565
in._impl->storage()._impl, U._impl->storage()._impl, vT._impl->storage()._impl,
6666
S._impl->storage()._impl, in.shape()[0], in.shape()[1]);
67-
68-
std::vector<Tensor> out;
69-
out.push_back(S);
70-
if (is_U) {
71-
if (applyQ) {
72-
U = Matmul(Q, U);
73-
}
74-
out.push_back(U);
75-
}
76-
if (is_vT) {
77-
out.push_back(vT);
78-
}
79-
80-
return out;
81-
8267
} else {
8368
#ifdef UNI_GPU
8469
checkCudaErrors(cudaSetDevice(in.device()));
8570
cytnx::linalg_internal::lii.cuGeSvd_ii[in.dtype()](
8671
in._impl->storage()._impl, U._impl->storage()._impl, vT._impl->storage()._impl,
8772
S._impl->storage()._impl, in.shape()[0], in.shape()[1]);
88-
73+
#else
74+
cytnx_error_msg(true, "[Rsvd] fatal error,%s",
75+
"try to call the gpu section without CUDA support.\n");
76+
return std::vector<Tensor>();
77+
#endif
8978
std::vector<Tensor> out;
9079
out.push_back(S);
9180
if (is_U) {
92-
U = Matmul(Q, U);
81+
if (applyQ) {
82+
U = Matmul(Q, U);
83+
}
9384
out.push_back(U);
9485
}
9586
if (is_vT) {
9687
out.push_back(vT);
9788
}
9889

9990
return out;
100-
#else
101-
cytnx_error_msg(true, "[Rsvd] fatal error,%s",
102-
"try to call the gpu section without CUDA support.\n");
103-
return std::vector<Tensor>();
104-
#endif
10591
}
10692
} // Rsvd(Tensor)
10793

0 commit comments

Comments
 (0)