Skip to content

Commit 85835c9

Browse files
dmjioclaude
andcommitted
Fix bitwise op return types, add bitNot, expand test coverage
- Arith: fix bitAnd/bitOr/bitXor/bitShiftL/bitShiftR to return Array a instead of Array CBool, using op2 instead of op2bool - Data: add bitNot (bitwise complement via XOR with all-ones array) - Main: replace unsafePerformIO-based Arbitrary with mkArray, add Scalar newtype for Num laws, expand type coverage to include Complex and 64-bit types, wire in hspec spec - NumericalSpec: new test module - AlgorithmSpec, ArithSpec, ArraySpec, LAPACKSpec, SignalSpec, SparseSpec: expanded coverage Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 4a1e18c commit 85835c9

14 files changed

Lines changed: 601 additions & 112 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ result/
77
cabal.project.local
88
tags
99
/.stack-work/
10+
/.ghc.environment*

arrayfire.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ test-suite test
177177
ArrayFire.ImageSpec
178178
ArrayFire.IndexSpec
179179
ArrayFire.LAPACKSpec
180+
ArrayFire.NumericalSpec
180181
ArrayFire.RandomSpec
181182
ArrayFire.SignalSpec
182183
ArrayFire.SparseSpec

src/ArrayFire/Arith.hs

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -526,10 +526,10 @@ bitAnd
526526
-- ^ First input
527527
-> Array a
528528
-- ^ Second input
529-
-> Array CBool
529+
-> Array a
530530
-- ^ Result of bitwise and
531531
bitAnd x y =
532-
x `op2bool` y $ \arr arr1 arr2 ->
532+
x `op2` y $ \arr arr1 arr2 ->
533533
af_bitand arr arr1 arr2 1
534534

535535
-- | Bitwise and the values in one 'Array' against another 'Array'
@@ -546,10 +546,10 @@ bitAndBatched
546546
-- ^ Second input
547547
-> Bool
548548
-- ^ Use batch
549-
-> Array CBool
549+
-> Array a
550550
-- ^ Result of bitwise and
551551
bitAndBatched x y (fromIntegral . fromEnum -> batch) = do
552-
x `op2bool` y $ \arr arr1 arr2 ->
552+
x `op2` y $ \arr arr1 arr2 ->
553553
af_bitand arr arr1 arr2 batch
554554

555555
-- | Bitwise or the values in one 'Array' against another 'Array'
@@ -564,10 +564,10 @@ bitOr
564564
-- ^ First input
565565
-> Array a
566566
-- ^ Second input
567-
-> Array CBool
568-
-- ^ Result of bit or
567+
-> Array a
568+
-- ^ Result of bitwise or
569569
bitOr x y = do
570-
x `op2bool` y $ \arr arr1 arr2 ->
570+
x `op2` y $ \arr arr1 arr2 ->
571571
af_bitor arr arr1 arr2 1
572572

573573
-- | Bitwise or the values in one 'Array' against another 'Array'
@@ -584,10 +584,10 @@ bitOrBatched
584584
-- ^ Second input
585585
-> Bool
586586
-- ^ Use batch
587-
-> Array CBool
588-
-- ^ Result of bit or
587+
-> Array a
588+
-- ^ Result of bitwise or
589589
bitOrBatched x y (fromIntegral . fromEnum -> batch) = do
590-
x `op2bool` y $ \arr arr1 arr2 ->
590+
x `op2` y $ \arr arr1 arr2 ->
591591
af_bitor arr arr1 arr2 batch
592592

593593
-- | Bitwise xor the values in one 'Array' against another 'Array'
@@ -602,10 +602,10 @@ bitXor
602602
-- ^ First input
603603
-> Array a
604604
-- ^ Second input
605-
-> Array CBool
606-
-- ^ Result of bit xor
605+
-> Array a
606+
-- ^ Result of bitwise xor
607607
bitXor x y = do
608-
x `op2bool` y $ \arr arr1 arr2 ->
608+
x `op2` y $ \arr arr1 arr2 ->
609609
af_bitxor arr arr1 arr2 1
610610

611611
-- | Bitwise xor the values in one 'Array' against another 'Array'
@@ -622,10 +622,10 @@ bitXorBatched
622622
-- ^ Second input
623623
-> Bool
624624
-- ^ Use batch
625-
-> Array CBool
626-
-- ^ Result of bit xor
625+
-> Array a
626+
-- ^ Result of bitwise xor
627627
bitXorBatched x y (fromIntegral . fromEnum -> batch) = do
628-
x `op2bool` y $ \arr arr1 arr2 ->
628+
x `op2` y $ \arr arr1 arr2 ->
629629
af_bitxor arr arr1 arr2 batch
630630

