Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for identities to Impls #968

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
eb0e70f
Added support for identities
varun10p Aug 20, 2024
8adc504
Syntax checking back in the macro
Aug 20, 2024
c77537e
remove debug
Aug 21, 2024
6f61a0a
fixing cond
Aug 21, 2024
aa4576e
Fixing comments
varun10p Aug 21, 2024
dd71ba6
Moving fp-safe to impl
Aug 27, 2024
fd4079c
Continued
varun10p Aug 27, 2024
643053e
Only comparators left
varun10p Aug 28, 2024
f5f47d0
Removed fp-safe rulesets
varun10p Aug 28, 2024
3b7f1ad
Exp rules
varun10p Aug 28, 2024
f1d66d1
Adding rules back
varun10p Aug 28, 2024
a22241a
Cleanup rules
varun10p Aug 28, 2024
fef3e8c
Fixed up rules
varun10p Sep 9, 2024
da5bd20
Condense sets
varun10p Sep 9, 2024
ff472ae
More condensing
varun10p Sep 9, 2024
477f2f5
Fixed some compilation
varun10p Sep 9, 2024
c016f83
typos
varun10p Sep 9, 2024
f747780
impl-infos
varun10p Sep 9, 2024
336c5e7
typo
varun10p Sep 9, 2024
f0f738b
contract
varun10p Sep 9, 2024
9e632dc
fmt
Sep 9, 2024
460027a
making dictionary
varun10p Sep 9, 2024
ba1f1ca
Remove expansive check
varun10p Sep 10, 2024
69991e6
Fixing fmt
Sep 10, 2024
71aa259
Merge remote-tracking branch 'origin/main' into varun10p-identities
varun10p Sep 10, 2024
8b310a1
Fixing merges
varun10p Sep 10, 2024
d905bda
Fixing contract
varun10p Sep 11, 2024
505583c
fmt
Sep 11, 2024
430dc24
Additional fixes
Sep 11, 2024
9a69987
Impl names in identity
Sep 11, 2024
83fbca6
More impl names
Sep 12, 2024
534ba00
Added f32
Sep 12, 2024
b89be15
Found bug + fmt
Sep 12, 2024
a4efa40
Exact rules
Sep 12, 2024
3a50b18
Type-checking logic added
Sep 25, 2024
09c9ffc
Fixed bugs.
Sep 27, 2024
8027e3b
More fixes
Sep 27, 2024
65556e8
Found last bugs
Sep 27, 2024
fc1064b
Fpcore bug fix
Sep 27, 2024
e975732
Better fix
Sep 27, 2024
c2e5e39
Merge remote-tracking branch 'origin/main' into varun10p-identities
Sep 27, 2024
1c76a18
Last change + fmt
Oct 16, 2024
95bc0e1
Merge remote-tracking branch 'origin/main' into varun10p-identities
Oct 16, 2024
3bcc196
Fixing contract + fmt
Oct 16, 2024
dcd5694
Merge branch 'main' into varun10p-identities
varun10p Oct 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/core/egg-herbie.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,6 @@
[(symbol? input)
; expansive rules
(define itype (dict-ref (rule-itypes ru) input))
(unless (type-name? itype)
(error 'rule->egg-rules "expansive rules over impls is unsound ~a" input))
Comment on lines -485 to -486
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still strange. What is the story here?

(for/list ([op (all-operators)]
#:when (eq? (operator-info op 'otype) itype))
(define itypes (operator-info op 'itype))
Expand Down
2 changes: 1 addition & 1 deletion src/core/mainloop.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@
(timeline-event! 'simplify)

; egg schedule (only mathematical rewrites)
(define rules (platform-impl-rules (*fp-safe-simplify-rules*)))
(define rules (append (*fp-safe-simplify-rules*) (real-rules (*simplify-rules*))))
(define schedule `((,rules . ((node . ,(*node-limit*)) (const-fold? . #f)))))

; egg runner
Expand Down
107 changes: 29 additions & 78 deletions src/core/rules.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

(provide *rules*
*simplify-rules*
*fp-safe-simplify-rules*
(struct-out rule)
real-rules)

Expand All @@ -20,8 +19,7 @@
register-ruleset!
*rulesets*
*unsound-rules*
register-ruleset*!
(struct-out rule)))
register-ruleset*!))

;; A rule represents a "find-and-replace" pattern where `input` and `output`
;; are patterns, `itypes` is a mapping from variable name to type
Expand Down Expand Up @@ -75,16 +73,6 @@
(for ([rule (in-list rules)])
(sow rule))))))

(define (*fp-safe-simplify-rules*)
(reap [sow]
(for ([(_ ruleset) (in-dict (*rulesets*))])
(match-define (list rules groups _) ruleset)
(when (and (ormap (curry flag-set? 'rules) groups)
(set-member? groups 'simplify)
(set-member? groups 'fp-safe))
(for ([rule (in-list rules)])
(sow rule))))))

(define (*unsound-rules*)
(reap [sow]
(for ([(_ ruleset) (in-dict (*rulesets*))])
Expand Down Expand Up @@ -151,7 +139,7 @@

; Commutativity
(define-ruleset* commutativity
(arithmetic simplify fp-safe sound)
(arithmetic simplify sound)
#:type ([a real] [b real])
[+-commutative (+ a b) (+ b a)]
[*-commutative (* a b) (* b a)])
Expand Down Expand Up @@ -191,12 +179,7 @@
[distribute-rgt-out (+ (* b a) (* c a)) (* a (+ b c))]
[distribute-rgt-out-- (- (* b a) (* c a)) (* a (- b c))]
[distribute-lft1-in (+ (* b a) a) (* (+ b 1) a)]
[distribute-rgt1-in (+ a (* c a)) (* (+ c 1) a)])

; Safe Distributiviity
(define-ruleset* distributivity-fp-safe
(arithmetic simplify fp-safe sound)
#:type ([a real] [b real])
[distribute-rgt1-in (+ a (* c a)) (* (+ c 1) a)]
[distribute-lft-neg-in (neg (* a b)) (* (neg a) b)]
[distribute-rgt-neg-in (neg (* a b)) (* a (neg b))]
[distribute-lft-neg-out (* (neg a) b) (neg (* a b))]
Expand All @@ -208,8 +191,8 @@
[distribute-neg-frac (neg (/ a b)) (/ (neg a) b)]
[distribute-neg-frac2 (neg (/ a b)) (/ a (neg b))])

(define-ruleset* cancel-sign-fp-safe
(arithmetic simplify fp-safe sound)
(define-ruleset* cancel-sign
(arithmetic simplify sound)
#:type ([a real] [b real] [c real])
[cancel-sign-sub (- a (* (neg b) c)) (+ a (* b c))]
[cancel-sign-sub-inv (- a (* b c)) (+ a (* (neg b) c))])
Expand Down Expand Up @@ -242,20 +225,12 @@
#:type ([a real])
[remove-double-div (/ 1 (/ 1 a)) a]
[rgt-mult-inverse (* a (/ 1 a)) 1]
[lft-mult-inverse (* (/ 1 a) a) 1])

(define-ruleset* id-reduce-fp-safe-nan
(arithmetic simplify fp-safe-nan sound)
#:type ([a real])
[lft-mult-inverse (* (/ 1 a) a) 1]
[+-inverses (- a a) 0]
[div0 (/ 0 a) 0]
[mul0-lft (* 0 a) 0]
[mul0-rgt (* a 0) 0]
[*-inverses (/ a a) 1])

(define-ruleset* id-reduce-fp-safe
(arithmetic simplify fp-safe sound)
#:type ([a real])
[*-inverses (/ a a) 1]
[+-lft-identity (+ 0 a) a]
[+-rgt-identity (+ a 0) a]
[--rgt-identity (- a 0) a]
Expand All @@ -266,8 +241,8 @@
[/-rgt-identity (/ a 1) a]
[mul-1-neg (* -1 a) (neg a)])

(define-ruleset* nan-transform-fp-safe
(arithmetic simplify fp-safe sound)
(define-ruleset* nan-transform
(arithmetic simplify sound)
#:type ([a real] [b real])
[sub-neg (- a b) (+ a (neg b))]
[unsub-neg (+ a (neg b)) (- a b)]
Expand All @@ -285,10 +260,7 @@
#:type ([a real] [b real])
[clear-num (/ a b) (/ 1 (/ b a))])

(define-ruleset* id-transform-fp-safe
(arithmetic fp-safe sound)
#:type ([a real])
[*-un-lft-identity a (* 1 a)])
(define-ruleset* id-transform (arithmetic sound) #:type ([a real]) [*-un-lft-identity a (* 1 a)])

; Difference of cubes
(define-ruleset*
Expand Down Expand Up @@ -324,13 +296,13 @@
[rem-sqrt-square (sqrt (* x x)) (fabs x)])

(define-ruleset* squares-reduce-fp-sound
(arithmetic simplify fp-safe sound)
(arithmetic simplify sound)
#:type ([x real])
[sqr-neg (* (neg x) (neg x)) (* x x)]
[sqr-abs (* (fabs x) (fabs x)) (* x x)])

(define-ruleset* fabs-reduce
(arithmetic simplify fp-safe sound)
(arithmetic simplify sound)
#:type ([x real] [a real] [b real])
[fabs-fabs (fabs (fabs x)) (fabs x)]
[fabs-sub (fabs (- a b)) (fabs (- b a))]
Expand All @@ -340,7 +312,7 @@
[fabs-div (fabs (/ a b)) (/ (fabs a) (fabs b))])

(define-ruleset* fabs-expand
(arithmetic fp-safe sound)
(arithmetic sound)
#:type ([x real] [a real] [b real])
[neg-fabs (fabs x) (fabs (neg x))]
[mul-fabs (* (fabs a) (fabs b)) (fabs (* a b))]
Expand Down Expand Up @@ -405,7 +377,7 @@
[rem-log-exp (log (exp x)) x])

(define-ruleset* exp-constants
(exponents simplify fp-safe sound)
(exponents simplify sound)
[exp-0 (exp 0) 1]
[exp-1-e (exp 1) (E)]
[1-exp 1 (exp 0)]
Expand All @@ -431,20 +403,15 @@
[exp-lft-cube (exp (* a 3)) (pow (exp a) 3)])

; Powers
(define-ruleset* pow-reduce (exponents simplify sound) #:type ([a real]) [unpow-1 (pow a -1) (/ 1 a)])

(define-ruleset* pow-reduce-fp-safe
(exponents simplify fp-safe sound)
#:type ([a real])
[unpow1 (pow a 1) a])

(define-ruleset* pow-reduce-fp-safe-nan
(exponents simplify fp-safe-nan sound)
(define-ruleset* pow-reduce
(exponents simplify sound)
#:type ([a real])
[unpow-1 (pow a -1) (/ 1 a)]
[unpow1 (pow a 1) a]
[unpow0 (pow a 0) 1]
[pow-base-1 (pow 1 a) 1])

(define-ruleset* pow-expand-fp-safe (exponents fp-safe sound) #:type ([a real]) [pow1 a (pow a 1)])
(define-ruleset* pow-expand (exponents sound) #:type ([a real]) [pow1 a (pow a 1)])

(define-ruleset* pow-canonicalize
(exponents simplify sound)
Expand Down Expand Up @@ -482,16 +449,8 @@
[pow-pow (pow (pow a b) c) (pow a (* b c))]
[pow-unpow (pow a (* b c)) (pow (pow a b) c)]
[unpow-prod-up (pow a (+ b c)) (* (pow a b) (pow a c))]
[unpow-prod-down (pow (* b c) a) (* (pow b a) (pow c a))])

(define-ruleset* pow-transform-fp-safe-nan
(exponents simplify fp-safe-nan sound)
#:type ([a real])
[pow-base-0 (pow 0 a) 0])

(define-ruleset* pow-transform-fp-safe
(exponents fp-safe sound)
#:type ([a real])
[unpow-prod-down (pow (* b c) a) (* (pow b a) (pow c a))]
[pow-base-0 (pow 0 a) 0]
[inv-pow (/ 1 a) (pow a -1)])

(define-ruleset* log-distribute-sound
Expand All @@ -517,24 +476,18 @@

; Trigonometry
(define-ruleset* trig-reduce-fp-sound
(trigonometry simplify fp-safe sound)
(trigonometry simplify sound)
[sin-0 (sin 0) 0]
[cos-0 (cos 0) 1]
[tan-0 (tan 0) 0])

(define-ruleset* trig-reduce-fp-sound-nan
(trigonometry simplify fp-safe-nan sound)
(trigonometry simplify sound)
#:type ([x real])
[sin-neg (sin (neg x)) (neg (sin x))]
[cos-neg (cos (neg x)) (cos x)]
[tan-neg (tan (neg x)) (neg (tan x))])

(define-ruleset* trig-expand-fp-safe
(trignometry fp-safe sound)
#:type ([x real])
[sqr-sin-b (* (sin x) (sin x)) (- 1 (* (cos x) (cos x)))]
[sqr-cos-b (* (cos x) (cos x)) (- 1 (* (sin x) (sin x)))])

(define-ruleset*
trig-inverses
(trigonometry sound)
Expand Down Expand Up @@ -597,6 +550,8 @@
(define-ruleset* trig-expand-sound
(trigonometry sound)
#:type ([x real] [y real] [a real] [b real])
[sqr-sin-b (* (sin x) (sin x)) (- 1 (* (cos x) (cos x)))]
[sqr-cos-b (* (cos x) (cos x)) (- 1 (* (sin x) (sin x)))]
[sin-sum (sin (+ x y)) (+ (* (sin x) (cos y)) (* (cos x) (sin y)))]
[cos-sum (cos (+ x y)) (- (* (cos x) (cos y)) (* (sin x) (sin y)))]
[tan-sum (tan (+ x y)) (/ (+ (tan x) (tan y)) (- 1 (* (tan x) (tan y))))]
Expand Down Expand Up @@ -667,6 +622,10 @@
(define-ruleset* htrig-expand-sound
(hyperbolic sound)
#:type ([x real] [y real])
[sinh-neg (sinh (neg x)) (neg (sinh x))]
[sinh-0 (sinh 0) 0]
[cosh-neg (cosh (neg x)) (cosh x)]
[cosh-0 (cosh 0) 1]
[sinh-undef (- (exp x) (exp (neg x))) (* 2 (sinh x))]
[cosh-undef (+ (exp x) (exp (neg x))) (* 2 (cosh x))]
[tanh-undef (/ (- (exp x) (exp (neg x))) (+ (exp x) (exp (neg x)))) (tanh x)]
Expand All @@ -691,14 +650,6 @@
#:type ([x real] [y real])
[tanh-1/2* (tanh (/ x 2)) (/ (- (cosh x) 1) (sinh x))])

(define-ruleset* htrig-expand-fp-safe
(hyperbolic fp-safe sound)
#:type ([x real])
[sinh-neg (sinh (neg x)) (neg (sinh x))]
[sinh-0 (sinh 0) 0]
[cosh-neg (cosh (neg x)) (cosh x)]
[cosh-0 (cosh 0) 1])

(define-ruleset* ahtrig-expand-sound
(hyperbolic sound)
#:type ([x real])
Expand Down
Loading