22{-# LANGUAGE TypeApplications #-}
33module ArrayFire.DataSpec where
44
5- import Control.Exception
6- import Data.Complex
7- import Data.Word
8- import Foreign.C.Types
9- import GHC.Int
10- import Test.Hspec
5+ import Control.Exception
6+ import Data.Complex
7+ import Data.Word
8+ import Foreign.C.Types
9+ import GHC.Int
10+ import Prelude hiding (flip )
11+ import Test.Hspec
1112
12- import ArrayFire
13+ import ArrayFire
1314
1415spec :: Spec
1516spec =
@@ -32,6 +33,116 @@ spec =
3233 constant @ (Complex Float ) [1 ] (1.0 :+ 1.0 )
3334 `shouldBe`
3435 constant @ (Complex Float ) [1 ] (1.0 :+ 1.0 )
36+
37+ describe " arange" $ do
38+ it " generates a sequence along dim 0 for a 1D array" $ do
39+ arange @ Double [5 ] (- 1 ) `shouldBe` vector @ Double 5 [0 ,1 ,2 ,3 ,4 ]
40+ it " generates a sequence along dim 1 for a 2D array" $ do
41+ arange @ Double [3 ,2 ] 1 `shouldBe` mkArray @ Double [3 ,2 ] [0 ,0 ,0 ,1 ,1 ,1 ]
42+
43+ describe " iota" $ do
44+ it " generates a flat sequence without tiling" $ do
45+ iota @ Double [5 ] [] `shouldBe` vector @ Double 5 [0 ,1 ,2 ,3 ,4 ]
46+ it " tiles the sequence along dim 0" $ do
47+ iota @ Double [3 ] [2 ] `shouldBe` vector @ Double 6 [0 ,1 ,2 ,0 ,1 ,2 ]
48+
49+ describe " identity" $ do
50+ it " creates a 2x2 identity matrix" $ do
51+ identity @ Double [2 ,2 ]
52+ `shouldBe` mkArray @ Double [2 ,2 ] [1 ,0 ,0 ,1 ]
53+ it " creates a 3x3 identity matrix" $ do
54+ identity @ Double [3 ,3 ]
55+ `shouldBe` mkArray @ Double [3 ,3 ] [1 ,0 ,0 ,0 ,1 ,0 ,0 ,0 ,1 ]
56+
57+ describe " diagCreate" $ do
58+ it " creates a diagonal matrix from a vector (diag 0)" $ do
59+ diagCreate (vector @ Double 3 [1 ,2 ,3 ]) 0
60+ `shouldBe` mkArray @ Double [3 ,3 ] [1 ,0 ,0 ,0 ,2 ,0 ,0 ,0 ,3 ]
61+ it " creates a superdiagonal matrix (diag 1)" $ do
62+ diagCreate (vector @ Double 2 [5 ,6 ]) 1
63+ `shouldBe` mkArray @ Double [3 ,3 ] [0 ,0 ,0 ,5 ,0 ,0 ,0 ,6 ,0 ]
64+
65+ describe " diagExtract" $ do
66+ it " extracts the main diagonal of a square matrix" $ do
67+ diagExtract (mkArray @ Double [3 ,3 ] [1 ,0 ,0 ,0 ,2 ,0 ,0 ,0 ,3 ]) 0
68+ `shouldBe` vector @ Double 3 [1 ,2 ,3 ]
69+ it " is the inverse of diagCreate on the main diagonal" $ do
70+ let v = vector @ Double 4 [1 ,2 ,3 ,4 ]
71+ diagExtract (diagCreate v 0 ) 0 `shouldBe` v
72+
73+ describe " lower" $ do
74+ it " extracts the lower triangular part (unit diagonal)" $ do
75+ let m = mkArray @ Double [3 ,3 ] [1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ]
76+ lower m True
77+ `shouldBe` mkArray @ Double [3 ,3 ] [1 ,2 ,3 ,0 ,1 ,6 ,0 ,0 ,1 ]
78+ it " extracts the lower triangular part (non-unit diagonal)" $ do
79+ let m = mkArray @ Double [3 ,3 ] [1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ]
80+ lower m False
81+ `shouldBe` mkArray @ Double [3 ,3 ] [1 ,2 ,3 ,0 ,5 ,6 ,0 ,0 ,9 ]
82+
83+ describe " upper" $ do
84+ it " extracts the upper triangular part (unit diagonal)" $ do
85+ let m = mkArray @ Double [3 ,3 ] [1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ]
86+ upper m True
87+ `shouldBe` mkArray @ Double [3 ,3 ] [1 ,0 ,0 ,4 ,1 ,0 ,7 ,8 ,1 ]
88+ it " extracts the upper triangular part (non-unit diagonal)" $ do
89+ let m = mkArray @ Double [3 ,3 ] [1 ,2 ,3 ,4 ,5 ,6 ,7 ,8 ,9 ]
90+ upper m False
91+ `shouldBe` mkArray @ Double [3 ,3 ] [1 ,0 ,0 ,4 ,5 ,0 ,7 ,8 ,9 ]
92+
93+ describe " tile" $ do
94+ it " tiles a scalar into a 3x3 array" $ do
95+ tile (scalar @ Int 7 ) [3 ,3 ]
96+ `shouldBe` constant @ Int [3 ,3 ] 7
97+ it " tiles a row vector along dim 0" $ do
98+ tile (mkArray @ Int [1 ,3 ] [1 ,2 ,3 ]) [2 ,1 ]
99+ `shouldBe` mkArray @ Int [2 ,3 ] [1 ,1 ,2 ,2 ,3 ,3 ]
100+
101+ describe " moddims" $ do
102+ it " reshapes a vector into a matrix" $ do
103+ moddims (vector @ Int 6 [1 .. 6 ]) [2 ,3 ]
104+ `shouldBe` mkArray @ Int [2 ,3 ] [1 ,2 ,3 ,4 ,5 ,6 ]
105+ it " reshapes a matrix back to a vector" $ do
106+ let v = vector @ Int 6 [1 .. 6 ]
107+ moddims (moddims v [2 ,3 ]) [6 ] `shouldBe` v
108+
109+ describe " flat" $ do
110+ it " flattens a 2x3 matrix to a 6-element vector" $ do
111+ flat (mkArray @ Int [2 ,3 ] [1 ,2 ,3 ,4 ,5 ,6 ])
112+ `shouldBe` vector @ Int 6 [1 ,2 ,3 ,4 ,5 ,6 ]
113+
114+ describe " flip" $ do
115+ it " reverses a vector (dim 0)" $ do
116+ flip (vector @ Int 4 [1 ,2 ,3 ,4 ]) 0
117+ `shouldBe` vector @ Int 4 [4 ,3 ,2 ,1 ]
118+ it " reverses columns of a matrix (dim 1)" $ do
119+ flip (mkArray @ Int [2 ,2 ] [1 ,2 ,3 ,4 ]) 1
120+ `shouldBe` mkArray @ Int [2 ,2 ] [3 ,4 ,1 ,2 ]
121+
122+ describe " shift" $ do
123+ it " shifts a vector by 2 elements (wrapping)" $ do
124+ shift (vector @ Double 4 [1 ,2 ,3 ,4 ]) 2 0 0 0
125+ `shouldBe` vector @ Double 4 [3 ,4 ,1 ,2 ]
126+
127+ describe " select" $ do
128+ it " selects elements from two arrays based on a boolean mask" $ do
129+ let cond = vector @ CBool 4 [1 ,0 ,1 ,0 ]
130+ a = vector @ Double 4 [10 ,20 ,30 ,40 ]
131+ b = vector @ Double 4 [1 ,2 ,3 ,4 ]
132+ select cond a b `shouldBe` vector @ Double 4 [10 ,2 ,30 ,4 ]
133+
134+ describe " selectScalarR" $ do
135+ it " uses scalar for false positions" $ do
136+ let cond = vector @ CBool 4 [1 ,0 ,1 ,0 ]
137+ a = vector @ Double 4 [10 ,20 ,30 ,40 ]
138+ selectScalarR cond a 99 `shouldBe` vector @ Double 4 [10 ,99 ,30 ,99 ]
139+
140+ describe " selectScalarL" $ do
141+ it " uses scalar for true positions" $ do
142+ let cond = vector @ CBool 4 [1 ,0 ,1 ,0 ]
143+ b = vector @ Double 4 [1 ,2 ,3 ,4 ]
144+ selectScalarL cond 99 b `shouldBe` vector @ Double 4 [99 ,2 ,99 ,4 ]
145+
35146 it " Should join Arrays along the specified dimension" $ do
36147 join 0 (constant @ Int [1 , 3 ] 1 ) (constant @ Int [1 , 3 ] 2 ) `shouldBe` mkArray @ Int [2 , 3 ] [1 , 2 , 1 , 2 , 1 , 2 ]
37148 join 1 (constant @ Int [1 , 2 ] 1 ) (constant @ Int [1 , 2 ] 2 ) `shouldBe` mkArray @ Int [1 , 4 ] [1 , 1 , 2 , 2 ]
0 commit comments