631631
-- | Left bit shift the values in one 'Array' against another 'Array'
@@ -640,10 +640,10 @@ bitShiftL
640640
-- ^ First input
641641
-> Array a
642642
-- ^ Second input
643-
-> Array CBool
643+
-> Array a
644644
-- ^ Result of bit shift left
645645
bitShiftL x y =
646-
x `op2bool` y $ \arr arr1 arr2 ->
646+
x `op2` y $ \arr arr1 arr2 ->
647647
af_bitshiftl arr arr1 arr2 1
648648

649649
-- | Left bit shift the values in one 'Array' against another 'Array'
@@ -660,10 +660,10 @@ bitShiftLBatched
660660
-- ^ Second input
661661
-> Bool
662662
-- ^ Use batch
663-
-> Array CBool
663+
-> Array a
664664
-- ^ Result of bit shift left
665665
bitShiftLBatched x y (fromIntegral . fromEnum -> batch) = do
666-
x `op2bool` y $ \arr arr1 arr2 ->
666+
x `op2` y $ \arr arr1 arr2 ->
667667
af_bitshiftl arr arr1 arr2 batch
668668

669669
-- | Right bit shift the values in one 'Array' against another 'Array'
@@ -678,10 +678,10 @@ bitShiftR
678678
-- ^ First input
679679
-> Array a
680680
-- ^ Second input
681-
-> Array CBool
681+
-> Array a
682682
-- ^ Result of bit shift right
683683
bitShiftR x y =
684-
x `op2bool` y $ \arr arr1 arr2 ->
684+
x `op2` y $ \arr arr1 arr2 ->
685685
af_bitshiftr arr arr1 arr2 1
686686

687687
-- | Right bit shift the values in one 'Array' against another 'Array'
@@ -698,10 +698,10 @@ bitShiftRBatched
698698
-- ^ Second input
699699
-> Bool
700700
-- ^ Use batch
701-
-> Array CBool
702-
-- ^ Result of bit shift left
701+
-> Array a
702+
-- ^ Result of bit shift right
703703
bitShiftRBatched x y (fromIntegral . fromEnum -> batch) = do
704-
x `op2bool` y $ \arr arr1 arr2 ->
704+
x `op2` y $ \arr arr1 arr2 ->
705705
af_bitshiftr arr arr1 arr2 batch
706706

707707
-- | Cast one 'Array' into another

