Skip to content

Commit fc6f819

Browse files
authored
feat: Add GraphPass trait and add some passes. (#28)
* Add GraphPass trait. * Add AllocateVar Graph pass. * Add AllocateEdge Graph Pass. * fix cargo fmt. * Add pass test dir. * Add pythonbind for allocate_var pass and test. * Add allocate edge pass.
1 parent f8a58bd commit fc6f819

File tree

14 files changed

+497
-6
lines changed

14 files changed

+497
-6
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
/target
22
TiledCUDA/
33
**/__pycache__/
4+
**/build/

thriller-bindings/Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
TEST ?= test_bindings.py
1+
TEST ?= pass/allocate_edge.py
22
TEST_DIR := tests
33

44
.PHONY: build test
@@ -7,5 +7,8 @@ build:
77
@cargo build
88
@maturin develop
99

10+
active:
11+
@source .env/bin/activate
12+
1013
test: build
1114
@python3 $(TEST_DIR)/$(TEST)

thriller-bindings/src/graph.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use pyo3::prelude::*;
22
use pyo3::types::PyList;
33

44
use thriller_core::{
5-
AccessMap, Convert, DataType, Gemm, Task, ThrillerEdge, ThrillerGraph, ThrillerNode,
6-
ThrillerNodeInner,
5+
AccessMap, AllocateEdge, AllocateVar, Convert, DataType, Gemm, GraphPass, Task, ThrillerEdge,
6+
ThrillerGraph, ThrillerNode, ThrillerNodeInner,
77
};
88

99
use crate::buffer::PyBuffer;
@@ -53,6 +53,20 @@ impl PyGraph {
5353
self.0.borrow_mut().connect();
5454
}
5555

56+
fn allocate_vars(&mut self) -> PyResult<String> {
57+
let mut graph = self.0.borrow_mut();
58+
let mut pass = AllocateVar::new();
59+
pass.run(&mut graph);
60+
Ok(pass.code().clone())
61+
}
62+
63+
fn allocate_edges(&mut self) -> PyResult<String> {
64+
let mut graph = self.0.borrow_mut();
65+
let mut pass = AllocateEdge::new();
66+
pass.run(&mut graph);
67+
Ok(pass.code().clone())
68+
}
69+
5670
fn codegen(&self) -> PyResult<String> {
5771
self.0
5872
.borrow()
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import context
2+
3+
from pythriller import initialize_thriller_flow, Layout, Tensor, TensorType
4+
from pythriller import Graph, Node, Edge, AttachedEdge, IterationVar, AccessMap
5+
from pythriller import Block
6+
7+
8+
if __name__ == '__main__':
9+
# Initialize runtime.
10+
initialize_thriller_flow()
11+
12+
# Define reg layout for A, B, C.
13+
RegLayoutA = Layout.RowMajor
14+
RegLayoutB = Layout.RowMajor
15+
RegLayoutC = Layout.RowMajor
16+
17+
# Define shared layout for A, B, C.
18+
SharedLayoutA = Layout.RowMajor
19+
SharedLayoutB = Layout.ColMajor
20+
SharedLayoutC = Layout.RowMajor
21+
22+
# Define global layout for A, B, C.
23+
GlobalLayoutA = Layout.RowMajor
24+
GlobalLayoutB = Layout.ColMajor
25+
GlobalLayoutC = Layout.RowMajor
26+
27+
# Define Reg Dim for A, B, C.
28+
RegDimA = [64, 64]
29+
RegDimB = [64, 64]
30+
RegDimC = [64, 64]
31+
32+
# Define Shared Dim for A, B, C.
33+
SharedDimA = [64, 64]
34+
SharedDimB = [64, 64]
35+
SharedDimC = [64, 64]
36+
37+
# Define Global Dim for A, B, C.
38+
GlobalDimA = [256, 256]
39+
GlobalDimB = [256, 256]
40+
GlobalDimC = [256, 256]
41+
42+
# Define Reg Tensor for A, B, C.
43+
rA = Tensor("rA", RegDimA, RegLayoutA, TensorType.RegTile)
44+
rB = Tensor("rB", RegDimB, RegLayoutB, TensorType.RegTile)
45+
acc = Tensor("acc", RegDimC, RegLayoutC, TensorType.RegTile)
46+
47+
# Define Shared Tensor for A, B, C.
48+
sA = Tensor("sA", SharedDimA, SharedLayoutA, TensorType.SharedTile)
49+
sB = Tensor("sB", SharedDimB, SharedLayoutB, TensorType.SharedTile)
50+
sC = Tensor("sC", SharedDimC, SharedLayoutC, TensorType.SharedTile)
51+
52+
# Define Global Tensor for A, B, C.
53+
gA = Tensor("gA", GlobalDimA, GlobalLayoutA, TensorType.GlobalTile)
54+
gB = Tensor("gB", GlobalDimB, GlobalLayoutB, TensorType.GlobalTile)
55+
gC = Tensor("gC", GlobalDimC, GlobalLayoutC, TensorType.GlobalTile)
56+
57+
# Define Reg Node for A, B, C.
58+
NodeRA = Node.tensor(rA)
59+
NodeRB = Node.tensor(rB)
60+
NodeRC = Node.tensor(acc)
61+
62+
# Define Reg GEMM Node.
63+
RegGemmNode = Node.gemm(NodeRA, NodeRB, NodeRC)
64+
65+
# Define Reg Edge for A, B, C, GEMM.
66+
RegEdgeA = Edge(NodeRA, RegGemmNode)
67+
RegEdgeB = Edge(NodeRB, RegGemmNode)
68+
RegEdgeC = Edge(RegGemmNode, NodeRC)
69+
70+
# Define Shared Node for A, B, C.
71+
NodeSA = Node.tensor(sA)
72+
NodeSB = Node.tensor(sB)
73+
NodeSC = Node.tensor(sC)
74+
75+
# Define Global Node for A, B, C.
76+
NodeGA = Node.tensor(gA)
77+
NodeGB = Node.tensor(gB)
78+
NodeGC = Node.tensor(gC)
79+
80+
# Define loop iter from shared to register
81+
LoopIterS2R = IterationVar('j', (0, 1))
82+
83+
# Define loop iter from global to shared
84+
LoopIterG2S = IterationVar('i', (0, 4))
85+
86+
# Build AccessMap from Shared to Register.
87+
AccessMapSA2RA = AccessMap(
88+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterS2R])
89+
AccessMapSB2RB = AccessMap(
90+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterS2R])
91+
AccessMapRC2SC = AccessMap([0], [[[]], [[]]], [[], []], [])
92+
93+
# Build AccessMap from Global to Shared.
94+
AccessMapGA2SA = AccessMap(
95+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterG2S])
96+
AccessMapGB2SB = AccessMap(
97+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterG2S])
98+
AccessMapSC2GC = AccessMap([0], [[[]], [[]]], [[], []], [])
99+
100+
# Build Attached Edge from Shared to Register.
101+
AttachedEdgeSA2RA = AttachedEdge(sA, rA, AccessMapSA2RA)
102+
AttachedEdgeSB2RB = AttachedEdge(sB, rB, AccessMapSB2RB)
103+
AttachedEdgeSC2RC = AttachedEdge(acc, sC, AccessMapRC2SC)
104+
105+
# Build Attached Edge from Global to Shared.
106+
AttachedEdgeGA2SA = AttachedEdge(gA, sA, AccessMapGA2SA)
107+
AttachedEdgeGB2SB = AttachedEdge(gB, sB, AccessMapGB2SB)
108+
AttachedEdgeSC2GC = AttachedEdge(sC, gC, AccessMapSC2GC)
109+
110+
# Build Register Level ETDG.
111+
RegGraph = Graph()
112+
113+
# Add Reg Nodes into Reg Graph.
114+
RegGraph.add_nodes([NodeRA, NodeRB, NodeRC, RegGemmNode])
115+
# Add Reg Edges into Reg Graph.
116+
RegGraph.add_edges([RegEdgeA, RegEdgeB, RegEdgeC])
117+
# Connect Reg Graph.
118+
RegGraph.connect()
119+
120+
# Print codegen for Reg Graph.
121+
reg_code = RegGraph.codegen()
122+
123+
# Build Block for Shared to Register.
124+
SharedToRegBlock = Block(
125+
[AttachedEdgeSA2RA, AttachedEdgeSB2RB], [AttachedEdgeSC2RC], RegGraph, [LoopIterS2R])
126+
127+
# Print codegen for Shared to Register Block.
128+
shared_to_reg_code = SharedToRegBlock.codegen()
129+
130+
# Define BlockNode for SharedToRegBlock
131+
SharedBlockNode = Node.block(SharedToRegBlock)
132+
133+
# Define Edge for SA, SB, SC, SharedBlockNode.
134+
EdgeSA2Block = Edge(NodeSA, SharedBlockNode)
135+
EdgeSB2Block = Edge(NodeSB, SharedBlockNode)
136+
EdgeBlock2SC = Edge(SharedBlockNode, NodeSC)
137+
138+
# Build Shared Level ETDG.
139+
SharedGraph = Graph()
140+
# Add Shared Nodes into Shared Graph.
141+
SharedGraph.add_nodes([NodeSA, NodeSB, NodeSC, SharedBlockNode])
142+
# Add Shared Edges into Shared Graph.
143+
SharedGraph.add_edges([EdgeSA2Block, EdgeSB2Block, EdgeBlock2SC])
144+
# Connect Shared Graph.
145+
SharedGraph.connect()
146+
147+
allocate_edges = SharedGraph.allocate_edges()
148+
print(allocate_edges)
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import context
2+
3+
from pythriller import initialize_thriller_flow, Layout, Tensor, TensorType
4+
from pythriller import Graph, Node, Edge, AttachedEdge, IterationVar, AccessMap
5+
from pythriller import Block
6+
7+
8+
if __name__ == '__main__':
9+
# Initialize runtime.
10+
initialize_thriller_flow()
11+
12+
# Define reg layout for A, B, C.
13+
RegLayoutA = Layout.RowMajor
14+
RegLayoutB = Layout.RowMajor
15+
RegLayoutC = Layout.RowMajor
16+
17+
# Define shared layout for A, B, C.
18+
SharedLayoutA = Layout.RowMajor
19+
SharedLayoutB = Layout.ColMajor
20+
SharedLayoutC = Layout.RowMajor
21+
22+
# Define global layout for A, B, C.
23+
GlobalLayoutA = Layout.RowMajor
24+
GlobalLayoutB = Layout.ColMajor
25+
GlobalLayoutC = Layout.RowMajor
26+
27+
# Define Reg Dim for A, B, C.
28+
RegDimA = [64, 64]
29+
RegDimB = [64, 64]
30+
RegDimC = [64, 64]
31+
32+
# Define Shared Dim for A, B, C.
33+
SharedDimA = [64, 64]
34+
SharedDimB = [64, 64]
35+
SharedDimC = [64, 64]
36+
37+
# Define Global Dim for A, B, C.
38+
GlobalDimA = [256, 256]
39+
GlobalDimB = [256, 256]
40+
GlobalDimC = [256, 256]
41+
42+
# Define Reg Tensor for A, B, C.
43+
rA = Tensor("rA", RegDimA, RegLayoutA, TensorType.RegTile)
44+
rB = Tensor("rB", RegDimB, RegLayoutB, TensorType.RegTile)
45+
acc = Tensor("acc", RegDimC, RegLayoutC, TensorType.RegTile)
46+
47+
# Define Shared Tensor for A, B, C.
48+
sA = Tensor("sA", SharedDimA, SharedLayoutA, TensorType.SharedTile)
49+
sB = Tensor("sB", SharedDimB, SharedLayoutB, TensorType.SharedTile)
50+
sC = Tensor("sC", SharedDimC, SharedLayoutC, TensorType.SharedTile)
51+
52+
# Define Global Tensor for A, B, C.
53+
gA = Tensor("gA", GlobalDimA, GlobalLayoutA, TensorType.GlobalTile)
54+
gB = Tensor("gB", GlobalDimB, GlobalLayoutB, TensorType.GlobalTile)
55+
gC = Tensor("gC", GlobalDimC, GlobalLayoutC, TensorType.GlobalTile)
56+
57+
# Define Reg Node for A, B, C.
58+
NodeRA = Node.tensor(rA)
59+
NodeRB = Node.tensor(rB)
60+
NodeRC = Node.tensor(acc)
61+
62+
# Define Reg GEMM Node.
63+
RegGemmNode = Node.gemm(NodeRA, NodeRB, NodeRC)
64+
65+
# Define Reg Edge for A, B, C, GEMM.
66+
RegEdgeA = Edge(NodeRA, RegGemmNode)
67+
RegEdgeB = Edge(NodeRB, RegGemmNode)
68+
RegEdgeC = Edge(RegGemmNode, NodeRC)
69+
70+
# Define Shared Node for A, B, C.
71+
NodeSA = Node.tensor(sA)
72+
NodeSB = Node.tensor(sB)
73+
NodeSC = Node.tensor(sC)
74+
75+
# Define Global Node for A, B, C.
76+
NodeGA = Node.tensor(gA)
77+
NodeGB = Node.tensor(gB)
78+
NodeGC = Node.tensor(gC)
79+
80+
# Define loop iter from shared to register
81+
LoopIterS2R = IterationVar('j', (0, 1))
82+
83+
# Define loop iter from global to shared
84+
LoopIterG2S = IterationVar('i', (0, 4))
85+
86+
# Build AccessMap from Shared to Register.
87+
AccessMapSA2RA = AccessMap(
88+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterS2R])
89+
AccessMapSB2RB = AccessMap(
90+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterS2R])
91+
AccessMapRC2SC = AccessMap([0], [[[]], [[]]], [[], []], [])
92+
93+
# Build AccessMap from Global to Shared.
94+
AccessMapGA2SA = AccessMap(
95+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterG2S])
96+
AccessMapGB2SB = AccessMap(
97+
[0], [[[1]], [[0]]], [[0], [0]], [LoopIterG2S])
98+
AccessMapSC2GC = AccessMap([0], [[[]], [[]]], [[], []], [])
99+
100+
# Build Attached Edge from Shared to Register.
101+
AttachedEdgeSA2RA = AttachedEdge(sA, rA, AccessMapSA2RA)
102+
AttachedEdgeSB2RB = AttachedEdge(sB, rB, AccessMapSB2RB)
103+
AttachedEdgeSC2RC = AttachedEdge(acc, sC, AccessMapRC2SC)
104+
105+
# Build Attached Edge from Global to Shared.
106+
AttachedEdgeGA2SA = AttachedEdge(gA, sA, AccessMapGA2SA)
107+
AttachedEdgeGB2SB = AttachedEdge(gB, sB, AccessMapGB2SB)
108+
AttachedEdgeSC2GC = AttachedEdge(sC, gC, AccessMapSC2GC)
109+
110+
# Build Register Level ETDG.
111+
RegGraph = Graph()
112+
113+
# Add Reg Nodes into Reg Graph.
114+
RegGraph.add_nodes([NodeRA, NodeRB, NodeRC, RegGemmNode])
115+
# Add Reg Edges into Reg Graph.
116+
RegGraph.add_edges([RegEdgeA, RegEdgeB, RegEdgeC])
117+
# Connect Reg Graph.
118+
RegGraph.connect()
119+
120+
# Print codegen for Reg Graph.
121+
reg_code = RegGraph.codegen()
122+
123+
# Build Block for Shared to Register.
124+
SharedToRegBlock = Block(
125+
[AttachedEdgeSA2RA, AttachedEdgeSB2RB], [AttachedEdgeSC2RC], RegGraph, [LoopIterS2R])
126+
127+
# Print codegen for Shared to Register Block.
128+
shared_to_reg_code = SharedToRegBlock.codegen()
129+
130+
# Define BlockNode for SharedToRegBlock
131+
SharedBlockNode = Node.block(SharedToRegBlock)
132+
133+
# Define Edge for SA, SB, SC, SharedBlockNode.
134+
EdgeSA2Block = Edge(NodeSA, SharedBlockNode)
135+
EdgeSB2Block = Edge(NodeSB, SharedBlockNode)
136+
EdgeBlock2SC = Edge(SharedBlockNode, NodeSC)
137+
138+
# Build Shared Level ETDG.
139+
SharedGraph = Graph()
140+
# Add Shared Nodes into Shared Graph.
141+
SharedGraph.add_nodes([NodeSA, NodeSB, NodeSC, SharedBlockNode])
142+
# Add Shared Edges into Shared Graph.
143+
SharedGraph.add_edges([EdgeSA2Block, EdgeSB2Block, EdgeBlock2SC])
144+
# Connect Shared Graph.
145+
SharedGraph.connect()
146+
147+
allocate_vars = SharedGraph.allocate_vars()
148+
print(allocate_vars)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import os
2+
import sys
3+
4+
sys.path.insert(
5+
0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))

thriller-core/src/backend/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

thriller-core/src/dataflow/graph.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use crate::{next_id, ThrillerResult};
1414
pub struct ThrillerGraph {
1515
#[allow(dead_code)]
1616
id: usize,
17-
nodes: Vec<Rc<RefCell<ThrillerNode>>>,
18-
edges: Vec<Rc<ThrillerEdge>>,
17+
pub(crate) nodes: Vec<Rc<RefCell<ThrillerNode>>>,
18+
pub(crate) edges: Vec<Rc<ThrillerEdge>>,
1919
}
2020

2121
impl ThrillerGraph {

thriller-core/src/dataflow/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ mod block;
22
mod edge;
33
mod graph;
44
mod node;
5+
mod pass;
56

67
pub use block::ThrillerBlock;
78
pub use edge::{AttachedEdge, ThrillerEdge};
89
pub use graph::ThrillerGraph;
910
pub use node::{ThrillerNode, ThrillerNodeInner};
11+
pub use pass::{AllocateEdge, AllocateVar, GraphPass};

0 commit comments

Comments
 (0)