Skip to content

Commit 89bed77

Browse files
committed
refactor projections tests
1 parent f3bfc54 commit 89bed77

4 files changed

Lines changed: 81 additions & 568 deletions

File tree

test/mooncake/projections.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_mooncake_projections(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
18+
end
19+
end

test/testsuite/TestSuite.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ end
9393

9494
include("ad_utils.jl")
9595

96-
include("projections.jl")
97-
9896
# Decompositions
9997
# --------------
10098
include("decompositions/qr.jl")
@@ -116,6 +114,7 @@ include("mooncake/eigh.jl")
116114
include("mooncake/svd.jl")
117115
include("mooncake/polar.jl")
118116
include("mooncake/orthnull.jl")
117+
include("mooncake/projections.jl")
119118

120119
include("enzyme.jl")
121120
include("chainrules.jl")

0 commit comments

Comments
 (0)