diff --git a/src/core/explain.rkt b/src/core/explain.rkt index bed6dc05f..512adfcc2 100644 --- a/src/core/explain.rkt +++ b/src/core/explain.rkt @@ -38,15 +38,10 @@ [else #t])) (define (actual-errors expr pcontext) - - (define errs + (match-define (cons subexprs pt-errorss) (parameterize ([*pcontext* pcontext]) - (first (compute-local-errors (list (all-subexpressions expr)) (*context*) #f)))) - - (define pruned (make-hash)) - (for ([(k v) (in-hash errs)]) - (hash-set! pruned k (hash-ref v 'ulp-errs))) - (match-define (cons subexprs pt-errorss) (flip-lists (hash->list pruned))) + (flip-lists (hash->list (first (compute-local-errors (list (all-subexpressions expr)) + (*context*))))))) (define pt-worst-subexpr (append* (reap [sow] diff --git a/src/core/localize.rkt b/src/core/localize.rkt index 8b3f0fb80..7d25d52ae 100644 --- a/src/core/localize.rkt +++ b/src/core/localize.rkt @@ -111,17 +111,10 @@ (define (batch-localize-errors exprs ctx) (define subexprss (map all-subexpressions exprs)) - (define errss (compute-local-errors subexprss ctx #f)) - - (define pruned-list - (for/list ([h (in-list errss)]) - (define pruned (make-hash)) - (for ([(k v) (in-hash h)]) - (hash-set! pruned k (hash-ref v 'ulp-errs))) - pruned)) + (define errss (compute-local-errors subexprss ctx)) (for/list ([_ (in-list exprs)] - [errs (in-list pruned-list)]) + [errs (in-list errss)]) (sort (sort (for/list ([(subexpr err) (in-hash errs)] #:when (or (list? subexpr) (approx? subexpr))) (cons err subexpr)) @@ -131,7 +124,55 @@ #:key (compose errors-score car)))) ; Compute local error or each sampled point at each node in `prog`. -(define (compute-local-errors subexprss ctx true-err?) +(define (compute-local-errors subexprss ctx) + (define exprs-list (append* subexprss)) ; unroll subexprss + (define ctx-list + (for/list ([subexpr (in-list exprs-list)]) + (struct-copy context ctx [repr (repr-of subexpr ctx)]))) + + (define expr-batch (progs->batch exprs-list)) + (define nodes (batch-nodes expr-batch)) + (define roots (batch-roots expr-batch)) + + (define subexprs-fn (eval-progs-real (map prog->spec exprs-list) ctx-list)) + + (define errs + (for/vector #:length (vector-length roots) + ([node (in-vector roots)]) + (make-vector (pcontext-length (*pcontext*))))) + + (for ([(pt ex) (in-pcontext (*pcontext*))] + [pt-idx (in-naturals)]) + (define exacts (list->vector (apply subexprs-fn pt))) + (for ([expr (in-list exprs-list)] + [root (in-vector roots)] + [exact (in-vector exacts)] + [expr-idx (in-naturals)]) + (define err + (match (vector-ref nodes root) + [(? literal?) 1] + [(? variable?) 1] + [(approx _ impl) + (define repr (repr-of expr ctx)) + (ulp-difference exact (vector-ref exacts (vector-member impl roots)) repr)] + [`(if ,c ,ift ,iff) 1] + [(list f args ...) + (define repr (impl-info f 'otype)) + (define argapprox + (for/list ([idx (in-list args)]) + (vector-ref exacts (vector-member idx roots)))) ; arg's index mapping to exact + (define approx (apply (impl-info f 'fl) argapprox)) + (ulp-difference exact approx repr)])) + (vector-set! (vector-ref errs expr-idx) pt-idx err))) + + (define n 0) + (for/list ([subexprs (in-list subexprss)]) + (for*/hash ([subexpr (in-list subexprs)]) + (begin0 (values subexpr (vector->list (vector-ref errs n))) + (set! n (add1 n)))))) + +; Compute local error or each sampled point at each node in `prog`. +(define (compute-errors subexprss ctx) ; true-err? is a flag used weather we should run true error or not. ; TODO split out local error for `core` use vs Odyessy? (define our_repr (context-repr ctx)) @@ -184,31 +225,30 @@ [actual (in-vector actuals)] [expr-idx (in-naturals)]) (define true-err - (if true-err? ;; ??? Whats the default values for true error literal, variable approx and if? - (match (vector-ref nodes root) - [(? literal?) 0] - [(? variable?) 0] - [(approx aprx-spec impl) exact] ;; TODO not sure what to do here. - [`(if ,c ,ift ,iff) 0] - [(list f args-roots ...) - ;; Find the index of the variables we need to substitute. - (match exact - [`+nan.0 `+nan.0] - [`-nan.0 `-nan.0] - [`+inf.0 `+inf.0] - [`-inf.0 `-inf.0] - [value - ; __exact double underscore to avoid conflicts with user provided - ; variables. Could use name mangling long term. - (define modifed-vars (append all-vars `(__exact))) - (define true-error-expr (list `(- ,spec __exact))) - (define diffMachine - (rival-compile true-error-expr modifed-vars (list flonum-discretization))) - (define inputs (map (representation-repr->bf our_repr) (append pt (list exact)))) - ;; ??? Is this always length 1, as we are asking about exact? - (define true-error (vector-ref (rival-apply diffMachine (list->vector inputs)) 0)) - true-error])]) - #f)) + ;; ??? Whats the default values for true error literal, variable approx and if? + (match (vector-ref nodes root) + [(? literal?) 0] + [(? variable?) 0] + [(approx aprx-spec impl) exact] ;; TODO not sure what to do here. + [`(if ,c ,ift ,iff) 0] + [(list f args-roots ...) + ;; Find the index of the variables we need to substitute. + (match exact + [`+nan.0 `+nan.0] + [`-nan.0 `-nan.0] + [`+inf.0 `+inf.0] + [`-inf.0 `-inf.0] + [value + ; __exact double underscore to avoid conflicts with user provided + ; variables. Could use name mangling long term. + (define modifed-vars (append all-vars `(__exact))) + (define true-error-expr (list `(- ,spec __exact))) + (define diffMachine + (rival-compile true-error-expr modifed-vars (list flonum-discretization))) + (define inputs (map (representation-repr->bf our_repr) (append pt (list exact)))) + ;; ??? Is this always length 1, as we are asking about exact? + (define true-error (vector-ref (rival-apply diffMachine (list->vector inputs)) 0)) + true-error])])) (define ulp-err (match (vector-ref nodes root) @@ -244,12 +284,11 @@ 'true-error-values (vector->list (vector-ref true-error-out n)))) (set! n (add1 n)))))) - ;; Compute the local error of every subexpression of `prog` ;; and returns the error information as an S-expr in the ;; same shape as `prog` (define (local-error-as-tree test ctx) - (define errs (first (compute-local-errors (list (all-subexpressions (test-input test))) ctx #t))) + (define errs (first (compute-errors (list (all-subexpressions (test-input test))) ctx))) (define local-error (let loop ([expr (test-input test)])