Skip to content

Commit fab809b

Browse files
dmjioclaude
andcommitted
Expand test coverage: Data, Index, Algorithm by-key NaN variants
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent c44d1f7 commit fab809b

7 files changed

Lines changed: 244 additions & 24 deletions

File tree

src/ArrayFire/Data.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ constant dims val =
192192

193193
-- | Creates a range of values in an Array
194194
--
195-
-- >>> range @Double [10] (-1)
195+
-- >>> arange @Double [10] (-1)
196196
-- ArrayFire Array
197197
-- [10 1 1 1]
198198
-- 0.0000
@@ -205,14 +205,14 @@ constant dims val =
205205
-- 7.0000
206206
-- 8.0000
207207
-- 9.0000
208-
range
208+
arange
209209
:: forall a
210210
. AFType a
211211
=> [Int]
212212
-> Int
213213
-> Array a
214-
{-# NOINLINE range #-}
215-
range dims (fromIntegral -> k) = unsafePerformIO . mask_ $ do
214+
{-# NOINLINE arange #-}
215+
arange dims (fromIntegral -> k) = unsafePerformIO . mask_ $ do
216216
ptr <- alloca $ \ptrPtr -> do
217217
withArray (fromIntegral <$> dims) $ \dimArray -> do
218218
throwAFError =<< af_range ptrPtr n dimArray k typ

src/ArrayFire/Index.hs

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
-- Functions for indexing into an 'Array'
1111
--
1212
--------------------------------------------------------------------------------
13+
{-# LANGUAGE FlexibleInstances #-}
1314
module ArrayFire.Index where
1415

1516
import ArrayFire.Internal.Index
@@ -52,7 +53,7 @@ lookup
5253
-> Array a
5354
lookup a b n = op2 a b $ \p x y -> af_lookup p x y (fromIntegral n)
5455

55-
-- | Assign values into an 'Array' slice defined by 'Seq' indices
56+
-- | Assign values into an 'Array' range defined by 'Seq' indices
5657
--
5758
-- @
5859
-- >>> let a = vector \@Double 5 [1..]
@@ -62,7 +63,7 @@ assignSeq
6263
:: Array a
6364
-- ^ Destination array
6465
-> [Seq]
65-
-- ^ Indices defining the slice to assign into
66+
-- ^ Indices defining the range to assign into
6667
-> Array a
6768
-- ^ Source array
6869
-> Array a
@@ -118,7 +119,7 @@ assignGen
118119
:: Array a
119120
-- ^ Destination array
120121
-> [Index]
121-
-- ^ List of 'Index' values defining the slice to assign into
122+
-- ^ List of 'Index' values defining the range to assign into
122123
-> Array a
123124
-- ^ Source array
124125
-> Array a
@@ -140,8 +141,58 @@ assignGen (Array fptr) indices (Array rhsFptr) =
140141
touchIdxFPtr _ = pure ()
141142

142143
-- | A special 'Seq' value representing the entire axis of an 'Array'.
143-
--
144-
-- Use this instead of @Prelude.span@.
145144
-- Hard-coded from include\/af\/seq.h because FFI cannot import static const values.
146145
afSpan :: Seq
147146
afSpan = Seq 1 1 0
147+
148+
-- | Select the full extent of a dimension. Use in tuple indices where you want all elements along an axis.
149+
--
150+
-- @
151+
-- arr ! (range 0 2, full, at 1)
152+
-- @
153+
full :: Index
154+
full = SeqIndex False afSpan
155+
156+
-- | Convert index expressions to a list of 'Index'.
157+
-- Supports a single 'Index' or tuples of up to four 'Index' values
158+
-- (matching ArrayFire's maximum of 4 dimensions).
159+
class ToIndexList a where
160+
toIndexList :: a -> [Index]
161+
162+
instance ToIndexList Index where
163+
toIndexList x = [x]
164+
165+
instance ToIndexList (Index, Index) where
166+
toIndexList (a, b) = [a, b]
167+
168+
instance ToIndexList (Index, Index, Index) where
169+
toIndexList (a, b, c) = [a, b, c]
170+
171+
instance ToIndexList (Index, Index, Index, Index) where
172+
toIndexList (a, b, c, d) = [a, b, c, d]
173+
174+
-- | Lift a 'Seq' to an 'Index' for use in tuple-based indexing.
175+
idx :: Seq -> Index
176+
idx s = SeqIndex False s
177+
178+
-- | Index an 'Array'. Accepts a single 'Index' or a tuple of up to four.
179+
--
180+
-- @
181+
-- arr ! at 0 -- 1D: element 0
182+
-- arr ! range 1 3 -- 1D: rows 1-3
183+
-- arr ! (range 0 2, at 1) -- 2D
184+
-- arr ! (range 0 2, full, at 1) -- 3D, full second axis
185+
-- @
186+
(!) :: ToIndexList ix => Array a -> ix -> Array a
187+
a ! ix = indexGen a (toIndexList ix)
188+
infixl 9 !
189+
190+
-- | Assign into a range of an 'Array'. Lens-style: use with '(&)'.
191+
--
192+
-- @
193+
-- arr & range 1 3 .~ src
194+
-- arr & (range 0 1, at 2) .~ src
195+
-- @
196+
(.~) :: ToIndexList ix => ix -> Array a -> Array a -> Array a
197+
(ix .~ rhs) arr = assignGen arr (toIndexList ix) rhs
198+
infixr 4 .~

src/ArrayFire/Internal/Types.hsc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,18 @@ seqIdx s batch = SeqIndex batch s
706706
arrIdx :: Array Int -> Bool -> Index
707707
arrIdx a batch = ArrIndex batch a
708708

709+
-- | Index a contiguous range [begin..end] with step 1.
710+
range :: Int -> Int -> Index
711+
range b e = SeqIndex False (Seq (fromIntegral b) (fromIntegral e) 1)
712+
713+
-- | Index a range [begin..end] with an explicit step.
714+
rangeStep :: Int -> Int -> Int -> Index
715+
rangeStep b e s = SeqIndex False (Seq (fromIntegral b) (fromIntegral e) (fromIntegral s))
716+
717+
-- | Index a single element.
718+
at :: Int -> Index
719+
at n = let d = fromIntegral n in SeqIndex False (Seq d d 1)
720+
709721
toAFIndex :: Index -> IO AFIndex
710722
toAFIndex (SeqIndex batch s) =
711723
pure $ AFIndex (Right (toAFSeq s)) True batch

src/ArrayFire/Types.hs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ module ArrayFire.Types
5555
, Index (..)
5656
, seqIdx
5757
, arrIdx
58+
, range
59+
, rangeStep
60+
, at
5861
, NormType (..)
5962
, ConvMode (..)
6063
, ConvDomain (..)

test/ArrayFire/AlgorithmSpec.hs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,16 @@ spec =
156156
(ko, vo) = A.anyTrueByKey keys vals 0
157157
ko `shouldBe` A.vector @Int 2 [1,2]
158158
vo `shouldBe` A.vector @A.CBool 2 [0,1]
159+
it "Should sum values grouped by key, substituting NaN with 0" $ do
160+
let keys = A.vector @Int 4 [1,1,2,2]
161+
vals = A.vector @Double 4 [10, (acos 2), 3, 4]
162+
(ko, vo) = A.sumByKeyNaN keys vals 0 0
163+
ko `shouldBe` A.vector @Int 2 [1,2]
164+
vo `shouldBe` A.vector @Double 2 [10, 7]
165+
it "Should take the product of values grouped by key, substituting NaN with 1" $ do
166+
let keys = A.vector @Int 4 [1,1,2,2]
167+
vals = A.vector @Double 4 [2, (acos 2), 4, 5]
168+
(ko, vo) = A.productByKeyNaN keys vals 0 1
169+
ko `shouldBe` A.vector @Int 2 [1,2]
170+
vo `shouldBe` A.vector @Double 2 [2, 20]
159171

test/ArrayFire/DataSpec.hs

Lines changed: 118 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
{-# LANGUAGE TypeApplications #-}
33
module ArrayFire.DataSpec where
44

5-
import Control.Exception
6-
import Data.Complex
7-
import Data.Word
8-
import Foreign.C.Types
9-
import GHC.Int
10-
import Test.Hspec
5+
import Control.Exception
6+
import Data.Complex
7+
import Data.Word
8+
import Foreign.C.Types
9+
import GHC.Int
10+
import Prelude hiding (flip)
11+
import Test.Hspec
1112

12-
import ArrayFire
13+
import ArrayFire
1314

1415
spec :: Spec
1516
spec =
@@ -32,6 +33,116 @@ spec =
3233
constant @(Complex Float) [1] (1.0 :+ 1.0)
3334
`shouldBe`
3435
constant @(Complex Float) [1] (1.0 :+ 1.0)
36+
37+
describe "arange" $ do
38+
it "generates a sequence along dim 0 for a 1D array" $ do
39+
arange @Double [5] (-1) `shouldBe` vector @Double 5 [0,1,2,3,4]
40+
it "generates a sequence along dim 1 for a 2D array" $ do
41+
arange @Double [3,2] 1 `shouldBe` mkArray @Double [3,2] [0,0,0,1,1,1]
42+
43+
describe "iota" $ do
44+
it "generates a flat sequence without tiling" $ do
45+
iota @Double [5] [] `shouldBe` vector @Double 5 [0,1,2,3,4]
46+
it "tiles the sequence along dim 0" $ do
47+
iota @Double [3] [2] `shouldBe` vector @Double 6 [0,1,2,0,1,2]
48+
49+
describe "identity" $ do
50+
it "creates a 2x2 identity matrix" $ do
51+
identity @Double [2,2]
52+
`shouldBe` mkArray @Double [2,2] [1,0,0,1]
53+
it "creates a 3x3 identity matrix" $ do
54+
identity @Double [3,3]
55+
`shouldBe` mkArray @Double [3,3] [1,0,0,0,1,0,0,0,1]
56+
57+
describe "diagCreate" $ do
58+
it "creates a diagonal matrix from a vector (diag 0)" $ do
59+
diagCreate (vector @Double 3 [1,2,3]) 0
60+
`shouldBe` mkArray @Double [3,3] [1,0,0,0,2,0,0,0,3]
61+
it "creates a superdiagonal matrix (diag 1)" $ do
62+
diagCreate (vector @Double 2 [5,6]) 1
63+
`shouldBe` mkArray @Double [3,3] [0,0,0,5,0,0,0,6,0]
64+
65+
describe "diagExtract" $ do
66+
it "extracts the main diagonal of a square matrix" $ do
67+
diagExtract (mkArray @Double [3,3] [1,0,0,0,2,0,0,0,3]) 0
68+
`shouldBe` vector @Double 3 [1,2,3]
69+
it "is the inverse of diagCreate on the main diagonal" $ do
70+
let v = vector @Double 4 [1,2,3,4]
71+
diagExtract (diagCreate v 0) 0 `shouldBe` v
72+
73+
describe "lower" $ do
74+
it "extracts the lower triangular part (unit diagonal)" $ do
75+
let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9]
76+
lower m True
77+
`shouldBe` mkArray @Double [3,3] [1,2,3,0,1,6,0,0,1]
78+
it "extracts the lower triangular part (non-unit diagonal)" $ do
79+
let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9]
80+
lower m False
81+
`shouldBe` mkArray @Double [3,3] [1,2,3,0,5,6,0,0,9]
82+
83+
describe "upper" $ do
84+
it "extracts the upper triangular part (unit diagonal)" $ do
85+
let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9]
86+
upper m True
87+
`shouldBe` mkArray @Double [3,3] [1,0,0,4,1,0,7,8,1]
88+
it "extracts the upper triangular part (non-unit diagonal)" $ do
89+
let m = mkArray @Double [3,3] [1,2,3,4,5,6,7,8,9]
90+
upper m False
91+
`shouldBe` mkArray @Double [3,3] [1,0,0,4,5,0,7,8,9]
92+
93+
describe "tile" $ do
94+
it "tiles a scalar into a 3x3 array" $ do
95+
tile (scalar @Int 7) [3,3]
96+
`shouldBe` constant @Int [3,3] 7
97+
it "tiles a row vector along dim 0" $ do
98+
tile (mkArray @Int [1,3] [1,2,3]) [2,1]
99+
`shouldBe` mkArray @Int [2,3] [1,1,2,2,3,3]
100+
101+
describe "moddims" $ do
102+
it "reshapes a vector into a matrix" $ do
103+
moddims (vector @Int 6 [1..6]) [2,3]
104+
`shouldBe` mkArray @Int [2,3] [1,2,3,4,5,6]
105+
it "reshapes a matrix back to a vector" $ do
106+
let v = vector @Int 6 [1..6]
107+
moddims (moddims v [2,3]) [6] `shouldBe` v
108+
109+
describe "flat" $ do
110+
it "flattens a 2x3 matrix to a 6-element vector" $ do
111+
flat (mkArray @Int [2,3] [1,2,3,4,5,6])
112+
`shouldBe` vector @Int 6 [1,2,3,4,5,6]
113+
114+
describe "flip" $ do
115+
it "reverses a vector (dim 0)" $ do
116+
flip (vector @Int 4 [1,2,3,4]) 0
117+
`shouldBe` vector @Int 4 [4,3,2,1]
118+
it "reverses columns of a matrix (dim 1)" $ do
119+
flip (mkArray @Int [2,2] [1,2,3,4]) 1
120+
`shouldBe` mkArray @Int [2,2] [3,4,1,2]
121+
122+
describe "shift" $ do
123+
it "shifts a vector by 2 elements (wrapping)" $ do
124+
shift (vector @Double 4 [1,2,3,4]) 2 0 0 0
125+
`shouldBe` vector @Double 4 [3,4,1,2]
126+
127+
describe "select" $ do
128+
it "selects elements from two arrays based on a boolean mask" $ do
129+
let cond = vector @CBool 4 [1,0,1,0]
130+
a = vector @Double 4 [10,20,30,40]
131+
b = vector @Double 4 [1,2,3,4]
132+
select cond a b `shouldBe` vector @Double 4 [10,2,30,4]
133+
134+
describe "selectScalarR" $ do
135+
it "uses scalar for false positions" $ do
136+
let cond = vector @CBool 4 [1,0,1,0]
137+
a = vector @Double 4 [10,20,30,40]
138+
selectScalarR cond a 99 `shouldBe` vector @Double 4 [10,99,30,99]
139+
140+
describe "selectScalarL" $ do
141+
it "uses scalar for true positions" $ do
142+
let cond = vector @CBool 4 [1,0,1,0]
143+
b = vector @Double 4 [1,2,3,4]
144+
selectScalarL cond 99 b `shouldBe` vector @Double 4 [99,2,99,4]
145+
35146
it "Should join Arrays along the specified dimension" $ do
36147
join 0 (constant @Int [1, 3] 1) (constant @Int [1, 3] 2) `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2]
37148
join 1 (constant @Int [1, 2] 1) (constant @Int [1, 2] 2) `shouldBe` mkArray @Int [1, 4] [1, 1, 2, 2]

test/ArrayFire/IndexSpec.hs

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
module ArrayFire.IndexSpec where
33

44
import qualified ArrayFire as A
5+
import Data.Function ((&))
56
import Test.Hspec
67

78
spec :: Spec
@@ -25,14 +26,14 @@ spec =
2526

2627
describe "lookup" $ do
2728
it "gathers elements by an index array" $ do
28-
let arr = A.vector @Double 5 [10, 20, 30, 40, 50]
29-
idx = A.vector @Int 3 [0, 2, 4]
30-
A.lookup arr idx 0
29+
let arr = A.vector @Double 5 [10, 20, 30, 40, 50]
30+
ixArr = A.vector @Int 3 [0, 2, 4]
31+
A.lookup arr ixArr 0
3132
`shouldBe` A.vector @Double 3 [10, 30, 50]
3233
it "allows repeated indices" $ do
33-
let arr = A.vector @Int 5 [10, 20, 30, 40, 50]
34-
idx = A.vector @Int 4 [0, 0, 4, 4]
35-
A.lookup arr idx 0
34+
let arr = A.vector @Int 5 [10, 20, 30, 40, 50]
35+
ixArr = A.vector @Int 4 [0, 0, 4, 4]
36+
A.lookup arr ixArr 0
3637
`shouldBe` A.vector @Int 4 [10, 10, 50, 50]
3738

3839
describe "assignSeq" $ do
@@ -57,8 +58,6 @@ spec =
5758
A.indexGen arr [A.seqIdx (A.Seq 0 2 1) False]
5859
`shouldBe` A.vector @Double 3 [10, 20, 30]
5960
it "indexes a 2D sub-matrix with two seqIdx" $ do
60-
-- matrix (3,3): columns [[1,2,3],[4,5,6],[7,8,9]]
61-
-- rows 0-1, cols 0-1 → columns [[1,2],[4,5]]
6261
let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]]
6362
A.indexGen arr [ A.seqIdx (A.Seq 0 1 1) False
6463
, A.seqIdx (A.Seq 0 1 1) False ]
@@ -78,3 +77,35 @@ spec =
7877
A.indexGen result [ A.seqIdx (A.Seq 0 1 1) False
7978
, A.seqIdx (A.Seq 0 1 1) False ]
8079
`shouldBe` src
80+
81+
describe "(!) operator" $ do
82+
it "indexes a 1D sub-range with range" $ do
83+
let arr = A.vector @Double 5 [10, 20, 30, 40, 50]
84+
(arr A.! A.range 0 2)
85+
`shouldBe` A.vector @Double 3 [10, 20, 30]
86+
it "indexes a single element with at" $ do
87+
let arr = A.vector @Double 5 [10, 20, 30, 40, 50]
88+
(arr A.! A.at 2)
89+
`shouldBe` A.scalar @Double 30
90+
it "indexes a 2D sub-matrix with a tuple" $ do
91+
let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]]
92+
(arr A.! (A.range 0 1, A.range 0 1))
93+
`shouldBe` A.matrix @Double (2,2) [[1,2],[4,5]]
94+
95+
describe "(.~) operator" $ do
96+
it "assigns into a 1D slice" $ do
97+
let arr = A.vector @Double 5 [1..]
98+
src = A.vector @Double 3 [0, 0, 0]
99+
result = arr & A.range 1 3 A..~ src
100+
(result A.! A.range 1 3) `shouldBe` src
101+
it "assigns into a 2D sub-matrix" $ do
102+
let arr = A.matrix @Double (3,3) [[1,2,3],[4,5,6],[7,8,9]]
103+
src = A.matrix @Double (2,2) [[0,0],[0,0]]
104+
result = arr & (A.range 0 1, A.range 0 1) A..~ src
105+
(result A.! (A.range 0 1, A.range 0 1)) `shouldBe` src
106+
107+
describe "rangeStep" $ do
108+
it "selects every other element" $ do
109+
let arr = A.vector @Double 6 [0,1,2,3,4,5]
110+
(arr A.! A.rangeStep 0 4 2)
111+
`shouldBe` A.vector @Double 3 [0,2,4]

0 commit comments

Comments
 (0)