1
- diff --git a/mlir/example/Ch8/include/toy/Ops.td b/mlir/example/Ch8/include/toy/Ops.td
2
- index 157e207..298bd3e 100644
3
- --- a/mlir/example/Ch8/include/toy/Ops.td
4
- +++ b/mlir/example/Ch8/include/toy/Ops.td
5
- @@ -367,4 +367,31 @@ def TransposeOp : Toy_Op<"transpose",
1
+ diff -urN Ch7/CMakeLists.txt Ch8/CMakeLists.txt
2
+ --- Ch7/CMakeLists.txt 2023-12-06 04:57:18.788273480 +0000
3
+ +++ Ch8/CMakeLists.txt 2024-10-01 13:51:09.920421616 +0000
4
+ @@ -6,10 +6,10 @@
5
+
6
+ set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
7
+ mlir_tablegen(ToyCombine.inc -gen-rewriters)
8
+ - add_public_tablegen_target(ToyCh7CombineIncGen)
9
+ + add_public_tablegen_target(ToyCh8CombineIncGen)
10
+
11
+ add_executable(
12
+ - mlir-example-ch7
13
+ + mlir-example-ch8
14
+ toyc.cpp
15
+ parser/AST.cpp
16
+ mlir/MLIRGen.cpp
17
+ @@ -19,8 +19,8 @@
18
+ mlir/ShapeInferencePass.cpp
19
+ mlir/ToyCombine.cpp)
20
+
21
+ - add_dependencies(mlir-example-ch7 ToyCh7ShapeInferenceInterfaceIncGen
22
+ - ToyCh7OpsIncGen ToyCh7CombineIncGen)
23
+ + add_dependencies(mlir-example-ch8 ToyCh8ShapeInferenceInterfaceIncGen
24
+ + ToyCh8OpsIncGen ToyCh8CombineIncGen)
25
+
26
+ include_directories(${CMAKE_CURRENT_BINARY_DIR})
27
+ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
28
+ @@ -28,7 +28,7 @@
29
+ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
30
+ get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
31
+ target_link_libraries(
32
+ - mlir-example-ch7
33
+ + mlir-example-ch8
34
+ PRIVATE ${dialect_libs}
35
+ ${conversion_libs}
36
+ ${extension_libs}
37
+ diff -urN Ch7/include/toy/AST.h Ch8/include/toy/AST.h
38
+ --- Ch7/include/toy/AST.h 2024-09-22 10:55:44.710339034 +0000
39
+ +++ Ch8/include/toy/AST.h 2024-10-01 13:51:14.420421786 +0000
40
+ @@ -20,9 +20,9 @@
41
+ #include "llvm/ADT/ArrayRef.h"
42
+ #include "llvm/ADT/StringRef.h"
43
+ #include "llvm/Support/Casting.h"
44
+ + #include <optional>
45
+ #include <utility>
46
+ #include <vector>
47
+ - #include <optional>
48
+
49
+ namespace toy {
50
+
51
+ diff -urN Ch7/include/toy/CMakeLists.txt Ch8/include/toy/CMakeLists.txt
52
+ --- Ch7/include/toy/CMakeLists.txt 2023-12-06 04:57:18.788273480 +0000
53
+ +++ Ch8/include/toy/CMakeLists.txt 2024-10-01 13:51:15.848421840 +0000
54
+ @@ -4,10 +4,10 @@
55
+ mlir_tablegen(Ops.cpp.inc -gen-op-defs)
56
+ mlir_tablegen(Dialect.h.inc -gen-dialect-decls)
57
+ mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs)
58
+ - add_public_tablegen_target(ToyCh7OpsIncGen)
59
+ + add_public_tablegen_target(ToyCh8OpsIncGen)
60
+
61
+ # Most dialects should use add_mlir_interfaces().
62
+ set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
63
+ mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls)
64
+ mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs)
65
+ - add_public_tablegen_target(ToyCh7ShapeInferenceInterfaceIncGen)
66
+ + add_public_tablegen_target(ToyCh8ShapeInferenceInterfaceIncGen)
67
+ diff -urN Ch7/include/toy/Ops.td Ch8/include/toy/Ops.td
68
+ --- Ch7/include/toy/Ops.td 2024-09-22 10:55:44.710339034 +0000
69
+ +++ Ch8/include/toy/Ops.td 2024-10-01 13:51:17.112421888 +0000
70
+ @@ -450,4 +450,31 @@
6
71
let hasVerifier = 1;
7
72
}
8
73
@@ -34,11 +99,41 @@ index 157e207..298bd3e 100644
34
99
+ }
35
100
+
36
101
#endif // TOY_OPS
37
- diff --git a/mlir/example/Ch8/matmul.toy.mlir b/mlir/example/Ch8/matmul.toy.mlir
38
- new file mode 100644
39
- index 0000000..5a0cd7e
40
- --- /dev/null
41
- +++ b/mlir/example/Ch8/matmul.toy.mlir
102
+ diff -urN Ch7/include/toy/Parser.h Ch8/include/toy/Parser.h
103
+ --- Ch7/include/toy/Parser.h 2024-09-22 10:55:44.714339101 +0000
104
+ +++ Ch8/include/toy/Parser.h 2024-10-01 13:51:18.412421937 +0000
105
+ @@ -22,9 +22,9 @@
106
+ #include "llvm/Support/raw_ostream.h"
107
+
108
+ #include <map>
109
+ + #include <optional>
110
+ #include <utility>
111
+ #include <vector>
112
+ - #include <optional>
113
+
114
+ namespace toy {
115
+
116
+ diff -urN Ch7/matmul.toy Ch8/matmul.toy
117
+ --- Ch7/matmul.toy 1970-01-01 00:00:00.000000000 +0000
118
+ +++ Ch8/matmul.toy 2024-10-01 13:51:11.744421685 +0000
119
+ @@ -0,0 +1,14 @@
120
+ + def main() {
121
+ + # Define a variable `a` with shape <2, 3>, initialized with the literal value.
122
+ + # The shape is inferred from the supplied literal.
123
+ + var a = [[1, 2, 3], [4, 5, 6]];
124
+ +
125
+ + # b is identical to a, the literal tensor is implicitly reshaped: defining new
126
+ + # variables is the way to reshape tensors (element count must match).
127
+ + var b<2, 3> = [1, 2, 3, 4, 5, 6];
128
+ +
129
+ + # transpose() and print() are the only builtin, the following will transpose
130
+ + # a and b and perform an element-wise multiplication before printing the result.
131
+ + # print(a * b + b);
132
+ + print(matmul(a, transpose(b)));
133
+ + }
134
+ diff -urN Ch7/matmul.toy.mlir Ch8/matmul.toy.mlir
135
+ --- Ch7/matmul.toy.mlir 1970-01-01 00:00:00.000000000 +0000
136
+ +++ Ch8/matmul.toy.mlir 2024-10-01 13:51:13.056421735 +0000
42
137
@@ -0,0 +1,16 @@
43
138
+ toy.func private @matmul_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
44
139
+ %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
@@ -56,11 +151,28 @@ index 0000000..5a0cd7e
56
151
+ toy.print %4 : tensor<*xf64>
57
152
+ toy.return
58
153
+ }
59
- diff --git a/mlir/example/Ch8/mlir/Dialect.cpp b/mlir/example/Ch8/mlir/Dialect.cpp
60
- index 6ec105a..d750782 100644
61
- --- a/mlir/example/Ch8/mlir/Dialect.cpp
62
- +++ b/mlir/example/Ch8/mlir/Dialect.cpp
63
- @@ -439,6 +439,63 @@ mlir::LogicalResult TransposeOp::verify() {
154
+ diff -urN Ch7/mlir/Dialect.cpp Ch8/mlir/Dialect.cpp
155
+ --- Ch7/mlir/Dialect.cpp 2024-09-22 10:55:44.714339101 +0000
156
+ +++ Ch8/mlir/Dialect.cpp 2024-10-01 13:51:19.988421996 +0000
157
+ @@ -13,6 +13,7 @@
158
+
159
+ #include "toy/Dialect.h"
160
+
161
+ + #include "mlir/Dialect/Arith/Utils/Utils.h"
162
+ #include "mlir/IR/Attributes.h"
163
+ #include "mlir/IR/Builders.h"
164
+ #include "mlir/IR/BuiltinAttributes.h"
165
+ @@ -429,7 +430,8 @@
166
+ auto resultType = results.front();
167
+
168
+ // Check that the result type of the function matches the operand type.
169
+ - if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
170
+ + if (inputType == resultType ||
171
+ + llvm::isa<mlir::UnrankedTensorType>(inputType) ||
172
+ llvm::isa<mlir::UnrankedTensorType>(resultType))
173
+ return mlir::success();
174
+
175
+ @@ -497,6 +499,58 @@
64
176
return mlir::success();
65
177
}
66
178
@@ -115,6 +227,147 @@ index 6ec105a..d750782 100644
115
227
+
116
228
+ return mlir::success();
117
229
+ }
230
+ +
118
231
//===----------------------------------------------------------------------===//
119
- // TableGen'd op method definitions
232
+ // Toy Types
120
233
//===----------------------------------------------------------------------===//
234
+ diff -urN Ch7/mlir/LowerToAffineLoops.cpp Ch8/mlir/LowerToAffineLoops.cpp
235
+ --- Ch7/mlir/LowerToAffineLoops.cpp 2024-09-22 10:55:44.714339101 +0000
236
+ +++ Ch8/mlir/LowerToAffineLoops.cpp 2024-10-01 13:51:21.668422059 +0000
237
+ @@ -19,6 +19,7 @@
238
+ #include "mlir/IR/Diagnostics.h"
239
+ #include "mlir/IR/DialectRegistry.h"
240
+ #include "mlir/IR/PatternMatch.h"
241
+ + #include "mlir/IR/Value.h"
242
+ #include "mlir/IR/ValueRange.h"
243
+ #include "mlir/Support/LLVM.h"
244
+ #include "mlir/Support/TypeID.h"
245
+ @@ -31,6 +32,7 @@
246
+ #include "mlir/Dialect/MemRef/IR/MemRef.h"
247
+ #include "mlir/Pass/Pass.h"
248
+ #include "mlir/Transforms/DialectConversion.h"
249
+ + #include "llvm/ADT/APFloat.h"
250
+ #include "llvm/ADT/ArrayRef.h"
251
+ #include "llvm/ADT/STLExtras.h"
252
+ #include "llvm/ADT/Sequence.h"
253
+ @@ -315,6 +317,91 @@
254
+ }
255
+ };
256
+
257
+ + //===----------------------------------------------------------------------===//
258
+ + // ToyToAffine RewritePatterns: MatMul operations
259
+ + //===----------------------------------------------------------------------===//
260
+ +
261
+ + struct MatMulOpLowering : public ConversionPattern {
262
+ + MatMulOpLowering(MLIRContext *ctx)
263
+ + : ConversionPattern(toy::MatMulOp::getOperationName(), 1, ctx) {}
264
+ +
265
+ + LogicalResult
266
+ + matchAndRewrite(Operation *op, ArrayRef<Value> operands,
267
+ + ConversionPatternRewriter &rewriter) const final {
268
+ + auto loc = op->getLoc();
269
+ +
270
+ + RankedTensorType lhsType =
271
+ + llvm::dyn_cast<RankedTensorType>(op->getOperand(0).getType());
272
+ + RankedTensorType rhsType =
273
+ + llvm::dyn_cast<RankedTensorType>(op->getOperand(1).getType());
274
+ + auto lhsShape = lhsType.getShape();
275
+ + auto rhsShape = rhsType.getShape();
276
+ +
277
+ + auto tensorType =
278
+ + llvm::dyn_cast<RankedTensorType>((*op->result_type_begin()));
279
+ +
280
+ + auto elemType = llvm::dyn_cast<FloatType>(tensorType.getElementType());
281
+ +
282
+ + // Insert an allocation and deallocation for the result of this operation.
283
+ + auto memRefType = convertTensorToMemRef(tensorType);
284
+ + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
285
+ +
286
+ + SmallVector<int64_t, 4> lowerBounds(tensorType.getRank() + 1, /*Value=*/0);
287
+ + SmallVector<int64_t, 4> steps(tensorType.getRank() + 1, /*Value=*/1);
288
+ + SmallVector<int64_t, 4> upperBounds{lhsShape[0], rhsShape[0], rhsShape[1]};
289
+ +
290
+ + // add initialization of result tensor.
291
+ + // Create a nest of affine loops to initialize the result tensor to 0.
292
+ + affine::buildAffineLoopNest(
293
+ + rewriter, loc, {0, 0}, tensorType.getShape(), {1, 1},
294
+ + [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
295
+ + // Create a constant float value of 0.0.
296
+ + auto valueToStore = nestedBuilder.create<arith::ConstantFloatOp>(
297
+ + loc, llvm::APFloat(0.0), elemType);
298
+ + // Store the constant value into the allocated memory.
299
+ + nestedBuilder.create<affine::AffineStoreOp>(loc, valueToStore, alloc,
300
+ + ivs);
301
+ + });
302
+ +
303
+ + // Create a nest of affine loops for matrix multiplication.
304
+ + affine::buildAffineLoopNest(
305
+ + rewriter, loc, lowerBounds, upperBounds, steps,
306
+ + [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
307
+ + // Extract loop induction variables.
308
+ + Value m = ivs[0];
309
+ + Value k = ivs[1];
310
+ + Value n = ivs[2];
311
+ +
312
+ + // Create an adaptor for the remapped operands of the MatMulOp.
313
+ + toy::MatMulOpAdaptor matmulAdaptor(operands);
314
+ +
315
+ + // Load elements from the left-hand side and right-hand side matrices.
316
+ + auto loadedLhs = nestedBuilder.create<affine::AffineLoadOp>(
317
+ + loc, matmulAdaptor.getLhs(), ValueRange{m, k});
318
+ + auto loadedRhs = nestedBuilder.create<affine::AffineLoadOp>(
319
+ + loc, matmulAdaptor.getRhs(), ValueRange{k, n});
320
+ + // Load elements from the result tensor from initial process above.
321
+ + auto loadedRes = nestedBuilder.create<affine::AffineLoadOp>(
322
+ + loc, alloc, ValueRange{m, n});
323
+ +
324
+ + // Perform the multiplication and addition operations.
325
+ + auto mulop =
326
+ + nestedBuilder.create<arith::MulFOp>(loc, loadedLhs, loadedRhs);
327
+ + auto valueToStore =
328
+ + nestedBuilder.create<arith::AddFOp>(loc, loadedRes, mulop);
329
+ +
330
+ + // Store the result back into the allocated memory.
331
+ + nestedBuilder.create<affine::AffineStoreOp>(loc, valueToStore, alloc,
332
+ + ValueRange{m, n});
333
+ + });
334
+ +
335
+ + // Replace this operation with the generated alloc.
336
+ + rewriter.replaceOp(op, alloc);
337
+ +
338
+ + return success();
339
+ + }
340
+ + };
341
+ +
342
+ } // namespace
343
+
344
+ //===----------------------------------------------------------------------===//
345
+ @@ -365,8 +452,8 @@
346
+ // the set of patterns that will lower the Toy operations.
347
+ RewritePatternSet patterns(&getContext());
348
+ patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
349
+ - PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
350
+ - &getContext());
351
+ + PrintOpLowering, ReturnOpLowering, TransposeOpLowering,
352
+ + MatMulOpLowering>(&getContext());
353
+
354
+ // With the target and rewrite patterns defined, we can now attempt the
355
+ // conversion. The conversion will signal failure if any of our `illegal`
356
+ diff -urN Ch7/mlir/MLIRGen.cpp Ch8/mlir/MLIRGen.cpp
357
+ --- Ch7/mlir/MLIRGen.cpp 2024-09-22 10:55:44.714339101 +0000
358
+ +++ Ch8/mlir/MLIRGen.cpp 2024-10-01 13:51:23.564422131 +0000
359
+ @@ -525,6 +525,14 @@
360
+ return builder.create<TransposeOp>(location, operands[0]);
361
+ }
362
+
363
+ + if (callee == "matmul") {
364
+ + if (call.getArgs().size() != 2) {
365
+ + emitError(location, "MLIR codegen encountered an error: toy.matmul "
366
+ + "expected 2 arguments");
367
+ + }
368
+ + return builder.create<MatMulOp>(location, operands[0], operands[1]);
369
+ + }
370
+ +
371
+ // Otherwise this is a call to a user-defined function. Calls to
372
+ // user-defined functions are mapped to a custom call that takes the callee
373
+ // name as an attribute.
0 commit comments