diff --git a/src/Futhark/CodeGen/ImpGen/Multicore.hs b/src/Futhark/CodeGen/ImpGen/Multicore.hs index e430a5775b..a57cc59e8e 100644 --- a/src/Futhark/CodeGen/ImpGen/Multicore.hs +++ b/src/Futhark/CodeGen/ImpGen/Multicore.hs @@ -48,6 +48,27 @@ parallelCopy pt destloc srcloc = do emit $ Imp.Read tmp srcmem srcidx pt srcspace Imp.Nonvolatile emit $ Imp.Write destmem destidx pt destspace Imp.Nonvolatile $ Imp.var tmp pt +parallelRotate :: VName -> [Imp.TExp Int64] -> VName -> MulticoreGen () +parallelRotate dest rs src = do + t <- lookupType src + let ds = map pe64 $ arrayDims t + seq_code <- collect $ localOps inThreadOps $ do + body <- genRotate ds + free_params <- freeParams body + emit $ Imp.Op $ Imp.ParLoop "rotate" body free_params + free_params <- freeParams seq_code + s <- prettyString <$> newVName "rotate" + iterations <- dPrimVE "iterations" $ product ds + let scheduling = Imp.SchedulerInfo (untyped iterations) Imp.Static + emit . Imp.Op $ + Imp.SegOp s free_params (Imp.ParallelTask seq_code) Nothing [] scheduling + where + genRotate ds = collect . inISPC . generateChunkLoop "rotate" Vectorized $ \i -> do + is' <- dIndexSpace' "rep_i" ds i + is'' <- sequence $ zipWith3 rotate ds rs is' + copyDWIMFix dest is' (Var src) is'' + rotate d r i = dPrimVE "rot_i" $ rotateIndex d r i + topLevelOps, inThreadOps :: Operations MCMem HostEnv Imp.Multicore inThreadOps = (defaultOperations opCompiler) @@ -113,6 +134,10 @@ withAcc pat inputs lam = do locksForInputs atomics inputs' compileMCExp :: ExpCompiler MCMem HostEnv Imp.Multicore +compileMCExp (Pat [pe]) (BasicOp (Rotate rs arr)) + | Acc {} <- patElemType pe = pure () + | otherwise = + parallelRotate (patElemName pe) (map pe64 rs) arr compileMCExp _ (BasicOp (UpdateAcc acc is vs)) = updateAcc acc is vs compileMCExp pat (WithAcc inputs lam) =