@@ -3,13 +3,23 @@ module Matrizer.RewriteRules (
3
3
Rules ,
4
4
optimizationRules ,
5
5
optimizationRulesFull ,
6
+ rotateTraceLeft ,
7
+ rotateTraceRight ,
8
+ productsToList ,
9
+ smartCommonFactor ,
10
+ optimalProduct ,
11
+ triplePartition ,
12
+ commonPrefix
6
13
) where
7
14
8
15
import Data.Maybe
16
+ import Data.Array
9
17
10
18
import Matrizer.MTypes
11
19
import Matrizer.Analysis
12
20
21
+ import Debug.Trace
22
+
13
23
------------------------------------------------------------------
14
24
-- List of optimizations
15
25
--
@@ -39,12 +49,12 @@ type Rules = [Rule]
39
49
40
50
optimizationRules :: Rules
41
51
optimizationRules = inverseRules ++ transposeRules ++ binopSumRules ++
42
- binopProductRules ++ ternProductRules ++ letExpRules ++ traceRules ++ detRules ++ diagRules ++ entrySumRules ++ hadamardProductRules ++ elementWiseRules
52
+ binopProductRules ++ ternProductRules ++ letExpRules ++ traceRules ++ detRules ++ diagRules ++ entrySumRules ++ hadamardProductRules ++ elementWiseRules ++ [smartCommonFactor]
43
53
44
54
-- moves that are valid and sometimes necessary, but generate many
45
55
-- matches and can slow down inference.
46
56
expensiveMoves :: Rules
47
- expensiveMoves = [introduceTranspose]
57
+ expensiveMoves = [introduceTranspose, productRtL, productLtR ]
48
58
49
59
optimizationRulesFull = optimizationRules ++ expensiveMoves
50
60
@@ -107,6 +117,8 @@ traceRules = [dissolveTrace
107
117
, linearTrace
108
118
, identityOps
109
119
, traceProduct
120
+ , rotateTraceLeft
121
+ , rotateTraceRight
110
122
, traceDiag]
111
123
112
124
detRules :: Rules
@@ -167,6 +179,135 @@ assocMult _ (Branch2 MScalarProduct (Branch2 MScalarProduct l c) r) = Just (Bran
167
179
assocMult _ (Branch2 MScalarProduct l (Branch2 MScalarProduct c r)) = Just (Branch2 MScalarProduct (Branch2 MScalarProduct l c) r)
168
180
assocMult _ _ = Nothing
169
181
182
+
183
+ -----------------------------------------------------------------------
184
+ -- 'smart' product rules that collapse an entire chain of products (a*b*c*d*e*....) into a list,
185
+ -- then re-expand that list
186
+ -- TODO: write code for the optimal re-expansion
187
+
188
+ productsToList :: Expr -> [Expr ]
189
+ productsToList (Branch2 MProduct a b) = (productsToList a) ++ (productsToList b)
190
+ productsToList (Branch3 MTernaryProduct a b c) = (productsToList a) ++ (productsToList b) ++ (productsToList c)
191
+ productsToList a = [a]
192
+
193
+ expandProductLtR :: [Expr ] -> Expr
194
+ expandProductLtR (x: [] ) = x
195
+ expandProductLtR (x: xs) = (Branch2 MProduct x (expandProductLtR xs))
196
+
197
+ -- takes a reversed list
198
+ expandProductRtL :: [Expr ] -> Expr
199
+ expandProductRtL (x: [] ) = x
200
+ expandProductRtL (x: xs) = (Branch2 MProduct (expandProductRtL xs) x)
201
+
202
+ productLtR :: Rule
203
+ productLtR _ z@ (Branch2 MProduct (Branch2 MProduct _ _) _) = Just $ expandProductLtR $ productsToList z
204
+ productLtR _ z@ (Branch2 MProduct _ (Branch2 MProduct _ _)) = Just $ expandProductLtR $ productsToList z
205
+ productLtR _ z@ (Branch3 MTernaryProduct _ _ _) = Just $ expandProductLtR $ productsToList z
206
+ productLtR _ _ = Nothing
207
+
208
+ productRtL :: Rule
209
+ productRtL _ z@ (Branch2 MProduct (Branch2 MProduct _ _) _) = Just $ expandProductRtL $ reverse $ productsToList z
210
+ productRtL _ z@ (Branch2 MProduct _ (Branch2 MProduct _ _)) = Just $ expandProductRtL $ reverse $ productsToList z
211
+ productRtL _ z@ (Branch3 MTernaryProduct _ _ _) = Just $ expandProductRtL $ reverse $ productsToList z
212
+ productRtL _ _ = Nothing
213
+
214
+
215
+ -- flopCost, split, rows, cols
216
+ optimalProductArray :: [(Int , Int )] -> Array (Int , Int ) (Int , Int , Int , Int )
217
+ optimalProductArray sizeList = r
218
+ where
219
+ n = length sizeList
220
+ (rows, cols) = unzip sizeList
221
+ r = listArray ((0 ,0 ),(n,n)) (map oP (range ((0 ,0 ), (n,n))))
222
+ oP (i, j) = if i== j then (0 , i, 0 , 0 )
223
+ else if (j== i+ 1 ) then (0 , i, rows!! i, cols!! i)
224
+ else let costs = [(prodCost (r! (i,k)) (r! (k,j)), k) | k <- [(i+ 1 ).. (j- 1 )]]
225
+ (minCost, k) = minimum costs in
226
+ (minCost, k, rows!! i, cols!! (j- 1 ))
227
+ prodCost (f1, k1, r1, c1) (f2, k2, r2, c2) = f1 + f2 + (r1* c2 * (2 * c1- 1 ))
228
+
229
+ optimalProduct tbl exprList =
230
+ case length exprList of
231
+ 1 -> head exprList
232
+ 2 -> (Branch2 MProduct (exprList!! 0 ) (exprList!! 1 ) )
233
+ _ -> let sizeList = map exprSize exprList
234
+ r = optimalProductArray sizeList in
235
+ assembleProduct exprList r 0 (length exprList)
236
+ where exprSize e = let (Right (Matrix a b _)) = treeMatrix e tbl in
237
+ (a, b)
238
+ assembleProduct exps r i j = if j== i+ 1 then exps!! i
239
+ else let (flops, k, a, b) = r! (i,j) in
240
+ (Branch2 MProduct (assembleProduct exps r i k)
241
+ (assembleProduct exps r k j))
242
+
243
+
244
+
245
+
246
+
247
+ smartCommonFactor tbl (Branch2 MSum a@ (Branch2 MProduct _ _) b@ (Branch2 MProduct _ _)) = scfHelperOuter tbl a b
248
+ smartCommonFactor tbl (Branch2 MSum a@ (Branch3 MTernaryProduct _ _ _) b@ (Branch2 MProduct _ _)) = scfHelperOuter tbl a b
249
+ smartCommonFactor tbl (Branch2 MSum a@ (Branch2 MProduct _ _) b@ (Branch3 MTernaryProduct _ _ _)) = scfHelperOuter tbl a b
250
+ smartCommonFactor tbl (Branch2 MSum a@ (Branch3 MTernaryProduct _ _ _) b@ (Branch3 MTernaryProduct _ _ _)) = scfHelperOuter tbl a b
251
+ smartCommonFactor _ _ = Nothing
252
+
253
+
254
+ scfHelperOuter tbl a b =
255
+ let list1 = productsToList a
256
+ list2 = productsToList b in
257
+ if list1== list2 then Just (Branch2 MScalarProduct (LiteralScalar 2.0 ) a)
258
+ else let (ncommon, clist) = (scfHelper tbl list1 list2) in
259
+ if ncommon == 0 then Nothing
260
+ else Just (optimalProduct tbl clist)
261
+
262
+ -- divide two lists into a common prefix, common suffix, and 'cores' that are different.
263
+ -- for example, [1,4,8,2,3] and [1,4,0,3] becomes prefix=[1,4], cores=([8,2], [0]), suffix=[3]
264
+ triplePartition a b = let (start, rA, rB) = commonPrefix a b
265
+ (rend, rcoreA, rcoreB) = commonPrefix (reverse rA) (reverse rB)
266
+ (end, coreA, coreB) = (reverse rend, reverse rcoreA, reverse rcoreB) in
267
+ (start, (coreA, coreB), end)
268
+ commonPrefix (x: xs) (y: ys) = if x== y
269
+ then let (p1, r1, r2) = commonPrefix xs ys in
270
+ ((x: p1), r1, r2)
271
+ else ([] , (x: xs), (y: ys))
272
+ commonPrefix a [] = ([] , a, [] )
273
+ commonPrefix [] b = ([] , [] , b)
274
+
275
+ -- we assume the lists are unequal, so at most one of the cores can be empty
276
+ scfHelper tbl list1 list2 =
277
+ let (start, (coreA, coreB), end) = triplePartition list1 list2
278
+ (cA, cB) = (optProdOrEmpty tbl coreA coreB, optProdOrEmpty tbl coreB coreA) in
279
+ (((length start) + (length end)), start ++ [Branch2 MSum cA cB] ++ end)
280
+ where
281
+ optProdOrEmpty tbl [] (b: bs) = let Right (Matrix r1 r2 [] ) = treeMatrix b tbl in
282
+ (IdentityLeaf r1)
283
+ optProdOrEmpty tbl a _ = optimalProduct tbl a
284
+
285
+
286
+ rotateTraceLeft :: Rule
287
+ rotateTraceLeft tbl (Branch1 MTrace z@ (Branch2 MProduct _ _)) = rtlHelper tbl z
288
+ rotateTraceLeft tbl (Branch1 MTrace z@ (Branch3 MTernaryProduct _ _ _)) = rtlHelper tbl z
289
+ rotateTraceLeft _ _ = Nothing
290
+ rtlHelper tbl z =
291
+ let x: xs = productsToList z
292
+ rotated = xs ++ [x]
293
+ candidate = expandProductRtL $ reverse rotated in
294
+ case treeMatrix candidate tbl of
295
+ Right m -> Just (Branch1 MTrace candidate)
296
+ Left err -> Nothing
297
+
298
+ rotateTraceRight :: Rule
299
+ rotateTraceRight tbl (Branch1 MTrace z@ (Branch2 MProduct _ _)) = rtrHelper tbl z
300
+ rotateTraceRight tbl (Branch1 MTrace z@ (Branch3 MTernaryProduct _ _ _)) = rtrHelper tbl z
301
+ rotateTraceRight _ _ = Nothing
302
+ rtrHelper tbl z =
303
+ let x: xs = reverse $ productsToList z
304
+ candidate = expandProductRtL $ (xs ++ [x]) in
305
+ case treeMatrix candidate tbl of
306
+ Right m -> Just (Branch1 MTrace candidate)
307
+ Left err -> Nothing
308
+
309
+ -------------------------------------------------
310
+
170
311
-- (AC + BC) -> (A+B)C
171
312
commonFactorRight :: Rule
172
313
commonFactorRight _ (Branch2 MSum (Branch2 MProduct l1 l2) (Branch2 MProduct r1 r2)) =
0 commit comments