Skip to content

Commit 8a46229

Browse files
committed
smart matrix multiplication (not enabled) and common factor detection
1 parent 9fe6c10 commit 8a46229

File tree

6 files changed

+206
-64
lines changed

6 files changed

+206
-64
lines changed

matrizer.cabal

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ executable matrizer
2828

2929

3030
Library
31-
Build-Depends: base >=4.5 && < 5, containers, mtl >=2.1 && <2.2, parsec >=3.1.3 && <3.2, multimap >=1.2, transformers >=0.3 && <0.4, heap >= 1.0.2, filepath, directory
31+
Build-Depends: base >=4.5 && < 5, containers, mtl >=2.1 && <2.2, parsec >=3.1.3 && <3.2, multimap >=1.2, transformers >=0.3 && <0.4, heap >= 1.0.2, filepath, directory, array
3232
hs-source-dirs: src
3333
Exposed-modules: Matrizer.MTypes, Matrizer.Parsing, Matrizer.Optimization, Matrizer.Analysis, Matrizer.CodeGen, Matrizer.Util, Matrizer.Search, Matrizer.Preprocess, Matrizer.Derivatives, Matrizer.RewriteRules, Matrizer.Equivalence
3434
default-language: Haskell2010

src/Matrizer/Derivatives.hs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import Matrizer.Optimization
1212
import Matrizer.Analysis
1313
import Matrizer.Search
1414

15+
import Debug.Trace
16+
1517
llexpr1 = (Branch3 MTernaryProduct (Branch1 MTranspose (Leaf "w")) (Branch2 MProduct (Branch1 MTranspose (Leaf "X")) (Leaf "X") ) (Leaf "w"))
1618
llexpr2 = (Branch2 MScalarProduct (LiteralScalar 2) (Branch2 MProduct (Branch1 MTranspose (Leaf "w")) (Branch2 MProduct (Branch1 MTranspose (Leaf "X")) (Leaf "y"))))
1719
llexpr3 = (Branch2 MProduct (Branch1 MTranspose (Leaf "y")) (Leaf "y"))
@@ -167,18 +169,15 @@ beamSearch2 fn iters beamSize nRewrites tbl beam =
167169

168170

169171
llSymbols2 :: SymbolTable
170-
llSymbols2 = Map.fromList [("X", Matrix 100 100 []), ("B", Matrix 100 100 []),("A", Matrix 100 100 []),("S", Matrix 100 100 []),("y", Matrix 100 1 []), ("C", Matrix 100 100 [])]
172+
llSymbols2 = Map.fromList [("X", Matrix 100 60 []), ("B", Matrix 100 100 []),("A", Matrix 100 100 []),("S", Matrix 60 60 []),("y", Matrix 100 1 []), ("C", Matrix 100 100 [])]
171173

172174
pp = (Branch3 MTernaryProduct (Leaf "X") (Leaf "S") (Branch1 MTranspose (Leaf "X")))
173175
ll17 = (Branch1 MTrace (Branch3 MTernaryProduct (Leaf "A") (Branch1 MInverse pp) (Leaf "B")))
174176
dl17 = reduceDifferential "X" ll17
175177

176178

177-
178-
dl17_1 = Branch1 MTrace (Branch2 MProduct (Leaf "A") (Branch2 MProduct (Branch3 MTernaryProduct (Leaf "C") (Branch2 MProduct (Branch1 MDifferential (Leaf "X")) (Branch2 MProduct (Leaf "S") (Branch1 MTranspose (Leaf "X")))) (Leaf "C")) (Leaf "B")))
179-
dl17_2 = Branch1 MTrace (Branch2 MProduct (Leaf "A") (Branch2 MProduct (Branch3 MTernaryProduct (Leaf "C") (Branch2 MProduct (Leaf "X") (Branch2 MProduct (Leaf "S") (Branch1 MTranspose (Branch1 MDifferential (Leaf "X"))))) (Leaf "C")) (Leaf "B")))
180-
181-
179+
llS = Map.fromList [("X", Matrix 100 2 []), ("w", Matrix 2 1 []),("d", Matrix 2 1 []) ]
180+
r = (Branch2 MSum (Branch2 MProduct (Branch2 MProduct (Branch2 MProduct (Branch1 MTranspose (Leaf "d")) (Branch1 MTranspose (Leaf "X"))) (Leaf "X")) (Leaf "w")) (Branch2 MProduct (Branch2 MProduct (Branch2 MProduct (Branch1 MTranspose (Leaf "w")) (Branch1 MTranspose (Leaf "X"))) (Leaf "X")) (Leaf "d")))
182181

