Skip to content

Commit fcd1563

Browse files
committed
Try with cblas fix
1 parent 44ad7bc commit fcd1563

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

Sources/Matft/library/cblas.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,10 @@ internal func copy_mfarray<T: MfStorable>(_ mfarray: MfArray, dsttmpMfarray: MfA
854854
mfarray.withUnsafeMutableStartPointer(datatype: T.self){
855855
[unowned mfarray] (srcptr) in
856856
for cblasPrams in OptOffsetParamsSequence(shape: dsttmpMfarray.shape, bigger_strides: dsttmpMfarray.strides, smaller_strides: mfarray.strides){
857-
wrap_cblas_copy(cblasPrams.blocksize, srcptr + cblasPrams.s_offset, cblasPrams.s_stride, dstptr + cblasPrams.b_offset, cblasPrams.b_stride, cblas_func)
857+
// Match Accelerate behavior: adjust pointer for negative strides
858+
let srcptr = cblasPrams.s_stride >= 0 ? srcptr + cblasPrams.s_offset : srcptr + (cblasPrams.blocksize - 1) * cblasPrams.s_stride + cblasPrams.s_offset
859+
let dstptr = cblasPrams.b_stride >= 0 ? dstptr + cblasPrams.b_offset : dstptr + (cblasPrams.blocksize - 1) * cblasPrams.b_stride + cblasPrams.b_offset
860+
wrap_cblas_copy(cblasPrams.blocksize, srcptr, cblasPrams.s_stride, dstptr, cblasPrams.b_stride, cblas_func)
858861
}
859862
}
860863
}
@@ -873,7 +876,10 @@ internal func copy_by_cblas<T: MfStorable>(_ src_mfarray: MfArray, _ dst_mfarray
873876
src_mfarray.withUnsafeMutableStartPointer(datatype: T.self){
874877
srcptr in
875878
for cblasPrams in OptOffsetParamsSequence(shape: shape, bigger_strides: bigger_strides, smaller_strides: smaller_strides){
876-
wrap_cblas_copy(cblasPrams.blocksize, srcptr + cblasPrams.s_offset, cblasPrams.s_stride, dstptr + cblasPrams.b_offset, cblasPrams.b_stride, cblas_func)
879+
// Match Accelerate behavior: adjust pointer for negative strides
880+
let srcptr = cblasPrams.s_stride >= 0 ? srcptr + cblasPrams.s_offset : srcptr + (cblasPrams.blocksize - 1) * cblasPrams.s_stride + cblasPrams.s_offset
881+
let dstptr = cblasPrams.b_stride >= 0 ? dstptr + cblasPrams.b_offset : dstptr + (cblasPrams.blocksize - 1) * cblasPrams.b_stride + cblasPrams.b_offset
882+
wrap_cblas_copy(cblasPrams.blocksize, srcptr, cblasPrams.s_stride, dstptr, cblasPrams.b_stride, cblas_func)
877883
}
878884
}
879885
}

0 commit comments

Comments
 (0)