Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved backend support #182

Draft
wants to merge 245 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
245 commits
Select commit Hold shift + click to select a range
60070f4
add lit test
christopherhjung Sep 28, 2022
d9199a9
save
christopherhjung Sep 29, 2022
7d07813
start of eta problem solution
NeuralCoder3 Oct 5, 2022
0338d45
ideas for type transformation
NeuralCoder3 Oct 5, 2022
970c2c9
inner autodiff type fix
NeuralCoder3 Oct 6, 2022
4869d56
fixed eta problem
NeuralCoder3 Oct 6, 2022
58b9519
one step towards fixing the infinite loop
NeuralCoder3 Oct 6, 2022
d1568c4
investigation of endless loop
NeuralCoder3 Oct 6, 2022
a330129
start of real tests
NeuralCoder3 Oct 7, 2022
d4ebad4
64 bit width is working
NeuralCoder3 Oct 7, 2022
e903622
save
christopherhjung Oct 9, 2022
75b96b2
Merge remote-tracking branch 'origin/autodiff' into autodiff
christopherhjung Oct 9, 2022
91c15e7
save
christopherhjung Oct 10, 2022
4fd9d09
real ad test
NeuralCoder3 Oct 10, 2022
67192e3
reordered dialects
NeuralCoder3 Oct 10, 2022
f67665b
made filter explicit
NeuralCoder3 Oct 10, 2022
087340f
enzyme test overview
NeuralCoder3 Oct 11, 2022
ed023e7
extracted relevant code
NeuralCoder3 Oct 11, 2022
70850ad
simple brussel test
NeuralCoder3 Oct 12, 2022
ccaf515
initialization
NeuralCoder3 Oct 12, 2022
bf0a218
time meassurement
NeuralCoder3 Oct 12, 2022
372973f
added brusselator functions
NeuralCoder3 Oct 12, 2022
f6ae248
fixed dominance
NeuralCoder3 Oct 12, 2022
5c528c8
more debugging on problem
NeuralCoder3 Oct 12, 2022
abe63c5
more attempts
NeuralCoder3 Oct 12, 2022
4231c0b
tried filters
NeuralCoder3 Oct 13, 2022
affc4e8
Merge branch 'master' into autodiff
NeuralCoder3 Oct 13, 2022
2903c96
localized problem even more
NeuralCoder3 Oct 13, 2022
5e0071b
working state without lamspec
NeuralCoder3 Oct 13, 2022
4849db5
earliest failure point
NeuralCoder3 Oct 13, 2022
c4326c6
bug in thorin
NeuralCoder3 Oct 13, 2022
cdd9f65
stuck at error in line 170
NeuralCoder3 Oct 13, 2022
1deea17
Merge branch 'eta-problem' into autodiff
NeuralCoder3 Oct 13, 2022
55fd344
Merge branch 'ad_real' into autodiff
NeuralCoder3 Oct 13, 2022
661e329
fixed merge/vscode issue
NeuralCoder3 Oct 13, 2022
670270a
new tests with pow deriv
NeuralCoder3 Oct 13, 2022
9a73cc7
refactor
NeuralCoder3 Oct 13, 2022
571a012
adapted tests to filter printing
NeuralCoder3 Oct 14, 2022
bdd9876
disabled LamSpec
NeuralCoder3 Oct 14, 2022
503565b
Merge branch 'master' into autodiff
NeuralCoder3 Oct 14, 2022
cf1c313
moved issues to closure
NeuralCoder3 Oct 14, 2022
84af1a0
removed debug print
NeuralCoder3 Oct 14, 2022
2baff33
Comparison to other functional languages
NeuralCoder3 Oct 17, 2022
53fd9f2
implement higher order scalerize
christopherhjung Oct 17, 2022
c0fd971
grouped evaluation, added thorin timing
NeuralCoder3 Oct 18, 2022
819bf8e
added timing
NeuralCoder3 Oct 18, 2022
392607d
alloca timing
NeuralCoder3 Oct 18, 2022
c78306b
commented out alloca run
NeuralCoder3 Oct 18, 2022
7e2d588
complete alloca
NeuralCoder3 Oct 18, 2022
a99aa96
implement autodiff arr and ptr support
christopherhjung Oct 20, 2022
a35f56c
refactoring
christopherhjung Oct 20, 2022
8566c36
more refactoring
christopherhjung Oct 20, 2022
a0d328c
haskell fully cps
NeuralCoder3 Oct 21, 2022
56274a6
polymorphic
NeuralCoder3 Oct 21, 2022
eaf5c6e
large program tests
NeuralCoder3 Oct 21, 2022
b1772d4
map issue
NeuralCoder3 Oct 21, 2022
e760116
avoid OS dependent ignores
NeuralCoder3 Oct 21, 2022
2054750
removed editor dependent files
NeuralCoder3 Oct 21, 2022
536f8e7
removed temporary files
NeuralCoder3 Oct 21, 2022
ea68622
started to improve comment style
NeuralCoder3 Oct 21, 2022
639b8fd
Merge branch 'master' into autodiff
NeuralCoder3 Oct 21, 2022
e7c681e
fix missing builder
christopherhjung Oct 21, 2022
aff8b00
add DS_Store to gitignore
christopherhjung Oct 21, 2022
d4a4e34
remove DS_Store
christopherhjung Oct 21, 2022
c2b7874
repeat commit e760116
NeuralCoder3 Oct 22, 2022
1c0b82b
Merge branch 'autodiff' into autodiff_ptr (only technically, inspecti…
NeuralCoder3 Oct 24, 2022
ec371b7
Merge branch 'master' into autodiff
NeuralCoder3 Oct 25, 2022
be13f84
top_level flat arguments
NeuralCoder3 Oct 25, 2022
1f9c5f3
reorder memory
NeuralCoder3 Oct 25, 2022
e05a86e
fixed thorin printer errors
NeuralCoder3 Oct 25, 2022
010e724
fixed compilation preventing bugs
NeuralCoder3 Oct 25, 2022
be06774
update
NeuralCoder3 Oct 25, 2022
e1c1f28
Merge branch 'autodiff' into ad_ptr_merge
NeuralCoder3 Oct 25, 2022
e3d0f71
clang format
NeuralCoder3 Oct 25, 2022
15a153e
clang format
NeuralCoder3 Oct 25, 2022
eb6cc63
readded unicodes
NeuralCoder3 Oct 25, 2022
21b58f3
filter print
NeuralCoder3 Oct 25, 2022
4cff7dd
removed unused declaration
NeuralCoder3 Oct 25, 2022
88acd2b
removed superfluous code from tests
NeuralCoder3 Oct 25, 2022
ecdfe71
also for simple_mem
NeuralCoder3 Oct 25, 2022
0719560
corrections
NeuralCoder3 Oct 25, 2022
62c08ae
moved phases, passes
NeuralCoder3 Oct 25, 2022
e15e605
removed duplicate files
NeuralCoder3 Oct 25, 2022
2888aaa
synchronized optimization order
NeuralCoder3 Oct 25, 2022
aed00fc
removed builder
NeuralCoder3 Oct 25, 2022
e94ebf4
update
NeuralCoder3 Oct 25, 2022
d30e495
cleanup
NeuralCoder3 Oct 26, 2022
d50ebe6
removed unused imports
NeuralCoder3 Oct 26, 2022
32635c8
fixed parsing error
NeuralCoder3 Oct 26, 2022
b17aa52
Merge branch 'master' into autodiff
NeuralCoder3 Oct 26, 2022
72e9dd0
readable error message
NeuralCoder3 Oct 26, 2022
cba4f59
fixed real test
NeuralCoder3 Oct 26, 2022
e967893
fixed secondary closure problem with higher order functions (by elmin…
NeuralCoder3 Oct 26, 2022
35be7ad
added optional filter annotation to expected output
NeuralCoder3 Oct 26, 2022
2444b29
split pow test in thorin and llvm part
NeuralCoder3 Oct 26, 2022
725f7ad
split toplevel memory operations
NeuralCoder3 Oct 26, 2022
fc15f7d
rewrote argument preparation
NeuralCoder3 Oct 26, 2022
db758c9
more toplevel updates
NeuralCoder3 Oct 26, 2022
27654d3
thoughts about autodiff_zero
NeuralCoder3 Oct 26, 2022
92fe7a9
update
NeuralCoder3 Oct 26, 2022
b641029
refactor extract
NeuralCoder3 Oct 26, 2022
23e4f1a
alphabetical order
NeuralCoder3 Oct 26, 2022
1b8d949
refactor
NeuralCoder3 Oct 26, 2022
d97bb60
refactor
NeuralCoder3 Oct 26, 2022
4a5c15d
separated memory functions
NeuralCoder3 Oct 27, 2022
aeca3d0
refactor
NeuralCoder3 Oct 28, 2022
f77617c
matrix.transpose differentiation
NeuralCoder3 Oct 28, 2022
100b680
code metric measurements
NeuralCoder3 Oct 31, 2022
169c9ff
visualization
NeuralCoder3 Oct 31, 2022
fd313e5
added impala files
NeuralCoder3 Oct 31, 2022
4dc44ed
Merge branch 'master' into autodiff
NeuralCoder3 Nov 2, 2022
dc22856
Merge branch 'autodiff' of https://github.com/NeuralCoder3/thorin2 in…
NeuralCoder3 Nov 2, 2022
a3daddd
removed get pullback function
NeuralCoder3 Nov 3, 2022
f173c28
Merge branch 'autodiff' into ad_ptr_merge
NeuralCoder3 Nov 3, 2022
9d4c253
more variants to try get it to work
NeuralCoder3 Nov 3, 2022
d89f551
Merge branch 'master' into autodiff
NeuralCoder3 Nov 4, 2022
d9fa330
additional test file
NeuralCoder3 Nov 4, 2022
8551c4f
merged
NeuralCoder3 Nov 4, 2022
0858ecc
working simplified test
NeuralCoder3 Nov 4, 2022
4b9ae49
fixed version 1
NeuralCoder3 Nov 4, 2022
1d0a08f
correctly fixed program
NeuralCoder3 Nov 4, 2022
9f8dae1
fixed (most) tests
NeuralCoder3 Nov 4, 2022
b8cb071
fixed memory test (now working)
NeuralCoder3 Nov 4, 2022
20bbd1f
format
NeuralCoder3 Nov 4, 2022
a455c0b
more informative error message
NeuralCoder3 Nov 4, 2022
0b66fba
Merge branch 'autodiff' into ad_ptr_merge
NeuralCoder3 Nov 4, 2022
66ccd45
fixed some more tests
NeuralCoder3 Nov 4, 2022
7dadddf
partially reverted tuple, app for now
NeuralCoder3 Nov 7, 2022
23f7244
split memory test in two parts to debug code gen in the presence/abse…
NeuralCoder3 Nov 7, 2022
965a811
disabled reshape (fixed 5 additional test cases) for now
NeuralCoder3 Nov 7, 2022
dd569fc
fixed non-memory tests
NeuralCoder3 Nov 7, 2022
c215f17
fixed zero application for now
NeuralCoder3 Nov 8, 2022
5253c78
rebuild correct memory tuple
NeuralCoder3 Nov 8, 2022
1779f61
create pullback for rebuild tuples
NeuralCoder3 Nov 8, 2022
8d00f43
simpler imperative memory test
NeuralCoder3 Nov 8, 2022
4688fca
association of (mostly) correct pullbacks
NeuralCoder3 Nov 8, 2022
393f8f6
found cps2ds error
NeuralCoder3 Nov 8, 2022
9c24614
fixed parser errors
NeuralCoder3 Nov 8, 2022
f86922d
more tests on the cps2ds issue
NeuralCoder3 Nov 9, 2022
428908c
Merge branch 'master' into autodiff
NeuralCoder3 Nov 9, 2022
d7b63e9
Merge branch 'autodiff' of https://github.com/NeuralCoder3/thorin2 in…
NeuralCoder3 Nov 9, 2022
50532e0
Merge branch 'autodiff' into ad_ptr_merge
NeuralCoder3 Nov 10, 2022
939456b
adapted test case
NeuralCoder3 Nov 10, 2022
ea6d984
more debugging
NeuralCoder3 Nov 10, 2022
fdf001c
merge bug test case from ad_ptr_merge branch
NeuralCoder3 Nov 10, 2022
7ad6892
Merge branch 'master' into direct_old_var
NeuralCoder3 Nov 10, 2022
50b3aef
Merge remote-tracking branch 'origin/optimize_phase_extensions' into …
NeuralCoder3 Nov 10, 2022
11e54e3
skeleton for add_mem conversion
NeuralCoder3 Nov 10, 2022
d410eb0
plan
NeuralCoder3 Nov 10, 2022
702d848
add mem
NeuralCoder3 Nov 11, 2022
66e55fb
removed eval
NeuralCoder3 Nov 14, 2022
b41b437
Merge branch 'master' into ad_ptr_merge
NeuralCoder3 Nov 14, 2022
e75ee79
Merge branch 'ad_ptr_merge' into ho_codegen
NeuralCoder3 Nov 14, 2022
94d8918
removed ad
NeuralCoder3 Nov 14, 2022
be2005f
merge ad from master
NeuralCoder3 Nov 14, 2022
b8a0c95
fixed name preservation for tests, small improvements
NeuralCoder3 Nov 15, 2022
734823a
add mem related tests
NeuralCoder3 Nov 15, 2022
28a2095
fixed tests
NeuralCoder3 Nov 15, 2022
480f812
added commands to failing tests
NeuralCoder3 Nov 15, 2022
d8afb79
arg style test
NeuralCoder3 Nov 15, 2022
35c4924
made functions external
NeuralCoder3 Nov 15, 2022
3178a71
tests without extern
NeuralCoder3 Nov 16, 2022
3ca0edc
add reshape flat
NeuralCoder3 Nov 16, 2022
4c292b7
test with preprocessor
NeuralCoder3 Nov 16, 2022
6f08aca
fixed test case lea access
NeuralCoder3 Nov 17, 2022
be1d7e1
fixed oblivious lea reconstruction
NeuralCoder3 Nov 17, 2022
1c46504
fixed more test cases for reshape (divergence)
NeuralCoder3 Nov 17, 2022
226ec64
fixed llvm codegen for conditionals
NeuralCoder3 Nov 17, 2022
81da58b
investigated disabled tests for clos
NeuralCoder3 Nov 17, 2022
005dc1a
moved reshape call to clos
NeuralCoder3 Nov 17, 2022
420984e
more robust testing
NeuralCoder3 Nov 17, 2022
212a498
Merge remote-tracking branch 'upstream/feature/add-mem' into ho_codegen
NeuralCoder3 Nov 18, 2022
1517b90
moved add mem to closure
NeuralCoder3 Nov 18, 2022
d72d448
added complex no mem test case
NeuralCoder3 Nov 18, 2022
5a1603b
Merge remote-tracking branch 'upstream/feature/add-mem' into ho_codegen
NeuralCoder3 Nov 22, 2022
d4094d7
Merge remote-tracking branch 'upstream/feature/add-mem' into ho_codegen
NeuralCoder3 Nov 23, 2022
897388e
Merge remote-tracking branch 'upstream/feature/add-mem' into ho_codegen
NeuralCoder3 Nov 24, 2022
b0aead5
resolved circular dependency
NeuralCoder3 Nov 29, 2022
6fbb8f6
Merge remote-tracking branch 'origin/master' into ho_codegen
NeuralCoder3 Nov 29, 2022
8e540be
merge reshape from ad_ptr_merge
NeuralCoder3 Nov 29, 2022
fdc74c7
handle external functions
NeuralCoder3 Nov 29, 2022
a076278
fixed handling of main function
NeuralCoder3 Nov 29, 2022
2d11419
merge ad_ptr_merge compilation extension
NeuralCoder3 Nov 29, 2022
4e08350
Merge remote-tracking branch 'origin/master' into ho_codegen
NeuralCoder3 Nov 30, 2022
f6fa75a
Merge remote-tracking branch 'origin/master' into test/direct_old_var
NeuralCoder3 Dec 2, 2022
31c8751
fixed tests
NeuralCoder3 Dec 2, 2022
03f8a59
updated test
NeuralCoder3 Dec 2, 2022
794726d
more attempts to fix cps2ds
NeuralCoder3 Dec 5, 2022
87cd669
fixed lambda rewrite
NeuralCoder3 Dec 5, 2022
1792c0a
refactoring
NeuralCoder3 Dec 5, 2022
eeaa8ab
add expected result
NeuralCoder3 Dec 5, 2022
a5b9c7b
re-enable curr nom print
NeuralCoder3 Dec 5, 2022
f88f568
more complex case
NeuralCoder3 Dec 5, 2022
8b3fa6a
removed comment
NeuralCoder3 Dec 5, 2022
cca49e6
rewrite recursively
NeuralCoder3 Dec 5, 2022
eae0fab
rewrite callee first
NeuralCoder3 Dec 5, 2022
668cd96
Merge remote-tracking branch 'origin/default-compilation' into ho_cod…
NeuralCoder3 Dec 5, 2022
87b6356
Merge remote-tracking branch 'origin/direct_fix' into ho_codegen
NeuralCoder3 Dec 6, 2022
fb1ec79
Merge remote-tracking branch 'origin/default-compilation' into ho_cod…
NeuralCoder3 Dec 6, 2022
299da40
only reshape small arrays
NeuralCoder3 Dec 7, 2022
78a8a13
move mem to front (for closure)
NeuralCoder3 Dec 8, 2022
d6f4ae2
Merge branch 'ho_codegen' of https://github.com/NeuralCoder3/thorin2 …
NeuralCoder3 Dec 8, 2022
4733334
merge
NeuralCoder3 Dec 8, 2022
9de0b35
allow for additional backends
NeuralCoder3 Dec 8, 2022
43bbf4d
removed empty function
NeuralCoder3 Dec 8, 2022
6a187f9
start of haskell backend
NeuralCoder3 Dec 8, 2022
4af170b
function shape
NeuralCoder3 Dec 8, 2022
ff2980c
start of haskell emitting
NeuralCoder3 Dec 8, 2022
685d2f1
add mem tests
NeuralCoder3 Dec 9, 2022
fe95409
do not overwrite mem from non-apps
NeuralCoder3 Dec 9, 2022
0add729
some comments
NeuralCoder3 Dec 9, 2022
f528b4b
error reporting script
NeuralCoder3 Dec 9, 2022
5b7d6bb
tuple
NeuralCoder3 Dec 9, 2022
3977590
dumper based emitter
NeuralCoder3 Dec 9, 2022
92c2a2e
switched to OCaml
NeuralCoder3 Dec 12, 2022
f961510
removed comments
NeuralCoder3 Dec 12, 2022
0ce1c6f
example code
NeuralCoder3 Dec 12, 2022
9b1d2f6
todos, fixes
NeuralCoder3 Dec 12, 2022
1fe61ac
syntax highlight
NeuralCoder3 Dec 12, 2022
0e6f52c
correctly negated check
NeuralCoder3 Dec 13, 2022
a3c67d9
Merge branch 'ho_codegen' of https://github.com/NeuralCoder3/thorin2 …
NeuralCoder3 Dec 13, 2022
a5a7595
merge codegen fixes from matrix
NeuralCoder3 Dec 20, 2022
0a96f0f
Merge branch 'ho_codegen' of https://github.com/NeuralCoder3/thorin2 …
NeuralCoder3 Dec 20, 2022
255797e
generalized internal cleanup
NeuralCoder3 Dec 22, 2022
5e2b2ea
more cases, complicated test case
NeuralCoder3 Dec 22, 2022
2d38fe5
emitter phase
NeuralCoder3 Dec 22, 2022
6c0024c
register phase
NeuralCoder3 Dec 23, 2022
f360a89
Merge remote-tracking branch 'origin/ho_codegen' into haskell-backend
NeuralCoder3 Dec 23, 2022
9e05da4
fix merge
NeuralCoder3 Dec 23, 2022
7d6871c
register emitter phase
NeuralCoder3 Dec 23, 2022
42a2431
completed backend (working pow test)
NeuralCoder3 Feb 1, 2023
1d42191
more complex test
NeuralCoder3 Feb 3, 2023
3d48e2a
naming
NeuralCoder3 Feb 3, 2023
b925f8a
thorin opt timing test
NeuralCoder3 Feb 3, 2023
12e1ac2
haskell backend
NeuralCoder3 Feb 3, 2023
06ffc75
timing tests
NeuralCoder3 Feb 3, 2023
4de7655
thorin time
NeuralCoder3 Feb 3, 2023
6a701f7
updated tests
NeuralCoder3 Feb 6, 2023
d34d699
rust timing
NeuralCoder3 Feb 6, 2023
c834870
updated tests
NeuralCoder3 Feb 7, 2023
d5e582f
explanation
NeuralCoder3 Feb 7, 2023
0878018
rust backend
NeuralCoder3 Feb 9, 2023
d940d00
cleanup comments
NeuralCoder3 Feb 9, 2023
8408bd5
static vs dynamic dispatch in rust
NeuralCoder3 Feb 13, 2023
3585006
reversed skeleton of minimal emitter from llvm backend
NeuralCoder3 Feb 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
save
christopherhjung committed Oct 9, 2022
commit e903622b2baf15bd35e259f60ea343a7c2fd29ca
Binary file added .DS_Store
Binary file not shown.
7 changes: 6 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -11,6 +11,8 @@
"request": "launch",
"program": "${workspaceFolder}/build/bin/thorin",
"args": [
"-d",
"affine",
"-d",
"mem",
"-d",
@@ -21,8 +23,11 @@
"tool",
"-d",
"autodiff",
"${workspaceFolder}/lit/autodiff/simple_autodiff.thorin",
"${workspaceFolder}/lit/autodiff/arr.thorin",
//"${workspaceFolder}/lit/autodiff/simple_autodiff.thorin",
//"${file}",
"--output-ll",
"-",
"--output-thorin",
"-",
"-VVVV",
62 changes: 62 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
{
"files.associations": {
"array": "cpp",
"atomic": "cpp",
"bit": "cpp",
"*.tcc": "cpp",
"bitset": "cpp",
"cctype": "cpp",
"cfenv": "cpp",
"chrono": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"codecvt": "cpp",
"compare": "cpp",
"concepts": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"deque": "cpp",
"forward_list": "cpp",
"list": "cpp",
"map": "cpp",
"set": "cpp",
"unordered_map": "cpp",
"unordered_set": "cpp",
"vector": "cpp",
"exception": "cpp",
"functional": "cpp",
"initializer_list": "cpp",
"iomanip": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"memory": "cpp",
"new": "cpp",
"numbers": "cpp",
"numeric": "cpp",
"optional": "cpp",
"ostream": "cpp",
"ranges": "cpp",
"ratio": "cpp",
"sstream": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"string": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"type_traits": "cpp",
"tuple": "cpp",
"typeindex": "cpp",
"typeinfo": "cpp",
"utility": "cpp",
"variant": "cpp",
"any": "cpp"
}
}
Binary file added dialects/.DS_Store
Binary file not shown.
Binary file added dialects/affine/.DS_Store
Binary file not shown.
12 changes: 7 additions & 5 deletions dialects/affine/passes/lower_for.cpp
Original file line number Diff line number Diff line change
@@ -35,18 +35,20 @@ const Def* LowerFor::rewrite(const Def* def) {
{ // construct for
auto [iter, end, step, acc] = for_lam->vars<4>();



// reduce the body to remove the cn parameter
auto nom_body = body->as_nom<Lam>();
auto new_body = nom_body->stub(w, w.cn(w.sigma()), body->dbg());
new_body->set(nom_body->reduce(w.tuple({iter, acc, yield_lam})));
auto new_body = nom_body->stub(w, w.cn(acc->type()), body->dbg())->as<Lam>();
new_body->set(nom_body->reduce(w.tuple({iter, new_body->var(), yield_lam})));

// break
auto if_else_cn = w.cn(w.sigma());
auto if_else_cn = w.cn(acc->type());
auto if_else = w.nom_lam(if_else_cn, nullptr);
if_else->app(false, brk, acc);
if_else->app(false, brk, if_else->var());

auto cmp = core::op(core::icmp::ul, iter, end);
for_lam->branch(false, cmp, new_body, if_else, w.tuple());
for_lam->branch(false, cmp, new_body, if_else, acc);
}

DefArray for_args{for_ax->num_args() - 2, [&](size_t i) { return for_ax->arg(i); }};
Binary file added dialects/autodiff/.DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions dialects/autodiff/autodiff.cpp
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ extern "C" THORIN_EXPORT thorin::DialectInfo thorin_get_dialect_info() {
[](thorin::PipelineBuilder& builder) {
builder.add_opt(110);
builder.add_opt(120);
builder.add_opt(300);
builder.extend_opt_phase(105, [](thorin::PassMan& man) { man.add<thorin::autodiff::AutoDiffEval>(); });
builder.extend_opt_phase(111, [](thorin::PassMan& man) {
// in theory only after partial eval (beta, ...)
1 change: 1 addition & 0 deletions dialects/autodiff/autodiff.h
Original file line number Diff line number Diff line change
@@ -127,6 +127,7 @@ inline const Def* op_autodiff(const Def* fun) {

inline const Def* op_zero(const Def* A) {
World& world = A->world();
A->dump();
return world.app(world.ax<zero>(), A);
}

182 changes: 138 additions & 44 deletions dialects/autodiff/auxiliary/autodiff_aux.cpp
Original file line number Diff line number Diff line change
@@ -44,40 +44,6 @@ const Def* zero_pullback(const Def* E, const Def* A) {
// TODO: rename to op_tangent_type
const Def* tangent_type_fun(const Def* ty) { return ty; }

const Def* equip_mem(const Def* def){
auto& world = def->world();
auto memType = mem::type_mem(world);
if(match<mem::M>(def->proj(0))){
return def;
}

def->dump();
if(def->isa<Sigma>()){
size_t size = def->num_ops() + 1;
DefArray newOps(size, [&](size_t i){
return i == 0 ? memType : def->op(i - 1);
});

return world.sigma(newOps);
}else if(auto pack = def->isa<Pack>()){
auto count = as_lit(pack->shape());
DefArray newOps(count + 1, [&](size_t i){
return i == 0 ? memType : pack->body();
});

return world.sigma(newOps);
}else if(auto pack = def->isa<Arr>()){
auto count = as_lit(pack->shape());
DefArray newOps(count + 1, [&](size_t i){
return i == 0 ? memType : pack->body();
});

return world.sigma(newOps);
}else{
return world.sigma({memType, def});
}
}

/// computes pb type E* -> A*
/// E - type of the expression (return type for a function)
/// A - type of the argument (point of orientation resp. derivative - argument type for partial pullbacks)
@@ -186,7 +152,7 @@ const Def* zero_def(const Def* T) {
auto shape = arr->shape();
auto body = arr->body();
// auto inner_zero = zero_def(body);
auto inner_zero = world.app(world.ax<zero>(), body);
auto inner_zero = op_zero(body);
auto zero_arr = world.pack(shape, inner_zero);
world.DLOG("zero_def for array of shape {} with type {}", shape, body);
world.DLOG("zero_arr: {}", zero_arr);
@@ -206,14 +172,23 @@ const Def* zero_def(const Def* T) {
return zero;
}
} else if (auto sig = T->isa<Sigma>()) {
DefArray ops(sig->ops(), [&](const Def* op) { return world.app(world.ax<zero>(), op); });
DefArray ops(sig->ops(), [&](const Def* op) { return op_zero(op); });
return world.tuple(ops);
}

if(match<mem::M>(T)){
return world.bot(mem::type_mem(world));
}

if(match<mem::Ptr>(T)){
auto lit_zero = world.lit_int_(64, 0);
return core::op_bitcast(T, lit_zero);
}

// or return bot
// assert(0);
// or id => zero T
// return world.app(world.ax<zero>(), T);
T->dump();
return nullptr;
}

@@ -227,6 +202,121 @@ const Def* op_sum(const Def* T, DefArray defs) {

namespace thorin {



const Pi* cn_mem_wrap(const Pi* pi){
auto &world = pi->world();
auto dom = pi->dom();

const Pi* result;
if(pi->ret_pi()){
auto arg = equip_mem(dom->proj(0));
auto ret_pi = cn_mem_wrap(dom->proj(1)->as<Pi>());
result = world.cn({arg, ret_pi});
}else{
auto arg = equip_mem(dom);
result = world.cn({arg});
}

return result;
}

const Def* lam_mem_wrap(const Def* lam){
auto &world = lam->world();
auto type = lam->type()->as<Pi>();
if(!match<mem::M>(type->dom(0)->proj(0))){

auto wrap = cn_mem_wrap(type);

auto mem_lam = world.nom_lam(wrap, world.dbg("memorized_" + lam->name()));
auto lam_return = world.nom_lam(type->ret_pi(), world.dbg("memorized_return_" + lam->name()));

auto mem_vars = mem_lam->var((nat_t)0)->projs();
auto mem = mem_vars[0];
auto vars = lam_return->vars();

auto compound = world.builder().add(mem).add(vars).tuple();
auto compound2 = world.builder().add(mem_vars.skip_front()).add(lam_return).tuple();

lam_return->set_body(
world.app(mem_lam->ret_var(), compound)
);

mem_lam->set_body(
world.app(lam, compound2)
);

mem_lam->set_filter(true);
lam_return->set_filter(true);
return mem_lam;
}

return lam;
}

const Def* get_mem(const Def* def){
auto& world = def->world();

if(match<mem::M>(def->type()->op(0))){
return def->proj(0);
}

return nullptr;
}

const Def* remove_mem(const Def* def){
auto& world = def->world();

if(def->isa<Sigma>() && match<mem::M>(def->op(0))){
auto ops = def->ops();
return world.sigma(ops.skip_front(1));
}else if(match<mem::M>(def->type()->op(0))){
auto ops = def->projs();
return world.tuple(ops.skip_front(1));
}

return def;
}

const Def* equip_mem(const Def* def){
auto& world = def->world();
auto memType = mem::type_mem(world);
if(match<mem::M>(def->proj(0))){
return def;
}

if(def->isa<Sigma>()){
size_t size = def->num_ops() + 1;
DefArray newOps(size, [&](size_t i){
return i == 0 ? memType : def->op(i - 1);
});

return world.sigma(newOps);
}else if(auto pack = def->isa<Pack>()){
auto count = as_lit(pack->shape());
DefArray newOps(count + 1, [&](size_t i){
return i == 0 ? memType : pack->body();
});

return world.sigma(newOps);
}else if(auto pack = def->isa<Arr>()){
auto count = as_lit(pack->shape());
DefArray newOps(count + 1, [&](size_t i){
return i == 0 ? memType : pack->body();
});

return world.sigma(newOps);
}else{
return world.sigma({memType, def});
}
}

const Def* continuation_codom(const Def* E) {
auto pi = E->as<Pi>();
assert(pi != NULL);
return pi->dom(1)->as<Pi>()->dom();
}

bool is_continuation_type(const Def* E) {
if (auto pi = E->isa<Pi>()) { return pi->codom()->isa<Bot>(); }
return false;
@@ -262,12 +352,6 @@ const Def* continuation_dom(const Def* E) {
return pi->dom(0);
}

const Def* continuation_codom(const Def* E) {
auto pi = E->as<Pi>();
assert(pi != NULL);
return pi->dom(1)->as<Pi>()->dom();
}

/// high level
/// f: B -> C
/// g: A -> B
@@ -287,12 +371,22 @@ const Def* compose_continuation(const Def* f, const Def* g) {
world.DLOG("compose g (A->B): {} : {}", g, g->type());
assert(is_returning_continuation(f));
assert(is_returning_continuation(g));

f = lam_mem_wrap(f);
g = lam_mem_wrap(g);

auto F = f->type()->as<Pi>();
auto G = g->type()->as<Pi>();

auto is_mem = match<mem::M>(F->dom(0)->proj(0));

F->dump();
G->dump();

auto A = continuation_dom(G);
auto B = continuation_codom(G);

auto B_hat = continuation_dom(F);
auto C = continuation_codom(F);
// better handled by application type checks
// auto B2 = continuation_dom(F);
@@ -307,7 +401,7 @@ const Def* compose_continuation(const Def* f, const Def* g) {
auto H = world.cn({A, world.cn(C)});
auto Hcont = world.cn(B);

auto h = world.nom_lam(H, world.dbg("comp_" + f->name() + "_" + g->name()));
auto h = world.nom_lam(H, world.dbg("comp_" + f->name() + "_" + g->name()));
auto hcont = world.nom_lam(Hcont, world.dbg("comp_" + f->name() + "_" + g->name() + "_cont"));

h->app(true, g, {h->var((nat_t)0), hcont});
7 changes: 7 additions & 0 deletions dialects/autodiff/auxiliary/autodiff_aux.h
Original file line number Diff line number Diff line change
@@ -22,6 +22,13 @@ const Def* op_sum(const Def* T, DefArray defs);

namespace thorin {


const Def* get_mem(const Def* def);
const Def* equip_mem(const Def* def);
const Def* remove_mem(const Def* def);
const Def* lam_mem_wrap(const Def* lam);
const Pi* cn_mem_wrap(const Pi* pi);

bool is_continuation_type(const Def* E);
bool is_continuation(const Def* e);
// TODO: change name to returning_continuation
208 changes: 139 additions & 69 deletions dialects/autodiff/auxiliary/autodiff_rewrite_inner.cpp
Original file line number Diff line number Diff line change
@@ -106,6 +106,12 @@ const Def* AutoDiffEval::augment_extract(const Extract* ext, Lam* f, Lam* f_diff
auto aug_tuple = augment(tuple, f, f_diff);
auto aug_index = augment(index, f, f_diff);

auto aug_ext = world.extract(aug_tuple, aug_index);

if(match<mem::M>(ext->type())){
return aug_ext;
}

// TODO: if not exists use:
// e:T, b:B
// b = e#i
@@ -125,16 +131,40 @@ const Def* AutoDiffEval::augment_extract(const Extract* ext, Lam* f, Lam* f_diff
auto pb_fun = world.nom_lam(pb_ty, world.dbg("extract_pb"));
world.DLOG("Pullback: {} : {}", pb_fun, pb_fun->type());
auto pb_tangent = pb_fun->var((nat_t)0, world.dbg("s"));
auto tuple_tan = world.insert(op_zero(aug_tuple->type()), aug_index, pb_tangent, world.dbg("tup_s"));
pb_fun->app(true, tuple_pb,
{
tuple_tan,
pb_fun->var(1) // ret_var but make sure to select correct one
});

auto aug_tuple_type = aug_tuple->type();

auto rm_test = remove_mem(aug_tuple_type);

auto mem = get_mem(pb_tangent);
pb_tangent = remove_mem(pb_tangent);

//auto test = world.app(world.ax<mem::remem>(), mem);


//core::op(core::wrap::sub, core::WMode::none, );

auto init = op_zero(rm_test);
const Def* tuple_tan;
if(rm_test != aug_tuple_type){
auto index_lit = as_lit(aug_index);
auto arity_lit = as_lit(aug_index->type()->as<Idx>()->size());

auto aug_index = world.lit_idx(index_lit - 1, arity_lit - 1);
if(arity_lit == 2){
tuple_tan = pb_tangent;
}else if(arity_lit >= 2){
tuple_tan = world.insert(init, aug_index, pb_tangent, world.dbg("tup_s"));
}
}else{
tuple_tan = world.insert(init, aug_index, pb_tangent, world.dbg("tup_s"));
}

auto arg = world.builder().add(mem).add(tuple_tan).tuple();

pb_fun->app(true, tuple_pb, {arg, pb_fun->var(1)});
pb = pb_fun;
}

auto aug_ext = world.extract(aug_tuple, aug_index);
partial_pullback[aug_ext] = pb;

return aug_ext;
@@ -149,10 +179,16 @@ const Def* AutoDiffEval::augment_tuple(const Tuple* tup, Lam* f, Lam* f_diff) {

auto projs = tup->projs();
// TODO: should use ops instead?
DefArray aug_ops(projs.skip_front(isMem), [&](const Def* op) { return augment(op, f, f_diff); });
DefArray aug_ops(projs, [&](const Def* op) {
return augment(op, f, f_diff);
});
auto aug_tup = world.tuple(aug_ops);

DefArray pbs(aug_ops, [&](const Def* op) { return partial_pullback[op]; });
DefArray pbs(aug_ops.skip_front(isMem), [&](const Def* op) {
auto pb = partial_pullback[op];
assert(pb);
return pb;
});
world.DLOG("tuple pbs {,}", pbs);
// create shadow pb
auto shadow_pb = world.tuple(pbs);
@@ -173,18 +209,21 @@ const Def* AutoDiffEval::augment_tuple(const Tuple* tup, Lam* f, Lam* f_diff) {
// TODO: move op_cps2ds to direct dialect and merge then
auto T = tangent_type_fun(f_arg_ty);
auto mem = pb->var((nat_t)0)->proj(0);
auto sum = world.app(world.ax<zero>(), T);
//auto T_without_mem = remove_mem(T);
auto sum = autodiff_zero(mem, f);

for(size_t i = 0 ; i < pbs.size() ; i++){
auto re = direct::op_cps2ds_dep(pbs[i]);
auto codom = re->type()->as<Pi>()->codom();
auto isMem = match<mem::M>(codom->proj(0));
auto extract = world.extract(pb_tangent, i + 1);
auto app = world.app(re, {mem, extract});
mem = world.extract(app, (nat_t)0);
mem->dump();
mem->dump();
sum = world.app(world.app(world.ax<add>(), T), {sum, app});
}

sum = world.insert(sum, (u64)0, mem);

//DefArray tangents(pbs.size(), [&](nat_t i) { return world.app(direct::op_cps2ds_dep(pbs[i]), world.extract(pb_tangent, i)); });
pb->app(true, pb->var(1), sum);
partial_pullback[aug_tup] = pb;
@@ -241,6 +280,7 @@ const Def* AutoDiffEval::augment_app(const App* app, Lam* f, Lam* f_diff) {
auto callee = app->callee();
auto arg = app->arg();

arg->type()->dump();
auto aug_arg = augment(arg, f, f_diff);
auto aug_callee = augment(callee, f, f_diff);
// auto arg_ppb = partial_pullback[aug_arg];
@@ -406,53 +446,6 @@ const Def* AutoDiffEval::augment_app(const App* app, Lam* f, Lam* f_diff) {
auto r_pb = c1->var(1);
c1->app(true, aug_cont, {res, compose_continuation(e_pb, r_pb)});

// auto X = continuation_codom(g->type());
// // auto A = f_diff->var((nat_t)0);
// auto A = f_diff->type()->dom(0);
// // auto E = g_deriv->var((nat_t)0);
// auto E = g_deriv->type()->as<Pi>()->dom(0);
// world.DLOG("A (var f): {}", A);
// world.DLOG("E (var g): {}", E);
// world.DLOG("X (out g = out f): {}", X);
// // auto ret_g_deriv = g_deriv->var(1);
// auto ret_g_deriv_ty = g_deriv->type()->as<Pi>()->dom(1);
// world.DLOG("ret_g_deriv_ty: {} ", ret_g_deriv_ty);
// // auto ret_f_deriv=f_diff->var(1);
// // world.DLOG("ret_f_deriv: {} : {}", ret_f_deriv, ret_f_deriv->type());

// // TODO: better debug names
// auto c1_ty=ret_g_deriv_ty->as<Pi>();
// world.DLOG("c1_ty: (cn[X, cn[X+, cn E+]]) {}", c1_ty);
// auto c2_ty=aug_cont->type()->as<Pi>()->dom(2)->as<Pi>();
// world.DLOG("c2_ty: (cn[X+, cn A+]) {}", c2_ty);
// auto c3_ty=c1_ty->dom(2)->as<Pi>()->dom(2)->as<Pi>();
// world.DLOG("c3_ty: (cn E+) {}", c3_ty);
// auto c1 = world.nom_lam(c1_ty,world.dbg("c1"));
// auto c2 = world.nom_lam(c2_ty,world.dbg("c2"));
// auto c3 = world.nom_lam(c3_ty,world.dbg("c3"));

// c1->app(true,
// aug_cont,
// {
// c1->var((nat_t)0),
// c2
// }
// );
// c2->app(true,
// c1->var(1),
// {
// c2->var((nat_t)0),
// c3
// }
// );
// c3->app(true,
// e_pb,
// {
// c3->var((nat_t)0),
// c2->var(1)
// }
// );

auto aug_app = world.app(aug_callee, {real_aug_args, c1});
world.DLOG("aug_app: {} : {}", aug_app, aug_app->type());

@@ -471,25 +464,102 @@ const Def* AutoDiffEval::augment_app(const App* app, Lam* f, Lam* f_diff) {
assert(false && "should not be reached");
}

static const Def* tuple_of_types(const Def* t) {
auto& world = t->world();
if (auto sigma = t->isa<Sigma>()) return world.tuple(sigma->ops());
if (auto arr = t->isa<Arr>()) return world.pack(arr->shape(), arr->body());
return t;
}

const Def* op_lea(const Def* ptr, const Def* index, const Def* dbg = {}) {
auto &world = ptr->world();
auto [pointee, addr_space] = match<mem::Ptr>(ptr->type())->args<2>();
auto Ts = tuple_of_types(pointee);
return world.app(world.app(world.ax<mem::lea>(), {pointee->arity(), Ts, addr_space}), {ptr, index}, dbg);
}

const Def* op_load(const Def* mem, const Def* ptr, const Def* dbg = {}) {
mem->dump();
auto &world = mem->world();
auto [T, a] = match<mem::Ptr>(ptr->type())->args<2>();
return world.app(world.app(world.ax<mem::load>(), {T, a}), {mem, ptr}, dbg);
}

const Def* op_store(const Def* mem, const Def* ptr, const Def* val, const Def* dbg = {}) {
auto &world = mem->world();
auto [T, a] = match<mem::Ptr>(ptr->type())->args<2>();
return world.app(world.app(world.ax<mem::store>(), {T, a}), {mem, ptr, val}, dbg);
}

const Def* AutoDiffEval::augment_lea(const App* lea, Lam* f, Lam* f_diff) {
auto [ptr, as] = augment(lea->arg(), f, f_diff)->projs<2>();
auto &world = lea->world();

auto [arr_ptr, idx] = lea->arg()->projs<2>();

auto aug_ptr = augment(arr_ptr, f, f_diff);
auto aug_idx = augment(idx, f, f_diff);

auto aug_arg = augment(lea->arg(), f, f_diff);

//auto [ptr, idx] = aug_arg->projs<2>();
//auto ptr_def = ptr->as<App>();

auto ptr_ty = lea->type()->as<App>();
auto elem_ty = ptr_ty->arg(0);

auto [arg_ty, ret_pi] = f->type()->doms<2>();

auto pb_type = pullback_type(elem_ty, arg_ty);

auto lam = world.nom_lam(pb_type, world.dbg("test"));

auto pb = partial_pullback[ptr];
auto pb_arg = lam->var((nat_t)0);
auto pb_ret = lam->var((nat_t)1);
auto pb_mem = pb_arg->proj(0);
auto pb_s = pb_arg->proj(1);

lea->dump();
pb->dump();
lea->dump();
return lea;
auto gradient_array = shadow_gradient_array[aug_ptr];

auto gradient_lea = op_lea(gradient_array, aug_idx, world.dbg("gradient_array_lea"));
auto [gradient_mem, gradient] = op_load(pb_mem, gradient_lea, world.dbg("gradient_array_load"))->projs<2>();
//auto [pb_s_mem, pb_s] = op_load(gradient_mem, pb_s_ptr)->projs<2>();
auto add = core::op(core::wrap::add, core::WMode::none, gradient, pb_s);
auto store_mem = op_store(gradient_mem, gradient_lea, add, world.dbg("add_to_gradient") );

//auto zero_ret = world.app(world.ax<zero>(), arg_ty);

auto default_zero = autodiff_zero(store_mem, f);
lam->set_body(world.app(pb_ret, {default_zero}));
lam->set_filter(true);

auto aug_lea = world.app(lea->callee(), aug_arg);
partial_pullback[aug_lea] = lam;
return aug_lea;
}

const Def* AutoDiffEval::augment_load(const App* load, Lam* f, Lam* f_diff) {
auto &world = load->world();

auto aug_arg = augment(load->arg(), f, f_diff);
auto [mem, lea] = load->arg()->projs<2>();

auto aug_mem = augment(mem, f, f_diff);
auto aug_lea = augment(lea, f, f_diff);

auto aug_load = op_load(aug_mem, aug_lea, world.dbg("aug_load"));

auto pb_lea = partial_pullback[aug_lea];
partial_pullback[aug_load] = pb_lea;

/*auto aug_arg = augment(load->arg(), f, f_diff);
auto pb = partial_pullback[aug_arg];
auto aug_load = world.app(load->callee(), aug_arg);
//auto [mem, ptr] = aug_arg->proj<2>();
load->dump();
load->dump();
return load;
load->dump();*/
return aug_load;
}

const Def* AutoDiffEval::augment_store(const App* store, Lam* f, Lam* f_diff) {
115 changes: 108 additions & 7 deletions dialects/autodiff/auxiliary/autodiff_rewrite_toplevel.cpp
Original file line number Diff line number Diff line change
@@ -2,6 +2,9 @@
#include "dialects/autodiff/auxiliary/autodiff_aux.h"
#include "dialects/autodiff/passes/autodiff_eval.h"

#include "dialects/mem/autogen.h"
#include "dialects/mem/mem.h"

namespace thorin::autodiff {

// void AutoDiffEval::create_shadow_id_pb(const Def* def) {
@@ -41,13 +44,74 @@ namespace thorin::autodiff {
// // // no structure => needs no structure pullback (base case also needs no str. pb because it is shallow)
// }

const Def* hoa(const Def* def, const Def* arg_ty){
if(auto arr = def->isa<Arr>()){
auto &world = def->world();
auto shape = arr->shape();
auto body = hoa(arr->body(), arg_ty);
return world.arr(shape, body);
}

auto pb_ty = pullback_type(def, arg_ty);
return pb_ty;
}

const Def* AutoDiffEval::autodiff_zero(const Def* mem, Lam* f) {
return autodiff_zero(mem, augmented[f->var()]->proj(0));
}

const Def* AutoDiffEval::autodiff_zero(const Def* mem, const Def* def) {
auto& world = def->world();

auto ty = def->type();

if (auto tup = def->isa<Tuple>()) {
DefArray ops(tup->ops(), [&](const Def* op) { return autodiff_zero(mem, op); });
return world.tuple(ops);
}
/*
if (auto var = def->isa<Var>()) {
DefArray ops(var->projs(), [&](const Def* op) { return autodiff_zero(op); });
return world.tuple(ops);
}*/

if (auto app = ty->isa<App>()) {
auto callee = app->callee();
// auto args = app->args();
world.DLOG("app callee: {} : {} <{}>", callee, callee->type(), callee->node_name());
// TODO: can you directly match Tag::Int?
if (callee->isa<Idx>()) {
// auto size = app->arg(0);
auto zero = world.lit_idx(ty, 0, world.dbg("zero"));
// world.DLOG("zero_def for int of size {} is {}", size, zero);
world.DLOG("zero_def for int is {}", zero);
return zero;
}
}

if(match<mem::M>(ty)){
return mem;
}

if(match<mem::Ptr>(ty)){
return shadow_gradient_array[def];
}

def->dump();
def->type()->dump();
return nullptr;
}

/// side effect: register pullback
const Def* AutoDiffEval::derive_(const Def* def) {
auto& world = def->world();
if (auto lam = def->isa_nom<Lam>()) {
world.DLOG("Derive lambda: {}", def);
auto deriv_ty = autodiff_type_fun_pi(lam->type());
auto deriv = world.nom_lam(deriv_ty, world.dbg(lam->name() + "_deriv"));
auto memType = mem::type_mem(world);
auto deriv_inner = world.nom_lam(world.cn({memType}), world.dbg(lam->name() + "_deriv_inner"));


// pre register derivative
// needed for recursion
@@ -78,8 +142,6 @@ const Def* AutoDiffEval::derive_(const Def* def) {
// id_pb_scalar
// );

auto deriv_all_args = deriv->var();
const Def* deriv_arg = deriv->var((nat_t)0, world.dbg("arg"));
// R auto deriv_ret = deriv->var((nat_t)1, world.dbg("ret"));
// R partial_pullback[deriv_arg] = id_pb;

@@ -102,9 +164,43 @@ const Def* AutoDiffEval::derive_(const Def* def) {
//
// but the DS/CPS special case has to be handled separately

const Def* var = deriv->var();

if(match<mem::M>(deriv->dom((nat_t) 0)->proj(0))){
auto vars = var->projs();
auto arg = vars[0];
current_mem = arg->proj(0);
auto args = arg->projs();
args[0] = deriv_inner->var();
arg = world.tuple(args);

for(auto var : deriv->var((nat_t)0)->projs()){
auto var_ty = var->type();
if(auto ptr = match<mem::Ptr>(var_ty)){

auto [mem2, gradient_ptr] = mem::op_malloc(ptr->arg(0), current_mem, world.dbg("gradient_arr"))->projs<2>();
auto pb_ty = hoa(ptr->arg(0), arg_ty);
auto [mem3, pb_ptr] = mem::op_malloc(pb_ty, mem2, world.dbg("pullback_arr"))->projs<2>();
current_mem = mem3;

shadow_gradient_array[var] = gradient_ptr;
shadow_pullback_array[var] = pb_ptr;
var_ty->dump();
}
}

vars[0] = arg;
var = world.tuple(vars);
}

augmented[lam->var()] = var;

//auto deriv_all_args = deriv->var();

// TODO: check identity
// could use identity tangent(arg_ty) = tangent(augment(arg_ty))
// with deriv_arg->type() = augment(arg_ty)
const Def* deriv_arg = deriv->var((nat_t)0, world.dbg("arg"));
auto arg_id_pb = id_pullback(arg_ty);
partial_pullback[deriv_arg] = arg_id_pb;
// set no pullback to all_arg and return
@@ -115,11 +211,11 @@ const Def* AutoDiffEval::derive_(const Def* def) {
auto ret_pb = zero_pullback(lam->var(1)->type(), arg_ty);
partial_pullback[ret_var] = ret_pb;

shadow_pullback[deriv_all_args] = world.tuple({arg_id_pb, ret_pb});
shadow_pullback[var] = world.tuple({arg_id_pb, ret_pb});
world.DLOG("pullback for argument {} : {} is {} : {}", deriv_arg, deriv_arg->type(), arg_id_pb,
arg_id_pb->type());
world.DLOG("args shadow pb is {} : {}", shadow_pullback[deriv_all_args],
shadow_pullback[deriv_all_args]->type());
world.DLOG("args shadow pb is {} : {}", shadow_pullback[var],
shadow_pullback[var]->type());

// TODO: remove as this is subsumed by lam->deriv
// R const Def* lam_ret = lam->var(1, world.dbg("ret"));
@@ -136,7 +232,7 @@ const Def* AutoDiffEval::derive_(const Def* def) {
// TODO: transively remove
// arguments (maybe not necessary)
// auto src_arg = lam->var()
augmented[lam->var()] = deriv->var();

world.DLOG("Associate vars {} with {}", lam->var(), deriv->var());

// already contains the correct application of
@@ -146,15 +242,20 @@ const Def* AutoDiffEval::derive_(const Def* def) {
// this is needed for continuations (without closure conversion)
// but also essentially for the return continuation


// reminder of types:
// expression e: B
// implicit: e_fun: A -> B
// partial pullback: e*: B* -> A*
// partial derivative: e': B' × (B* -> A*)
// implicit: e'_fun: A' -> B' × (B* -> A*)
auto new_body = augment(lam->body(), lam, deriv);
deriv_inner->set_filter(true);
deriv_inner->set_body(new_body);

deriv->set_filter(true);
deriv->set_body(new_body);
deriv->set_body(world.app(deriv_inner, {current_mem}));
deriv->dump(10);

return deriv;
}
2 changes: 1 addition & 1 deletion dialects/autodiff/normalizers.cpp
Original file line number Diff line number Diff line change
@@ -136,7 +136,7 @@ const Def* normalize_sum(const Def* type, const Def* callee, const Def* arg, con
// R if(val == 1)
// R return arg;
DefArray args = arg->projs(val);
auto sum = world.app(world.ax<zero>(), T);
auto sum = op_zero(T);
// would also be handled by add zero
if (val >= 1) { sum = args[0]; }
for (auto i = 1; i < val; ++i) { sum = world.app(world.app(world.ax<add>(), T), {sum, args[i]}); }
9 changes: 2 additions & 7 deletions dialects/autodiff/passes/autodiff_eval.cpp
Original file line number Diff line number Diff line change
@@ -35,14 +35,9 @@ const Def* AutoDiffEval::rewrite(const Def* def) {
world().DLOG("found a autodiff::autodiff of {}", arg);
// world.DLOG("found a autodiff::autodiff {} to {}",callee,arg);

if (arg->isa<Lam>()) {
// world.DLOG("found a autodiff::autodiff of a lambda");
return derive(arg);
}

assert(arg->isa<Lam>());
// TODO: handle operators analogous

assert(0);
def = derive(arg);
return def;
}

7 changes: 7 additions & 0 deletions dialects/autodiff/passes/autodiff_eval.h
Original file line number Diff line number Diff line change
@@ -40,6 +40,9 @@ class AutoDiffEval : public RWPass<AutoDiffEval, Lam> {
const Def* augment_load(const App*, Lam*, Lam*);
const Def* augment_store(const App*, Lam*, Lam*);

const Def* autodiff_zero(const Def* mem, Lam* f);
const Def* autodiff_zero(const Def* mem, const Def* def);

/// fills partial_pullback and shadow/structure pullback maps
void create_shadow_id_pb(const Def*);

@@ -73,6 +76,10 @@ class AutoDiffEval : public RWPass<AutoDiffEval, Lam> {

// TODO: remove?
Def2Def app_pb;

Def2Def shadow_pullback_array;
Def2Def shadow_gradient_array;
const Def* current_mem;
};

} // namespace thorin::autodiff
Binary file added dialects/clos/.DS_Store
Binary file not shown.
Binary file added dialects/core/.DS_Store
Binary file not shown.
16 changes: 14 additions & 2 deletions dialects/core/be/ll/ll.cpp
Original file line number Diff line number Diff line change
@@ -221,7 +221,8 @@ std::string Emitter::convert_ret_pi(const Pi* pi) {
void Emitter::start() {
Super::start();

ostream() << "declare i8* @malloc(i64)" << '\n'; // HACK
ostream() << "declare i8* @calloc(i64)" << '\n'; // HACK
ostream() << "declare void @free(i8*)" << '\n'; // HACK
// SJLJ intrinsics (GLIBC Versions)
ostream() << "declare i32 @_setjmp(i8*) returns_twice" << '\n';
ostream() << "declare void @longjmp(i8*, i32) noreturn" << '\n';
@@ -727,8 +728,16 @@ std::string Emitter::emit_bb(BB& bb, const Def* def) {
emit_unsafe(malloc->arg(0));
auto size = emit(malloc->arg(1));
auto ptr_t = convert(force<mem::Ptr>(def->proj(1)->type()));
bb.assign(name + ".i8", "call i8* @malloc(i64 {})", size);
bb.assign(name + ".i8", "call i8* @calloc(i64 {})", size);
return bb.assign(name, "bitcast i8* {} to {}", name + ".i8", ptr_t);
} else if (auto free = match<mem::free>(def)) {
emit_unsafe(free->arg(0));
auto ptr = emit(free->arg(1));
auto ptr_t = convert(force<mem::Ptr>(free->arg(1)->type()));

bb.assign(name + ".i8", "bitcast {} {} to i8*", ptr_t, ptr);
bb.tail("call void @free(i8* {})", name + ".i8");
return {};
} else if (auto mslot = match<mem::mslot>(def)) {
emit_unsafe(mslot->arg(0));
// TODO array with size
@@ -817,6 +826,9 @@ std::string Emitter::emit_bb(BB& bb, const Def* def) {
print(vars_decls_, "{} = global {} {}\n", name, convert(pointee), init);
return globals_[global] = name;
}
def->dump();
def->dump();
def->dump();

unreachable(); // not yet implemented
}
Binary file added dialects/direct/.DS_Store
Binary file not shown.
7 changes: 7 additions & 0 deletions dialects/mem/mem.h
Original file line number Diff line number Diff line change
@@ -109,6 +109,13 @@ inline const Def* op_malloc(const Def* type, const Def* mem, const Def* dbg) {
return w.app(w.app(w.ax<malloc>(), {type, w.lit_nat_0()}), {mem, size}, dbg);
}

inline const Def* op_free(const Def* type, const Def* mem, const Def* ptr, const Def* dbg) {
World& w = type->world();
auto ptr_ty = match<Ptr>(ptr->type())->as<App>();
auto pointee = ptr_ty->arg(0_s);
return w.app(w.app(w.ax<free>(), {pointee, w.lit_nat_0()}), {mem, ptr}, dbg);
}

inline const Def* op_mslot(const Def* type, const Def* mem, const Def* id, const Def* dbg) {
World& w = type->world();
auto size = core::op(core::trait::size, type);
5 changes: 5 additions & 0 deletions dialects/mem/mem.thorin
Original file line number Diff line number Diff line change
@@ -52,6 +52,11 @@
/// The difference to %mem.alloc is that the `size` is automatically inferred.
.ax %mem.malloc: Π [T: *, as: .Nat] -> [%mem.M, .Nat] -> [%mem.M, %mem.Ptr(T, as)];
///
/// ### %mem.free
///
/// Frees memory of type `T` in address space `as`.
.ax %mem.free: Π [T: *, as: .Nat] -> [%mem.M, %mem.Ptr(T, as)] -> [%mem.M];
///
/// ### %mem.mslot
///
/// Reserves a memory slot for type `T` in address space `as`.
Binary file added lit/.DS_Store
Binary file not shown.
53 changes: 42 additions & 11 deletions lit/autodiff/arr.thorin
Original file line number Diff line number Diff line change
@@ -7,35 +7,66 @@
.import core;
.import autodiff;
.import mem;
.import affine;

.let i32 = .Idx 4294967296;
.let arr_size = 100:.Nat;

.cn printInteger [mem: %mem.M, val: i32, return : .Cn [%mem.M]];

.cn f [[mem : %mem.M, a: %mem.Ptr («100:.Nat; i32», 0)], ret: .Cn [mem : %mem.M, i32]] = {
.let lea = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (a, 1:(.Idx 100));
.let (load_mem, load_val) = %mem.load (i32, 0) (mem, lea);
ret (load_mem, load_val)
.let lea1 = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (a, 0:(.Idx 100));
.let lea2 = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (a, 1:(.Idx 100));
.let (load_mem1, load_val1) = %mem.load (i32, 0) (mem, lea1);
.let (load_mem2, load_val2) = %mem.load (i32, 0) (load_mem1, lea2);
.let scaled = %core.wrap.mul (0:.Nat, 4294967296:.Nat) (load_val1, load_val2);
ret (load_mem2, scaled)
};

.cn init [mem: %mem.M, arr : %mem.Ptr (<<100:.Nat; i32>>, 0:.Nat), offset : i32, return : .Cn [%mem.M]] = {
.cn for_body [i : i32, mem : %mem.M, continue : .Cn [%mem.M]] = {
.let idx_100_i = %core.bitcast ((.Idx 100), i32) i;
.let lea = %mem.lea (arr_size, <arr_size; i32>, 0) (arr, idx_100_i);
.let add = %core.wrap.add (0, 4294967296) (offset, i);
.let store_mem = %mem.store (i32, 0) (mem, lea, add);
continue(store_mem)
};
%affine.For (4294967296, 1, (%mem.M)) (0:i32, 100:i32, 1:i32, (mem), for_body, return)
};

.cn .extern main [mem : %mem.M, argc : i32, argv : %mem.Ptr (%mem.Ptr (.Idx 256, 0:.Nat), 0:.Nat), return : .Cn [%mem.M, i32]] = {

.cn ret_cont [[mem : %mem.M, a: i32], pb:.Cn[[%mem.M, i32],.Cn[[%mem.M, %mem.Ptr («100:.Nat; i32», 0)]]]] = {
.cn pb_ret_cont [mem : %mem.M, a: %mem.Ptr («100:.Nat; i32», 0)] = {
return (mem, 99:i32)
.let lea1 = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (a, 0:(.Idx 100));

.let (load_mem1, load_val) = %mem.load (i32, 0) (mem, lea1);

.cn print_integer_callback2 [mem : %mem.M] = {
.let free_mem = %mem.free («100:.Nat; i32», 0) (mem, a);
return (free_mem, 0:i32)
};

.cn print_integer_callback [mem : %mem.M] = {
.let lea2 = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (a, 1:(.Idx 100));
.let (load_mem2, load_val) = %mem.load (i32, 0) (load_mem1, lea2);
printInteger ( load_mem2, load_val, print_integer_callback2 )
};

printInteger ( load_mem1, load_val, print_integer_callback )
};
pb((mem, a), pb_ret_cont)
pb((mem, 1:i32), pb_ret_cont)
};

.let (alloc_mem, alloc_val) = %mem.alloc («100:.Nat; i32», 0) mem;

//.let (alloc_mem, alloc_val) = %mem.malloc (i32, 0) (mem, 4);
.let lea = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (alloc_val, 1:(.Idx 100));
.let store = %mem.store (i32, 0) (alloc_mem, lea, argc);

.let f_diff = %autodiff.autodiff (.Cn [[%mem.M, %mem.Ptr («100:.Nat; i32», 0)],.Cn[%mem.M, i32]]) f;

f_diff ((store, alloc_val),ret_cont)
.cn init_callback [mem : %mem.M] = {
f_diff ((mem, alloc_val),ret_cont)
};

//f ((store, alloc_val),return)
init(alloc_mem, alloc_val, 0:i32, init_callback)
};


79 changes: 79 additions & 0 deletions lit/autodiff/arr2.thorin
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s

// a call to a autodiff style function
// ./build/bin/thorin -d debug -d direct -d autodiff ./lit/autodiff/multiply_autodiff.thorin --output-thorin - -VVVV

.import core;
.import autodiff;
.import mem;
.import affine;

.let i32 = .Idx 4294967296;

.cn printInteger [mem: %mem.M, val: i32, return : .Cn [%mem.M]];

.cn f [[mem : %mem.M, a: %mem.Ptr («100:.Nat; i32», 0), b: %mem.Ptr («100:.Nat; i32», 0)], ret: .Cn [mem : %mem.M, i32]] = {
.cn for_exit acc :: [mem : %mem.M, i32, i32] = {
.let lea = %mem.lea (arr_size, <arr_size; i32>, 0) (ptr, %core.conv.u2u (arr_size, 4294967296) (%core.wrap.sub (0, 4294967296) (argc, 4:i32)));
.let (load_mem, val) = %mem.load (i32, 0) (mem, lea);
ret (load_mem2, scaled)
};

.cn for_body [i : i32, [mem : %mem.M, acc: i32], continue : .Cn [%mem.M]] = {
.let lea1 = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (a, i);
.let lea2 = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (a, i);
.let (load_mem1, load_val1) = %mem.load (i32, 0) (mem, lea1);
.let (load_mem2, load_val2) = %mem.load (i32, 0) (load_mem1, lea2);
.let scaled = %core.wrap.mul (0:.Nat, 4294967296:.Nat) (load_val1, load_val2);
continue (load_mem2, scaled)
};

%affine.For (4294967296, 3, (%mem.M)) (0:(.Idx 100), 100:(.Idx 100), 1:(.Idx 100), (alloc_mem, 0:i32), for_body, for_exit)
};

.cn .extern main [mem : %mem.M, argc : i32, argv : %mem.Ptr (%mem.Ptr (.Idx 256, 0:.Nat), 0:.Nat), return : .Cn [%mem.M, i32]] = {

.cn ret_cont [[mem : %mem.M, a: i32], pb:.Cn[[%mem.M, i32],.Cn[[%mem.M, %mem.Ptr («100:.Nat; i32», 0)]]]] = {
.cn pb_ret_cont [mem : %mem.M, a: %mem.Ptr («100:.Nat; i32», 0)] = {
.let lea1 = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (a, 0:(.Idx 100));

.let (load_mem1, load_val) = %mem.load (i32, 0) (mem, lea1);

.cn print_integer_callback2 [mem : %mem.M] = {
.let free_mem = %mem.free («100:.Nat; i32», 0) (mem, a);
return (free_mem, 0:i32)
};

.cn print_integer_callback [mem : %mem.M] = {
.let lea2 = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (a, 1:(.Idx 100));
.let (load_mem2, load_val) = %mem.load (i32, 0) (load_mem1, lea2);
printInteger ( load_mem2, load_val, print_integer_callback2 )
};

printInteger ( load_mem1, load_val, print_integer_callback )
};
pb((mem, 1:i32), pb_ret_cont)
};

.let (alloc_mem, alloc_val) = %mem.alloc («100:.Nat; i32», 0) mem;

.let lea1 = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (alloc_val, 0:(.Idx 100));
.let lea2 = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (alloc_val, 1:(.Idx 100));
.let store_mem = %mem.store (i32, 0) (alloc_mem, lea1, 11:i32);
.let store_mem2 = %mem.store (i32, 0) (store_mem, lea2, 42:i32);

.let f_diff = %autodiff.autodiff (.Cn [[%mem.M, %mem.Ptr («100:.Nat; i32», 0)],.Cn[%mem.M, i32]]) f;

f_diff ((store_mem2, alloc_val),ret_cont)
};



// CHECK-DAG: .cn .extern main _{{[0-9_]+}}::[mem_[[memId:[_0-9]*]]: %mem.M, (i32), %mem.Ptr (%mem.Ptr ((.Idx 256), 0:.Nat), 0:.Nat), return_[[returnId:[_0-9]*]]: .Cn [%mem.M, (i32)]] = {
// CHECK-DAG: _[[appId:[_0-9]*]]: ⊥:★ = return_[[returnEtaId:[_0-9]*]] (mem_[[memId]], 42:(i32));
// CHECK-DAG: _[[appId]]

// CHECK-DAG: return_[[returnEtaId]] _[[returnEtaVarId:[0-9_]+]]: [%mem.M, (i32)] = {
// CHECK-DAG: return_[[retAppId:[_0-9]*]]: ⊥:★ = return_[[returnId]] _[[returnEtaVarId]];
// CHECK-DAG: return_[[retAppId]]
2 changes: 0 additions & 2 deletions lit/autodiff/simple_autodiff.thorin
Original file line number Diff line number Diff line change
@@ -11,8 +11,6 @@
.let I32 = .Idx 4294967296;

.cn f [[mem : %mem.M, a:I32], ret: .Cn [%mem.M, I32]] = {
// is pack
// .let b = %core.wrap.mul (0:.Nat, 4294967296:.Nat) (a, a);
.let b = %core.wrap.mul (0:.Nat, 4294967296:.Nat) (2:I32, a);
ret (mem, b)
};
49 changes: 49 additions & 0 deletions lit/autodiff/test.thorin
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: rm -f %t.ll ; \
// RUN: %thorin -d autodiff %s --output-ll %t.ll --output-thorin - | FileCheck %s

// a call to a autodiff style function
// ./build/bin/thorin -d debug -d direct -d autodiff ./lit/autodiff/multiply_autodiff.thorin --output-thorin - -VVVV

.import core;
.import autodiff;
.import mem;

.let i32 = .Idx 4294967296;

.cn f [[mem : %mem.M, a: %mem.Ptr («100:.Nat; i32», 0)], ret: .Cn [mem : %mem.M, i32]] = {
.let lea = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (a, 1:(.Idx 100));
.let (load_mem, load_val) = %mem.load (i32, 0) (mem, lea);
.let scaled = %core.wrap.mul (0:.Nat, 4294967296:.Nat) (100:i32, load_val);
ret (load_mem, scaled)
};

.cn .extern main [mem : %mem.M, argc : i32, argv : %mem.Ptr (%mem.Ptr (.Idx 256, 0:.Nat), 0:.Nat), return : .Cn [%mem.M, i32]] = {

.cn ret_cont [[mem : %mem.M, a: i32], pb:.Cn[[%mem.M, i32],.Cn[[%mem.M, %mem.Ptr («100:.Nat; i32», 0)]]]] = {
.cn pb_ret_cont [mem : %mem.M, a: %mem.Ptr («100:.Nat; i32», 0)] = {
.let lea = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (a, 1:(.Idx 100));
.let (load_mem, load_val) = %mem.load (i32, 0) (mem, lea);
return (load_mem, load_val)
};
pb((mem, a), pb_ret_cont)
};

.let (alloc_mem, alloc_val) = %mem.alloc («100:.Nat; i32», 0) mem;

.let lea = %mem.lea (100:.Nat, <100:.Nat; i32>, 0) (alloc_val, 1:(.Idx 100));
.let store = %mem.store (i32, 0) (alloc_mem, lea, argc);

.let f_diff = %autodiff.autodiff (.Cn [[%mem.M, %mem.Ptr («100:.Nat; i32», 0)],.Cn[%mem.M, i32]]) f;

f_diff ((store, alloc_val),ret_cont)
};



// CHECK-DAG: .cn .extern main _{{[0-9_]+}}::[mem_[[memId:[_0-9]*]]: %mem.M, (i32), %mem.Ptr (%mem.Ptr ((.Idx 256), 0:.Nat), 0:.Nat), return_[[returnId:[_0-9]*]]: .Cn [%mem.M, (i32)]] = {
// CHECK-DAG: _[[appId:[_0-9]*]]: ⊥:★ = return_[[returnEtaId:[_0-9]*]] (mem_[[memId]], 42:(i32));
// CHECK-DAG: _[[appId]]

// CHECK-DAG: return_[[returnEtaId]] _[[returnEtaVarId:[0-9_]+]]: [%mem.M, (i32)] = {
// CHECK-DAG: return_[[retAppId:[_0-9]*]]: ⊥:★ = return_[[returnId]] _[[returnEtaVarId]];
// CHECK-DAG: return_[[retAppId]]
Binary file added thorin/.DS_Store
Binary file not shown.
2 changes: 2 additions & 0 deletions thorin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -23,6 +23,8 @@ add_library(libthorin
tuple.h
world.cpp
world.h
builder.cpp
builder.h
analyses/cfg.cpp
analyses/cfg.h
analyses/deptree.cpp
Binary file added thorin/be/.DS_Store
Binary file not shown.
18 changes: 9 additions & 9 deletions thorin/dump.cpp
Original file line number Diff line number Diff line change
@@ -108,13 +108,13 @@ std::ostream& operator<<(std::ostream& os, Inline u) {

if (auto type = u->isa<Type>()) {
auto level = as_lit(type->level()); // TODO other levels
return print(os, level == 0 ? "" : "");
return print(os, level == 0 ? "*" : "");
} else if (u->isa<Nat>()) {
return print(os, ".Nat");
} else if (auto bot = u->isa<Bot>()) {
return print(os, ":{}", bot->type());
return print(os, ".bot:{}", bot->type());
} else if (auto top = u->isa<Top>()) {
return print(os, ":{}", top->type());
return print(os, ".top:{}", top->type());
} else if (auto axiom = u->isa<Axiom>()) {
const auto& name = axiom->name();
return print(os, "{}{}", name[0] == '%' ? "" : "%", name);
@@ -129,8 +129,8 @@ std::ostream& operator<<(std::ostream& os, Inline u) {
} else if (auto pi = u->isa<Pi>()) {
if (pi->is_cn()) return print(os, ".Cn {}", pi->dom());
if (auto nom = pi->isa_nom<Pi>(); nom && nom->var())
return print(os, "Π {}: {} {}", nom->var(), pi->dom(), pi->codom());
return print(os, "Π {} {}", pi->dom(), pi->codom());
return print(os, "Π {}: {} -> {}", nom->var(), pi->dom(), pi->codom());
return print(os, "Π {} -> {}", pi->dom(), pi->codom());
} else if (auto lam = u->isa<Lam>()) {
return print(os, "{}, {}", lam->filter(), lam->body());
} else if (auto int_ = u->isa<Idx>()) {
@@ -153,12 +153,12 @@ std::ostream& operator<<(std::ostream& os, Inline u) {
return tuple->type()->isa_nom() ? print(os, ":{}", tuple->type()) : os;
} else if (auto arr = u->isa<Arr>()) {
if (auto nom = arr->isa_nom<Arr>(); nom && nom->var())
return print(os, "«{}: {}; {}»", nom->var(), nom->shape(), nom->body());
return print(os, "«{}; {}»", arr->shape(), arr->body());
return print(os, "<<{}: {}; {}>>", nom->var(), nom->shape(), nom->body());
return print(os, "<<{}; {}>>", arr->shape(), arr->body());
} else if (auto pack = u->isa<Pack>()) {
if (auto nom = pack->isa_nom<Pack>(); nom && nom->var())
return print(os, "{}: {}; {}", nom->var(), nom->shape(), nom->body());
return print(os, "{}; {}", pack->shape(), pack->body());
return print(os, "<{}: {}; {}>", nom->var(), nom->shape(), nom->body());
return print(os, "<{}; {}>", pack->shape(), pack->body());
} else if (auto proxy = u->isa<Proxy>()) {
return print(os, ".proxy#{}#{} {, }", proxy->pass(), proxy->tag(), proxy->ops());
} else if (auto bound = isa_bound(*u)) {
Binary file added thorin/pass/.DS_Store
Binary file not shown.
5 changes: 5 additions & 0 deletions thorin/world.h
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
#include "thorin/flags.h"
#include "thorin/lattice.h"
#include "thorin/tuple.h"
#include "thorin/builder.h"

#include "thorin/util/hash.h"
#include "thorin/util/log.h"
@@ -443,6 +444,10 @@ class World {
void write() const; ///< Same above but file name defaults to World::name.
///@}

Builder builder(){
return Builder(*this);
}

private:
/// @name put into sea of nodes
///@{