Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 48 additions & 54 deletions Numeric/Sum.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE BangPatterns, DeriveDataTypeable, FlexibleContexts,
MultiParamTypeClasses, TypeFamilies, CPP #-}
FlexibleInstances, MultiParamTypeClasses, ScopedTypeVariables,
TypeFamilies, CPP #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
-- |
-- Module : Numeric.Sum
Expand Down Expand Up @@ -69,25 +70,21 @@ import qualified Data.Vector.Generic.Mutable as GM
import qualified Data.Vector.Unboxed as U

-- | A class for summation of floating point numbers.
class Summation s where
class RealFloat a => Summation s a where
-- | The identity for summation.
zero :: s
zero :: s a

-- | Add a value to a sum.
add :: s -> Double -> s
add :: s a -> a -> s a

-- | Sum a collection of values.
--
-- Example:
-- @foo = 'Numeric.Sum.sum' 'kbn' [1,2,3]@
sum :: (F.Foldable f) => (s -> Double) -> f Double -> Double
sum f = f . F.foldl' add zero
sum :: F.Foldable f => (s a -> a) -> f a -> a
sum f = f . F.foldl' add (zero :: s a)
{-# INLINE sum #-}

instance Summation Double where
zero = 0
add = (+)

-- | Kahan summation. This is the least accurate of the compensated
-- summation methods. In practice, it only beats naive summation for
-- inputs with large magnitude. Kahan summation can be /less/
Expand All @@ -96,12 +93,12 @@ instance Summation Double where
-- This summation method is included for completeness. Its use is not
-- recommended. In practice, 'KBNSum' is both 30% faster and more
-- accurate.
data KahanSum = KahanSum {-# UNPACK #-} !Double {-# UNPACK #-} !Double
data KahanSum a = KahanSum !a !a
deriving (Eq, Show, Typeable, Data)

instance U.Unbox KahanSum
newtype instance U.MVector s KahanSum = MV_KahanSum (U.MVector s (Double, Double))
instance MVector U.MVector KahanSum where
instance U.Unbox a => U.Unbox (KahanSum a)
newtype instance U.MVector s (KahanSum a) = MV_KahanSum (U.MVector s (a, a))
instance U.Unbox a => MVector U.MVector (KahanSum a) where
{-# INLINE GM.basicLength #-}
{-# INLINE GM.basicUnsafeSlice #-}
{-# INLINE basicOverlaps #-}
Expand Down Expand Up @@ -129,8 +126,8 @@ instance MVector U.MVector KahanSum where
basicUnsafeMove (MV_KahanSum mvec) (MV_KahanSum mvec') = basicUnsafeMove mvec mvec'
basicUnsafeGrow (MV_KahanSum mvec) len = MV_KahanSum `liftM` basicUnsafeGrow mvec len

newtype instance U.Vector KahanSum = V_KahanSum (U.Vector (Double, Double))
instance Vector U.Vector KahanSum where
newtype instance U.Vector (KahanSum a) = V_KahanSum (U.Vector (a, a))
instance U.Unbox a => Vector U.Vector (KahanSum a) where
{-# INLINE basicUnsafeFreeze #-}
{-# INLINE basicUnsafeThaw #-}
{-# INLINE G.basicLength #-}
Expand All @@ -147,43 +144,43 @@ instance Vector U.Vector KahanSum where
elemseq (V_KahanSum vec) val = elemseq vec ((\ (KahanSum a b) -> (a, b)) val)


instance Summation KahanSum where
instance RealFloat a => Summation KahanSum a where
zero = KahanSum 0 0
add = kahanAdd

instance NFData KahanSum where
instance NFData (KahanSum a) where
rnf !_ = ()

-- | @since 0.3.0.0
instance Monoid KahanSum where
instance RealFloat a => Monoid (KahanSum a) where
mempty = zero
s `mappend` KahanSum s' _ = add s s'

#if MIN_VERSION_base(4,9,0)
-- | @since 0.3.0.0
instance Semigroup KahanSum where
instance RealFloat a => Semigroup (KahanSum a) where
(<>) = mappend
#endif

kahanAdd :: KahanSum -> Double -> KahanSum
kahanAdd :: RealFloat a => KahanSum a -> a -> KahanSum a
kahanAdd (KahanSum sum c) x = KahanSum sum' c'
where sum' = sum + y
c' = (sum' - sum) - y
y = x - c

-- | Return the result of a Kahan sum.
kahan :: KahanSum -> Double
kahan :: KahanSum a -> a
kahan (KahanSum sum _) = sum

-- | Kahan-Babuška-Neumaier summation. This is a little more
-- computationally costly than plain Kahan summation, but is /always/
-- at least as accurate.
data KBNSum = KBNSum {-# UNPACK #-} !Double {-# UNPACK #-} !Double
data KBNSum a = KBNSum !a !a
deriving (Eq, Show, Typeable, Data)

instance U.Unbox KBNSum
newtype instance U.MVector s KBNSum = MV_KBNSum (U.MVector s (Double, Double))
instance MVector U.MVector KBNSum where
instance U.Unbox a => U.Unbox (KBNSum a)
newtype instance U.MVector s (KBNSum a) = MV_KBNSum (U.MVector s (a, a))
instance U.Unbox a => MVector U.MVector (KBNSum a) where
{-# INLINE GM.basicLength #-}
{-# INLINE GM.basicUnsafeSlice #-}
{-# INLINE basicOverlaps #-}
Expand Down Expand Up @@ -211,8 +208,8 @@ instance MVector U.MVector KBNSum where
basicUnsafeMove (MV_KBNSum mvec) (MV_KBNSum mvec') = basicUnsafeMove mvec mvec'
basicUnsafeGrow (MV_KBNSum mvec) len = MV_KBNSum `liftM` basicUnsafeGrow mvec len

newtype instance U.Vector KBNSum = V_KBNSum (U.Vector (Double, Double))
instance Vector U.Vector KBNSum where
newtype instance U.Vector (KBNSum a) = V_KBNSum (U.Vector (a, a))
instance U.Unbox a => Vector U.Vector (KBNSum a) where
{-# INLINE basicUnsafeFreeze #-}
{-# INLINE basicUnsafeThaw #-}
{-# INLINE G.basicLength #-}
Expand All @@ -229,32 +226,32 @@ instance Vector U.Vector KBNSum where
elemseq (V_KBNSum vec) val = elemseq vec ((\ (KBNSum a b) -> (a, b)) val)


instance Summation KBNSum where
instance RealFloat a => Summation KBNSum a where
zero = KBNSum 0 0
add = kbnAdd

instance NFData KBNSum where
instance NFData (KBNSum a) where
rnf !_ = ()

-- | @since 0.3.0.0
instance Monoid KBNSum where
instance RealFloat a => Monoid (KBNSum a) where
mempty = zero
s `mappend` KBNSum s' c' = add (add s s') c'

#if MIN_VERSION_base(4,9,0)
-- | @since 0.3.0.0
instance Semigroup KBNSum where
instance RealFloat a => Semigroup (KBNSum a) where
(<>) = mappend
#endif

kbnAdd :: KBNSum -> Double -> KBNSum
kbnAdd :: (Num a, Ord a) => KBNSum a -> a -> KBNSum a
kbnAdd (KBNSum sum c) x = KBNSum sum' c'
where c' | abs sum >= abs x = c + ((sum - sum') + x)
| otherwise = c + ((x - sum') + sum)
sum' = sum + x

-- | Return the result of a Kahan-Babuška-Neumaier sum.
kbn :: KBNSum -> Double
kbn :: Num a => KBNSum a -> a
kbn (KBNSum sum c) = sum + c

-- | Second-order Kahan-Babuška summation. This is more
Expand All @@ -265,14 +262,12 @@ kbn (KBNSum sum c) = sum + c
-- This method compensates for error in both the sum and the
-- first-order compensation term, hence the use of \"second order\" in
-- the name.
data KB2Sum = KB2Sum {-# UNPACK #-} !Double
{-# UNPACK #-} !Double
{-# UNPACK #-} !Double
data KB2Sum a = KB2Sum !a !a !a
deriving (Eq, Show, Typeable, Data)

instance U.Unbox KB2Sum
newtype instance U.MVector s KB2Sum = MV_KB2Sum (U.MVector s (Double, Double, Double))
instance MVector U.MVector KB2Sum where
instance U.Unbox a => U.Unbox (KB2Sum a)
newtype instance U.MVector s (KB2Sum a) = MV_KB2Sum (U.MVector s (a, a, a))
instance U.Unbox a => MVector U.MVector (KB2Sum a) where
{-# INLINE GM.basicLength #-}
{-# INLINE GM.basicUnsafeSlice #-}
{-# INLINE basicOverlaps #-}
Expand Down Expand Up @@ -300,8 +295,8 @@ instance MVector U.MVector KB2Sum where
basicUnsafeMove (MV_KB2Sum mvec) (MV_KB2Sum mvec') = basicUnsafeMove mvec mvec'
basicUnsafeGrow (MV_KB2Sum mvec) len = MV_KB2Sum `liftM` basicUnsafeGrow mvec len

newtype instance U.Vector KB2Sum = V_KB2Sum (U.Vector (Double, Double, Double))
instance Vector U.Vector KB2Sum where
newtype instance U.Vector (KB2Sum a) = V_KB2Sum (U.Vector (a, a, a))
instance U.Unbox a => Vector U.Vector (KB2Sum a) where
{-# INLINE basicUnsafeFreeze #-}
{-# INLINE basicUnsafeThaw #-}
{-# INLINE G.basicLength #-}
Expand All @@ -317,26 +312,26 @@ instance Vector U.Vector KB2Sum where
basicUnsafeCopy (MV_KB2Sum mvec) (V_KB2Sum vec) = G.basicUnsafeCopy mvec vec
elemseq (V_KB2Sum vec) val = elemseq vec ((\ (KB2Sum a b c) -> (a, b, c)) val)

instance Summation KB2Sum where
instance RealFloat a => Summation KB2Sum a where
zero = KB2Sum 0 0 0
add = kb2Add

instance NFData KB2Sum where
instance NFData (KB2Sum a) where
rnf !_ = ()

-- | @since 0.3.0.0
instance Monoid KB2Sum where
instance RealFloat a => Monoid (KB2Sum a) where
mempty = zero
s `mappend` KB2Sum s' c' cc' = add (add (add s s') c') cc'

#if MIN_VERSION_base(4,9,0)
-- | @since 0.3.0.0
instance Semigroup KB2Sum where
instance RealFloat a => Semigroup (KB2Sum a) where
(<>) = mappend
#endif


kb2Add :: KB2Sum -> Double -> KB2Sum
kb2Add :: (Num a, Ord a) => KB2Sum a -> a -> KB2Sum a
kb2Add (KB2Sum sum c cc) x = KB2Sum sum' c' cc'
where sum' = sum + x
c' = c + k
Expand All @@ -346,12 +341,11 @@ kb2Add (KB2Sum sum c cc) x = KB2Sum sum' c' cc'
| otherwise = (x - sum') + sum

-- | Return the result of an order-2 Kahan-Babuška sum.
kb2 :: KB2Sum -> Double
kb2 :: Num a => KB2Sum a -> a
kb2 (KB2Sum sum c cc) = sum + c + cc

-- | /O(n)/ Sum a vector of values.
sumVector :: (Vector v Double, Summation s) =>
(s -> Double) -> v Double -> Double
sumVector :: RealFloat a => (Vector v a, Summation s a) => (s a -> a) -> v a -> a
sumVector f = f . foldl' add zero
{-# INLINE sumVector #-}

Expand All @@ -361,7 +355,7 @@ sumVector f = f . foldl' add zero
-- bounds on its error growth. Instead of having roughly constant
-- error regardless of the size of the input vector, in the worst case
-- its accumulated error grows with /O(log n)/.
pairwiseSum :: (Vector v Double) => v Double -> Double
pairwiseSum :: RealFloat a => (Vector v a) => v a -> a
pairwiseSum v
| len <= 256 = G.sum v
| otherwise = uncurry (+) . (pairwiseSum *** pairwiseSum) .
Expand All @@ -383,7 +377,7 @@ pairwiseSum v
-- computes the sum of elements in a list.
--
-- @
-- sillySumList :: [Double] -> Double
-- sillySumList :: RealFloat a => [a] -> a
-- sillySumList = loop 'zero'
-- where loop s [] = 'kbn' s
-- loop s (x:xs) = 'seq' s' loop s' xs
Expand All @@ -397,7 +391,7 @@ pairwiseSum v
-- -- Avoid ambiguity around which sum function we are using.
-- import Prelude hiding (sum)
-- --
-- betterSumList :: [Double] -> Double
-- betterSumList :: RealFloat a => [a] -> a
-- betterSumList xs = 'Numeric.Sum.sum' 'kbn' xs
-- @

Expand All @@ -410,7 +404,7 @@ pairwiseSum v
-- intermediate values are as accurate as possible.
--
-- @
-- prefixSum :: [Double] -> [Double]
-- prefixSum :: RealFloat a => [a] -> [a]
-- prefixSum xs = map 'kbn' . 'scanl' 'add' 'zero' $ xs
-- @

Expand Down
Loading