Skip to content

Commit d99615f

Browse files
dmjioclaude
andcommitted
Add fromVector: zero-copy Storable Vector → Array ingestion
Avoids the linked-list traversal and intermediate newArray allocation of mkArray by pinning the vector's buffer and passing it directly to af_create_array. Includes round-trip and dimension-mismatch tests. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 85835c9 commit d99615f

2 files changed

Lines changed: 62 additions & 0 deletions

File tree

src/ArrayFire/Array.hs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,46 @@ mkArray dims xs =
209209

210210
-- af_err af_create_handle(af_array *arr, const unsigned ndims, const dim_t * const dims, const af_dtype type);
211211

212+
-- | Constructs an 'Array' from a 'Storable' 'Vector', avoiding the intermediate list allocation of 'mkArray'.
213+
--
214+
-- The vector's pinned buffer is passed directly to @af_create_array@.
215+
-- Throws 'AFException' if the vector length does not match the product of the given dimensions.
216+
--
217+
-- >>> fromVector @Double [3] (Data.Vector.Storable.fromList [1,2,3])
218+
-- ArrayFire Array
219+
-- [3 1 1 1]
220+
-- 1.0000
221+
-- 2.0000
222+
-- 3.0000
223+
fromVector
224+
:: forall a
225+
. AFType a
226+
=> [Int]
227+
-- ^ Dimensions
228+
-> Vector a
229+
-- ^ Source storable vector
230+
-> Array a
231+
{-# NOINLINE fromVector #-}
232+
fromVector dims vec =
233+
unsafePerformIO . mask_ $ do
234+
let size = Prelude.product dims
235+
ndims = fromIntegral (Prelude.length dims)
236+
dType = afType (Proxy @a)
237+
when (V.length vec /= size) $
238+
throwIO $ AFException SizeError 203 $
239+
"fromVector: dimension product " <> show size <>
240+
" does not match vector length " <> show (V.length vec)
241+
alloca $ \arrayPtr -> do
242+
zeroOutArray arrayPtr
243+
dimsPtr <- newArray (DimT . fromIntegral <$> dims)
244+
onException
245+
(V.unsafeWith vec $ \ptr -> do
246+
throwAFError =<< af_create_array arrayPtr (castPtr ptr) ndims dimsPtr dType
247+
free dimsPtr)
248+
(free dimsPtr)
249+
arr <- peek arrayPtr
250+
Array <$> newForeignPtr af_release_array_finalizer arr
251+
212252
-- | Copies an 'Array' to a new 'Array'
213253
--
214254
-- >>> copyArray (scalar @Double 10)

test/ArrayFire/ArraySpec.hs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,25 @@ spec =
168168
it "length of toVector matches getElements" $ do
169169
let arr = mkArray @Double [7, 13] (repeat 0)
170170
V.length (toVector arr) `shouldBe` getElements arr
171+
172+
describe "fromVector" $ do
173+
it "round-trips a Double vector" $ do
174+
let xs = V.fromList [1..10 :: Double]
175+
arr = fromVector @Double [10] xs
176+
toVector arr `shouldBe` xs
177+
it "round-trips an Int vector" $ do
178+
let xs = V.fromList [1..100 :: Int]
179+
arr = fromVector @Int [100] xs
180+
toVector arr `shouldBe` xs
181+
it "round-trips a Complex Double vector" $ do
182+
let xs = V.fromList [1 :+ 2, 3 :+ 4 :: Complex Double]
183+
arr = fromVector @(Complex Double) [2] xs
184+
toVector arr `shouldBe` xs
185+
it "produces the same result as mkArray" $ do
186+
let xs = [1..25 :: Double]
187+
arr1 = mkArray @Double [5,5] xs
188+
arr2 = fromVector @Double [5,5] (V.fromList xs)
189+
arr2 `shouldBe` arr1
190+
it "throws on dimension mismatch" $ do
191+
let xs = V.fromList [1,2,3 :: Double]
192+
evaluate (fromVector @Double [4] xs) `shouldThrow` anyException

0 commit comments

Comments
 (0)