183182
decompose :: Expr -> [(Float, Expr)]
184183
decompose (Branch2 MScalarProduct (LiteralScalar c) a) = [(c1*c, d) | (c1, d) <- decompose a]
@@ -214,7 +213,7 @@ differentiate tbl (Branch2 MDiff a b) c = do da <- differentiate tbl a c
214213
differentiate tbl (Branch2 MScalarProduct (LiteralScalar s) a) c =
215214
do da <- differentiate tbl a c
216215
return $ (Branch2 MScalarProduct (LiteralScalar s) da)
217-
differentiate tbl expr c = differentiateBySearch tbl expr c
216+
differentiate tbl expr c = differentiateBySearch tbl expr c
218217

219218

220219
differentiateBySearch tbl expr c =
@@ -227,6 +226,8 @@ differentiateBySearch tbl expr c =
227226
derivFromAstar :: SymbolTable -> Expr -> Maybe Expr
228227
derivFromAstar tbl expr = do r <- runAstar tbl expr
229228
extractDeriv tbl (nvalue r)
229+
230+
230231

231232
reduceWithTrace tbl expr c = let d = reduceDifferential c expr
232233
Right (Matrix d1 d2 dprops) = treeMatrix d tbl in

src/Matrizer/MTypes.hs

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ data Expr = Leaf VarName
1818
| Branch2 BinOp Expr Expr
1919
| Branch3 TernOp Expr Expr Expr
2020
| Let VarName Expr Bool Expr -- bool flag specifies whether this intermediate variable can be optimized out
21-
deriving (Eq, Ord)
21+
deriving (Eq, Ord, Show)
2222

23-
data TernOp = MTernaryProduct deriving (Eq, Ord)
23+
data TernOp = MTernaryProduct deriving (Eq, Ord, Show)
2424
data BinOp = MProduct
2525
| MSum
2626
| MDiff
@@ -30,7 +30,7 @@ data BinOp = MProduct
3030
| MScalarProduct
3131
| MHadamardProduct
3232
| MColProduct
33-
deriving (Eq, Ord, Enum)
33+
deriving (Eq, Ord, Enum, Show)
3434

3535
data UnOp = MInverse
3636
| MTranspose
@@ -44,60 +44,58 @@ data UnOp = MInverse
4444
| MDiagMV -- extract a matrix diagonal as a vector
4545
| MEntrySum
4646
| MElementWise ScalarOp
47-
deriving (Eq, Ord)
47+
deriving (Eq, Ord, Show)
4848

4949
data ScalarOp = MLog
5050
| MExp -- TODO: support matrix exponentials
5151
| MReciprocal
52-
deriving (Eq, Ord, Enum)
52+
deriving (Eq, Ord, Enum, Show)
5353

5454
-- AST pretty printing
55+
pprint_ternop MTernaryProduct = "***"
5556

