Skip to content

Commit ff07a4e

Browse files
committed
Unit tests for Fisher Yates algorithm
1 parent 3e69bd5 commit ff07a4e

2 files changed

Lines changed: 28 additions & 6 deletions

File tree

src/DataFrame/Operations/Permutation.hs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import qualified Data.Vector.Unboxed.Mutable as VUM
1313

1414
import Control.Exception (throw)
1515
import Control.Monad.ST (runST)
16+
import Data.Vector.Internal.Check (HasCallStack)
1617
import DataFrame.Errors (DataFrameException (..))
1718
import DataFrame.Internal.Column (Columnable, atIndicesStable)
1819
import DataFrame.Internal.DataFrame (DataFrame (..))
@@ -76,20 +77,23 @@ shuffle pureGen df =
7677
in
7778
df{columns = V.map (atIndicesStable indexes) (columns df)}
7879

79-
shuffledIndices :: (RandomGen g) => g -> Int -> VU.Vector Int
80+
shuffledIndices :: (HasCallStack, RandomGen g) => g -> Int -> VU.Vector Int
8081
shuffledIndices pureGen k
81-
| k <= 0 = VU.empty
82+
| k < 0 = error $ "Vector index may not be a neative number: " <> show k
83+
| k == 0 = VU.empty
8284
| otherwise = shuffleVec pureGen
8385
where
8486
shuffleVec :: (RandomGen g) => g -> VU.Vector Int
8587
shuffleVec g = runST $ do
8688
vm <- VUM.generate k id
87-
let (n, nGen) = randomR (1, (k - 1)) g
89+
let (n, nGen) = randomR (1, k - 1) g
8890
go vm n nGen
8991
VU.unsafeFreeze vm
9092

9193
go v (-1) _ = pure ()
9294
go v 0 _ = pure ()
9395
go v maxInd gen =
94-
let (n, nextGen) = randomR (1, maxInd) gen
95-
in VUM.swap v 0 n *> go (VUM.tail v) (maxInd - 1) nextGen
96+
let
97+
(n, nextGen) = randomR (1, maxInd) gen
98+
in
99+
VUM.swap v 0 n *> go (VUM.tail v) (maxInd - 1) nextGen

tests/Operations/Shuffle.hs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ module Operations.Shuffle where
55

66
import qualified DataFrame as D
77

8-
import DataFrame.Operations.Permutation (shuffle)
8+
import qualified Data.Set as Set
9+
import qualified Data.Vector.Unboxed as VU
10+
import DataFrame.Operations.Permutation (shuffle, shuffledIndices)
911
import System.Random (mkStdGen)
1012
import Test.HUnit (Test (..), assertEqual)
1113

@@ -74,11 +76,27 @@ shuffleDifferentSeedIsDifferent =
7476
(shuffled1 == shuffled2)
7577
)
7678

79+
-- Test that ShuffleIndeces does not dorp, add, or repeat any index
80+
shuffleDoesNotAddOrDropIndices :: Test
81+
shuffleDoesNotAddOrDropIndices =
82+
let
83+
gen = mkStdGen 42
84+
actual = (Set.fromList [0 .. 10])
85+
computedVector = shuffledIndices gen 11
86+
computed = (Set.fromList $ VU.toList $ shuffledIndices gen 11)
87+
in
88+
TestList
89+
[
90+
TestCase (assertEqual "Indecis are not dropped or added" (VU.length computedVector) 11)
91+
, TestCase (assertEqual "There are no repeated indecis" computed actual)
92+
]
93+
7794
tests :: [Test]
7895
tests =
7996
[ TestLabel "shuffleShuffles" shuffleShuffles
8097
, TestLabel "shufflePreservesData" shufflePreservesData
8198
, TestLabel "shufflePreservesColumnNames" shufflePreservesColumnNames
8299
, TestLabel "shuffleSameSeedIsSameShuffle" shuffleSameSeedIsSameShuffle
83100
, TestLabel "shuffleDifferentSeedIsDifferent" shuffleDifferentSeedIsDifferent
101+
, TestLabel "shuffleDoesNotAddOrDropIndices" shuffleDoesNotAddOrDropIndices
84102
]

0 commit comments

Comments
 (0)