Skip to content

Commit

Permalink
Remove Copy from IR. (#1963)
Browse files Browse the repository at this point in the history
Copy has long been identical to Replicate with an empty shape, and
redundancy is bad.

Closes #1962.
  • Loading branch information
athas committed Jun 15, 2023
1 parent 46a8719 commit 94a6ce1
Show file tree
Hide file tree
Showing 57 changed files with 126 additions and 125 deletions.
3 changes: 0 additions & 3 deletions src/Futhark/AD/Fwd.hs
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,6 @@ basicFwd pat aux op = do
arr_tan <- tangent arr
arrs_tans <- mapM tangent arrs
addStm $ Let pat_tan aux $ BasicOp $ Concat d (arr_tan :| arrs_tans) w
Copy arr -> do
arr_tan <- tangent arr
addStm $ Let pat_tan aux $ BasicOp $ Copy arr_tan
Manifest ds arr -> do
arr_tan <- tangent arr
addStm $ Let pat_tan aux $ BasicOp $ Manifest ds arr_tan
Expand Down
12 changes: 7 additions & 5 deletions src/Futhark/AD/Rev.hs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ diffBasicOp pat aux e m =
updateAdj arr <=< letExp "adj_rotate" . BasicOp $
Rotate rots' pat_adj
--
Replicate (Shape []) (Var se) -> do
(_pat_v, pat_adj) <- commonBasicOp pat aux e m
returnSweepCode $ void $ updateAdj se pat_adj
--
Replicate (Shape ns) x -> do
(_pat_v, pat_adj) <- commonBasicOp pat aux e m
returnSweepCode $ do
Expand Down Expand Up @@ -184,10 +188,6 @@ diffBasicOp pat aux e m =

zipWithM_ updateAdj (arr : arrs) slices
--
Copy se -> do
(_pat_v, pat_adj) <- commonBasicOp pat aux e m
returnSweepCode $ void $ updateAdj se pat_adj
--
Manifest _ se -> do
(_pat_v, pat_adj) <- commonBasicOp pat aux e m
returnSweepCode $ void $ updateAdj se pat_adj
Expand All @@ -211,7 +211,9 @@ diffBasicOp pat aux e m =
t <- lookupType v_adj
v_adj_copy <-
case t of
Array {} -> letExp "update_val_adj_copy" $ BasicOp $ Copy v_adj
Array {} ->
letExp "update_val_adj_copy" . BasicOp $
Replicate mempty (Var v_adj)
_ -> pure v_adj
updateSubExpAdj v v_adj_copy
zeroes <- letSubExp "update_zero" . zeroExp =<< subExpType v
Expand Down
8 changes: 6 additions & 2 deletions src/Futhark/AD/Rev/Hist.hs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ diffMinMaxHist _ops x aux n minmax ne is vs w rf dst m = do
dst_type <- lookupType dst
let dst_dims = arrayDims dst_type

dst_cpy <- letExp (baseString dst <> "_copy") $ BasicOp $ Copy dst
dst_cpy <-
letExp (baseString dst <> "_copy") . BasicOp $
Replicate mempty (Var dst)

acc_v_p <- newParam "acc_v" $ Prim t
acc_i_p <- newParam "acc_i" $ Prim int64
Expand Down Expand Up @@ -492,7 +494,9 @@ diffAddHist ::
diffAddHist _ops x aux n add ne is vs w rf dst m = do
let t = paramDec $ head $ lambdaParams add

dst_cpy <- letExp (baseString dst <> "_copy") $ BasicOp $ Copy dst
dst_cpy <-
letExp (baseString dst <> "_copy") . BasicOp $
Replicate mempty (Var dst)

f <- mkIdentityLambda [Prim int64, t]
auxing aux . letBindNames [x] . Op $
Expand Down
3 changes: 2 additions & 1 deletion src/Futhark/AD/Rev/Loop.hs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ restore stms_adj loop_params' i' =
v' <- letExp "restore" $ BasicOp $ Index vs $ fullSlice vs_t [DimFix i_i64']
t <- lookupType v
v'' <- case (t, v `elem` consumed) of
(Array {}, True) -> letExp "restore_copy" $ BasicOp $ Copy v'
(Array {}, True) ->
letExp "restore_copy" $ BasicOp $ Replicate mempty $ Var v'
_ -> pure v'
pure $ Just (v, v'')
| otherwise = pure Nothing
Expand Down
10 changes: 8 additions & 2 deletions src/Futhark/AD/Rev/Monad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ copyConsumedArrsInStm s = inScopeOf s $ collectStms $ copyConsumedArrsInStm' s
v_t <- lookupType v
case v_t of
Array {} -> do
v' <- letExp (baseString v <> "_ad_copy") (BasicOp $ Copy v)
v' <-
letExp (baseString v <> "_ad_copy") . BasicOp $
Replicate mempty (Var v)
addSubstitution v' v
pure [(v, v')]
_ -> pure mempty
Expand All @@ -269,7 +271,11 @@ copyConsumedArrsInBody dontCopy b =
v_t <- lookupType v
case v_t of
Acc {} -> error $ "copyConsumedArrsInBody: Acc " <> prettyString v
Array {} -> M.singleton v <$> letExp (baseString v <> "_ad_copy") (BasicOp $ Copy v)
Array {} ->
M.singleton v
<$> letExp
(baseString v <> "_ad_copy")
(BasicOp $ Replicate mempty (Var v))
_ -> pure mempty

returnSweepCode :: ADM a -> ADM a
Expand Down
4 changes: 3 additions & 1 deletion src/Futhark/AD/Rev/Scatter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ vjpScatter1 pys aux (w, ass, (shp, num_vals, xs)) m = do
-- of the program. In that case the asymptotics will not be
-- (locally) preserved, but since ys must necessarily have been
-- constructed somewhere close, they are probably globally OK.
ys_copy <- letExp (baseString ys <> "_copy") $ BasicOp $ Copy ys
ys_copy <-
letExp (baseString ys <> "_copy") . BasicOp $
Replicate mempty (Var ys)
returnSweepCode $ do
ys_adj <- lookupAdjVal ys
-- computing vs_ctrbs and updating vs_adj
Expand Down
1 change: 0 additions & 1 deletion src/Futhark/Analysis/Metrics.hs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ basicOpMetrics Update {} = seen "Update"
basicOpMetrics FlatIndex {} = seen "FlatIndex"
basicOpMetrics FlatUpdate {} = seen "FlatUpdate"
basicOpMetrics Concat {} = seen "Concat"
basicOpMetrics Copy {} = seen "Copy"
basicOpMetrics Manifest {} = seen "Manifest"
basicOpMetrics Iota {} = seen "Iota"
basicOpMetrics Replicate {} = seen "Replicate"
Expand Down
1 change: 1 addition & 0 deletions src/Futhark/Analysis/SymbolTable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ indexExp table (BasicOp (Replicate (Shape ds) v)) _ is
Just $ Indexed mempty $ primExpFromSubExp t v
indexExp table (BasicOp (Replicate s (Var v))) _ is = do
guard $ v `available` table
guard $ s /= mempty
index' v (drop (shapeRank s) is) table
indexExp table (BasicOp (Reshape _ newshape v)) _ is
| Just oldshape <- arrayDims <$> lookupType v table =
Expand Down
4 changes: 2 additions & 2 deletions src/Futhark/CodeGen/ImpGen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,8 @@ defCompileBasicOp (Pat [pe]) (FlatUpdate _ slice v) = do
slice' = fmap pe64 slice
defCompileBasicOp (Pat [pe]) (Replicate shape se)
| Acc {} <- patElemType pe = pure ()
| shape == mempty =
copyDWIM (patElemName pe) [] se []
| otherwise =
sLoopNest shape $ \is -> copyDWIMFix (patElemName pe) is se []
defCompileBasicOp _ Scratch {} =
Expand All @@ -951,8 +953,6 @@ defCompileBasicOp (Pat [pe]) (Iota n e s it) = do
BinOpExp (Add it OverflowUndef) e' $
BinOpExp (Mul it OverflowUndef) i' s'
copyDWIM (patElemName pe) [DimFix i] (Var (tvVar x)) []
defCompileBasicOp (Pat [pe]) (Copy src) =
copyDWIM (patElemName pe) [] (Var src) []
defCompileBasicOp (Pat [pe]) (Manifest _ src) =
copyDWIM (patElemName pe) [] (Var src) []
defCompileBasicOp (Pat [pe]) (Concat i (x :| ys) _) = do
Expand Down
4 changes: 1 addition & 3 deletions src/Futhark/CodeGen/ImpGen/GPU/Group.hs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ compileGroupExp _ (BasicOp (UpdateAcc acc is vs)) = do
ltid <- kernelLocalThreadId . kernelConstants <$> askEnv
sWhen (ltid .==. 0) $ updateAcc acc is vs
sOp $ Imp.Barrier Imp.FenceLocal
compileGroupExp (Pat [dest]) (BasicOp (Replicate ds se)) = do
compileGroupExp (Pat [dest]) (BasicOp (Replicate ds se)) | ds /= mempty = do
flat <- newVName "rep_flat"
is <- replicateM (arrayRank dest_t) (newVName "rep_i")
let is' = map le64 is
Expand Down Expand Up @@ -696,8 +696,6 @@ segOpSizes = onStms
S.singleton $ arrayDims $ patElemType pe
onStm (Let (Pat [pe]) _ (BasicOp (Iota {}))) =
S.singleton $ arrayDims $ patElemType pe
onStm (Let (Pat [pe]) _ (BasicOp (Copy {}))) =
S.singleton $ arrayDims $ patElemType pe
onStm (Let (Pat [pe]) _ (BasicOp (Manifest {}))) =
S.singleton $ arrayDims $ patElemType pe
onStm (Let _ _ (Match _ cases defbody _)) =
Expand Down
4 changes: 2 additions & 2 deletions src/Futhark/Construct.hs
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,12 @@ eSignum em = do
_ ->
error $ "eSignum: operand " ++ prettyString e ++ " has invalid type."

-- | Construct a 'Copy' expression.
-- | Copy a value.
eCopy ::
MonadBuilder m =>
m (Exp (Rep m)) ->
m (Exp (Rep m))
eCopy e = BasicOp . Copy <$> (letExp "copy_arg" =<< e)
eCopy e = BasicOp . Replicate mempty <$> (letSubExp "copy_arg" =<< e)

-- | Construct a body from expressions. If multiple expressions are
-- provided, their results will be concatenated in order and returned
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/Mem/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ unExistentialiseMemory vtable pat _ (cond, cases, defbody, ifdec)
v_copy <- newVName $ baseString v <> "_nonext_copy"
let v_pat =
Pat [PatElem v_copy $ MemArray pt shape u $ ArrayIn mem ixfun]
addStm $ mkWiseStm v_pat (defAux ()) $ BasicOp (Copy v)
addStm $ mkWiseStm v_pat (defAux ()) $ BasicOp $ Replicate mempty $ Var v
pure $ SubExpRes cs $ Var v_copy
| Just mem <- lookup (patElemName pat_elem) oldmem_to_mem =
pure $ SubExpRes cs $ Var mem
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/Parse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ pBasicOp =
keyword "trace"
$> uncurry (Opaque . OpaqueTrace)
<*> parens ((,) <$> pStringLiteral <* pComma <*> pSubExp),
keyword "copy" $> Copy <*> parens pVName,
keyword "copy" $> Replicate mempty . Var <*> parens pVName,
keyword "assert"
*> parens
( Assert
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ instance Pretty BasicOp where
pretty (Iota e x s et) = "iota" <> et' <> apply [pretty e, pretty x, pretty s]
where
et' = pretty $ show $ primBitSize $ IntType et
pretty (Replicate (Shape []) e) = "copy" <> parens (pretty e)
pretty (Replicate ne ve) =
"replicate" <> apply [pretty ne, align (pretty ve)]
pretty (Scratch t shape) =
Expand All @@ -234,7 +235,6 @@ instance Pretty BasicOp where
"rotate" <> apply [apply (map pretty es), pretty e]
pretty (Concat i (x :| xs) w) =
"concat" <> "@" <> pretty i <> apply (pretty w : pretty x : map pretty xs)
pretty (Copy e) = "copy" <> parens (pretty e)
pretty (Manifest perm e) = "manifest" <> apply [apply (map pretty perm), pretty e]
pretty (Assert e msg (loc, _)) =
"assert" <> apply [pretty e, pretty msg, pretty $ show $ locStr loc]
Expand Down
1 change: 0 additions & 1 deletion src/Futhark/IR/Prop.hs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ safeExp (BasicOp op) = safeBasicOp op
safeBasicOp Manifest {} = True
safeBasicOp Iota {} = True
safeBasicOp Replicate {} = True
safeBasicOp Copy {} = True
safeBasicOp _ = False
safeExp (DoLoop _ _ body) = safeBody body
safeExp (Apply fname _ _ _) =
Expand Down
1 change: 0 additions & 1 deletion src/Futhark/IR/Prop/Aliases.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ basicOpAliases (Reshape _ _ e) = [vnameAliases e]
basicOpAliases (Rearrange _ e) = [vnameAliases e]
basicOpAliases (Rotate _ e) = [vnameAliases e]
basicOpAliases Concat {} = [mempty]
basicOpAliases Copy {} = [mempty]
basicOpAliases Manifest {} = [mempty]
basicOpAliases Assert {} = [mempty]
basicOpAliases UpdateAcc {} = [mempty]
Expand Down
2 changes: 0 additions & 2 deletions src/Futhark/IR/Prop/TypeOf.hs
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ basicOpType (Concat i (x :| _) ressize) =
result <$> lookupType x
where
result xt = [setDimSize i xt ressize]
basicOpType (Copy v) =
pure <$> lookupType v
basicOpType (Manifest _ v) =
pure <$> lookupType v
basicOpType Assert {} =
Expand Down
12 changes: 6 additions & 6 deletions src/Futhark/IR/SOACS/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ liftIdentityMapping _ pat aux op
where
e inp = case patElemType outId of
Acc {} -> BasicOp $ SubExp $ Var inp
_ -> BasicOp (Copy inp)
_ -> BasicOp (Replicate mempty (Var inp))
checkInvariance (outId, SubExpRes _ e, t) (invariant, mapresult, rettype')
| freeOrConst e =
( (Pat [outId], BasicOp $ Replicate (Shape [w]) e) : invariant,
Expand Down Expand Up @@ -308,7 +308,7 @@ liftIdentityStreaming _ (Pat pes) aux (Stream w arrs nes lam)
partitionEithers $ map isInvariantRes $ zip3 map_ts map_pes map_res,
not $ null invariant_map = Simplify $ do
forM_ invariant_map $ \(pe, arr) ->
letBind (Pat [pe]) $ BasicOp $ Copy arr
letBind (Pat [pe]) $ BasicOp $ Replicate mempty $ Var arr

let (variant_map_ts, variant_map_pes, variant_map_res) = unzip3 variant_map
lam' =
Expand Down Expand Up @@ -462,7 +462,7 @@ removeDuplicateMapOutput _ (Pat pes) aux (Screma w arrs form)
}
auxing aux $ letBind (Pat pes') $ Op $ Screma w arrs $ mapSOAC fun'
forM_ copies $ \(from, to) ->
letBind (Pat [to]) $ BasicOp $ Copy $ patElemName from
letBind (Pat [to]) $ BasicOp $ Replicate mempty $ Var $ patElemName from
where
checkForDuplicates (ses_ts_pes', copies) (se, t, pe)
| Just (_, _, pe') <- find (\(x, _, _) -> resSubExp x == resSubExp se) ses_ts_pes' =
Expand Down Expand Up @@ -728,7 +728,7 @@ isArrayOp cs (BasicOp (Rotate rots arr)) =
Just $ ArrayRotate cs arr rots
isArrayOp cs (BasicOp (Reshape k new_shape arr)) =
Just $ ArrayReshape cs arr k new_shape
isArrayOp cs (BasicOp (Copy arr)) =
isArrayOp cs (BasicOp (Replicate (Shape []) (Var arr))) =
Just $ ArrayCopy cs arr
isArrayOp _ _ =
Nothing
Expand All @@ -738,7 +738,7 @@ fromArrayOp (ArrayIndexing cs arr slice) = (cs, BasicOp $ Index arr slice)
fromArrayOp (ArrayRearrange cs arr perm) = (cs, BasicOp $ Rearrange perm arr)
fromArrayOp (ArrayRotate cs arr rots) = (cs, BasicOp $ Rotate rots arr)
fromArrayOp (ArrayReshape cs arr k new_shape) = (cs, BasicOp $ Reshape k new_shape arr)
fromArrayOp (ArrayCopy cs arr) = (cs, BasicOp $ Copy arr)
fromArrayOp (ArrayCopy cs arr) = (cs, BasicOp $ Replicate mempty $ Var arr)
fromArrayOp (ArrayVar cs arr) = (cs, BasicOp $ SubExp $ Var arr)

arrayOps ::
Expand Down Expand Up @@ -964,7 +964,7 @@ moveTransformToInput vtable screma_pat aux soac@(Screma w arrs (ScremaForm scan
ArrayReshape _ _ k new_shape ->
BasicOp $ Reshape k (Shape [w] <> new_shape) arr
ArrayCopy {} ->
BasicOp $ Copy arr
BasicOp $ Replicate mempty $ Var arr
ArrayVar {} ->
BasicOp $ SubExp $ Var arr
arr_transformed_t <- lookupType arr_transformed
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/SegOp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1393,7 +1393,7 @@ bottomUpSegOp (vtable, used) (Pat kpes) dec segop = Simplify $ do
then do
precopy <- newVName $ baseString (patElemName kpe) <> "_precopy"
index kpe {patElemName = precopy}
letBindNames [patElemName kpe] $ BasicOp $ Copy precopy
letBindNames [patElemName kpe] $ BasicOp $ Replicate mempty $ Var precopy
else index kpe
pure
( kpes'',
Expand Down
5 changes: 2 additions & 3 deletions src/Futhark/IR/Syntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,6 @@ data BasicOp
--
-- @concat(1, [[1,2], [3, 4]] :| [[[5,6]], [[7, 8]]], 4) = [[1, 2, 5, 6], [3, 4, 7, 8]]@
Concat Int (NonEmpty VName) SubExp
| -- | Copy the given array. The result will not alias anything.
Copy VName
| -- | Manifest an array with dimensions represented in the given
-- order. The result will not alias anything.
Manifest [Int] VName
Expand All @@ -368,7 +366,8 @@ data BasicOp
-- The t'IntType' indicates the type of the array returned and the
-- offset/stride arguments, but not the length argument.
Iota SubExp SubExp SubExp IntType
| -- | @replicate([3][2],1) = [[1,1], [1,1], [1,1]]@
| -- | @replicate([3][2],1) = [[1,1], [1,1], [1,1]]@. The result
-- has no aliases. Copy a value by passing an empty shape.
Replicate Shape SubExp
| -- | Create array of given type and shape, with undefined elements.
Scratch PrimType [SubExp]
Expand Down
4 changes: 0 additions & 4 deletions src/Futhark/IR/Traversals.hs
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,6 @@ mapExpM tv (BasicOp (Concat i (x :| ys) size)) = do
ys' <- mapM (mapOnVName tv) ys
size' <- mapOnSubExp tv size
pure $ BasicOp $ Concat i (x' :| ys') size'
mapExpM tv (BasicOp (Copy e)) =
BasicOp <$> (Copy <$> mapOnVName tv e)
mapExpM tv (BasicOp (Manifest perm e)) =
BasicOp <$> (Manifest perm <$> mapOnVName tv e)
mapExpM tv (BasicOp (Assert e msg loc)) =
Expand Down Expand Up @@ -337,8 +335,6 @@ walkExpM tv (BasicOp (Rotate es e)) =
mapM_ (walkOnSubExp tv) es >> walkOnVName tv e
walkExpM tv (BasicOp (Concat _ (x :| ys) size)) =
walkOnVName tv x >> mapM_ (walkOnVName tv) ys >> walkOnSubExp tv size
walkExpM tv (BasicOp (Copy e)) =
walkOnVName tv e
walkExpM tv (BasicOp (Manifest _ e)) =
walkOnVName tv e
walkExpM tv (BasicOp (Assert e msg _)) =
Expand Down
2 changes: 0 additions & 2 deletions src/Futhark/IR/TypeCheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -934,8 +934,6 @@ checkBasicOp (Concat i (arr1exp :| arr2exps) ressize) = do
bad $
TypeError "Types of arguments to concat do not match."
require [Prim int64] ressize
checkBasicOp (Copy e) =
void $ checkArrIdent e
checkBasicOp (Manifest perm arr) =
checkBasicOp $ Rearrange perm arr -- Basically same thing!
checkBasicOp (Assert e (ErrorMsg parts) _) = do
Expand Down
5 changes: 3 additions & 2 deletions src/Futhark/Internalise/Exps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1771,13 +1771,14 @@ isIntrinsicFunction qname args loc = do
r <- I.arrayRank <$> lookupType v
pure $ I.Rearrange ([1, 0] ++ [2 .. r - 1]) v
handleRest [x, y] "zip" = Just $ \desc ->
mapM (letSubExp "zip_copy" . BasicOp . Copy)
mapM (letSubExp "zip_copy" . BasicOp . Replicate mempty . I.Var)
=<< ( (++)
<$> internaliseExpToVars (desc ++ "_zip_x") x
<*> internaliseExpToVars (desc ++ "_zip_y") y
)
handleRest [x] "unzip" = Just $ \desc ->
mapM (letSubExp desc . BasicOp . Copy) =<< internaliseExpToVars desc x
mapM (letSubExp desc . BasicOp . Replicate mempty . I.Var)
=<< internaliseExpToVars desc x
handleRest [arr, offset, n1, s1, n2, s2] "flat_index_2d" = Just $ \desc -> do
flatIndexHelper desc loc arr offset [(n1, s1), (n2, s2)]
handleRest [arr1, offset, s1, s2, arr2] "flat_update_2d" = Just $ \desc -> do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ genCoalStmtInfo ::
Stm (Aliases rep) ->
ShortCircuitM rep (Maybe [SSPointInfo])
-- CASE a) @let x <- copy(b^{lu})@
genCoalStmtInfo lutab _ scopetab (Let pat aux (BasicOp (Copy b)))
genCoalStmtInfo lutab _ scopetab (Let pat aux (BasicOp (Replicate (Shape []) (Var b))))
| Pat [PatElem x (_, MemArray _ _ _ (ArrayIn m_x ind_x))] <- pat =
pure $ case (M.lookup x lutab, getScopeMemInfo b scopetab) of
(Just last_uses, Just (MemBlock tpb shpb m_b ind_b)) ->
Expand Down
1 change: 0 additions & 1 deletion src/Futhark/Optimise/ArrayShortCircuiting/DataStructs.hs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@ createsNewArrOK :: Exp rep -> Bool
createsNewArrOK (BasicOp Replicate {}) = True
createsNewArrOK (BasicOp Iota {}) = True
createsNewArrOK (BasicOp Manifest {}) = True
createsNewArrOK (BasicOp Copy {}) = True
createsNewArrOK (BasicOp Concat {}) = True
createsNewArrOK (BasicOp ArrayLit {}) = True
createsNewArrOK (BasicOp Scratch {}) = True
Expand Down
5 changes: 3 additions & 2 deletions src/Futhark/Optimise/ArrayShortCircuiting/MemRefAggreg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,13 @@ getUseSumFromStm td_env coal_tab (Let (Pat [x']) _ (BasicOp (Update _ _x (Slice
Var a -> case getDirAliasedIxfn td_env coal_tab a of
Nothing -> Just ([r1], [r1])
Just r2 -> Just ([r1], [r1, r2])
getUseSumFromStm td_env coal_tab (Let (Pat [y]) _ (BasicOp (Copy x))) = do
getUseSumFromStm td_env coal_tab (Let (Pat [y]) _ (BasicOp (Replicate (Shape []) (Var x)))) = do
-- y = copy x
wrt <- getDirAliasedIxfn td_env coal_tab $ patElemName y
rd <- getDirAliasedIxfn td_env coal_tab x
pure ([wrt], [wrt, rd])
getUseSumFromStm _ _ (Let Pat {} _ (BasicOp Copy {})) = error "Impossible"
getUseSumFromStm _ _ (Let Pat {} _ (BasicOp (Replicate (Shape []) _))) =
error "Impossible"
getUseSumFromStm td_env coal_tab (Let (Pat ys) _ (BasicOp (Concat _i (a :| bs) _ses))) =
-- concat
let ws = mapMaybe (getDirAliasedIxfn td_env coal_tab . patElemName) ys
Expand Down
Loading

0 comments on commit 94a6ce1

Please sign in to comment.