diff --git a/infra/testApi.mjs b/infra/testApi.mjs index 48ebbf427..9348cbeaf 100644 --- a/infra/testApi.mjs +++ b/infra/testApi.mjs @@ -6,6 +6,7 @@ import { strict as assert } from 'node:assert'; // use strict equality everywhe const SAMPLE_SIZE = 8000 const FPCoreFormula = '(FPCore (x) (- (sqrt (+ x 1)) (sqrt x)))' const FPCoreFormula2 = '(FPCore (x) (- (sqrt (+ x 1))))' +const FPCoreFormula3 = '(FPCore (x) (if (<= (- (sqrt (+ x 1.0)) (sqrt x)) 0.05) (* 0.5 (sqrt (/ 1.0 x))) (fma (fma (- 0.125) x 0.5) x (- 1.0 (sqrt x)))))' const eval_sample = [[[1], -1.4142135623730951]] // improve endpoint @@ -88,7 +89,7 @@ assert.deepEqual(sample.points[1], sample2.points[1]) const explainBody = { method: 'POST', body: JSON.stringify({ - formula: FPCoreFormula, sample: sample2.points + formula: FPCoreFormula, sample: sample.points }) } const explain = await (await fetch(makeEndpoint("/api/explanations"), explainBody)).json() @@ -142,7 +143,7 @@ assert.deepEqual(calculateAsyncResult.points, [[[1], -1.4142135623730951]]) // Local error endpoint const localErrorBody = { method: 'POST', body: JSON.stringify({ - formula: FPCoreFormula, sample: sample2.points + formula: FPCoreFormula, sample: sample.points }) } const localError = await (await fetch(makeEndpoint("/api/localerror"), localErrorBody)).json() @@ -164,6 +165,64 @@ const localError2 = await (await fetch(makeEndpoint("/api/localerror"), { })).json() // Test that different sample points produce different job ids ensuring that different results are served for these inputs. assert.notEqual(localError1.job, localError2.job) +// Assert local error works for default example. +const ignoredValue = 1e+308 +'(FPCore (1e-100) (- (sqrt (+ x 1)) (sqrt x)))' +const localError5 = await (await fetch(makeEndpoint("/api/localerror"), { + method: 'POST', body: JSON.stringify({ + formula: FPCoreFormula, sample: [[[1e-100], ignoredValue]], seed: 5 + }) +})).json() +const rootMinusNode = localError5.tree +const leftSQRT = localError5.tree['children'][0] +const rightSQRT = localError5.tree['children'][1] +const plusNode = localError5.tree['children'][0]['children'][0] +const xNode = localError5.tree['children'][0]['children'][0]['children'][0] +const oneNode = localError5.tree['children'][0]['children'][0]['children'][1] + +// node, name, approx_value, avg_error, exact_value, true_error_value, ulps_error +assertCheckNode(rootMinusNode, '-', '1.0', '0.0', '1.0', '-1e-50', 1) +assertCheckNode(leftSQRT, 'sqrt', '1.0', '0.0', '1.0', '5e-101', 1) +assertCheckNode(rightSQRT, 'sqrt', '1e-50', '0.0', '1e-50', '2.379726195519099e-68', 1) +assertCheckNode(plusNode, '+', '1.0', '0.0', '1.0', '1e-100', 1) +assertCheckNode(xNode, 'x', '1e-100', '0.0', '1e-100', '0', 1) +assertCheckNode(oneNode, '1.0', '1.0', '0.0', '1.0', '-0.0', 1) + +// '(FPCore (1e100) (- (sqrt (+ x 1)) (sqrt x)))' +const localError6 = await (await fetch(makeEndpoint("/api/localerror"), { + method: 'POST', body: JSON.stringify({ + formula: FPCoreFormula, sample: [[[1e100], ignoredValue]], seed: 5 + }) +})).json() +const rootMinusNode6 = localError6.tree +const leftSQRT6 = localError6.tree['children'][0] +const rightSQRT6 = localError6.tree['children'][1] +const plusNode6 = localError6.tree['children'][0]['children'][0] +const xNode6 = localError6.tree['children'][0]['children'][0]['children'][0] +const oneNode6 = localError6.tree['children'][0]['children'][0]['children'][1] +// node, name, approx_value, avg_error, exact_value, true_error_value, ulps_error +assertCheckNode(rootMinusNode6, '-', '0.0', '61.7', '5e-51', '-7.78383463033115e-68', 3854499065107888000) +assertCheckNode(leftSQRT6, 'sqrt', '1e+50', '0.0', '1e+50', '-6.834625285603891e+33', 1) +assertCheckNode(rightSQRT6, 'sqrt', '1e+50', '0.0', '1e+50', '-6.834625285603891e+33', 1) +assertCheckNode(plusNode6, '+', '1e+100', '0.0', '1e+100', '1.0', 1) +assertCheckNode(xNode6, 'x', '1e+100', '0.0', '1e+100', '0', 1) +assertCheckNode(oneNode6, '1.0', '1.0', '0.0', '1.0', '-0.0', 1) + +function assertCheckNode(node, name, approx, avg_error, exact_value, true_error_value, ulps_error) { + assert.equal(node['e'], name) + assert.equal(node['approx-value'][0], approx) + assert.equal(node['avg-error'], avg_error) + assert.equal(node['exact-value'][0], exact_value) + assert.equal(node['true-error-value'][0], true_error_value) + assert.equal(node['ulps-error'][0], ulps_error) +} + +// TODO if statements +// const localError7 = await (await fetch(makeEndpoint("/api/localerror"), { +// method: 'POST', body: JSON.stringify({ +// formula: FPCoreFormula3, sample: [[[1e100], ignoredValue]], seed: 5 +// }) +// })).json() // Alternatives endpoint const altBody = { diff --git a/src/core/explain.rkt b/src/core/explain.rkt index 283061eb0..512adfcc2 100644 --- a/src/core/explain.rkt +++ b/src/core/explain.rkt @@ -38,16 +38,10 @@ [else #t])) (define (actual-errors expr pcontext) - - (define errs + (match-define (cons subexprs pt-errorss) (parameterize ([*pcontext* pcontext]) - (first (compute-local-errors (list (all-subexpressions expr)) (*context*))))) - - (define pruned (make-hash)) - (for ([(k v) (in-hash errs)]) - (hash-set! pruned k (hash-ref v 'errs))) - (define idk (flip-lists (hash->list pruned))) - (match-define (cons subexprs pt-errorss) idk) + (flip-lists (hash->list (first (compute-local-errors (list (all-subexpressions expr)) + (*context*))))))) (define pt-worst-subexpr (append* (reap [sow] diff --git a/src/core/localize.rkt b/src/core/localize.rkt index f616985e0..c4277bbfd 100644 --- a/src/core/localize.rkt +++ b/src/core/localize.rkt @@ -1,5 +1,7 @@ #lang racket +(require math/bigfloat + rival) (require "../syntax/sugar.rkt" "../syntax/syntax.rkt" "../syntax/types.rkt" @@ -9,7 +11,7 @@ "../syntax/platform.rkt" "../syntax/read.rkt" "../syntax/read.rkt" - "rival.rkt" + "../core/rival.rkt" "points.rkt" "programs.rkt" "sampling.rkt" @@ -111,15 +113,8 @@ (define subexprss (map all-subexpressions exprs)) (define errss (compute-local-errors subexprss ctx)) - (define pruned-list - (for/list ([h (in-list errss)]) - (define pruned (make-hash)) - (for ([(k v) (in-hash h)]) - (hash-set! pruned k (hash-ref v 'errs))) - pruned)) - (for/list ([_ (in-list exprs)] - [errs (in-list pruned-list)]) + [errs (in-list errss)]) (sort (sort (for/list ([(subexpr err) (in-hash errs)] #:when (or (list? subexpr) (approx? subexpr))) (cons err subexpr)) @@ -139,37 +134,169 @@ (define nodes (batch-nodes expr-batch)) (define roots (batch-roots expr-batch)) - ; TODO don't ignore the status code from make-real-compiler in eval-progs-real (define subexprs-fn (eval-progs-real (map prog->spec exprs-list) ctx-list)) - (define actual-value-fn (compile-progs exprs-list ctx)) (define errs (for/vector #:length (vector-length roots) ([node (in-vector roots)]) (make-vector (pcontext-length (*pcontext*))))) + (for ([(pt ex) (in-pcontext (*pcontext*))] + [pt-idx (in-naturals)]) + (define exacts (list->vector (apply subexprs-fn pt))) + (for ([expr (in-list exprs-list)] + [root (in-vector roots)] + [exact (in-vector exacts)] + [expr-idx (in-naturals)]) + (define err + (match (vector-ref nodes root) + [(? literal?) 1] + [(? variable?) 1] + [(approx _ impl) + (define repr (repr-of expr ctx)) + (ulp-difference exact (vector-ref exacts (vector-member impl roots)) repr)] + [`(if ,c ,ift ,iff) 1] + [(list f args ...) + (define repr (impl-info f 'otype)) + (define argapprox + (for/list ([idx (in-list args)]) + (vector-ref exacts (vector-member idx roots)))) ; arg's index mapping to exact + (define approx (apply (impl-info f 'fl) argapprox)) + (ulp-difference exact approx repr)])) + (vector-set! (vector-ref errs expr-idx) pt-idx err))) + + (define n 0) + (for/list ([subexprs (in-list subexprss)]) + (for*/hash ([subexpr (in-list subexprs)]) + (begin0 (values subexpr (vector->list (vector-ref errs n))) + (set! n (add1 n)))))) + +(define (check-for-invalid-exact input) + (match input + ['+inf.0 #t] + ['-inf.0 #t] + ['+nan.0 #t] + [value (or (bfnan? value) (boolean? value))])) + +; Compute local error or each sampled point at each node in `prog`. +(define (compute-errors subexprss ctx) + (define exprs-list (append* subexprss)) ; unroll subexprss + (define spec-list (map prog->spec exprs-list)) + (define ctx-list + (for/list ([subexpr (in-list exprs-list)]) + (struct-copy context ctx [repr (repr-of subexpr ctx)]))) + (define exact-var-name '__exact) + (define extended + (for/list ([ctx (in-list ctx-list)]) + (context-extend ctx exact-var-name (context-repr ctx)))) + (define compare-specs + (for/list ([spec (in-list spec-list)]) + `(- ,spec ,exact-var-name))) + + (define expr-batch (progs->batch exprs-list)) + (define nodes (batch-nodes expr-batch)) + (define roots (batch-roots expr-batch)) + + (define subexprs-fn (eval-progs-real spec-list ctx-list)) + (define actual-value-fn (compile-progs exprs-list ctx)) + + (define ulp-errs + (for/vector #:length (vector-length roots) + ([node (in-vector roots)]) + (make-vector (pcontext-length (*pcontext*))))) + (define exacts-out (for/vector #:length (vector-length roots) ([node (in-vector roots)]) (make-vector (pcontext-length (*pcontext*))))) - (define actuals-out + (define approx-out (for/vector #:length (vector-length roots) ([node (in-vector roots)]) (make-vector (pcontext-length (*pcontext*))))) + (define true-error-out + (for/vector #:length (vector-length roots) + ([node (in-vector roots)]) + (make-vector (pcontext-length (*pcontext*))))) + + (define (error-for cur-ctx cur-sepc pt exact) + (define extended (context-append cur-ctx exact-var-name (context-repr cur-ctx))) + (define compare-specs `(- ,cur-sepc ,exact-var-name)) + (define new-compare (eval-progs-real (list compare-specs) (list extended))) + (define inputs (append pt (list exact))) + (define true-errors (list->vector (apply new-compare inputs))) + (vector-ref true-errors 0)) + + (define spec-vec (list->vector spec-list)) + (define ctx-vec (list->vector ctx-list)) (for ([(pt ex) (in-pcontext (*pcontext*))] [pt-idx (in-naturals)]) (define exacts (list->vector (apply subexprs-fn pt))) (define actuals (apply actual-value-fn pt)) + (define index 0) + ; (for ([i (in-naturals)] + ; [node (in-vector nodes)]) + ; (eprintf "node[~a] ~a\n" i node)) - (for ([expr (in-list exprs-list)] + (define (parse-true-error i pt pt-idx) + (define root (vector-ref roots i)) + (define node (vector-ref nodes root)) + (define cur-ctx (vector-ref ctx-vec i)) + (define cur-sepc (vector-ref spec-vec i)) + (define exact (vector-ref exacts i)) + (define true-error + (match node + [(? literal?) + ; (eprintf "literal?: ~a\n" node) + ; (define inputs (append (list exact) (vector->list (make-vector (length pt) 0)))) + ; (define compare-fn (eval-progs-real compare-specs extended)) + ; (define true-errors (list->vector (apply compare-fn inputs))) + ; TODO not correct because of eval-progs-real and literals being different. + ; (define extended (context-append cur-ctx exact-var-name (context-repr cur-ctx))) + ; (define compare-specs `(- ,cur-sepc ,exact-var-name)) + ; (define new-compare (eval-progs-real (list compare-specs) (list extended))) + ; (define inputs (append pt (list exact))) + ; (define true-errors (list->vector (apply new-compare inputs))) + ; (vector-ref true-errors 0) + (if (check-for-invalid-exact exact) + #f + (error-for cur-ctx cur-sepc pt exact))] + [(? variable?) + ; (eprintf "variable?: ~a\n" node) + 0] + [(approx approx-spec impl) + ; (green "approx") + 0] + [`(if ,c ,ift ,iff) + ; (eprintf "if: ~a\n" node) + (parse-true-error (vector-member c roots) pt pt-idx) + (parse-true-error (vector-member ift roots) pt pt-idx) + (parse-true-error (vector-member iff roots) pt pt-idx) + 0] + [(list f args-roots ...) + ; (eprintf "func[~a]: ~a, count: ~a\n" root node (length args-roots)) + (for ([idx (in-list args-roots)]) + (define node (vector-member idx roots)) + ; (eprintf "root: ~a, idx: ~a, node: ~a\n" root idx node) + (parse-true-error node pt pt-idx)) + ; (eprintf "~a: ~a ~a\n" root cur-sepc exact) + (if (check-for-invalid-exact exact) + #f + (error-for cur-ctx cur-sepc pt exact))])) + (vector-set! (vector-ref true-error-out i) pt-idx true-error)) + + (parse-true-error index pt pt-idx) + + (for ([cur-sepc (in-list spec-list)] + [cur-ctx (in-list ctx-list)] + [expr (in-list exprs-list)] [root (in-vector roots)] [exact (in-vector exacts)] [actual (in-vector actuals)] [expr-idx (in-naturals)]) - (define err + (define ulp-err (match (vector-ref nodes root) [(? literal?) 1] [(? variable?) 1] @@ -177,39 +304,45 @@ (define repr (repr-of expr ctx)) (ulp-difference exact (vector-ref exacts (vector-member impl roots)) repr)] [`(if ,c ,ift ,iff) 1] - [(list f args ...) + [(list f args-roots ...) (define repr (impl-info f 'otype)) (define argapprox - (for/list ([idx (in-list args)]) + (for/list ([idx (in-list args-roots)]) (vector-ref exacts (vector-member idx roots)))) ; arg's index mapping to exact (define approx (apply (impl-info f 'fl) argapprox)) (ulp-difference exact approx repr)])) + (vector-set! (vector-ref exacts-out expr-idx) pt-idx exact) - (vector-set! (vector-ref errs expr-idx) pt-idx err) - (vector-set! (vector-ref actuals-out expr-idx) pt-idx actual))) + (vector-set! (vector-ref approx-out expr-idx) pt-idx actual) + (vector-set! (vector-ref ulp-errs expr-idx) pt-idx ulp-err))) (define n 0) (for/list ([subexprs (in-list subexprss)]) (for*/hash ([subexpr (in-list subexprs)]) (begin0 (values subexpr - (hasheq 'errs - (vector->list (vector-ref errs n)) + (hasheq 'ulp-errs + (vector->list (vector-ref ulp-errs n)) 'exact-values (vector->list (vector-ref exacts-out n)) - 'actual-values - (vector->list (vector-ref actuals-out n)))) + 'approx-values + (vector->list (vector-ref approx-out n)) + 'true-error-values + (vector->list (vector-ref true-error-out n)))) (set! n (add1 n)))))) +(define (green msg) + (printf "\e[~a;1m~a\e[0m\n" 32 msg)) + ;; Compute the local error of every subexpression of `prog` ;; and returns the error information as an S-expr in the ;; same shape as `prog` (define (local-error-as-tree test ctx) - (define errs (first (compute-local-errors (list (all-subexpressions (test-input test))) ctx))) + (define errs (first (compute-errors (list (all-subexpressions (test-input test))) ctx))) (define local-error (let loop ([expr (test-input test)]) (define expr-info (hash-ref errs expr)) - (define err-list (hash-ref expr-info 'errs)) + (define err-list (hash-ref expr-info 'ulp-errs)) (match expr [(list op args ...) (cons err-list (map loop args))] [_ (list err-list)]))) @@ -222,42 +355,59 @@ [(list op args ...) (cons exacts-list (map loop args))] [_ (list exacts-list)]))) - (define actual-values + (define approx-values + (let loop ([expr (test-input test)]) + (define expr-info (hash-ref errs expr)) + (define exacts-list (hash-ref expr-info 'approx-values)) + (match expr + [(list op args ...) (cons exacts-list (map loop args))] + [_ (list exacts-list)]))) + + (define true-error-values (let loop ([expr (test-input test)]) (define expr-info (hash-ref errs expr)) - (define actual-list (hash-ref expr-info 'actual-values)) + (define actual-list (hash-ref expr-info 'true-error-values)) (match expr [(list op args ...) (cons actual-list (map loop args))] [_ (list actual-list)]))) (define tree (let loop ([expr (prog->fpcore (test-input test) (test-context test))] - [err local-error] + [ulp-err local-error] [exact exact-values] - [actual actual-values]) + [approx approx-values] + [t-err true-error-values]) (match expr [(list op args ...) ;; err => (List (listof Integer) List ...) (hasheq 'e (~a op) + 'ulps-error + (first ulp-err) 'avg-error - (format-bits (errors-score (first err))) + (format-bits (errors-score (first ulp-err))) 'exact-value (map ~s (first exact)) - 'actual-value - (map ~s (first actual)) + 'approx-value + (map ~s (first approx)) + 'true-error-value + (map ~s (first t-err)) 'children - (map loop args (rest err) (rest exact) (rest actual)))] + (map loop args (rest ulp-err) (rest exact) (rest approx) (rest t-err)))] ;; err => (List (listof Integer)) [_ (hasheq 'e (~a expr) + 'ulps-error + (first ulp-err) 'avg-error - (format-bits (errors-score (first err))) + (format-bits (errors-score (first ulp-err))) 'exact-value (map ~s (first exact)) - 'actual-value - (map ~s (first actual)) + 'approx-value + (map ~s (first approx)) + 'true-error-value + (map ~s (first t-err)) 'children '())]))) tree) diff --git a/src/syntax/types.rkt b/src/syntax/types.rkt index e0c8d98e3..86186375b 100644 --- a/src/syntax/types.rkt +++ b/src/syntax/types.rkt @@ -12,6 +12,7 @@ (struct-out context) *context* context-extend + context-append context-lookup) (module+ internals @@ -134,5 +135,11 @@ [vars (cons var (context-vars ctx))] [var-reprs (cons repr (context-var-reprs ctx))])) +(define (context-append ctx var repr) + (struct-copy context + ctx + [vars (append (context-vars ctx) (list var))] + [var-reprs (append (context-var-reprs ctx) (list repr))])) + (define (context-lookup ctx var) (dict-ref (map cons (context-vars ctx) (context-var-reprs ctx)) var))