Skip to content

Commit 402f667

Browse files
committed
Implement Fisher Yates algorithm
1 parent 77a140e commit 402f667

1 file changed

Lines changed: 23 additions & 8 deletions

File tree

src/DataFrame/Operations/Permutation.hs

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@ import qualified Data.List as L
99
import qualified Data.Text as T
1010
import qualified Data.Vector as V
1111
import qualified Data.Vector.Unboxed as VU
12+
import qualified Data.Vector.Unboxed.Mutable as VUM
1213

1314
import Control.Exception (throw)
15+
import Control.Monad.ST (runST)
1416
import DataFrame.Errors (DataFrameException (..))
15-
import DataFrame.Internal.Column
17+
import DataFrame.Internal.Column ( Columnable, atIndicesStable )
1618
import DataFrame.Internal.DataFrame (DataFrame (..))
17-
import DataFrame.Internal.Expression
18-
import DataFrame.Internal.Row
19-
import DataFrame.Operations.Core
20-
import System.Random
21-
import System.Random.Shuffle (shuffle')
19+
import DataFrame.Internal.Expression ( Expr(Col) )
20+
import DataFrame.Internal.Row ( sortedIndexes', toRowVector )
21+
import DataFrame.Operations.Core ( columnNames, dimensions )
22+
import System.Random ( RandomGen, Random(randomR))
23+
2224

2325
-- | Sort order taken as a parameter by the 'sortBy' function.
2426
data SortOrder where
@@ -75,5 +77,18 @@ shuffle pureGen df =
7577
in
7678
df{columns = V.map (atIndicesStable indexes) (columns df)}
7779

78-
shuffledIndices :: (RandomGen g) => g -> Int -> VU.Vector Int
79-
shuffledIndices pureGen k = VU.fromList (shuffle' [0 .. (k - 1)] k pureGen)
80+
shuffledIndices :: (RandomGen g) => g -> Int -> VU.Vector Int
81+
shuffledIndices pureGen k = shuffleVec pureGen (VU.fromList [0 .. (k - 1)] )
82+
where
83+
shuffleVec :: (RandomGen g) => g -> VU.Vector Int -> VU.Vector Int
84+
shuffleVec g v = runST $ do
85+
vm <- VU.thaw v
86+
let (n, nGen) = randomR (0, (k - 1)) g
87+
go vm n nGen
88+
VU.unsafeFreeze vm
89+
90+
go v (- 1) _ = pure ()
91+
go v 0 _ = pure ()
92+
go v maxInd gen =
93+
let (n, nextGen) = randomR (0,maxInd) gen
94+
in VUM.swap v 0 n *> go (VUM.tail v) (maxInd - 1) nextGen

0 commit comments

Comments
 (0)