Skip to content

Commit

Permalink
Split out core local error from Odyssey.
Browse files Browse the repository at this point in the history
  • Loading branch information
zaneenders committed Oct 14, 2024
1 parent 7d38bd3 commit a94bb07
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 45 deletions.
11 changes: 3 additions & 8 deletions src/core/explain.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,10 @@
[else #t]))

(define (actual-errors expr pcontext)

(define errs
(match-define (cons subexprs pt-errorss)
(parameterize ([*pcontext* pcontext])
(first (compute-local-errors (list (all-subexpressions expr)) (*context*) #f))))

(define pruned (make-hash))
(for ([(k v) (in-hash errs)])
(hash-set! pruned k (hash-ref v 'ulp-errs)))
(match-define (cons subexprs pt-errorss) (flip-lists (hash->list pruned)))
(flip-lists (hash->list (first (compute-local-errors (list (all-subexpressions expr))
(*context*)))))))

(define pt-worst-subexpr
(append* (reap [sow]
Expand Down
113 changes: 76 additions & 37 deletions src/core/localize.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,10 @@

(define (batch-localize-errors exprs ctx)
(define subexprss (map all-subexpressions exprs))
(define errss (compute-local-errors subexprss ctx #f))

(define pruned-list
(for/list ([h (in-list errss)])
(define pruned (make-hash))
(for ([(k v) (in-hash h)])
(hash-set! pruned k (hash-ref v 'ulp-errs)))
pruned))
(define errss (compute-local-errors subexprss ctx))

(for/list ([_ (in-list exprs)]
[errs (in-list pruned-list)])
[errs (in-list errss)])
(sort (sort (for/list ([(subexpr err) (in-hash errs)]
#:when (or (list? subexpr) (approx? subexpr)))
(cons err subexpr))
Expand All @@ -131,7 +124,55 @@
#:key (compose errors-score car))))

; Compute local error or each sampled point at each node in `prog`.
(define (compute-local-errors subexprss ctx true-err?)
(define (compute-local-errors subexprss ctx)
(define exprs-list (append* subexprss)) ; unroll subexprss
(define ctx-list
(for/list ([subexpr (in-list exprs-list)])
(struct-copy context ctx [repr (repr-of subexpr ctx)])))

(define expr-batch (progs->batch exprs-list))
(define nodes (batch-nodes expr-batch))
(define roots (batch-roots expr-batch))

(define subexprs-fn (eval-progs-real (map prog->spec exprs-list) ctx-list))

(define errs
(for/vector #:length (vector-length roots)
([node (in-vector roots)])
(make-vector (pcontext-length (*pcontext*)))))

(for ([(pt ex) (in-pcontext (*pcontext*))]
[pt-idx (in-naturals)])
(define exacts (list->vector (apply subexprs-fn pt)))
(for ([expr (in-list exprs-list)]
[root (in-vector roots)]
[exact (in-vector exacts)]
[expr-idx (in-naturals)])
(define err
(match (vector-ref nodes root)
[(? literal?) 1]
[(? variable?) 1]
[(approx _ impl)
(define repr (repr-of expr ctx))
(ulp-difference exact (vector-ref exacts (vector-member impl roots)) repr)]
[`(if ,c ,ift ,iff) 1]
[(list f args ...)
(define repr (impl-info f 'otype))
(define argapprox
(for/list ([idx (in-list args)])
(vector-ref exacts (vector-member idx roots)))) ; arg's index mapping to exact
(define approx (apply (impl-info f 'fl) argapprox))
(ulp-difference exact approx repr)]))
(vector-set! (vector-ref errs expr-idx) pt-idx err)))

(define n 0)
(for/list ([subexprs (in-list subexprss)])
(for*/hash ([subexpr (in-list subexprs)])
(begin0 (values subexpr (vector->list (vector-ref errs n)))
(set! n (add1 n))))))

; Compute local error or each sampled point at each node in `prog`.
(define (compute-errors subexprss ctx)
; true-err? is a flag used weather we should run true error or not.
; TODO split out local error for `core` use vs Odyessy?
(define our_repr (context-repr ctx))
Expand Down Expand Up @@ -184,31 +225,30 @@
[actual (in-vector actuals)]
[expr-idx (in-naturals)])
(define true-err
(if true-err? ;; ??? Whats the default values for true error literal, variable approx and if?
(match (vector-ref nodes root)
[(? literal?) 0]
[(? variable?) 0]
[(approx aprx-spec impl) exact] ;; TODO not sure what to do here.
[`(if ,c ,ift ,iff) 0]
[(list f args-roots ...)
;; Find the index of the variables we need to substitute.
(match exact
[`+nan.0 `+nan.0]
[`-nan.0 `-nan.0]
[`+inf.0 `+inf.0]
[`-inf.0 `-inf.0]
[value
; __exact double underscore to avoid conflicts with user provided
; variables. Could use name mangling long term.
(define modifed-vars (append all-vars `(__exact)))
(define true-error-expr (list `(- ,spec __exact)))
(define diffMachine
(rival-compile true-error-expr modifed-vars (list flonum-discretization)))
(define inputs (map (representation-repr->bf our_repr) (append pt (list exact))))
;; ??? Is this always length 1, as we are asking about exact?
(define true-error (vector-ref (rival-apply diffMachine (list->vector inputs)) 0))
true-error])])
#f))
;; ??? Whats the default values for true error literal, variable approx and if?
(match (vector-ref nodes root)
[(? literal?) 0]
[(? variable?) 0]
[(approx aprx-spec impl) exact] ;; TODO not sure what to do here.
[`(if ,c ,ift ,iff) 0]
[(list f args-roots ...)
;; Find the index of the variables we need to substitute.
(match exact
[`+nan.0 `+nan.0]
[`-nan.0 `-nan.0]
[`+inf.0 `+inf.0]
[`-inf.0 `-inf.0]
[value
; __exact double underscore to avoid conflicts with user provided
; variables. Could use name mangling long term.
(define modifed-vars (append all-vars `(__exact)))
(define true-error-expr (list `(- ,spec __exact)))
(define diffMachine
(rival-compile true-error-expr modifed-vars (list flonum-discretization)))
(define inputs (map (representation-repr->bf our_repr) (append pt (list exact))))
;; ??? Is this always length 1, as we are asking about exact?
(define true-error (vector-ref (rival-apply diffMachine (list->vector inputs)) 0))
true-error])]))

(define ulp-err
(match (vector-ref nodes root)
Expand Down Expand Up @@ -244,12 +284,11 @@
'true-error-values
(vector->list (vector-ref true-error-out n))))
(set! n (add1 n))))))

;; Compute the local error of every subexpression of `prog`
;; and returns the error information as an S-expr in the
;; same shape as `prog`
(define (local-error-as-tree test ctx)
(define errs (first (compute-local-errors (list (all-subexpressions (test-input test))) ctx #t)))
(define errs (first (compute-errors (list (all-subexpressions (test-input test))) ctx)))

(define local-error
(let loop ([expr (test-input test)])
Expand Down

0 comments on commit a94bb07

Please sign in to comment.