56-
instance Show TernOp where
57-
show _ = "***"
58-
59-
instance Show BinOp where
60-
show MProduct = "mmul"
61-
show MScalarProduct = "smul"
62-
show MHadamardProduct = "hmul"
63-
show MColProduct = "cmul" -- don't really expect people to use this input syntax
57+
pprint_binop MProduct = "mmul"
58+
pprint_binop MScalarProduct = "smul"
59+
pprint_binop MHadamardProduct = "hmul"
60+
pprint_binop MColProduct = "cmul" -- don't really expect people to use this input syntax
6461
-- except for internal test cases
65-
show MSum = "add"
66-
show MDiff = "sub"
67-
show MLinSolve = "solve"
68-
show MTriSolve = "triSolve"
69-
show MCholSolve = "cholSolve"
70-
71-
instance Show UnOp where
72-
show MInverse = "inv"
73-
show MTranspose = "transpose"
74-
show MChol = "chol"
75-
show MTrace = "tr"
76-
show (MDeriv v) = "deriv_" ++ v
77-
show (MUnresolvedDeriv v) = "unresolved_deriv_" ++ v
78-
show MDifferential = "differential"
79-
show MDet = "det"
80-
show MDiagVM = "toDiag"
81-
show MDiagMV = "diag"
82-
show MEntrySum = "sum"
83-
show (MElementWise sop) = show sop
84-
85-
instance Show ScalarOp where
86-
show MLog = "log"
87-
show MExp = "exp"
88-
show MReciprocal = "recip"
89-
90-
instance Show Expr where
91-
show (Leaf a) = a
92-
show (IdentityLeaf _) = "I"
93-
show (ZeroLeaf _ _) = "0"
94-
show (LiteralScalar x) = show x
95-
show (Branch1 op c) = "(" ++ show op ++ " " ++ show c ++ ")"
96-
show (Branch2 op a b) = "(" ++ show op ++ " " ++ show a ++ " "
97-
++ show b ++ ")"
98-
show (Branch3 op a b c) = "(" ++ show op ++ " " ++ show a ++ " "
99-
++ show b ++ " " ++ show c ++ ")"
100-
show (Let v a tmp b) = "(let (" ++ v ++ " := " ++ show a ++ (if tmp then " #temporary ) " else ") ") ++ "\n" ++ show b ++ ")"
62+
pprint_binop MSum = "add"
63+
pprint_binop MDiff = "sub"
64+
pprint_binop MLinSolve = "solve"
65+
pprint_binop MTriSolve = "triSolve"
66+
pprint_binop MCholSolve = "cholSolve"
67+
68+
69+
pprint_unop MInverse = "inv"
70+
pprint_unop MTranspose = "transpose"
71+
pprint_unop MChol = "chol"
72+
pprint_unop MTrace = "tr"
73+
pprint_unop (MDeriv v) = "deriv_" ++ v
74+
pprint_unop (MUnresolvedDeriv v) = "unresolved_deriv_" ++ v
75+
pprint_unop MDifferential = "differential"
76+
pprint_unop MDet = "det"
77+
pprint_unop MDiagVM = "toDiag"
78+
pprint_unop MDiagMV = "diag"
79+
pprint_unop MEntrySum = "sum"
80+
pprint_unop (MElementWise sop) = pprint_scalarop sop
81+
82+
pprint_scalarop MLog = "log"
83+
pprint_scalarop MExp = "exp"
84+
pprint_scalarop MReciprocal = "recip"
85+
86+
87+
pprint (Leaf a) = a
88+
pprint (IdentityLeaf _) = "I"
89+
pprint (ZeroLeaf _ _) = "0"
90+
pprint (LiteralScalar x) = show x
91+
pprint (Branch1 op c) = "(" ++ pprint_unop op ++ " " ++ pprint c ++ ")"
92+
pprint (Branch2 op a b) = "(" ++ pprint_binop op ++ " " ++ pprint a ++ " "
93+
++ pprint b ++ ")"
94+
pprint (Branch3 op a b c) = "(" ++ pprint_ternop op ++ " " ++ pprint a ++ " "
95+
++ pprint b ++ " " ++ pprint c ++ ")"
96+
pprint (Let v a tmp b) = "(let (" ++ v ++ " := " ++ pprint a ++ (if tmp then " #temporary ) " else ") ") ++ "\n" ++ pprint b ++ ")"
97+
98+
10199

102100
------------------------------------------------------------------------
103101
-- Symbol Table Definition

