diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 912bfd539..67fd37595 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -482,8 +482,6 @@ [(symbol? input) ; expansive rules (define itype (dict-ref (rule-itypes ru) input)) - (unless (type-name? itype) - (error 'rule->egg-rules "expansive rules over impls is unsound ~a" input)) (for/list ([op (all-operators)] #:when (eq? (operator-info op 'otype) itype)) (define itypes (operator-info op 'itype)) diff --git a/src/core/mainloop.rkt b/src/core/mainloop.rkt index b69c7a24a..a63388244 100644 --- a/src/core/mainloop.rkt +++ b/src/core/mainloop.rkt @@ -366,7 +366,7 @@ (timeline-event! 'simplify) ; egg schedule (only mathematical rewrites) - (define rules (platform-impl-rules (*fp-safe-simplify-rules*))) + (define rules (append (*fp-safe-simplify-rules*) (real-rules (*simplify-rules*)))) (define schedule `((,rules . ((node . ,(*node-limit*)) (const-fold? . #f))))) ; egg runner diff --git a/src/core/rules.rkt b/src/core/rules.rkt index aad2f0d42..cfb842212 100644 --- a/src/core/rules.rkt +++ b/src/core/rules.rkt @@ -10,7 +10,6 @@ (provide *rules* *simplify-rules* - *fp-safe-simplify-rules* (struct-out rule) real-rules) @@ -20,8 +19,7 @@ register-ruleset! *rulesets* *unsound-rules* - register-ruleset*! - (struct-out rule))) + register-ruleset*!)) ;; A rule represents a "find-and-replace" pattern where `input` and `output` ;; are patterns, `itypes` is a mapping from variable name to type @@ -75,16 +73,6 @@ (for ([rule (in-list rules)]) (sow rule)))))) -(define (*fp-safe-simplify-rules*) - (reap [sow] - (for ([(_ ruleset) (in-dict (*rulesets*))]) - (match-define (list rules groups _) ruleset) - (when (and (ormap (curry flag-set? 'rules) groups) - (set-member? groups 'simplify) - (set-member? groups 'fp-safe)) - (for ([rule (in-list rules)]) - (sow rule)))))) - (define (*unsound-rules*) (reap [sow] (for ([(_ ruleset) (in-dict (*rulesets*))]) @@ -151,7 +139,7 @@ ; Commutativity (define-ruleset* commutativity - (arithmetic simplify fp-safe sound) + (arithmetic simplify sound) #:type ([a real] [b real]) [+-commutative (+ a b) (+ b a)] [*-commutative (* a b) (* b a)]) @@ -191,12 +179,7 @@ [distribute-rgt-out (+ (* b a) (* c a)) (* a (+ b c))] [distribute-rgt-out-- (- (* b a) (* c a)) (* a (- b c))] [distribute-lft1-in (+ (* b a) a) (* (+ b 1) a)] - [distribute-rgt1-in (+ a (* c a)) (* (+ c 1) a)]) - -; Safe Distributiviity -(define-ruleset* distributivity-fp-safe - (arithmetic simplify fp-safe sound) - #:type ([a real] [b real]) + [distribute-rgt1-in (+ a (* c a)) (* (+ c 1) a)] [distribute-lft-neg-in (neg (* a b)) (* (neg a) b)] [distribute-rgt-neg-in (neg (* a b)) (* a (neg b))] [distribute-lft-neg-out (* (neg a) b) (neg (* a b))] @@ -208,8 +191,8 @@ [distribute-neg-frac (neg (/ a b)) (/ (neg a) b)] [distribute-neg-frac2 (neg (/ a b)) (/ a (neg b))]) -(define-ruleset* cancel-sign-fp-safe - (arithmetic simplify fp-safe sound) +(define-ruleset* cancel-sign + (arithmetic simplify sound) #:type ([a real] [b real] [c real]) [cancel-sign-sub (- a (* (neg b) c)) (+ a (* b c))] [cancel-sign-sub-inv (- a (* b c)) (+ a (* (neg b) c))]) @@ -242,20 +225,12 @@ #:type ([a real]) [remove-double-div (/ 1 (/ 1 a)) a] [rgt-mult-inverse (* a (/ 1 a)) 1] - [lft-mult-inverse (* (/ 1 a) a) 1]) - -(define-ruleset* id-reduce-fp-safe-nan - (arithmetic simplify fp-safe-nan sound) - #:type ([a real]) + [lft-mult-inverse (* (/ 1 a) a) 1] [+-inverses (- a a) 0] [div0 (/ 0 a) 0] [mul0-lft (* 0 a) 0] [mul0-rgt (* a 0) 0] - [*-inverses (/ a a) 1]) - -(define-ruleset* id-reduce-fp-safe - (arithmetic simplify fp-safe sound) - #:type ([a real]) + [*-inverses (/ a a) 1] [+-lft-identity (+ 0 a) a] [+-rgt-identity (+ a 0) a] [--rgt-identity (- a 0) a] @@ -266,8 +241,8 @@ [/-rgt-identity (/ a 1) a] [mul-1-neg (* -1 a) (neg a)]) -(define-ruleset* nan-transform-fp-safe - (arithmetic simplify fp-safe sound) +(define-ruleset* nan-transform + (arithmetic simplify sound) #:type ([a real] [b real]) [sub-neg (- a b) (+ a (neg b))] [unsub-neg (+ a (neg b)) (- a b)] @@ -285,10 +260,7 @@ #:type ([a real] [b real]) [clear-num (/ a b) (/ 1 (/ b a))]) -(define-ruleset* id-transform-fp-safe - (arithmetic fp-safe sound) - #:type ([a real]) - [*-un-lft-identity a (* 1 a)]) +(define-ruleset* id-transform (arithmetic sound) #:type ([a real]) [*-un-lft-identity a (* 1 a)]) ; Difference of cubes (define-ruleset* @@ -324,13 +296,13 @@ [rem-sqrt-square (sqrt (* x x)) (fabs x)]) (define-ruleset* squares-reduce-fp-sound - (arithmetic simplify fp-safe sound) + (arithmetic simplify sound) #:type ([x real]) [sqr-neg (* (neg x) (neg x)) (* x x)] [sqr-abs (* (fabs x) (fabs x)) (* x x)]) (define-ruleset* fabs-reduce - (arithmetic simplify fp-safe sound) + (arithmetic simplify sound) #:type ([x real] [a real] [b real]) [fabs-fabs (fabs (fabs x)) (fabs x)] [fabs-sub (fabs (- a b)) (fabs (- b a))] @@ -340,7 +312,7 @@ [fabs-div (fabs (/ a b)) (/ (fabs a) (fabs b))]) (define-ruleset* fabs-expand - (arithmetic fp-safe sound) + (arithmetic sound) #:type ([x real] [a real] [b real]) [neg-fabs (fabs x) (fabs (neg x))] [mul-fabs (* (fabs a) (fabs b)) (fabs (* a b))] @@ -405,7 +377,7 @@ [rem-log-exp (log (exp x)) x]) (define-ruleset* exp-constants - (exponents simplify fp-safe sound) + (exponents simplify sound) [exp-0 (exp 0) 1] [exp-1-e (exp 1) (E)] [1-exp 1 (exp 0)] @@ -431,20 +403,15 @@ [exp-lft-cube (exp (* a 3)) (pow (exp a) 3)]) ; Powers -(define-ruleset* pow-reduce (exponents simplify sound) #:type ([a real]) [unpow-1 (pow a -1) (/ 1 a)]) - -(define-ruleset* pow-reduce-fp-safe - (exponents simplify fp-safe sound) - #:type ([a real]) - [unpow1 (pow a 1) a]) - -(define-ruleset* pow-reduce-fp-safe-nan - (exponents simplify fp-safe-nan sound) +(define-ruleset* pow-reduce + (exponents simplify sound) #:type ([a real]) + [unpow-1 (pow a -1) (/ 1 a)] + [unpow1 (pow a 1) a] [unpow0 (pow a 0) 1] [pow-base-1 (pow 1 a) 1]) -(define-ruleset* pow-expand-fp-safe (exponents fp-safe sound) #:type ([a real]) [pow1 a (pow a 1)]) +(define-ruleset* pow-expand (exponents sound) #:type ([a real]) [pow1 a (pow a 1)]) (define-ruleset* pow-canonicalize (exponents simplify sound) @@ -482,16 +449,8 @@ [pow-pow (pow (pow a b) c) (pow a (* b c))] [pow-unpow (pow a (* b c)) (pow (pow a b) c)] [unpow-prod-up (pow a (+ b c)) (* (pow a b) (pow a c))] - [unpow-prod-down (pow (* b c) a) (* (pow b a) (pow c a))]) - -(define-ruleset* pow-transform-fp-safe-nan - (exponents simplify fp-safe-nan sound) - #:type ([a real]) - [pow-base-0 (pow 0 a) 0]) - -(define-ruleset* pow-transform-fp-safe - (exponents fp-safe sound) - #:type ([a real]) + [unpow-prod-down (pow (* b c) a) (* (pow b a) (pow c a))] + [pow-base-0 (pow 0 a) 0] [inv-pow (/ 1 a) (pow a -1)]) (define-ruleset* log-distribute-sound @@ -517,24 +476,18 @@ ; Trigonometry (define-ruleset* trig-reduce-fp-sound - (trigonometry simplify fp-safe sound) + (trigonometry simplify sound) [sin-0 (sin 0) 0] [cos-0 (cos 0) 1] [tan-0 (tan 0) 0]) (define-ruleset* trig-reduce-fp-sound-nan - (trigonometry simplify fp-safe-nan sound) + (trigonometry simplify sound) #:type ([x real]) [sin-neg (sin (neg x)) (neg (sin x))] [cos-neg (cos (neg x)) (cos x)] [tan-neg (tan (neg x)) (neg (tan x))]) -(define-ruleset* trig-expand-fp-safe - (trignometry fp-safe sound) - #:type ([x real]) - [sqr-sin-b (* (sin x) (sin x)) (- 1 (* (cos x) (cos x)))] - [sqr-cos-b (* (cos x) (cos x)) (- 1 (* (sin x) (sin x)))]) - (define-ruleset* trig-inverses (trigonometry sound) @@ -597,6 +550,8 @@ (define-ruleset* trig-expand-sound (trigonometry sound) #:type ([x real] [y real] [a real] [b real]) + [sqr-sin-b (* (sin x) (sin x)) (- 1 (* (cos x) (cos x)))] + [sqr-cos-b (* (cos x) (cos x)) (- 1 (* (sin x) (sin x)))] [sin-sum (sin (+ x y)) (+ (* (sin x) (cos y)) (* (cos x) (sin y)))] [cos-sum (cos (+ x y)) (- (* (cos x) (cos y)) (* (sin x) (sin y)))] [tan-sum (tan (+ x y)) (/ (+ (tan x) (tan y)) (- 1 (* (tan x) (tan y))))] @@ -667,6 +622,10 @@ (define-ruleset* htrig-expand-sound (hyperbolic sound) #:type ([x real] [y real]) + [sinh-neg (sinh (neg x)) (neg (sinh x))] + [sinh-0 (sinh 0) 0] + [cosh-neg (cosh (neg x)) (cosh x)] + [cosh-0 (cosh 0) 1] [sinh-undef (- (exp x) (exp (neg x))) (* 2 (sinh x))] [cosh-undef (+ (exp x) (exp (neg x))) (* 2 (cosh x))] [tanh-undef (/ (- (exp x) (exp (neg x))) (+ (exp x) (exp (neg x)))) (tanh x)] @@ -691,14 +650,6 @@ #:type ([x real] [y real]) [tanh-1/2* (tanh (/ x 2)) (/ (- (cosh x) 1) (sinh x))]) -(define-ruleset* htrig-expand-fp-safe - (hyperbolic fp-safe sound) - #:type ([x real]) - [sinh-neg (sinh (neg x)) (neg (sinh x))] - [sinh-0 (sinh 0) 0] - [cosh-neg (cosh (neg x)) (cosh x)] - [cosh-0 (cosh 0) 1]) - (define-ruleset* ahtrig-expand-sound (hyperbolic sound) #:type ([x real]) diff --git a/src/platforms/binary32.rkt b/src/platforms/binary32.rkt index e7f5142db..aea3f1213 100644 --- a/src/platforms/binary32.rkt +++ b/src/platforms/binary32.rkt @@ -33,10 +33,10 @@ (define-syntax (define-libm-impl/binary32 stx) (syntax-case stx (real) - [(_ op (itype ...) otype [key value] ...) + [(_ op (itype ...) otype attrib ...) (with-syntax ([impl (string->symbol (format "~a.f32" (syntax->datum #'op)))] [cname (string->symbol (format "~af" (syntax->datum #'op)))]) - #'(define-libm-impl cname (op impl itype ...) otype [key value] ...))])) + #'(define-libm-impl cname (op impl itype ...) otype attrib ...))])) (define-syntax-rule (define-libm-impls/binary32* (itype ... otype) name ...) (begin @@ -46,35 +46,136 @@ (begin (define-libm-impls/binary32* (itype ... otype) name ...) ...)) -(define-operator-impl (neg.f32 [x : binary32]) - binary32 - #:spec (neg x) - #:fpcore (! :precision binary32 (- x)) - #:fl fl32-) +(define-operator-impl + (neg.f32 [x : binary32]) + binary32 + #:spec (neg x) + #:fpcore (! :precision binary32 (- x)) + #:fl fl32- + #:identities (#:exact (neg.f32 x) + [distribute-lft-neg-in (neg.f32 (*.f32 a b)) (*.f32 (neg.f32 a) b)] + [distribute-rgt-neg-in (neg.f32 (*.f32 a b)) (*.f32 a (neg.f32 b))] + [distribute-neg-in (neg.f32 (+.f32 a b)) (+.f32 (neg.f32 a) (neg.f32 b))] + [distribute-neg-frac (neg.f32 (/.f32 a b)) (/.f32 (neg.f32 a) b)] + [distribute-neg-frac2 (neg.f32 (/.f32 a b)) (/.f32 a (neg.f32 b))] + [remove-double-neg (neg.f32 (neg.f32 a)) a] + [neg-sub0 (neg.f32 b) (-.f32 0 b)] + [neg-mul-1 (neg.f32 a) (*.f32 -1 a)])) (define-operator-impl (+.f32 [x : binary32] [y : binary32]) binary32 #:spec (+ x y) #:fpcore (! :precision binary32 (+ x y)) - #:fl fl32+) + #:fl fl32+ + #:commutes + #:identities + ([distribute-neg-out (+.f32 (neg.f32 a) (neg.f32 b)) (neg.f32 (+.f32 a b))] + [+-lft-identity (+.f32 0 a) a] + [+-rgt-identity (+.f32 a 0) a] + [unsub-neg (+.f32 a (neg.f32 b)) (-.f32 a b)])) (define-operator-impl (-.f32 [x : binary32] [y : binary32]) binary32 #:spec (- x y) #:fpcore (! :precision binary32 (- x y)) - #:fl fl32-) + #:fl fl32- + #:identities + ([cancel-sign-sub (-.f32 a (*.f32 (neg.f32 b) c)) (+.f32 a (*.f32 b c))] + [cancel-sign-sub-inv (-.f32 a (*.f32 b c)) (+.f32 a (*.f32 (neg.f32 b) c))] + #:exact (-.f32 a a) + #:exact (-.f32 a 0) + [sub0-neg (-.f32 0 a) (neg.f32 a)] + [sub-neg (-.f32 a b) (+.f32 a (neg.f32 b))])) -(define-operator-impl (*.f32 [x : binary32] [y : binary32]) - binary32 - #:spec (* x y) - #:fpcore (! :precision binary32 (* x y)) - #:fl fl32*) +(define-operator-impl + (*.f32 [x : binary32] [y : binary32]) + binary32 + #:spec (* x y) + #:fpcore (! :precision binary32 (* x y)) + #:fl fl32* + #:commutes + #:identities ([distribute-lft-neg-out (*.f32 (neg.f32 x) y) (neg.f32 (*.f32 x y))] + [distribute-rgt-neg-out (*.f32 x (neg.f32 y)) (neg.f32 (*.f32 x y))] + [mul0-lft (*.f32 0 a) 0] + [mul0-rgt (*.f32 a 0) 0] + [*-lft-identity (*.f32 1 a) a] + [*-rgt-identity (*.f32 a 1) a] + [mul-1-neg (*.f32 -1 a) (neg.f32 a)] + [*-un-lft-identity a (*.f32 1 a)] + [sqr-neg (*.f32 (neg.f32 x) (neg.f32 x)) (*.f32 x x)] + [sqr-abs (*.f32 (fabs.f32 x) (fabs.f32 x)) (*.f32 x x)] + [mul-fabs (*.f32 (fabs.f32 a) (fabs.f32 b)) (fabs.f32 (*.f32 a b))] + [sqr-sin-b (*.f32 (sin.f32 x) (sin.f32 x)) (-.f32 1 (*.f32 (cos.f32 x) (cos.f32 x)))] + [sqr-cos-b (*.f32 (cos.f32 x) (cos.f32 x)) (-.f32 1 (*.f32 (sin.f32 x) (sin.f32 x)))])) (define-operator-impl (/.f32 [x : binary32] [y : binary32]) binary32 #:spec (/ x y) #:fpcore (! :precision binary32 (/ x y)) - #:fl fl32/) + #:fl fl32/ + #:identities ([distribute-frac-neg (/.f32 (neg.f32 x) y) (neg.f32 (/.f32 x y))] + [distribute-frac-neg2 (/.f32 x (neg.f32 y)) (neg.f32 (/.f32 x y))] + [div0 (/.f32 0 a) 0] + [*-inverses (/.f32 a a) 1] + [/-rgt-identity (/.f32 a 1) a] + [inv-pow (/.f32 1 a) (pow.f32 a -1)])) + +(define-libm-impl/binary32 fabs + (binary32) + binary32 + #:identities + ([fabs-fabs (fabs.f32 (fabs.f32 a)) (fabs.f32 a)] + [fabs-sub (fabs.f32 (-.f32 a b)) (fabs.f32 (-.f32 b a))] + [fabs-neg (fabs.f32 (neg.f32 a)) (fabs.f32 a)] + [fabs-sqr (fabs.f32 (*.f32 a a)) (*.f32 a a)] + [fabs-mul (fabs.f32 (*.f32 a b)) (*.f32 (fabs.f32 a) (fabs.f32 b))] + [fabs-div (fabs.f32 (/.f32 a b)) (/.f32 (fabs.f32 a) (fabs.f32 b))] + [neg-fabs (fabs.f32 x) (fabs.f32 (neg.f32 x))])) + +(define-libm-impl/binary32 exp + (binary32) + binary32 + #:identities ([exp-0 (exp.f32 0) 1] [exp-1-e (exp.f32 1) (E)] + [1-exp 1 (exp.f32 0)] + [e-exp-1 (E) (exp.f32 1)])) + +(define-libm-impl/binary32 pow + (binary32 binary32) + binary32 + #:identities ([unpow1 (pow.f32 a 1) a] [unpow0 (pow.f32 a 0) 1] + [pow-base-1 (pow.f32 1 a) 1] + [pow1 a (pow.f32 a 1)] + [pow-base-0 (pow.f32 0 a) 0])) + +(define-libm-impl/binary32 + sin + (binary32) + binary32 + #:identities ([sin-0 (sin.f32 0) 0] [sin-neg (sin.f32 (neg.f32 x)) (neg.f32 (sin.f32 x))])) + +(define-libm-impl/binary32 cos + (binary32) + binary32 + #:identities + ([cos-0 (cos.f32 0) 1] [cos-neg (cos.f32 (neg.f32 x)) (cos.f32 x)])) + +(define-libm-impl/binary32 + tan + (binary32) + binary32 + #:identities ([tan-0 (tan.f32 0) 0] [tan-neg (tan.f32 (neg.f32 x)) (neg.f32 (tan.f32 x))])) + +(define-libm-impl/binary32 sinh + (binary32) + binary32 + #:identities ([sinh-neg (sinh.f32 (neg.f32 x)) (neg.f32 (sinh.f32 x))] + [sinh-0 (sinh.f32 0) 0])) + +(define-libm-impl/binary32 cosh + (binary32) + binary32 + #:identities + ([cosh-neg (cosh.f32 (neg.f32 x)) (cosh.f32 x)] [cosh-0 (cosh.f32 0) 1])) (define-comparator-impls binary32 [== ==.f32 =] @@ -92,12 +193,9 @@ atanh cbrt ceil - cos - cosh erf exp exp2 - fabs floor lgamma log @@ -106,10 +204,7 @@ logb rint round - sin - sinh sqrt - tan tanh tgamma trunc)] diff --git a/src/platforms/binary64.rkt b/src/platforms/binary64.rkt index e42705515..8d6abe73e 100644 --- a/src/platforms/binary64.rkt +++ b/src/platforms/binary64.rkt @@ -34,9 +34,9 @@ (define-syntax (define-libm-impl/binary64 stx) (syntax-case stx (real) - [(_ op (itype ...) otype [key value] ...) + [(_ op (itype ...) otype attrib ...) (with-syntax ([impl (string->symbol (format "~a.f64" (syntax->datum #'op)))]) - #'(define-libm-impl op (op impl itype ...) otype [key value] ...))])) + #'(define-libm-impl op (op impl itype ...) otype attrib ...))])) (define-syntax-rule (define-libm-impls/binary64* (itype ... otype) name ...) (begin @@ -46,35 +46,134 @@ (begin (define-libm-impls/binary64* (itype ... otype) name ...) ...)) -(define-operator-impl (neg.f64 [x : binary64]) - binary64 - #:spec (neg x) - #:fpcore (! :precision binary64 (- x)) - #:fl -) - +(define-operator-impl + (neg.f64 [x : binary64]) + binary64 + #:spec (neg x) + #:fpcore (! :precision binary64 (- x)) + #:fl - + #:identities (#:exact (neg.f64 a) + [distribute-lft-neg-in (neg.f64 (*.f64 a b)) (*.f64 (neg.f64 a) b)] + [distribute-rgt-neg-in (neg.f64 (*.f64 a b)) (*.f64 a (neg.f64 b))] + [distribute-neg-in (neg.f64 (+.f64 a b)) (+.f64 (neg.f64 a) (neg.f64 b))] + [distribute-neg-frac (neg.f64 (/.f64 a b)) (/.f64 (neg.f64 a) b)] + [distribute-neg-frac2 (neg.f64 (/.f64 a b)) (/.f64 a (neg.f64 b))] + [remove-double-neg (neg.f64 (neg.f64 a)) a] + [neg-sub0 (neg.f64 b) (-.f64 0 b)] + [neg-mul-1 (neg.f64 a) (*.f64 -1 a)])) (define-operator-impl (+.f64 [x : binary64] [y : binary64]) binary64 #:spec (+ x y) #:fpcore (! :precision binary64 (+ x y)) - #:fl +) - + #:fl + + #:commutes + #:identities + ([distribute-neg-out (+.f64 (neg.f64 a) (neg.f64 b)) (neg.f64 (+.f64 a b))] + [+-lft-identity (+.f64 0 a) a] + [+-rgt-identity (+.f64 a 0) a] + [unsub-neg (+.f64 a (neg.f64 b)) (-.f64 a b)])) (define-operator-impl (-.f64 [x : binary64] [y : binary64]) binary64 #:spec (- x y) #:fpcore (! :precision binary64 (- x y)) - #:fl -) - -(define-operator-impl (*.f64 [x : binary64] [y : binary64]) - binary64 - #:spec (* x y) - #:fpcore (! :precision binary64 (* x y)) - #:fl *) - + #:fl - + #:identities + ([cancel-sign-sub (-.f64 a (*.f64 (neg.f64 b) c)) (+.f64 a (*.f64 b c))] + [cancel-sign-sub-inv (-.f64 a (*.f64 b c)) (+.f64 a (*.f64 (neg.f64 b) c))] + #:exact (-.f64 a a) + #:exact (-.f64 a 0) + [sub0-neg (-.f64 0 a) (neg.f64 a)] + [sub-neg (-.f64 a b) (+.f64 a (neg.f64 b))])) +(define-operator-impl + (*.f64 [x : binary64] [y : binary64]) + binary64 + #:spec (* x y) + #:fpcore (! :precision binary64 (* x y)) + #:fl * + #:commutes + #:identities ([distribute-lft-neg-out (*.f64 (neg.f64 a) b) (neg.f64 (*.f64 a b))] + [distribute-rgt-neg-out (*.f64 a (neg.f64 b)) (neg.f64 (*.f64 a b))] + [mul0-lft (*.f64 0 a) 0] + [mul0-rgt (*.f64 a 0) 0] + [*-lft-identity (*.f64 1 a) a] + [*-rgt-identity (*.f64 a 1) a] + [mul-1-neg (*.f64 -1 a) (neg.f64 a)] + [*-un-lft-identity a (*.f64 1 a)] + [sqr-neg (*.f64 (neg.f64 a) (neg.f64 a)) (*.f64 a a)] + [sqr-abs (*.f64 (fabs.f64 a) (fabs.f64 a)) (*.f64 a a)] + [mul-fabs (*.f64 (fabs.f64 a) (fabs.f64 b)) (fabs.f64 (*.f64 a b))] + [sqr-sin-b (*.f64 (sin.f64 x) (sin.f64 x)) (-.f64 1 (*.f64 (cos.f64 x) (cos.f64 x)))] + [sqr-cos-b (*.f64 (cos.f64 x) (cos.f64 x)) (-.f64 1 (*.f64 (sin.f64 x) (sin.f64 x)))])) (define-operator-impl (/.f64 [x : binary64] [y : binary64]) binary64 #:spec (/ x y) #:fpcore (! :precision binary64 (/ x y)) - #:fl /) + #:fl / + #:identities + ([distribute-frac-neg (/.f64 (neg.f64 a) b) (neg.f64 (/.f64 a b))] + [distribute-frac-neg2 (/.f64 a (neg.f64 b)) (neg.f64 (/.f64 a b))] + [div0 (/.f64 0 a) 0] + [*-inverses (/.f64 a a) 1] + [/-rgt-identity (/.f64 a 1) a] + [div-fabs (/.f64 (fabs.f64 a) (fabs.f64 b)) (fabs.f64 (/.f64 a b))] + [inv-pow (/.f64 1 a) (pow.f64 a -1)])) + +(define-libm-impl/binary64 fabs + (binary64) + binary64 + #:identities + ([fabs-fabs (fabs.f64 (fabs.f64 a)) (fabs.f64 a)] + [fabs-sub (fabs.f64 (-.f64 a b)) (fabs.f64 (-.f64 b a))] + [fabs-neg (fabs.f64 (neg.f64 a)) (fabs.f64 a)] + [fabs-sqr (fabs.f64 (*.f64 a a)) (*.f64 a a)] + [fabs-mul (fabs.f64 (*.f64 a b)) (*.f64 (fabs.f64 a) (fabs.f64 b))] + [fabs-div (fabs.f64 (/.f64 a b)) (/.f64 (fabs.f64 a) (fabs.f64 b))] + [neg-fabs (fabs.f64 x) (fabs.f64 (neg.f64 x))])) + +(define-libm-impl/binary64 exp + (binary64) + binary64 + #:identities ([exp-0 (exp.f64 0) 1] [exp-1-e (exp.f64 1) (E)] + [1-exp 1 (exp.f64 0)] + [e-exp-1 (E) (exp.f64 1)])) + +(define-libm-impl/binary64 pow + (binary64 binary64) + binary64 + #:identities ([unpow1 (pow.f64 a 1) a] [unpow0 (pow.f64 a 0) 1] + [pow-base-1 (pow.f64 1 a) 1] + [pow1 a (pow.f64 a 1)] + [pow-base-0 (pow.f64 0 a) 0])) + +(define-libm-impl/binary64 + sin + (binary64) + binary64 + #:identities ([sin-0 (sin.f64 0) 0] [sin-neg (sin.f64 (neg.f64 x)) (neg.f64 (sin.f64 x))])) + +(define-libm-impl/binary64 cos + (binary64) + binary64 + #:identities + ([cos-0 (cos.f64 0) 1] [cos-neg (cos.f64 (neg.f64 x)) (cos.f64 x)])) + +(define-libm-impl/binary64 + tan + (binary64) + binary64 + #:identities ([tan-0 (tan.f64 0) 0] [tan-neg (tan.f64 (neg.f64 x)) (neg.f64 (tan.f64 x))])) + +(define-libm-impl/binary64 sinh + (binary64) + binary64 + #:identities ([sinh-neg (sinh.f64 (neg.f64 x)) (neg.f64 (sinh.f64 x))] + [sinh-0 (sinh.f64 0) 0])) + +(define-libm-impl/binary64 cosh + (binary64) + binary64 + #:identities + ([cosh-neg (cosh.f64 (neg.f64 x)) (cosh.f64 x)] [cosh-0 (cosh.f64 0) 1])) (define-libm-impls/binary64 [(binary64 binary64) (acos acosh @@ -84,12 +183,10 @@ atanh cbrt ceil - cos cosh erf exp exp2 - fabs floor lgamma log @@ -98,10 +195,8 @@ logb rint round - sin sinh sqrt - tan tanh tgamma trunc)] diff --git a/src/platforms/bool.rkt b/src/platforms/bool.rkt index ce92e500b..9ebef0160 100644 --- a/src/platforms/bool.rkt +++ b/src/platforms/bool.rkt @@ -29,29 +29,16 @@ (define (or-fn . as) (ormap identity as)) -(define-operator-impl (not [x : bool]) bool #:spec (not x) #:fpcore (! (not x)) #:fl not) - -(define-operator-impl (and [x : bool] [y : bool]) bool #:spec (and x y) #:fl and-fn) - -(define-operator-impl (or [x : bool] [y : bool]) bool #:spec (or x y) #:fl or-fn) - -;;;;;;;;;;;;;;;;;;;;;;;;;;;;; rules ;;;;;;;;;;;;;;;;;;;;;;;;;;;;; - -(define-ruleset* bool-reduce - (bools simplify fp-safe) - #:type ([a bool] [b bool]) - [not-true (not (TRUE)) (FALSE)] - [not-false (not (FALSE)) (TRUE)] - [not-not (not (not a)) a] - [not-and (not (and a b)) (or (not a) (not b))] - [not-or (not (or a b)) (and (not a) (not b))] - [and-true-l (and (TRUE) a) a] - [and-true-r (and a (TRUE)) a] - [and-false-l (and (FALSE) a) (FALSE)] - [and-false-r (and a (FALSE)) (FALSE)] - [and-same (and a a) a] - [or-true-l (or (TRUE) a) (TRUE)] - [or-true-r (or a (TRUE)) (TRUE)] - [or-false-l (or (FALSE) a) a] - [or-false-r (or a (FALSE)) a] - [or-same (or a a) a]) +(define-operator-impl (not [x : bool]) bool #:spec (not x) #:fl not #:identities (#:exact (not a))) + +(define-operator-impl (and [x : bool] [y : bool]) + bool + #:spec (and x y) + #:fl and-fn + #:identities (#:exact (and a b))) + +(define-operator-impl (or [x : bool] [y : bool]) + bool + #:spec (or x y) + #:fl or-fn + #:identities (#:exact (or a b))) diff --git a/src/platforms/runtime/utils.rkt b/src/platforms/runtime/utils.rkt index d58f1bbe3..741844b26 100644 --- a/src/platforms/runtime/utils.rkt +++ b/src/platforms/runtime/utils.rkt @@ -22,6 +22,10 @@ (begin (define-operator-impl (impl-name) repr #:spec (name) #:fl (const value)) ...)) -(define-syntax-rule (define-comparator-impls repr [name impl-name impl-fn] ...) +(define-syntax-rule (define-comparator-impls repr [name impl-name impl-fn attrib ...] ...) (begin - (define-operator-impl (impl-name [x : repr] [y : repr]) bool #:spec (name x y) #:fl impl-fn) ...)) + (define-operator-impl (impl-name [x : repr] [y : repr]) + bool + #:spec (name x y) + #:fl impl-fn + attrib ...) ...)) diff --git a/src/syntax/matcher.rkt b/src/syntax/matcher.rkt index 84fd68b73..e9a3c6269 100644 --- a/src/syntax/matcher.rkt +++ b/src/syntax/matcher.rkt @@ -2,7 +2,8 @@ #lang racket -(provide pattern-match +(provide merge-bindings + pattern-match pattern-substitute) ;; Unions two bindings. Returns #f if they disagree. diff --git a/src/syntax/platform.rkt b/src/syntax/platform.rkt index c8751f640..fd3b39c78 100644 --- a/src/syntax/platform.rkt +++ b/src/syntax/platform.rkt @@ -5,6 +5,7 @@ "../core/programs.rkt" "../core/rules.rkt" "matcher.rkt" + "sugar.rkt" "syntax.rkt" "types.rkt") @@ -12,6 +13,7 @@ get-platform *active-platform* activate-platform! + *fp-safe-simplify-rules* platform-lifting-rules platform-lowering-rules platform-impl-rules @@ -547,3 +549,101 @@ (representation-name repr) (string-join (map (lambda (subst) (~a (cdr subst))) isubst) "-")))) (sow (rule name* input* output* itypes* repr)))))])))) + +(define (expr-otype expr) + (match expr + [(? number?) #f] + [(? variable?) #f] + [(list 'if cond ift iff) (expr-otype ift)] + [(list op args ...) (impl-info op 'otype)])) + +(define (type-verify expr otype) + (match expr + [(? number?) '()] + [(? variable?) (list (cons expr otype))] + [(list 'if cond ift iff) + (define bool-repr (get-representation 'bool)) + (define combined + (merge-bindings (type-verify cond bool-repr) + (merge-bindings (type-verify ift otype) (type-verify iff otype)))) + (unless combined + (error 'type-verify "Variable types do not match in ~a" expr)) + combined] + [(list op args ...) + (define op-otype (impl-info op 'otype)) + (when (not (equal? op-otype otype)) + (error 'type-verify "Operator ~a has type ~a, expected ~a" op op-otype otype)) + (define bindings '()) + (for ([arg (in-list args)] + [itype (in-list (impl-info op 'itype))]) + (define combined (merge-bindings bindings (type-verify arg itype))) + (unless combined + (error 'type-verify "Variable types do not match in ~a" expr)) + (set! bindings combined)) + bindings])) + +(define (expr->prog expr repr) + (match expr + [(? number?) (literal expr (representation-name repr))] + [(? variable?) expr] + [`(if ,cond ,ift ,iff) + `(if ,(expr->prog cond (get-representation 'bool)) + ,(expr->prog ift repr) + ,(expr->prog iff repr))] + [`(,impl ,args ...) + `(,impl ,@(for/list ([arg (in-list args)] + [itype (in-list (impl-info impl 'itype))]) + (expr->prog arg itype)))])) + +(define (*fp-safe-simplify-rules*) + (reap [sow] + (for ([impl (in-list (platform-impls (*active-platform*)))]) + (define rules (impl-info impl 'identities)) + (for ([identity (in-list rules)]) + (match identity + [(list 'exact name expr) + (when (not (expr-otype expr)) + (error "Exact identity expr cannot infer type")) + (define otype (expr-otype expr)) + (define var-types (type-verify expr otype)) + (define prog (expr->prog expr otype)) + (define r + (rule name + prog + (prog->spec prog) + (for/hash ([binding (in-list var-types)]) + (values (car binding) (cdr binding))) + (impl-info impl 'otype))) + (sow r)] + [(list 'commutes name expr rev-expr) + (define vars (impl-info impl 'vars)) + (define itype (car (impl-info impl 'itype))) + (define otype (impl-info impl 'otype)) + (define r + (rule name + (expr->prog expr otype) + (expr->prog rev-expr otype) + (for/hash ([v (in-list vars)]) + (values v itype)) + otype)) ; Commutes by definition the types are matching + (sow r)] + [(list 'directed name lhs rhs) + (define lotype (expr-otype lhs)) + (define rotype (expr-otype rhs)) + (when (and (not lotype) (not rotype)) + (error "Could not find type for lhs ~a and rhs ~a" lhs rhs)) + (when (not lotype) + (set! lotype rotype)) + (when (not rotype) + (set! rotype lotype)) + (when (not (equal? lotype rotype)) + (error "Incompatible types for lhs ~a and rhs ~a" lhs rhs)) + (define var-types (merge-bindings (type-verify lhs lotype) (type-verify rhs rotype))) + (define r + (rule name + (expr->prog lhs lotype) + (expr->prog rhs rotype) + (for/hash ([binding (in-list var-types)]) + (values (car binding) (cdr binding))) + (impl-info impl 'otype))) + (sow r)]))))) diff --git a/src/syntax/sugar.rkt b/src/syntax/sugar.rkt index 36870bc8d..c6576b404 100644 --- a/src/syntax/sugar.rkt +++ b/src/syntax/sugar.rkt @@ -177,7 +177,11 @@ (define ift* (loop ift prop-dict)) (define iff* (loop iff prop-dict)) (list 'if cond* ift* iff*)] - [(list '! props ... body) (loop body (apply dict-set prop-dict props))] + [(list '! props ... body) + (loop body + (if (not (null? props)) + (apply dict-set prop-dict props) + prop-dict))] [(list 'neg arg) ; non-standard but useful [TODO: remove] (define arg* (loop arg prop-dict)) (fpcore->impl-app '- prop-dict (list arg*) ctx)] diff --git a/src/syntax/syntax.rkt b/src/syntax/syntax.rkt index 11fd6702b..83f306cab 100644 --- a/src/syntax/syntax.rkt +++ b/src/syntax/syntax.rkt @@ -223,7 +223,7 @@ ;; - its FPCore representation ;; - a floating-point implementation ;; -(struct operator-impl (name ctx spec fpcore fl)) +(struct operator-impl (name ctx spec fpcore fl identities)) ;; Operator implementation table ;; Tracks implementations that are loaded into Racket's runtime @@ -237,7 +237,7 @@ ;; Looks up a property `field` of an real operator `op`. ;; Panics if the operator is not found. (define/contract (impl-info impl field) - (-> symbol? (or/c 'vars 'itype 'otype 'spec 'fpcore 'fl) any/c) + (-> symbol? (or/c 'vars 'itype 'otype 'spec 'fpcore 'fl 'identities) any/c) (unless (hash-has-key? operator-impls impl) (error 'impl-info "Unknown operator implementation ~a" impl)) (define info (hash-ref operator-impls impl)) @@ -247,7 +247,8 @@ [(otype) (context-repr (operator-impl-ctx info))] [(spec) (operator-impl-spec info)] [(fpcore) (operator-impl-fpcore info)] - [(fl) (operator-impl-fl info)])) + [(fl) (operator-impl-fl info)] + [(identities) (operator-impl-identities info)])) ;; Returns all operator implementations. (define (all-operator-impls) @@ -327,8 +328,16 @@ ; Registers an operator implementation `name` with context `ctx` and spec `spec. ; Can optionally specify a floating-point implementation and fpcore translation. -(define/contract (register-operator-impl! name ctx spec #:fl [fl-proc #f] #:fpcore [fpcore #f]) - (->* (symbol? context? any/c) (#:fl (or/c procedure? #f) #:fpcore any/c) void?) +(define/contract (register-operator-impl! name + ctx + spec + #:commutes? [commutes? #f] + #:fl [fl-proc #f] + #:fpcore [fpcore #f] + #:identities [identities #f]) + (->* (symbol? context? any/c) + (#:commutes? boolean? #:fl (or/c procedure? #f) #:fpcore any/c #:identities any/c) + void?) ; check specification (check-spec! name ctx spec) (define vars (context-vars ctx)) @@ -383,10 +392,62 @@ (first exs) fail)) name)])) + + ; make hash table + (define rules '()) + (define rule-names (make-hasheq)) + (define commutes? #f) + (when identities + (when commutes? + (cons (list 'commutes) identities)) + (set! rules + (for/list ([ident (in-list identities)] + [i (in-naturals)]) + (match ident + [(list ident-name lhs-expr rhs-expr) + (cond + [(hash-has-key? rule-names ident-name) + (raise-herbie-syntax-error "Duplicate identity ~a" ident-name)] + [(not (well-formed? lhs-expr)) + (raise-herbie-syntax-error "Ill-formed identity expression ~a" lhs-expr)] + [(not (well-formed? rhs-expr)) + (raise-herbie-syntax-error "Ill-formed identity expression ~a" rhs-expr)] + [else + (define rule-name (string->symbol (format "~a-~a" ident-name name))) + (hash-set! rule-names rule-name #f) + (list 'directed rule-name lhs-expr rhs-expr)])] + [(list 'exact expr) + (cond + [(not (well-formed? expr)) + (raise-herbie-syntax-error "Ill-formed identity expression ~a" expr)] + [else + (define rule-name (gensym (string->symbol (format "~a-exact-~a" name i)))) + (hash-set! rule-names rule-name #f) + (list 'exact rule-name expr)])] + [(list 'commutes) + (cond + [commutes? (error "Commutes identity already defined")] + [(hash-has-key? rule-names (string->symbol (format "~a-commutes" name))) + (error "Commutes identity already manually defined")] + [(not (equal? (length vars) 2)) + (raise-herbie-syntax-error "Cannot commute a non 2-ary operator")] + [else + (set! commutes? #t) + (define rule-name (string->symbol (format "~a-commutes" name))) + (hash-set! rule-names rule-name #f) + (list 'commutes rule-name `(,name ,@vars) `(,name ,@(reverse vars)))])])))) + ; update tables - (define impl (operator-impl name ctx spec fpcore* fl-proc*)) + (define impl (operator-impl name ctx spec fpcore* fl-proc* rules)) (hash-set! operator-impls name impl)) +(define (well-formed? expr) + (match expr + [(? number?) #t] + [(? variable?) #t] + [`(,impl ,args ...) (andmap well-formed? args)] + [_ #f])) + (define-syntax (define-operator-impl stx) (define (oops! why [sub-stx #f]) (raise-syntax-error 'define-operator-impl why stx sub-stx)) @@ -400,9 +461,11 @@ (for ([var (in-list vars)]) (unless (identifier? var) (oops! "expected identifier" var))) + (define commutes? #f) (define spec #f) (define core #f) (define fl-expr #f) + (define identities #f) (let loop ([fields fields]) (syntax-case fields () [() @@ -411,14 +474,18 @@ (with-syntax ([id id] [spec spec] [core core] - [fl-expr fl-expr]) + [commutes? commutes?] + [fl-expr fl-expr] + [identities identities]) #'(register-operator-impl! 'id (context '(var ...) (get-representation 'rtype) (list (get-representation 'repr) ...)) 'spec + #:commutes? 'commutes? #:fl fl-expr - #:fpcore 'core))] + #:fpcore 'core + #:identities 'identities))] [(#:spec expr rest ...) (cond [spec (oops! "multiple #:spec clauses" stx)] @@ -440,6 +507,27 @@ (set! fl-expr #'expr) (loop #'(rest ...))])] [(#:fl) (oops! "expected value after keyword `#:fl`" stx)] + [(#:commutes rest ...) + (cond + [commutes? (oops! "multiple #:commutes clauses" stx)] + [else + (set! commutes? #t) + (loop #'(rest ...))])] + [(#:identities (ident-exprs ...) rest ...) + (cond + [identities (oops! "multiple #:identities clauses" stx)] + [else + (set! identities + (let ident-loop ([ident-exprs #'(ident-exprs ...)]) + (syntax-case ident-exprs () + [() '()] + [([name lhs-expr rhs-expr] rem ...) + (cons (list #'name #'lhs-expr #'rhs-expr) (ident-loop #'(rem ...)))] + [(#:exact expr rem ...) (cons (list 'exact #'expr) (ident-loop #'(rem ...)))] + [_ (oops! "bad syntax" ident-exprs)]))) + (loop #'(rest ...))])] + [(#:identities rest ...) (oops! "expected list of impl identities" stx)] + [(#:identities) (oops! "expected value after keyword #:identities clause" stx)] ; bad [_ (oops! "bad syntax" fields)])))] [_ (oops! "bad syntax")]))