Skip to content

Commit 46c5209

Browse files
committed
refactor projections tests
1 parent 2bbe9d2 commit 46c5209

5 files changed

Lines changed: 81 additions & 712 deletions

File tree

test/mooncake/projections.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using LinearAlgebra: Diagonal
4+
using CUDA, AMDGPU
5+
6+
BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
7+
GenericFloats = ()
8+
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
9+
using .TestSuite
10+
11+
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
12+
13+
m = 19
14+
for T in (BLASFloats..., GenericFloats...)
15+
TestSuite.seed_rng!(123)
16+
if !is_buildkite
17+
TestSuite.test_mooncake_projections(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
18+
end
19+
end

test/testsuite/TestSuite.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A
8787

8888
include("ad_utils.jl")
8989

90-
include("projections.jl")
91-
9290
# Decompositions
9391
# --------------
9492
include("decompositions/qr.jl")
@@ -110,6 +108,7 @@ include("mooncake/eigh.jl")
110108
include("mooncake/svd.jl")
111109
include("mooncake/polar.jl")
112110
include("mooncake/orthnull.jl")
111+
include("mooncake/projections.jl")
113112

114113
include("enzyme.jl")
115114
include("chainrules.jl")

0 commit comments

Comments
 (0)