src/Matrizer/Optimization.hs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import Control.Monad.Error
1111
import Matrizer.MTypes
1212
import Matrizer.Analysis
1313
import Matrizer.RewriteRules
14+
import Debug.Trace
1415

1516
----------------------------------
1617

@@ -343,7 +344,7 @@ optimizerTraversal fn tbl rules z@( n@(Branch1 _ _), _) =
343344
-- the net change in FLOPs
344345
optimizeAtNode :: ScoreFn -> SymbolTable -> [Rule] -> Expr -> ThrowsError [(Expr, Int)]
345346
optimizeAtNode fn tbl rules t = let opts = mapMaybeFunc t [f tbl | f <- rules ] in
346-
scoreOptimizations fn tbl t opts
347+
scoreOptimizations fn tbl t opts
347348

348349
scoreOptimizations :: ScoreFn -> SymbolTable -> Expr -> [Expr] -> ThrowsError [(Expr, Int)]
349350
scoreOptimizations fn tbl t opts = mapM (scoreOpt t) opts where
@@ -373,3 +374,4 @@ mapMaybeFunc x (f:fs) =
373374

374375
recognizeVar :: VarName -> Expr -> SymbolTable -> Expr -> Maybe Expr
375376
recognizeVar var rhs tbl tree = if rhs==tree then Just (Leaf var) else Nothing
377+

src/Matrizer/RewriteRules.hs

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,23 @@ module Matrizer.RewriteRules (
33
Rules,
44
optimizationRules,
55
optimizationRulesFull,
6+
rotateTraceLeft,
7+
rotateTraceRight,
8+
productsToList,
9+
smartCommonFactor,
10+
optimalProduct,
11+
triplePartition,
12+
commonPrefix
613
) where
714

815
import Data.Maybe
16+
import Data.Array
917

1018
import Matrizer.MTypes
1119
import Matrizer.Analysis
1220

21+
import Debug.Trace
22+
1323
------------------------------------------------------------------
1424
-- List of optimizations
1525
--
@@ -39,12 +49,12 @@ type Rules = [Rule]
3949

4050
optimizationRules :: Rules
4151
optimizationRules = inverseRules ++ transposeRules ++ binopSumRules ++
42-
binopProductRules ++ ternProductRules ++ letExpRules ++ traceRules ++ detRules ++ diagRules ++ entrySumRules ++ hadamardProductRules ++ elementWiseRules
52+
binopProductRules ++ ternProductRules ++ letExpRules ++ traceRules ++ detRules ++ diagRules ++ entrySumRules ++ hadamardProductRules ++ elementWiseRules ++ [smartCommonFactor]
4353

4454
-- moves that are valid and sometimes necessary, but generate many
4555
-- matches and can slow down inference.
4656
expensiveMoves :: Rules
47-
expensiveMoves = [introduceTranspose]
57+
expensiveMoves = [introduceTranspose, productRtL, productLtR]
4858

4959
optimizationRulesFull = optimizationRules ++ expensiveMoves
5060

@@ -107,6 +117,8 @@ traceRules = [dissolveTrace
107117
, linearTrace
108118
, identityOps
109119
, traceProduct
120+
, rotateTraceLeft
121+
, rotateTraceRight
110122
, traceDiag]
111123

