Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 156 additions & 10 deletions src/LAoP/Matrix/Internal.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE ConstraintKinds #-}
Expand Down Expand Up @@ -30,7 +32,7 @@
-- motivation behind the library, the underlying theory, and implementation details.
--
-- This module offers many of the combinators mentioned in the work of
-- Macedo (2012) and Oliveira (2012).
-- Macedo (2012) and Oliveira (2012).
--
-- This is an Internal module and it is no supposed to be imported.
--
Expand All @@ -43,8 +45,8 @@ module LAoP.Matrix.Internal
--
-- There exists two type families that make it easier to write
-- matrix dimensions: 'FromNat' and 'Count'. This approach
-- leads to a very straightforward implementation
-- of LAoP combinators.
-- leads to a very straightforward implementation
-- of LAoP combinators.

-- * Type safe matrix representation
Matrix (..),
Expand Down Expand Up @@ -154,12 +156,18 @@ module LAoP.Matrix.Internal
negateM,
orM,
andM,
subM
subM,

-- * Semantics

function
)
where

import LAoP.Utils.Internal
import Data.Bool
import Data.Bifunctor
import Data.Functor.Contravariant
import Data.Kind
import Data.List
import Data.Maybe
Expand Down Expand Up @@ -340,7 +348,7 @@ instance {-# OVERLAPPING #-} (FromLists e () rows) => FromLists e () (Either ()
fromLists _ = error "Wrong dimensions"

instance {-# OVERLAPPABLE #-} (FromLists e () a, FromLists e () b, Countable a) => FromLists e () (Either a b) where
fromLists l@([_] : _) =
fromLists l@([_] : _) =
let rowsA = fromInteger (natVal (Proxy :: Proxy (Count a)))
in Fork (fromLists (take rowsA l)) (fromLists (drop rowsA l))
fromLists _ = error "Wrong dimensions"
Expand Down Expand Up @@ -518,7 +526,7 @@ comp (Join a b) (Fork c d) = comp a c + comp b d -- Divide-and-conquer l
comp (Fork a b) c = Fork (comp a c) (comp b c) -- Fork fusion law
comp c (Join a b) = Join (comp c a) (comp c b) -- Join fusion law
{-# NOINLINE comp #-}
{-# RULES
{-# RULES
"comp/iden1" forall m. comp m iden = m ;
"comp/iden2" forall m. comp iden m = m
#-}
Expand Down Expand Up @@ -576,7 +584,7 @@ rows :: forall e cols rows. (Countable rows) => Matrix e cols rows -> Int
rows _ = fromInteger $ natVal (Proxy :: Proxy (Count rows))

-- | Obtain the number of columns.
--
--
-- NOTE: The 'KnownNat' constaint is needed in order to obtain the
-- dimensions in constant time.
--
Expand Down Expand Up @@ -639,7 +647,7 @@ sndM = matrixBuilder' f
-- | Khatri Rao Matrix product also known as matrix pairing.
--
-- NOTE: That this is not a true categorical product, see for instance:
--
--
-- @
-- | fstM . kr a b == a
-- kr a b ==> |
Expand Down Expand Up @@ -675,7 +683,7 @@ infixl 4 ><
FromLists e (Normalize (m, n)) n,
FromLists e (Normalize (p, q)) p,
FromLists e (Normalize (p, q)) q
)
)
=> Matrix e m p -> Matrix e n q -> Matrix e (Normalize (m, n)) (Normalize (p, q))
(><) a b =
let fstM' = fstM @e @m @n
Expand Down Expand Up @@ -1008,7 +1016,7 @@ toRel ::
Eq b,
CountableDimensionsN a b,
FromListsN Boolean b a
)
)
=> (a -> b -> Bool) -> Relation (Normalize a) (Normalize b)
toRel f =
let minA = minBound @a
Expand All @@ -1028,3 +1036,141 @@ toRel f =
where
buildList [] _ = []
buildList l r = take r l : buildList (drop r l) r


----------------------------- Linear map semantics -----------------------------
newtype Vector e a = Vector { at :: a -> e }

instance Contravariant (Vector e) where
contramap f (Vector g) = Vector (g . f)

instance Num e => Num (Vector e a) where
fromInteger = Vector . const . fromInteger

(+) = liftV2 (+)
(-) = liftV2 (-)
(*) = liftV2 (*)
abs = liftV1 abs
negate = liftV1 negate
signum = error "No sensible definition"

liftV1 :: (e -> e) -> Vector e a -> Vector e a
liftV1 f x = Vector (\a -> f (at x a))

liftV2 :: (e -> e -> e) -> Vector e a -> Vector e a -> Vector e a
liftV2 f x y = Vector (\a -> f (at x a) (at y a))

-- Semantics of Matrix e a b
type LinearMap e a b = Vector e a -> Vector e b

semantics :: Num e => Matrix e a b -> LinearMap e a b
semantics m = case m of
Empty -> Prelude.id
One e -> const (Vector (const e))
Join x y -> \v -> semantics x (Left >$< v) + semantics y (Right >$< v)
Fork x y -> \v -> Vector $ either (at (semantics x v)) (at (semantics y v))

padLeft :: Num e => Vector e b -> Vector e (Either a b)
padLeft v = Vector $ \case Left _ -> 0
Right b -> at v b

padRight :: Num e => Vector e a -> Vector e (Either a b)
padRight v = Vector $ \case Left a -> at v a
Right _ -> 0

dot :: (Num e, Enumerable a) => Vector e a -> Vector e a -> e
dot x y = sum [ at x a * at y a | a <- enumerate ]

--------------------------------- Construction ---------------------------------
class Construct a where
row' :: Num e => Vector e a -> Matrix e a ()

linearMap :: (Construct b, Num e) => LinearMap e a b -> Matrix e a b

instance Construct () where
row' v = one (at v ())

linearMap m = column (m 1)

instance (Construct a, Construct b) => Construct (Either a b) where
row' v = row' (Left >$< v) ||| row' (Right >$< v)

linearMap m = linearMap (m . padRight) ||| linearMap (m . padLeft)

column :: (Construct a, Num e) => Vector e a -> Matrix e () a
column = tr . row'

function :: (Construct a, Construct b, Enumerable a, Num e) => (a -> b -> e) -> Matrix e a b
function f = linearMap $ \v -> Vector $ \b -> dot v $ Vector $ \a -> f a b

-------------------------------- Deconstruction --------------------------------
class Enumerable a where
enumerate :: [a]
default enumerate :: Enum a => [a]
enumerate = enumFrom (toEnum 0)

instance Enumerable Void where
enumerate = []

-- 1, 2, 3...
instance Enumerable ()
instance Enumerable Bool
instance Enumerable Ordering

instance (Enumerable a, Enumerable b) => Enumerable (Either a b) where
enumerate = (Left <$> enumerate) ++ (Right <$> enumerate)

instance (Enumerable a, Enumerable b) => Enumerable (a, b) where
enumerate = [ (a, b) | a <- enumerate, b <- enumerate ]

basis :: (Enumerable a, Eq a, Num e) => [Vector e a]
basis = [ Vector (bool 0 1 . (==a)) | a <- enumerate ]

toLists' :: (Enumerable a, Enumerable b, Eq a, Num e) => Matrix e a b -> [[e]]
toLists' m = transpose
[ [ at r i | i <- enumerate ] | c <- basis, let r = semantics m c ]

dump :: (Enumerable a, Enumerable b, Eq a, Num e, Show e) => Matrix e a b -> IO ()
dump = mapM_ print . toLists'

------------------------------- Normalised types -------------------------------
class Profunctor p where
dimap :: (a -> b) -> (c -> d) -> p b c -> p a d

instance Profunctor (->) where
dimap f g h = g . h . f

class Construct (Norm a) => ConstructNorm a where
type Norm a
toNorm :: a -> Norm a
fromNorm :: Norm a -> a

instance ConstructNorm () where
type Norm () = ()
toNorm = Prelude.id
fromNorm = Prelude.id

instance (ConstructNorm a, ConstructNorm b) => ConstructNorm (Either a b) where
type Norm (Either a b) = Either (Norm a) (Norm b)
toNorm = either (Left . toNorm) (Right . toNorm)
fromNorm = either (Left . fromNorm) (Right . fromNorm)

-- Instance for other data types can be obtained mechanically using Generics
instance ConstructNorm Bool where
type Norm Bool = Either () ()
toNorm = bool (Left ()) (Right ())
fromNorm = either (const False) (const True)

rowN :: (ConstructNorm a, Num e) => Vector e a -> Matrix e (Norm a) ()
rowN = row' . contramap fromNorm

linearMapN :: (ConstructNorm a, ConstructNorm b, Num e)
=> LinearMap e a b -> Matrix e (Norm a) (Norm b)
linearMapN = linearMap . dimap (contramap toNorm) (contramap fromNorm)

columnN :: (ConstructNorm a, Num e) => Vector e a -> Matrix e () (Norm a)
columnN = tr . rowN

functionN :: (ConstructNorm a, ConstructNorm b, Enumerable a, Num e)
=> (a -> b -> e) -> Matrix e (Norm a) (Norm b)
functionN f = linearMapN $ \v -> Vector $ \b -> dot v $ Vector $ \a -> f a b