Skip to content

Commit 7eaf4e7

Browse files
committed
Implement Fisher Yates algorithm
1 parent 77a140e commit 7eaf4e7

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

src/DataFrame/Operations/Permutation.hs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@ import qualified Data.List as L
99
import qualified Data.Text as T
1010
import qualified Data.Vector as V
1111
import qualified Data.Vector.Unboxed as VU
12+
import qualified Data.Vector.Unboxed.Mutable as VUM
1213

1314
import Control.Exception (throw)
15+
import Control.Monad.ST (runST)
1416
import DataFrame.Errors (DataFrameException (..))
15-
import DataFrame.Internal.Column
17+
import DataFrame.Internal.Column (Columnable, atIndicesStable)
1618
import DataFrame.Internal.DataFrame (DataFrame (..))
17-
import DataFrame.Internal.Expression
18-
import DataFrame.Internal.Row
19-
import DataFrame.Operations.Core
20-
import System.Random
21-
import System.Random.Shuffle (shuffle')
19+
import DataFrame.Internal.Expression (Expr (Col))
20+
import DataFrame.Internal.Row (sortedIndexes', toRowVector)
21+
import DataFrame.Operations.Core (columnNames, dimensions)
22+
import System.Random (Random (randomR), RandomGen)
2223

2324
-- | Sort order taken as a parameter by the 'sortBy' function.
2425
data SortOrder where
@@ -76,4 +77,17 @@ shuffle pureGen df =
7677
df{columns = V.map (atIndicesStable indexes) (columns df)}
7778

7879
shuffledIndices :: (RandomGen g) => g -> Int -> VU.Vector Int
79-
shuffledIndices pureGen k = VU.fromList (shuffle' [0 .. (k - 1)] k pureGen)
80+
shuffledIndices pureGen k = shuffleVec pureGen (VU.fromList [0 .. (k - 1)])
81+
where
82+
shuffleVec :: (RandomGen g) => g -> VU.Vector Int -> VU.Vector Int
83+
shuffleVec g v = runST $ do
84+
vm <- VU.thaw v
85+
let (n, nGen) = randomR (0, (k - 1)) g
86+
go vm n nGen
87+
VU.unsafeFreeze vm
88+
89+
go v (-1) _ = pure ()
90+
go v 0 _ = pure ()
91+
go v maxInd gen =
92+
let (n, nextGen) = randomR (0, maxInd) gen
93+
in VUM.swap v 0 n *> go (VUM.tail v) (maxInd - 1) nextGen

0 commit comments

Comments
 (0)