src/ArrayFire/Array.hs

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -177,21 +177,30 @@ mkArray
177177
-- ^ Returned array
178178
{-# NOINLINE mkArray #-}
179179
mkArray dims xs =
180-
unsafePerformIO $ do
181-
when (Prelude.length (take size xs) < size) $ do
182-
let msg = "Invalid elements provided. "
183-
<> "Expected "
184-
<> show size
185-
<> " elements received "
186-
<> show (Prelude.length xs)
187-
throwIO (AFException SizeError 203 msg)
188-
dataPtr <- castPtr <$> newArray (Prelude.take size xs)
180+
unsafePerformIO . mask_ $ do
189181
let ndims = fromIntegral (Prelude.length dims)
190182
alloca $ \arrayPtr -> do
191183
zeroOutArray arrayPtr
192184
dimsPtr <- newArray (DimT . fromIntegral <$> dims)
193-
throwAFError =<< af_create_array arrayPtr dataPtr ndims dimsPtr dType
194-
free dataPtr >> free dimsPtr
185+
if size == 0
186+
then onException
187+
(do throwAFError =<< af_create_handle arrayPtr ndims dimsPtr dType
188+
free dimsPtr)
189+
(free dimsPtr)
190+
else do
191+
when (Prelude.length (Prelude.take size xs) < size) $ do
192+
free dimsPtr
193+
let msg = "Invalid elements provided. "
194+
<> "Expected "
195+
<> show size
196+
<> " elements received "
197+
<> show (Prelude.length xs)
198+
throwIO (AFException SizeError 203 msg)
199+
dataPtr <- castPtr <$> newArray (Prelude.take size xs)
200+
onException
201+
(do throwAFError =<< af_create_array arrayPtr dataPtr ndims dimsPtr dType
202+
free dataPtr >> free dimsPtr)
203+
(free dataPtr >> free dimsPtr)
195204
arr <- peek arrayPtr
196205
Array <$> newForeignPtr af_release_array_finalizer arr
197206
where
@@ -484,7 +493,7 @@ toVector arr@(Array fptr) =
484493
unsafePerformIO . mask_ . withForeignPtr fptr $ \arrPtr -> do
485494
let len = getElements arr
486495
size = len * getSizeOf (Proxy @a)
487-
ptr <- mallocBytes (len * size)
496+
ptr <- mallocBytes size
488497
throwAFError =<< af_get_data_ptr (castPtr ptr) arrPtr
489498
newFptr <- newForeignPtr finalizerFree ptr
490499
pure $ unsafeFromForeignPtr0 newFptr len

src/ArrayFire/Data.hs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,37 @@ import Foreign.Storable
4242
import System.IO.Unsafe
4343
import Unsafe.Coerce
4444

45+
import Data.Bits
46+
4547
import ArrayFire.Exception
4648
import ArrayFire.FFI
49+
import ArrayFire.Internal.Array (af_get_dims)
4750
import ArrayFire.Internal.Data
4851
import ArrayFire.Internal.Defines
4952
import ArrayFire.Internal.Types
5053
import ArrayFire.Arith
5154

55+
-- | Bitwise complement of every element in an 'Array'
56+
--
57+
-- >>> A.bitNot (A.scalar @Int32 0)
58+
-- ArrayFire Array
59+
-- [1 1 1 1]
60+
-- -1
61+
bitNot
62+
:: (AFType a, Bits a)
63+
=> Array a
64+
-> Array a
65+
bitNot arr = arr `bitXor` ones
66+
where
67+
(d0, d1, d2, d3) = arr `infoFromArray4` af_get_dims
68+
ones = constant
69+
[ fromIntegral d0
70+
, fromIntegral d1
71+
, fromIntegral d2
72+
, fromIntegral d3
73+
]
74+
(complement zeroBits)
75+
5276
-- | Creates an 'Array' from a scalar value from given dimensions
5377
--
5478
-- >>> constant @Double [2,2] 2.0

src/ArrayFire/FFI.hs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,16 @@ op2p2kv (Array fptr1) (Array fptr2) op =
201201
peek p
202202
alloca $ \ptrOutput1 ->
203203
alloca $ \ptrOutput2 -> do
204-
throwAFError =<< op ptrOutput1 ptrOutput2 castedKey ptr2
204+
onException
205+
(throwAFError =<< op ptrOutput1 ptrOutput2 castedKey ptr2)
206+
(af_release_array_ffi castedKey)
205207
_ <- af_release_array_ffi castedKey
206208
outKey <- peek ptrOutput1
207209
outVal <- peek ptrOutput2
208210
finalKey <- alloca $ \p -> do
209-
throwAFError =<< af_cast p outKey s64
211+
onException
212+
(throwAFError =<< af_cast p outKey s64)
213+
(af_release_array_ffi outKey)
210214
peek p
211215
_ <- af_release_array_ffi outKey
212216
pure (finalKey, outVal)
@@ -415,7 +419,7 @@ infoFromFeatures
415419
-> a
416420
{-# NOINLINE infoFromFeatures #-}
417421
infoFromFeatures (Features fptr1) op =
418-
unsafePerformIO $ do
422+
unsafePerformIO . mask_ $ do
419423
withForeignPtr fptr1 $ \ptr1 -> do
420424
alloca $ \ptrInput -> do
421425
throwAFError =<< op ptrInput ptr1
@@ -450,7 +454,7 @@ infoFromArray
450454
-> a
451455
{-# NOINLINE infoFromArray #-}
452456
infoFromArray (Array fptr1) op =
453-
unsafePerformIO $ do
457+
unsafePerformIO . mask_ $ do
454458
withForeignPtr fptr1 $ \ptr1 -> do
455459
alloca $ \ptrInput -> do
456460
throwAFError =<< op ptrInput ptr1
@@ -463,7 +467,7 @@ infoFromArray2
463467
-> (a,b)
464468
{-# NOINLINE infoFromArray2 #-}
465469
infoFromArray2 (Array fptr1) op =
466-
unsafePerformIO $ do
470+
unsafePerformIO . mask_ $ do
467471
withForeignPtr fptr1 $ \ptr1 -> do
468472
alloca $ \ptrInput1 -> do
469473
alloca $ \ptrInput2 -> do
@@ -478,7 +482,7 @@ infoFromArray22
478482
-> (a,b)
479483
{-# NOINLINE infoFromArray22 #-}
480484
infoFromArray22 (Array fptr1) (Array fptr2) op =
481-
unsafePerformIO $ do
485+
unsafePerformIO . mask_ $ do
482486
withForeignPtr fptr1 $ \ptr1 -> do
483487
withForeignPtr fptr2 $ \ptr2 -> do
484488
alloca $ \ptrInput1 -> do
@@ -493,7 +497,7 @@ infoFromArray3
493497
-> (a,b,c)
494498
{-# NOINLINE infoFromArray3 #-}
495499
infoFromArray3 (Array fptr1) op =
496-
unsafePerformIO $
500+
unsafePerformIO . mask_ $
497501
withForeignPtr fptr1 $ \ptr1 -> do
498502
alloca $ \ptrInput1 -> do
499503
alloca $ \ptrInput2 -> do
@@ -510,7 +514,7 @@ infoFromArray4
510514
-> (a,b,c,d)
511515
{-# NOINLINE infoFromArray4 #-}
512516
infoFromArray4 (Array fptr1) op =
513-
unsafePerformIO $
517+
unsafePerformIO . mask_ $
514518
withForeignPtr fptr1 $ \ptr1 ->
515519
alloca $ \ptrInput1 ->
516520
alloca $ \ptrInput2 ->

0 commit comments

Comments
 (0)