Expand API: BLAS, reductions, statistics, index ops, bitwise; type & FFI fixes#68
Open
dmjio wants to merge 10 commits into
Open
Expand API: BLAS, reductions, statistics, index ops, bitwise; type & FFI fixes#68dmjio wants to merge 10 commits into
dmjio wants to merge 10 commits into
Conversation
…gnGen, index type fixes ## New functions ### BLAS: `gemm` Adds `gemm :: AFType a => MatProp -> MatProp -> a -> Array a -> Array a -> a -> Array a`, the general matrix multiply C = alpha * op(A) * op(B) + beta * C_prev. This is more expressive than the existing `matmul`: it supports in-place accumulation and scalar scaling, making it directly useful for iterative eigenvalue algorithms (e.g. Jacobi rotations) that accumulate orthogonal transformations in Q. Implemented via the C FFI binding `af_gemm`; scalars are passed through `Storable` alloca/poke so any `AFType` element type is supported. Three new unit tests cover identity scaling, alpha-scaling, and transposition. ### Algorithm: key-value (segmented) reductions Adds nine new functions mirroring ArrayFire's `af_*_by_key` family: `sumByKey`, `sumByKeyNaN`, `productByKey`, `productByKeyNaN`, `minByKey`, `maxByKey`, `allTrueByKey`, `anyTrueByKey`, `countByKey` Each takes a keys `Array Int` and a values `Array a`, performs the named reduction over contiguous equal-key runs along a given dimension, and returns `(Array Int, Array a)`. These are essential for sparse tensor contractions that arise in many-body quantum systems and tensor network methods (e.g. grouping indices in an MPO sweep). A new internal FFI helper `op2p2kv` handles the keys–values two-output calling convention. Because ArrayFire requires the key array to be `s32` (C int) while Haskell uses `Int` (typically `s64`), the helper casts input keys to `s32` before calling the C function and casts the output keys back to `s64`, keeping the Haskell API uniform at `Array Int`. ### Statistics: `meanVar` and `meanVarWeighted` Adds `meanVar :: AFType a => Array a -> VarBias -> Int -> (Array a, Array a)` and its weighted variant, bound to `af_meanvar`. Computing mean and variance in a single pass is both more accurate and more efficient than calling them separately, which matters for normalisation steps in quantum state tomography and Hamiltonian learning. Introduces the `VarBias` high-level type (`VarianceDefault | VarianceSample | VariancePopulation`) backed by the previously-commented-out `AFVarBias` newtype in `Internal/Defines.hsc` (now uncommented and given a `Storable` instance). `VarBias` and its conversion `fromVarBias` are exported from `ArrayFire.Types`. ### Index: `assignSeq`, `indexGen`, `assignGen`; rename `span` → `afSpan` Implements three functions that were previously stubs (`error "Not implemented"`): - `assignSeq :: Array a -> [Seq] -> Array a -> Array a` — write a source array into a sequential slice of a destination array, bound to `af_assign_seq`. - `indexGen :: Array a -> [Index] -> Array a` — generalised indexing by a list of `Index` values (sequence or array), bound to `af_index_gen`. - `assignGen :: Array a -> [Index] -> Array a -> Array a` — generalised slice assignment, bound to `af_assign_gen`. These are needed for constructing sparse interaction terms (e.g. projecting onto a subspace defined by an index set). `span` is renamed to `afSpan` to avoid shadowing `Prelude.span`, which caused silent import errors in downstream modules. ## Type corrections and bug fixes ### `Index` type redesign (`Internal/Types.hsc`) The `Index a` type (which parameterised over the array element type) is replaced by a simpler unparameterised GADT-style sum: `data Index = SeqIndex Bool Seq | ArrIndex Bool (Array Int)` This removes a phantom type parameter that was never meaningful (index arrays are always integral), and fixes the `toAFIndex` implementation which was using `unsafeForeignPtrToPtr` incorrectly — the old version passed a pointer whose lifetime was not guaranteed by `withForeignPtr`. The new version stores the raw pointer and relies on `touchForeignPtr` calls at the use site to keep the ForeignPtr alive. The `Storable` peek instance for `AFIndex` also had the `Left`/`Right` branches swapped (`isSeq == True` should produce a sequence, not an array pointer); this is fixed. ### Return types for index-returning operations `imin`, `imax`, `sortIndex`, and `topk` all return an index array. Their return types are corrected from `(Array a, Array a)` to `(Array a, Array Word32)`, matching ArrayFire's documented `u32` output for index arrays. The corresponding `op2p` helper in `FFI.hs` is generalised from `(Array a, Array a)` to `(Array a, Array b)`. ### `afBackendCpu` constant (`Internal/Defines.hsc`) Fixed: `afBackendCpu` was mistakenly bound to `AF_BACKEND_DEFAULT` instead of `AF_BACKEND_CPU`. ### `toConnectivity` (`Internal/Types.hsc`) Fixed: `AFConnectivity 8` was mapped to `Conn4` instead of `Conn8`. ### `histogram` (`Image.hs`) Removed a spurious `cast` wrapping around the `af_histogram` call; the C function already returns `u32`, so double-casting was wrong. ## FFI infrastructure ### `op1d` removed; `op1` generalised `op1d :: Array a -> (...) -> Array b` was an alias for `op1` but with the output type fixed to `Array b` (different from input). All call sites that used `op1d` (`not`, `real`, `imag`, `count`) are migrated to `op1`. `op1` itself is generalised from `Array a -> ... -> Array a` to `Array a -> ... -> Array b`, making `op1d` redundant. ### `mask_` added to all `unsafePerformIO` helpers Every `op*` helper in `FFI.hs` now wraps its `unsafePerformIO` block with `mask_`. Without `mask_`, an asynchronous exception arriving during the FFI call can leave the output `AFArray` pointer uninitialised, producing a segfault or a garbage `ForeignPtr` finalization. ### `af_cast` disambiguation (`Arith.hs`) `af_cast` is now qualified as `ArrayFire.Internal.Arith.af_cast` at its call site in `cast` because `FFI.hs` also imports the same C symbol (needed for `op2p2kv`), creating an ambiguous occurrence error under GHC 9.10. ## `Num` / `Floating` instance fixes (`Orphans.hs`) - `negate` is simplified from an allocate-a-zero-constant approach to `scalar (-1) \`mul\` arr`, removing a dependency on dimension information. - `Eq` checks now compare dimensions first before invoking `allTrueAll`, avoiding a broadcast-induced wrong answer when shapes differ. - `pi` now uses `realToFrac (Prelude.pi :: Double)` instead of the hard-coded literal `3.14159`, gaining full IEEE 754 double precision. - Added `NFData (Array a)` instance (shallow: evaluates the `ForeignPtr` to WHNF). ## Documentation - Haddock constructor comments added to all sum types: `Backend`, `MatProp`, `BinaryOp`, `Storage`, `InterpType`, `CSpace`, `YccStd`, `MomentType`, `CannyThreshold`, `FluxFunction`, `DiffusionEq`, `IterativeDeconvAlgo`, `InverseDeconvAlgo`, `Cell`, `ColorMap`, `MarkerType`, `MatchType`, `TopK`, `HomographyType`, and the new `VarBias`. - Fixed stale parameter documentation in `drawVectorField2d` (previously all four array parameters were labelled "is the window handle"). ## Tests - `AlgorithmSpec`: seven new tests covering all `*ByKey` functions. - `BLASSpec`: three new tests for `gemm` (identity, alpha-scaling, transpose). - `IndexSpec`: complete rewrite — `index`, `afSpan`, `lookup`, `assignSeq`, `indexGen`, `assignGen` each covered with multiple cases. - `LAPACKSpec`: variable names corrected (`s,v,d` → `l,u,piv` / `q,r,tau`); `det` test split into real and complex cases with exact expected values; `inverse`, `rank`, and `norm` tests added. - `StatisticsSpec`: `topk` index type updated to `Word32`; three new tests for `meanVar` (population, sample) and `meanVarWeighted`. - `ArraySpec`: placeholder `1+1==2` replaced with a real `Array` addition test. - `ApproxExpect`: `shouldBeApprox` rewritten to use numpy-compatible `|a-b| <= atol + rtol * max(|a|, |b|)` (rtol=1e-5, atol=1e-8) instead of the fragile scale-and-compare hack; signature now requires `Ord` and is exported cleanly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
9373e43 to
a99e153
Compare
gemm, by-key reductions, meanVar, index ops, type fixes
a99e153 to
723c64a
Compare
5788fa0 to
c44d1f7
Compare
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Keeps the gen tool in sync with the manually-added bindings for by-key reductions, gemm, and meanvar. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Arith: fix bitAnd/bitOr/bitXor/bitShiftL/bitShiftR to return Array a instead of Array CBool, using op2 instead of op2bool - Data: add bitNot (bitwise complement via XOR with all-ones array) - Main: replace unsafePerformIO-based Arbitrary with mkArray, add Scalar newtype for Num laws, expand type coverage to include Complex and 64-bit types, wire in hspec spec - NumericalSpec: new test module - AlgorithmSpec, ArithSpec, ArraySpec, LAPACKSpec, SignalSpec, SparseSpec: expanded coverage Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Avoids the linked-list traversal and intermediate newArray allocation of mkArray by pinning the vector's buffer and passing it directly to af_create_array. Includes round-trip and dimension-mismatch tests. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
gemm, by-key reductions, meanVar, index ops, type fixes- isZero, isInf, isNaN: Array a -> Array CBool (af_is* always emits u8) - allTrue, anyTrue: Array a -> Int -> Array CBool (af_all/any_true emits u8) - where': Array a -> Array Word32 (af_where emits u32 indices) - cplx, cplx2, cplx2Batched: return Array (Complex a), not Array a - real, imag: simplified to (RealFloat a, AFType a, AFType (Complex a)) => Array (Complex a) -> Array a; previous signature was unlinked (a, b) - Update tests to match corrected return types Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
sign(-x) - sign(x) broke for two reasons: - Unsigned types (CBool, Word32): negate wraps (e.g. -1_u8 = 255), making sign(-x) = 0 for all positive inputs, so signum always returns 0 - Float zero: af_sign(-0.0) = 1 due to sign-bit check, giving signum(0.0) = 1 Replace with cast(gt x 0) - cast(lt x 0), which avoids negate entirely and correctly handles unsigned types and IEEE 754 negative zero. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds several new functions, fixes type errors and bugs, hardens the FFI layer, and expands test coverage.
New API surface
gemm(BLAS): General matrix multiplyC = α·op(A)·op(B) + β·C, bound toaf_gemm. Useful for iterative eigenvalue algorithms (Jacobi rotations, power iteration) where accumulated orthogonal transformations need scaling.sumByKey,sumByKeyNaN,productByKey,productByKeyNaN,minByKey,maxByKey,allTrueByKey,anyTrueByKey,countByKey— all bound to theiraf_*_by_keyC counterparts. These enable sparse tensor contractions and grouped reductions needed for MPO sweeps in tensor network methods.meanVar/meanVarWeighted(Statistics): simultaneous mean+variance in one pass viaaf_meanvar. Introduces theVarBiastype (VarianceDefault | VarianceSample | VariancePopulation).assignSeq,indexGen,assignGen(Index): three functions that were previouslyerror "Not implemented"stubs, now fully implemented viaaf_assign_seq,af_index_gen,af_assign_gen.bitNot(Arith): bitwise complement viaaf_bitnot, completing the bitwise operator set alongsidebitAnd,bitOr,bitXor,bitShiftL,bitShiftR.fromVector(Data): zero-copy ingestion of aStorableVectorinto anArrayviaaf_create_array. Avoids an intermediate list conversion and is the idiomatic path when data is already in a pinnedVector.Type corrections and bug fixes
imin,imax,sortIndex,topk: index output changed fromArray atoArray Word32(matching ArrayFire'su32contract).bitAnd,bitOr,bitXor,bitShiftL,bitShiftR: return types corrected fromArray btoArray a(same element type as inputs, as required by ArrayFire).afBackendCpuwas bound toAF_BACKEND_DEFAULTinstead ofAF_BACKEND_CPU.toConnectivity:AFConnectivity 8mapped toConn4instead ofConn8.AFIndexStorable peek:Left/Rightbranches were swapped (seq vs array pointer).histogram: spurious double-castremoved.spanrenamed toafSpanto stop shadowingPrelude.span.op1generalised fromArray a -> ... -> Array atoArray a -> ... -> Array b;op1dremoved.op2preturn type generalised to(Array a, Array b).af_castqualified inArith.hsto resolve GHC 9.10 ambiguous occurrence error.FFI hardening
include/headers:af_bitnot,af_create_array, and other functions missing from the C binding layer.unsafePerformIOhelpers inFFI.hsnow usemask_to prevent async exceptions from leaving output pointers uninitialised.op2p2kvadded for the key-value two-output calling convention (handlesInt↔s32/s64casting transparently).Num/Floatingfixes (Orphans.hs)negatesimplified toscalar (-1) \mul` arr`.Eqchecks dimension-guards before broadcasting.piuses full IEEE 754 precision viarealToFrac Prelude.pi.NFData (Array a)instance added.Documentation
Haddock constructor comments added to all major sum types in
Internal/Types.hsc. Fixed stale parameter docs indrawVectorField2d.Tests
Full test coverage added or corrected for all new and fixed functions.
shouldBeApproxrewritten to use numpy-compatible tolerances (rtol=1e-5,atol=1e-8).Test plan
cabal testpasses (Algorithm, BLAS, Index, LAPACK, Statistics specs)gemmtests cover identity, alpha-scaling, and transpose cases*ByKeytests cover sum, product, min, max, count, allTrue, anyTruemeanVartests cover population variance, sample variance, and weighted variantassignSeq/indexGen/assignGentests cover 1D and 2D casestopkandimin/imaxindex outputs are now correctly typed asWord32bitNottests cover round-trip complement identityfromVectortests cover 1D ingestion and round-trip equality withtoVector🤖 Generated with Claude Code