Skip to content

Commit 888be21

Browse files
dmjioclaude
andcommitted
Fix gemm API, add tests for bitNot and complex number functions.
- Remove dead `beta` parameter from `gemm`: the C binding always starts with a null C array, so beta*C_prev was silently a no-op. Beta memory is now zero-filled internally. - Add tests for `bitNot`: complement of 0/-1 for Int32/Word32, and round-trip identity. - Add tests for `cplx`, `cplx2`, `real`, `imag`: scalar/vector construction, extraction, and the round-trip property `cplx2 (real c) (imag c) == c`. - Add non-trivial gemm test (A*B with known exact result). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 6907d0f commit 888be21

4 files changed

Lines changed: 75 additions & 21 deletions

File tree

src/ArrayFire/BLAS.hs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE ScopedTypeVariables #-}
12
{-# LANGUAGE ViewPatterns #-}
23
--------------------------------------------------------------------------------
34
-- |
@@ -35,8 +36,9 @@ import Control.Exception (mask_)
3536
import Data.Complex
3637
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
3738
import Foreign.Marshal.Alloc (alloca)
38-
import Foreign.Ptr (castPtr)
39-
import Foreign.Storable (peek, poke)
39+
import Foreign.Marshal.Utils (fillBytes)
40+
import Foreign.Ptr (Ptr, castPtr)
41+
import Foreign.Storable (peek, poke, sizeOf)
4042
import System.IO.Unsafe (unsafePerformIO)
4143

4244
import ArrayFire.Exception
@@ -175,18 +177,18 @@ transposeInPlace
175177
transposeInPlace arr (fromIntegral . fromEnum -> b) =
176178
arr `inPlace` (`af_transpose_inplace` b)
177179

178-
-- | General Matrix Multiply: C = alpha * op(A) * op(B) + beta * C_prev
180+
-- | General Matrix Multiply: C = alpha * op(A) * op(B)
179181
--
180-
-- More general than 'matmul': supports scaling and accumulation.
181-
-- When @beta = 0@, equivalent to @alpha * op(A) * op(B)@.
182+
-- More general than 'matmul': supports per-element scaling and optional
183+
-- transposition via 'MatProp'.
182184
--
183-
-- >>> gemm None None 1.0 (matrix @Double (2,2) [[1,0],[0,1]]) (matrix @Double (2,2) [[3,4],[5,6]]) 0.0
185+
-- >>> gemm None None 1.0 (matrix @Double (2,2) [[1,0],[0,1]]) (matrix @Double (2,2) [[3,4],[5,6]])
184186
-- ArrayFire Array
185187
-- [2 2 1 1]
186188
-- 3.0000 5.0000
187189
-- 4.0000 6.0000
188190
gemm
189-
:: AFType a
191+
:: forall a . AFType a
190192
=> MatProp
191193
-- ^ Transformation applied to A ('None', 'Trans', or 'CTrans')
192194
-> MatProp
@@ -197,20 +199,18 @@ gemm
197199
-- ^ Matrix A
198200
-> Array a
199201
-- ^ Matrix B
200-
-> a
201-
-- ^ Scalar beta (use 0 for pure multiply)
202202
-> Array a
203-
-- ^ Result C = alpha * op(A) * op(B) + beta * C_prev
204-
gemm opA opB alpha (Array fptrA) (Array fptrB) beta =
203+
-- ^ Result C = alpha * op(A) * op(B)
204+
gemm opA opB alpha (Array fptrA) (Array fptrB) =
205205
unsafePerformIO . mask_ $
206206
withForeignPtr fptrA $ \ptrA ->
207207
withForeignPtr fptrB $ \ptrB ->
208208
alloca $ \pOut ->
209209
alloca $ \pAlpha ->
210-
alloca $ \pBeta -> do
210+
alloca $ \(pBeta :: Ptr a) -> do
211211
zeroOutArray pOut
212212
poke pAlpha alpha
213-
poke pBeta beta
213+
fillBytes pBeta 0 (sizeOf alpha)
214214
throwAFError =<< af_gemm pOut (toMatProp opA) (toMatProp opB) (castPtr pAlpha) ptrA ptrB (castPtr pBeta)
215215
Array <$> (newForeignPtr af_release_array_finalizer =<< peek pOut)
216216
{-# NOINLINE gemm #-}

test/ArrayFire/ArithSpec.hs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
module ArrayFire.ArithSpec where
66

7-
import ArrayFire (AFType, Array, cast, clamp, getType, isInf, isZero, matrix, maxOf, minOf, mkArray, scalar, vector)
7+
import ArrayFire (AFType, Array, cast, clamp, cplx, cplx2, getType, imag, isInf, isZero, matrix, maxOf, minOf, mkArray, real, scalar, vector)
88
import qualified ArrayFire
9+
import Data.Complex (Complex (..))
910
import Control.Exception (throwIO)
1011
import Control.Monad (unless, when)
1112
import Foreign.C
@@ -226,3 +227,37 @@ spec =
226227
evalf (signum (scalar @Double (-2.5))) `shouldBeApprox` (-1)
227228
it "signum vector" $
228229
signum (vector @Int 3 [-4, 0, 7]) `shouldBe` vector @Int 3 [-1, 0, 1]
230+
231+
describe "cplx" $ do
232+
it "lifts a real scalar to complex with zero imaginary part" $
233+
cplx (scalar @Double 5.0) `shouldBe` scalar @(Complex Double) (5.0 :+ 0.0)
234+
it "real . cplx == id on a vector" $ do
235+
let v = vector @Double 4 [1, 2, 3, 4]
236+
(real (cplx v) :: Array Double) `shouldBe` v
237+
it "imag . cplx == 0 on a vector" $ do
238+
let v = vector @Double 4 [1, 2, 3, 4]
239+
ArrayFire.toList (imag (cplx v) :: Array Double) `shouldBe` [0, 0, 0, 0]
240+
241+
describe "cplx2" $ do
242+
it "combines real and imaginary parts into a complex scalar" $
243+
cplx2 (scalar @Double 3.0) (scalar @Double 4.0)
244+
`shouldBe` scalar @(Complex Double) (3.0 :+ 4.0)
245+
it "real . cplx2 r i == r" $ do
246+
let r = vector @Double 3 [1, 2, 3]
247+
i = vector @Double 3 [4, 5, 6]
248+
(real (cplx2 r i) :: Array Double) `shouldBe` r
249+
it "imag . cplx2 r i == i" $ do
250+
let r = vector @Double 3 [1, 2, 3]
251+
i = vector @Double 3 [4, 5, 6]
252+
(imag (cplx2 r i) :: Array Double) `shouldBe` i
253+
254+
describe "real / imag" $ do
255+
it "real extracts the real part of a complex scalar" $
256+
(real (scalar @(Complex Double) (7.0 :+ 3.0)) :: Array Double)
257+
`shouldBe` scalar @Double 7.0
258+
it "imag extracts the imaginary part of a complex scalar" $
259+
(imag (scalar @(Complex Double) (7.0 :+ 3.0)) :: Array Double)
260+
`shouldBe` scalar @Double 3.0
261+
it "real and imag round-trip via cplx2" $ do
262+
let c = vector @(Complex Double) 3 [1:+2, 3:+4, 5:+6]
263+
cplx2 (real c :: Array Double) (imag c :: Array Double) `shouldBe` c

test/ArrayFire/BLASSpec.hs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,25 @@ spec =
2828
let m = matrix @Double (2,2) [[1,1],[2,2]]
2929
transposeInPlace m False
3030
m `shouldBe` matrix @Double (2,2) [[1,2],[1,2]]
31-
it "Should perform gemm: C = 1*A*B + 0*C (identity scaling)" $ do
31+
it "Should perform gemm: alpha=1, A*I = A" $ do
3232
let a = matrix @Double (2,2) [[1,2],[3,4]]
3333
b = matrix @Double (2,2) [[1,0],[0,1]]
34-
gemm None None 1.0 a b 0.0 `shouldBe` a
35-
it "Should perform gemm: C = alpha*A*B with alpha=2" $ do
36-
-- b is column-major: col0=[3,4], col1=[5,6] → matrix [[3,5],[4,6]]
34+
gemm None None 1.0 a b `shouldBe` a
35+
it "Should perform gemm: alpha=2 scales the result" $ do
36+
-- b col-major: col0=[3,4], col1=[5,6]
3737
-- 2 * I * b = 2b → col0=[6,8], col1=[10,12]
3838
let a = matrix @Double (2,2) [[1,0],[0,1]]
3939
b = matrix @Double (2,2) [[3,4],[5,6]]
40-
gemm None None 2.0 a b 0.0 `shouldBe` matrix @Double (2,2) [[6,8],[10,12]]
41-
it "Should perform gemm with transposed A: C = A^T * B" $ do
40+
gemm None None 2.0 a b `shouldBe` matrix @Double (2,2) [[6,8],[10,12]]
41+
it "Should perform gemm with transposed A" $ do
4242
let a = matrix @Double (2,2) [[1,3],[2,4]]
4343
b = matrix @Double (2,2) [[1,0],[0,1]]
44-
gemm Trans None 1.0 a b 0.0 `shouldBe` matrix @Double (2,2) [[1,2],[3,4]]
44+
gemm Trans None 1.0 a b `shouldBe` matrix @Double (2,2) [[1,2],[3,4]]
45+
it "Should perform gemm: non-trivial A*B" $ do
46+
-- matrix (2,2) [[c0r0,c0r1],[c1r0,c1r1]] is column-major.
47+
-- A = [[1,3],[2,4]], B = [[5,7],[6,8]] (rows displayed by ArrayFire)
48+
-- A*B col0 = [1*5+3*6, 2*5+4*6] = [23,34]
49+
-- A*B col1 = [1*7+3*8, 2*7+4*8] = [31,46]
50+
let a = matrix @Double (2,2) [[1,2],[3,4]]
51+
b = matrix @Double (2,2) [[5,6],[7,8]]
52+
gemm None None 1.0 a b `shouldBe` matrix @Double (2,2) [[23,34],[31,46]]

test/ArrayFire/DataSpec.hs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,14 @@ spec =
148148
join 1 (constant @Int [1, 2] 1) (constant @Int [1, 2] 2) `shouldBe` mkArray @Int [1, 4] [1, 1, 2, 2]
149149
joinMany 0 [constant @Int [1, 3] 1, constant @Int [1, 3] 2] `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2]
150150
joinMany 1 [constant @Int [1, 2] 1, constant @Int [1, 1] 2, constant @Int [1, 3] 3] `shouldBe` mkArray @Int [1, 6] [1, 1, 2, 3, 3, 3]
151+
152+
describe "bitNot" $ do
153+
it "complements 0 to all-ones (-1 in two's complement) for Int32" $ do
154+
bitNot (scalar @Int32 0) `shouldBe` scalar @Int32 (-1)
155+
it "complements -1 to 0 for Int32" $ do
156+
bitNot (scalar @Int32 (-1)) `shouldBe` scalar @Int32 0
157+
it "complements 0 to maxBound for Word32" $ do
158+
bitNot (scalar @Word32 0) `shouldBe` scalar @Word32 maxBound
159+
it "bitNot . bitNot == id" $ do
160+
let v = vector @Int32 4 [0, 1, -1, 42]
161+
bitNot (bitNot v) `shouldBe` v

0 commit comments

Comments
 (0)