112124
detRules :: Rules
@@ -167,6 +179,135 @@ assocMult _ (Branch2 MScalarProduct (Branch2 MScalarProduct l c) r) = Just (Bran
167179
assocMult _ (Branch2 MScalarProduct l (Branch2 MScalarProduct c r)) = Just (Branch2 MScalarProduct (Branch2 MScalarProduct l c) r)
168180
assocMult _ _ = Nothing
169181

182+
183+
-----------------------------------------------------------------------
184+
-- 'smart' product rules that collapse an entire chain of products (a*b*c*d*e*....) into a list,
185+
-- then re-expand that list
186+
-- TODO: write code for the optimal re-expansion
187+
188+
productsToList :: Expr -> [Expr]
189+
productsToList (Branch2 MProduct a b) = (productsToList a) ++ (productsToList b)
190+
productsToList (Branch3 MTernaryProduct a b c) = (productsToList a) ++ (productsToList b) ++ (productsToList c)
191+
productsToList a = [a]
192+
193+
expandProductLtR :: [Expr] -> Expr
194+
expandProductLtR (x:[]) = x
195+
expandProductLtR (x:xs) = (Branch2 MProduct x (expandProductLtR xs))
196+
197+
-- takes a reversed list
198+
expandProductRtL :: [Expr] -> Expr
199+
expandProductRtL (x:[]) = x
200+
expandProductRtL (x:xs) = (Branch2 MProduct (expandProductRtL xs) x)
201+
202+
productLtR :: Rule
203+
productLtR _ z@(Branch2 MProduct (Branch2 MProduct _ _) _) = Just $ expandProductLtR $ productsToList z
204+
productLtR _ z@(Branch2 MProduct _ (Branch2 MProduct _ _)) = Just $ expandProductLtR $ productsToList z
205+
productLtR _ z@(Branch3 MTernaryProduct _ _ _) = Just $ expandProductLtR $ productsToList z
206+
productLtR _ _ = Nothing
207+
208+
productRtL :: Rule
209+
productRtL _ z@(Branch2 MProduct (Branch2 MProduct _ _) _) = Just $ expandProductRtL $ reverse $ productsToList z
210+
productRtL _ z@(Branch2 MProduct _ (Branch2 MProduct _ _)) = Just $ expandProductRtL $ reverse $ productsToList z
211+
productRtL _ z@(Branch3 MTernaryProduct _ _ _) = Just $ expandProductRtL $ reverse $ productsToList z
212+
productRtL _ _ = Nothing
213+
214+
215+
-- flopCost, split, rows, cols
216+
optimalProductArray :: [(Int, Int)] -> Array (Int, Int) (Int, Int, Int, Int)
217+
optimalProductArray sizeList = r
218+
where
219+
n = length sizeList
220+
(rows, cols) = unzip sizeList
221+
r = listArray ((0,0),(n,n)) (map oP (range ((0,0), (n,n))))
222+
oP (i, j) = if i==j then (0, i, 0, 0)
223+
else if (j==i+1) then (0, i, rows!!i, cols!!i)
224+
else let costs = [(prodCost (r!(i,k)) (r!(k,j)), k) | k <- [(i+1)..(j-1)]]
225+
(minCost, k) = minimum costs in
226+
(minCost, k, rows!!i, cols!!(j-1))
227+
prodCost (f1, k1, r1, c1) (f2, k2, r2, c2) = f1 + f2 + (r1*c2 * (2*c1-1))
228+
229+
optimalProduct tbl exprList =
230+
case length exprList of
231+
1 -> head exprList
232+
2 -> (Branch2 MProduct (exprList!!0) (exprList!!1) )
233+
_ -> let sizeList = map exprSize exprList
234+
r = optimalProductArray sizeList in
235+
assembleProduct exprList r 0 (length exprList)
236+
where exprSize e = let (Right (Matrix a b _)) = treeMatrix e tbl in
237+
(a, b)
238+
assembleProduct exps r i j = if j==i+1 then exps!!i
239+
else let (flops, k, a, b) = r!(i,j) in
240+
(Branch2 MProduct (assembleProduct exps r i k)
241+
(assembleProduct exps r k j))
242+
243+
244+
245+
246+
247+
smartCommonFactor tbl (Branch2 MSum a@(Branch2 MProduct _ _) b@(Branch2 MProduct _ _)) = scfHelperOuter tbl a b
248+
smartCommonFactor tbl (Branch2 MSum a@(Branch3 MTernaryProduct _ _ _) b@(Branch2 MProduct _ _)) = scfHelperOuter tbl a b
249+
smartCommonFactor tbl (Branch2 MSum a@(Branch2 MProduct _ _) b@(Branch3 MTernaryProduct _ _ _)) = scfHelperOuter tbl a b
250+
smartCommonFactor tbl (Branch2 MSum a@(Branch3 MTernaryProduct _ _ _) b@(Branch3 MTernaryProduct _ _ _)) = scfHelperOuter tbl a b
251+
smartCommonFactor _ _ = Nothing
252+
253+
254+
scfHelperOuter tbl a b =
255+
let list1 = productsToList a
256+
list2 = productsToList b in
257+
if list1==list2 then Just (Branch2 MScalarProduct (LiteralScalar 2.0) a)
258+
else let (ncommon, clist) = (scfHelper tbl list1 list2) in
259+
if ncommon == 0 then Nothing
260+
else Just (optimalProduct tbl clist)
261+
262+
-- divide two lists into a common prefix, common suffix, and 'cores' that are different.
263+
-- for example, [1,4,8,2,3] and [1,4,0,3] becomes prefix=[1,4], cores=([8,2], [0]), suffix=[3]
264+
triplePartition a b = let (start, rA, rB) = commonPrefix a b
265+
(rend, rcoreA, rcoreB) = commonPrefix (reverse rA) (reverse rB)
266+
(end, coreA, coreB) = (reverse rend, reverse rcoreA, reverse rcoreB) in
267+
(start, (coreA, coreB), end)
268+
commonPrefix (x:xs) (y:ys) = if x==y
269+
then let (p1, r1, r2) = commonPrefix xs ys in
270+
((x:p1), r1, r2)
271+
else ([], (x:xs), (y:ys))
272+
commonPrefix a [] = ([], a, [])
273+
commonPrefix [] b = ([], [], b)
274+
275+
-- we assume the lists are unequal, so at most one of the cores can be empty
276+
scfHelper tbl list1 list2 =
277+
let (start, (coreA, coreB), end) = triplePartition list1 list2
278+
(cA, cB) = (optProdOrEmpty tbl coreA coreB, optProdOrEmpty tbl coreB coreA) in
279+
(((length start) + (length end)), start ++ [Branch2 MSum cA cB] ++ end)
280+
where
281+
optProdOrEmpty tbl [] (b:bs) = let Right (Matrix r1 r2 []) = treeMatrix b tbl in
282+
(IdentityLeaf r1)
283+
optProdOrEmpty tbl a _ = optimalProduct tbl a
284+
285+
286+
rotateTraceLeft :: Rule
287+
rotateTraceLeft tbl (Branch1 MTrace z@(Branch2 MProduct _ _)) = rtlHelper tbl z
288+
rotateTraceLeft tbl (Branch1 MTrace z@(Branch3 MTernaryProduct _ _ _)) = rtlHelper tbl z
289+
rotateTraceLeft _ _ = Nothing
290+
rtlHelper tbl z =
291+
let x:xs = productsToList z
292+
rotated = xs ++ [x]
293+
candidate = expandProductRtL $ reverse rotated in
294+
case treeMatrix candidate tbl of
295+
Right m -> Just (Branch1 MTrace candidate)
296+
Left err -> Nothing
297+
298+
rotateTraceRight :: Rule
299+
rotateTraceRight tbl (Branch1 MTrace z@(Branch2 MProduct _ _)) = rtrHelper tbl z
300+
rotateTraceRight tbl (Branch1 MTrace z@(Branch3 MTernaryProduct _ _ _)) = rtrHelper tbl z
301+
rotateTraceRight _ _ = Nothing
302+
rtrHelper tbl z =
303+
let x:xs = reverse $ productsToList z
304+
candidate = expandProductRtL $ (xs ++ [x]) in
305+
case treeMatrix candidate tbl of
306+
Right m -> Just (Branch1 MTrace candidate)
307+
Left err -> Nothing
308+
309+
-------------------------------------------------
310+
170311
-- (AC + BC) -> (A+B)C
171312
commonFactorRight :: Rule
172313
commonFactorRight _ (Branch2 MSum (Branch2 MProduct l1 l2) (Branch2 MProduct r1 r2)) =

0 commit comments

Comments
 (0)