Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Data/Array/Accelerate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ module Data.Array.Accelerate (
fst, afst, snd, asnd, curry, uncurry,

-- *** Flow control
(?), match, cond, while, iterate,
(?), match, cond, select, while, iterate,
Assert(assert, assertMessage),

-- *** Scalar reduction
Expand Down
10 changes: 10 additions & 0 deletions src/Data/Array/Accelerate/AST/Exp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ data PreOpenExp arr env t where
-> PreOpenExp arr env t
-> PreOpenExp arr env t

-- Conditional expression that chooses one value based on the condition without branching.
Select :: PreOpenExp arr env PrimBool
-> PreOpenExp arr env t
-> PreOpenExp arr env t
-> PreOpenExp arr env t

Assert :: Text
-> PreOpenExp arr env PrimBool
-> PreOpenExp arr env t
Expand Down Expand Up @@ -385,6 +391,7 @@ expType = \case
Case _ [] (Just e) -> expType e
Case{} -> internalError "empty case encountered"
Cond _ e _ -> expType e
Select _ e _ -> expType e
While _ (Lam lhs _) _ -> lhsToTupR lhs
While{} -> internalError "What's the matter, you're running in the shadows"
Const tR _ -> TupRsingle tR
Expand Down Expand Up @@ -515,6 +522,7 @@ rnfOpenExp topExp =
FromIndex shr sh ix -> rnfShapeR shr `seq` rnfE sh `seq` rnfE ix
Case e rhs def -> rnfE e `seq` rnfList (\(t,c) -> t `seq` rnfE c) rhs `seq` rnfMaybe rnfE def
Cond p e1 e2 -> rnfE p `seq` rnfE e1 `seq` rnfE e2
Select p e1 e2 -> rnfE p `seq` rnfE e1 `seq` rnfE e2
While p f x -> rnfF p `seq` rnfF f `seq` rnfE x
PrimApp f x -> rnfPrimFun f `seq` rnfE x
ArrayInstr arr e -> rnfArrayInstr arr `seq` rnfE e
Expand Down Expand Up @@ -639,6 +647,7 @@ liftOpenExp pexp =
FromIndex shr sh ix -> [|| FromIndex $$(liftShapeR shr) $$(liftE sh) $$(liftE ix) ||]
Case p rhs def -> [|| Case $$(liftE p) $$(liftList (\(t,c) -> [|| (t, $$(liftE c)) ||]) rhs) $$(liftMaybe liftE def) ||]
Cond p t e -> [|| Cond $$(liftE p) $$(liftE t) $$(liftE e) ||]
Select p t e -> [|| Select $$(liftE p) $$(liftE t) $$(liftE e) ||]
While p f x -> [|| While $$(liftF p) $$(liftF f) $$(liftE x) ||]
PrimApp f x -> [|| PrimApp $$(liftPrimFun f) $$(liftE x) ||]
ArrayInstr arr x -> [|| ArrayInstr $$(liftArrayInstr arr) $$(liftE x) ||]
Expand Down Expand Up @@ -757,6 +766,7 @@ formatExpOp = later $ \case
FromIndex{} -> "FromIndex"
Case{} -> "Case"
Cond{} -> "Cond"
Select{} -> "Select"
While{} -> "While"
PrimApp{} -> "PrimApp"
ArrayInstr ar _ -> fromString $ showArrayInstrOp ar
Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Analysis/Hash/Exp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ encodeOpenExp exp =
FromIndex _ sh i -> intHost $(hashQ "FromIndex") <> travE sh <> travE i
Case e rhs def -> intHost $(hashQ "Case") <> travE e <> mconcat [ word8 t <> travE c | (t,c) <- rhs ] <> encodeMaybe travE def
Cond c t e -> intHost $(hashQ "Cond") <> travE c <> travE t <> travE e
Select c t e -> intHost $(hashQ "Select") <> travE c <> travE t <> travE e
While p f x -> intHost $(hashQ "While") <> travF p <> travF f <> travE x
PrimApp f x -> intHost $(hashQ "PrimApp") <> encodePrimFun f <> travE x
ArrayInstr arr e -> intHost $(hashQ "ArrayInstr") <> encodeArrayInstr arr <> travE e
Expand Down
7 changes: 7 additions & 0 deletions src/Data/Array/Accelerate/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,13 @@ evalOpenExp pexp env arr@(EvalArrayInstr runArrayInstr) =
!v <- evalE c
if toBool v then evalE t else evalE e

Select c t e -> do
!v <- evalE c
-- evaluate both branches
let !t' = evalE t
let !e' = evalE e
if toBool v then t' else e'

While cond body seed ->
evalE seed >>= go
where
Expand Down
6 changes: 6 additions & 0 deletions src/Data/Array/Accelerate/Interpreter/Simple.hs
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,12 @@ evalOpenExp pexp runarr env =
| toBool (evalE c) -> evalE t
| otherwise -> evalE e

Select c t e ->
-- evaluate both branches
let !t' = evalE t
!e' = evalE e
in if toBool (evalE c) then t' else e'

While cond body seed -> go (evalE seed)
where
f = evalF body
Expand Down
11 changes: 10 additions & 1 deletion src/Data/Array/Accelerate/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ module Data.Array.Accelerate.Language (

-- * Flow-control
acond, awhile,
cond, while,
cond, select, while,
Assert(..), assertBounds,

-- * Utilities for bounds and shape checks
Expand Down Expand Up @@ -1296,6 +1296,15 @@ cond :: Elt t
-> Exp t
cond (Exp c) (Exp x) (Exp y) = mkExp $ Cond (mkCoerce' c) x y

-- | A scalar-level conditional expression
-- that chooses one value based on the condition without branching.
select :: (Elt t)
=> Exp Bool -- ^ condition
-> Exp t -- ^ then-expression
-> Exp t -- ^ else-expression
-> Exp t
select (Exp c) (Exp x) (Exp y) = mkExp $ Select (mkCoerce' c) x y

-- | While construct. Continue to apply the given function, starting with the
-- initial value, until the test function evaluates to 'False'.
--
Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Pretty/Exp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ prettyPreOpenExp ctx prettyArrayInstr env exp =
, hang shiftwidth (sep [ then_, t' ])
, hang shiftwidth (sep [ else_, e' ]) ]
--
Select p t e -> ppF3 "select" (ppE p) (ppE t) (ppE e)
ToIndex _ sh ix -> ppF2 "toIndex" (ppE sh) (ppE ix)
FromIndex _ sh ix -> ppF2 "fromIndex" (ppE sh) (ppE ix)
While p f x -> ppF3 "while" (ppF p) (ppF f) (ppE x)
Expand Down
7 changes: 7 additions & 0 deletions src/Data/Array/Accelerate/Smart.hs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,11 @@ data PreSmartExp acc exp t where
-> exp t
-> PreSmartExp acc exp t

Select :: exp PrimBool
-> exp t
-> exp t
-> PreSmartExp acc exp t

While :: TypeR t
-> (SmartExp t -> exp PrimBool)
-> (SmartExp t -> exp t)
Expand Down Expand Up @@ -901,6 +906,7 @@ instance HasTypeR exp => HasTypeR (PreSmartExp acc exp) where
Case _ ((_,c):_) -> typeR c
Case{} -> internalError "encountered empty case"
Cond _ e _ -> typeR e
Select _ e _ -> typeR e
While t _ _ _ -> t
PrimApp f _ -> snd $ primFunType f
Index tp _ _ -> tp
Expand Down Expand Up @@ -1417,6 +1423,7 @@ formatPreExpOp = later $ \case
FromIndex{} -> "FromIndex"
Case{} -> "Case"
Cond{} -> "Cond"
Select{} -> "Select"
While{} -> "While"
PrimApp{} -> "PrimApp"
Index{} -> "Index"
Expand Down
3 changes: 3 additions & 0 deletions src/Data/Array/Accelerate/Trafo/Exp/Shrink.hs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ shrinkExp = Stats.substitution "shrinkE" . first getAny . shrinkE
FromIndex shr sh i -> FromIndex shr <$> shrinkE sh <*> shrinkE i
Case e rhs def -> Case <$> shrinkE e <*> sequenceA [ (t,) <$> shrinkE c | (t,c) <- rhs ] <*> shrinkMaybeE def
Cond p t e -> Cond <$> shrinkE p <*> shrinkE t <*> shrinkE e
Select p t e -> Select <$> shrinkE p <*> shrinkE t <*> shrinkE e
While p f x -> While <$> shrinkF p <*> shrinkF f <*> shrinkE x
PrimApp f x -> PrimApp f <$> shrinkE x
ArrayInstr arr e -> ArrayInstr arr <$> shrinkE e
Expand Down Expand Up @@ -366,6 +367,7 @@ usesOfExp range = countE
ToIndex _ sh e -> countE sh <> countE e
Case e rhs def -> countE e <> mconcat [ countE c | (_,c) <- rhs ] <> maybe (Finite 0) countE def
Cond p t e -> countE p <> countE t <> countE e
Select p t e -> countE p <> countE t <> countE e
While p f x -> countE x <> loopCount (usesOfFun range p) <> loopCount (usesOfFun range f)
PrimApp _ x -> countE x
ArrayInstr _ e -> countE e
Expand Down Expand Up @@ -396,6 +398,7 @@ arrayInstrsInExp = (`travE` [])
FromIndex _ sh i -> travE sh $ travE i acc
Case e rhs def -> travE e $ travAE rhs $ travME def acc
Cond p t e -> travE p $ travE t $ travE e acc
Select p t e -> travE p $ travE t $ travE e acc
While p f x -> travF p $ travF f $ travE x acc
PrimApp _ x -> travE x acc
ArrayInstr arr e -> Exists arr : travE e acc
Expand Down
73 changes: 73 additions & 0 deletions src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import Data.Array.Accelerate.Representation.Shape ( ShapeR(..), shapeToList )
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Representation.Vec
import Data.Array.Accelerate.Representation.Slice ( SliceIndex(..) )

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.4 release

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.8 release

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.12 release

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | windows-latest-x64 ghc-9.4 release

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.8 debug

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.6 debug

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.10 debug

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.12 debug

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.6 debug

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.8 debug

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.12 release

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.10 release

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.6 release

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.4 release

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.8 release

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant

Check warning on line 46 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.4 debug

The import of ‘Data.Array.Accelerate.Representation.Slice’ is redundant
import Data.Array.Accelerate.Trafo.Exp.Algebra
import Data.Array.Accelerate.Trafo.Environment
import Data.Array.Accelerate.Trafo.Shrink
Expand Down Expand Up @@ -238,6 +238,7 @@
FromIndex shr sh ix -> hoist2 (fromIndex shr) (cvtE sh) (cvtE ix)
Case e rhs def -> hoist (\e' -> caseof e' (sequenceA [ (t,) <$> cvtE c | (t,c) <- rhs ]) (cvtMaybeE def)) (cvtE e)
Cond p t e -> hoist (\p' -> cond p' (cvtE t) (cvtE e)) (cvtE p)
Select p t e -> hoist (\p' -> select p' (cvtE t) (cvtE e)) (cvtE p)
PrimApp f x -> hoist (evalPrimApp env f) (cvtE x)
ArrayInstr arr e -> hoist (arrayInstr arr) (cvtE e)
ShapeSize shr sh -> hoist (shapeSize shr) (cvtE sh)
Expand Down Expand Up @@ -306,6 +307,17 @@
shouldInline Const{} = True
shouldInline _ = False

select :: PreOpenExp arr env PrimBool
-> (Any, PreOpenExp arr env t)
-> (Any, PreOpenExp arr env t)
-> (Any, PreOpenExp arr env t)
select p t@(_,t') e@(_,e')
| Const _ 1 <- p = Stats.knownBranch "True" (yes t')
| Const _ 0 <- p = Stats.knownBranch "False" (yes e')
| Just Refl <- matchOpenExp t' e' = Stats.knownBranch "redundant" (yes e')
| PrimApp PrimLNot c <- p = yes $ snd $ select c e t
| otherwise = Select p <$> t <*> e

-- Simplify conditional expressions, in particular by eliminating branches
-- when the predicate is a known constant.
--
Expand All @@ -317,8 +329,68 @@
| Const _ 1 <- p = Stats.knownBranch "True" (yes t')
| Const _ 0 <- p = Stats.knownBranch "False" (yes e')
| Just Refl <- matchOpenExp t' e' = Stats.knownBranch "redundant" (yes e')
| isCheap t' && isCheap e' = yes $ snd $ select p t e
Comment thread
Doppie23 marked this conversation as resolved.
| PrimApp PrimLNot c <- p = yes $ snd $ cond c e t
| otherwise = Cond p <$> t <*> e

isCheap :: PreOpenExp arr env t -> Bool
isCheap = maybe False (<= maxCost) . expCost
where
maxCost = 5

expCost :: PreOpenExp arr env' t -> Maybe Int
expCost = \case
ArrayInstr a _ -> if inlineArrayInstr a
then Just 1
else Nothing
Evar{} -> Just 1
Nil -> Just 1
Const{} -> Just 1
Undef{} -> Just 1
PrimApp f e -> primCost f .+. expCost e
Let _ bnd body -> Just 1 .+. expCost bnd .+. expCost body
Pair e1 e2 -> Just 1 .+. expCost e1 .+. expCost e2
VecPack _ e -> Just 1 .+. expCost e
VecUnpack _ e -> Just 1 .+. expCost e
Coerce _ _ e -> Just 1 .+. expCost e
Assume e1 e2 -> Just 1 .+. expCost e1 .+. expCost e2
Foreign{} -> Nothing
ToIndex{} -> Nothing
FromIndex{} -> Nothing
Case{} -> Nothing
Cond{} -> Nothing
Select{} -> Nothing
While{} -> Nothing
ShapeSize{} -> Nothing
Assert{} -> Nothing

primCost :: PrimFun f -> Maybe Int
primCost = \case
PrimAdd _ -> Just 1
PrimSub _ -> Just 1
PrimMul _ -> Just 1
PrimNeg _ -> Just 1
PrimAbs _ -> Just 1
PrimSig _ -> Just 1
PrimBAnd _ -> Just 1
PrimBOr _ -> Just 1
PrimBXor _ -> Just 1
PrimBNot _ -> Just 1
PrimBShiftL _ -> Just 1
PrimBShiftR _ -> Just 1
PrimBRotateL _ -> Just 1
PrimBRotateR _ -> Just 1
PrimCmp _ _ -> Just 1
PrimMax _ -> Just 1
PrimMin _ -> Just 1
PrimLAnd -> Just 1
PrimLOr -> Just 1
PrimLNot -> Just 1
_ -> Nothing

(.+.) :: Maybe Int -> Maybe Int -> Maybe Int
a .+. b = (+) <$> a <*> b

caseof :: PreOpenExp arr env TAG
-> (Any, [(TAG, PreOpenExp arr env b)])
-> (Any, Maybe (PreOpenExp arr env b))
Expand Down Expand Up @@ -619,7 +691,7 @@
travNumType (FloatingNumType t) = travFloatingType t & types +~ 1

travBoundedType :: BoundedType t -> Stats
travBoundedType (IntegralBoundedType t) = travIntegralType t & types +~ 1

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.4 release

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.8 release

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.12 release

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | windows-latest-x64 ghc-9.4 release

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.8 debug

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.6 debug

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.10 debug

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | ubuntu-latest-x64 ghc-9.12 debug

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.6 debug

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.8 debug

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.12 release

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.10 release

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.6 release

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.4 release

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.8 release

Defined but not used: ‘travBoundedType’

Check warning on line 694 in src/Data/Array/Accelerate/Trafo/Exp/Simplify.hs

View workflow job for this annotation

GitHub Actions / cabal | macOS-latest-x64 ghc-9.4 debug

Defined but not used: ‘travBoundedType’

-- travScalarType :: ScalarType t -> Stats
-- travScalarType (SingleScalarType t) = travSingleType t & types +~ 1
Expand Down Expand Up @@ -652,6 +724,7 @@
FromIndex _ sh ix -> travE sh +++ travE ix
Case e rhs def -> travE e +++ mconcat [ travE c | (_,c) <- rhs ] +++ maybe zero travE def
Cond p t e -> travE p +++ travE t +++ travE e
Select p t e -> travE p +++ travE t +++ travE e
While p f x -> travF p +++ travF f +++ travE x
ArrayInstr a e -> travA a +++ travE e
ShapeSize _ sh -> travE sh
Expand Down
3 changes: 3 additions & 0 deletions src/Data/Array/Accelerate/Trafo/Exp/Substitution.hs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ inlineVars lhsBound expr bound
FromIndex shr e1 e2 -> FromIndex shr <$> travE e1 <*> travE e2
Case e1 rhs def -> Case <$> travE e1 <*> mapM (\(t,c) -> (t,) <$> travE c) rhs <*> travMaybeE def
Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3
Select e1 e2 e3 -> Select <$> travE e1 <*> travE e2 <*> travE e3
While f1 f2 e1 -> While <$> travF f1 <*> travF f2 <*> travE e1
Const t c -> Just $ Const t c
PrimApp p e1 -> PrimApp p <$> travE e1
Expand Down Expand Up @@ -460,6 +461,7 @@ rebuildOpenExp v exp =
FromIndex shr sh ix -> FromIndex shr <$> rebuildOpenExp v sh <*> rebuildOpenExp v ix
Case e rhs def -> Case <$> rebuildOpenExp v e <*> sequenceA [ (t,) <$> rebuildOpenExp v c | (t,c) <- rhs ] <*> rebuildMaybeExp v def
Cond p t e -> Cond <$> rebuildOpenExp v p <*> rebuildOpenExp v t <*> rebuildOpenExp v e
Select p t e -> Select <$> rebuildOpenExp v p <*> rebuildOpenExp v t <*> rebuildOpenExp v e
While p f x -> While <$> rebuildFun v p <*> rebuildFun v f <*> rebuildOpenExp v x
PrimApp f x -> PrimApp f <$> rebuildOpenExp v x
ArrayInstr arr e -> ArrayInstr arr <$> rebuildOpenExp v e
Expand Down Expand Up @@ -534,6 +536,7 @@ rebuildArrayInstrOpenExp v = \case
FromIndex shr sh ix -> FromIndex shr <$> travE sh <*> travE ix
Case e rhs def -> Case <$> travE e <*> sequenceA [ (t,) <$> travE c | (t,c) <- rhs ] <*> travME def
Cond e1 e2 e3 -> Cond <$> travE e1 <*> travE e2 <*> travE e3
Select e1 e2 e3 -> Select <$> travE e1 <*> travE e2 <*> travE e3
While c f x -> While <$> travF c <*> travF f <*> travE x
Const tp c -> pure $ Const tp c
PrimApp f x -> PrimApp f <$> travE x
Expand Down
32 changes: 32 additions & 0 deletions src/Data/Array/Accelerate/Trafo/Operation/Bounds.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ module Data.Array.Accelerate.Trafo.Operation.Bounds (
boundsOptimizeFun, boundsOptimizeFun1, boundsOptimizeFun2, boundsOptimizeExp,
) where

import Data.Array.Accelerate.AST.Exp
import Data.Array.Accelerate.AST.Environment
import qualified Data.Array.Accelerate.AST.Graph as Graph
import Data.Array.Accelerate.AST.Idx
Expand All @@ -51,6 +52,7 @@ import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Analysis.Match

import Data.Array.Accelerate.Trafo.Operation.Bounds.Algebra
import Data.Array.Accelerate.Trafo.Operation.Bounds.Environment
Expand Down Expand Up @@ -326,6 +328,36 @@ boundsOptimizeExp env@(BoundsEnv _ _ zero _) expr = detectConst env $ case expr
( unions (makeTransitives env trueBounds) (makeTransitives env falseBounds)
, Cond c' true' false' )

Select c t f -> case travE c of
-- Check if the condition is already known based on the bounds analysis
Const _ 1 -> boundsOptimizeExp env t
Const _ 0 -> boundsOptimizeExp env f

c' ->
let (trueBounds, t') = boundsOptimizeExp env t
(falseBounds, f') = boundsOptimizeExp env f
bs = unions (makeTransitives env trueBounds)
(makeTransitives env falseBounds)

isBoolBounds :: TermBounds (UniformEnv benv env) Word8 -> Bool
isBoolBounds (TupRsingle bound) = (l >= 0 && h <= 1) -- (0, 1) (0, 0) (1, 1)
where
(l, h) = getBoundRange (boundsZero env) TypeWord8
(makeTransitive env bound)

simplify
| Just Refl <- matchTypeR (expType t') (TupRsingle scalarTypeWord8)
, isBoolBounds trueBounds
, isBoolBounds falseBounds
= case (t', f') of
(_ , Const _ 0) -> (bs, PrimApp PrimLAnd (Pair c' t')) -- c ? t : 0 => c && t
(Const _ 0, _ ) -> (bs, PrimApp PrimLAnd (Pair (PrimApp PrimLNot c') f')) -- c ? 0 : f => (not c) && f
(Const _ 1, _ ) -> (bs, PrimApp PrimLOr (Pair c' f')) -- c ? 1 : f => c || f
(_ , Const _ 1) -> (bs, PrimApp PrimLOr (Pair (PrimApp PrimLNot c') t')) -- c ? t : 1 => (not c) || t
_ -> (bs, Select c' t' f')
| otherwise = (bs, Select c' t' f')
in simplify

Assert msg c body -> case travE c of
Const _ 1 -> boundsOptimizeExp env body
c'@(Const _ 0) -> Assert msg c' <$> boundsOptimizeExp env (undefs $ expType body)
Expand Down
3 changes: 3 additions & 0 deletions src/Data/Array/Accelerate/Trafo/Partitioning/ILP/Labels.hs
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,9 @@ getExpDeps (Case poe1 poes poe2) env = getExpDeps poe1 env <>
getExpDeps (Cond poe1 poe2 exp3) env = getExpDeps poe1 env <>
getExpDeps poe2 env <>
getExpDeps exp3 env
getExpDeps (Select poe1 poe2 exp3) env = getExpDeps poe1 env <>
getExpDeps poe2 env <>
getExpDeps exp3 env
getExpDeps (While pof1 pof2 poe) env = getFunDeps pof1 env <>
getFunDeps pof2 env <>
getExpDeps poe env
Expand Down
Loading
Loading