Skip to content

Commit

Permalink
test: adds tests for Tx and propagation.
Browse files Browse the repository at this point in the history
  • Loading branch information
jcchavezs committed Nov 15, 2018
1 parent 79cca89 commit 24b349c
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 65 deletions.
142 changes: 82 additions & 60 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ var (
_ driver.Driver = &zDriver{}
_ conn = &zConn{}
_ driver.Result = &zResult{}
_ driver.Rows = &zRows{}
)

var (
Expand Down Expand Up @@ -79,11 +78,11 @@ func wrapDriver(d driver.Driver, t *zipkin.Tracer, o TraceOptions) driver.Driver
}

func wrapConn(c driver.Conn, t *zipkin.Tracer, options TraceOptions) driver.Conn {
return &zConn{driver: c, tracer: t, options: options}
return &zConn{conn: c, tracer: t, options: options}
}

func wrapStmt(stmt driver.Stmt, query string, tracer *zipkin.Tracer, options TraceOptions) driver.Stmt {
s := zStmt{driver: stmt, query: query, options: options, tracer: tracer}
s := zStmt{stmt: stmt, query: query, options: options, tracer: tracer}
_, hasExeCtx := stmt.(driver.StmtExecContext)
_, hasQryCtx := stmt.(driver.StmtQueryContext)
c, hasColCnv := stmt.(driver.ColumnConverter)
Expand Down Expand Up @@ -147,28 +146,28 @@ func (d zDriver) Open(name string) (driver.Conn, error) {

// zConn implements driver.Conn
type zConn struct {
driver driver.Conn
conn driver.Conn
tracer *zipkin.Tracer
options TraceOptions
}

func (c zConn) Ping(ctx context.Context) (err error) {
if pinger, ok := c.driver.(driver.Pinger); ok {
if pinger, ok := c.conn.(driver.Pinger); ok {
err = pinger.Ping(ctx)
}
return
}

func (c zConn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
if exec, ok := c.driver.(driver.Execer); ok {
if exec, ok := c.conn.(driver.Execer); ok {
return exec.Exec(query, args)
}

return nil, driver.ErrSkip
}

func (c zConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (res driver.Result, err error) {
if execCtx, ok := c.driver.(driver.ExecerContext); ok {
if execCtx, ok := c.conn.(driver.ExecerContext); ok {
if zipkin.SpanFromContext(ctx) == nil && !c.options.AllowRootSpan {
return execCtx.ExecContext(ctx, query, args)
}
Expand All @@ -187,22 +186,22 @@ func (c zConn) ExecContext(ctx context.Context, query string, args []driver.Name
return nil, err
}

return zResult{driver: res, tracer: c.tracer, ctx: ctx, options: c.options}, nil
return zResult{result: res, tracer: c.tracer, ctx: ctx, options: c.options}, nil
}

return nil, driver.ErrSkip
}

func (c zConn) Query(query string, args []driver.Value) (rows driver.Rows, err error) {
if queryer, ok := c.driver.(driver.Queryer); ok {
if queryer, ok := c.conn.(driver.Queryer); ok {
return queryer.Query(query, args)
}

return nil, driver.ErrSkip
}

func (c zConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
if queryerCtx, ok := c.driver.(driver.QueryerContext); ok {
if queryerCtx, ok := c.conn.(driver.QueryerContext); ok {
if zipkin.SpanFromContext(ctx) == nil && !c.options.AllowRootSpan {
return queryerCtx.QueryContext(ctx, query, args)
}
Expand All @@ -221,14 +220,14 @@ func (c zConn) QueryContext(ctx context.Context, query string, args []driver.Nam
return nil, err
}

return zRows{driver: rows, ctx: ctx, options: c.options}, nil
return rows, nil
}

return nil, driver.ErrSkip
}

func (c zConn) Prepare(query string) (stmt driver.Stmt, err error) {
stmt, err = c.driver.Prepare(query)
stmt, err = c.conn.Prepare(query)
if err != nil {
return nil, err
}
Expand All @@ -238,72 +237,72 @@ func (c zConn) Prepare(query string) (stmt driver.Stmt, err error) {
}

func (c *zConn) Close() error {
return c.driver.Close()
return c.conn.Close()
}

func (c *zConn) Begin() (driver.Tx, error) {
return c.Begin()
}

func (c *zConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if prepCtx, ok := c.driver.(driver.ConnPrepareContext); ok {
if prepCtx, ok := c.conn.(driver.ConnPrepareContext); ok {
return prepCtx.PrepareContext(ctx, query)
}

return c.driver.Prepare(query)
return c.conn.Prepare(query)
}

func (c *zConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if zipkin.SpanFromContext(ctx) == nil && !c.options.AllowRootSpan {
if connBeginTx, ok := c.driver.(driver.ConnBeginTx); ok {
if connBeginTx, ok := c.conn.(driver.ConnBeginTx); ok {
return connBeginTx.BeginTx(ctx, opts)
}

return c.driver.Begin()
return c.conn.Begin()
}

span, _ := c.tracer.StartSpanFromContext(ctx, "sql/begin_transaction", zipkin.Kind(zipkinmodel.Client))
defer span.Finish()

setSpanDefaultTags(span, c.options.DefaultTags)

if connBeginTx, ok := c.driver.(driver.ConnBeginTx); ok {
if connBeginTx, ok := c.conn.(driver.ConnBeginTx); ok {
tx, err := connBeginTx.BeginTx(ctx, opts)
setSpanError(span, err)
if err != nil {
return nil, err
}
return zTx{driver: tx, ctx: ctx}, nil
return zTx{tx: tx, ctx: ctx, tracer: c.tracer, options: c.options}, nil
}

tx, err := c.driver.Begin()
tx, err := c.conn.Begin()
setSpanError(span, err)
if err != nil {
return nil, err
}

return zTx{driver: tx, ctx: ctx, tracer: c.tracer}, nil
return zTx{tx: tx, ctx: ctx, tracer: c.tracer, options: c.options}, nil
}

// zResult implements driver.Result
type zResult struct {
driver driver.Result
result driver.Result
ctx context.Context
tracer *zipkin.Tracer
options TraceOptions
}

func (r zResult) LastInsertId() (int64, error) {
if !r.options.LastInsertIDSpan {
return r.driver.LastInsertId()
return r.result.LastInsertId()
}

span, _ := r.tracer.StartSpanFromContext(r.ctx, "sql/last_insert_id", zipkin.Kind(zipkinmodel.Client))
defer span.Finish()

setSpanDefaultTags(span, r.options.DefaultTags)

id, err := r.driver.LastInsertId()
id, err := r.result.LastInsertId()
setSpanError(span, err)

return id, err
Expand All @@ -321,37 +320,80 @@ func (r zResult) RowsAffected() (cnt int64, err error) {
}()
}

cnt, err = r.driver.RowsAffected()
cnt, err = r.result.RowsAffected()
return
}

// zStmt implements driver.Stmt
type zStmt struct {
driver driver.Stmt
stmt driver.Stmt
query string
tracer *zipkin.Tracer
options TraceOptions
}

func (s zStmt) Exec(args []driver.Value) (driver.Result, error) {
return s.driver.Exec(args)
func (s zStmt) Exec(args []driver.Value) (res driver.Result, err error) {
if !s.options.AllowRootSpan {
return s.stmt.Exec(args)
}

span, ctx := s.tracer.StartSpanFromContext(context.Background(), "sql:exec", zipkin.Kind(zipkinmodel.Client))
setSpanDefaultTags(span, s.options.DefaultTags)

if s.options.TagQuery {
span.Tag("sql.query", s.query)
}

defer func() {
setSpanError(span, err)
span.Finish()
}()

res, err = s.stmt.Exec(args)
if err != nil {
return nil, err
}

res, err = zResult{result: res, ctx: ctx, tracer: s.tracer, options: s.options}, nil
return
}

func (s zStmt) Close() error {
return s.driver.Close()
return s.stmt.Close()
}

func (s zStmt) NumInput() int {
return s.driver.NumInput()
return s.stmt.NumInput()
}

func (s zStmt) Query(args []driver.Value) (driver.Rows, error) {
return s.driver.Query(args)
func (s zStmt) Query(args []driver.Value) (rows driver.Rows, err error) {
if !s.options.AllowRootSpan {
return s.stmt.Query(args)
}

span, _ := s.tracer.StartSpanFromContext(context.Background(), "sql:query", zipkin.Kind(zipkinmodel.Client))
setSpanDefaultTags(span, s.options.DefaultTags)

if s.options.TagQuery {
span.Tag("sql.query", s.query)
}

defer func() {
setSpanError(span, err)
span.Finish()
}()

rows, err = s.stmt.Query(args)
if err != nil {
return nil, err
}

return
}

func (s zStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) {
if zipkin.SpanFromContext(ctx) == nil && !s.options.AllowRootSpan {
return s.driver.(driver.StmtExecContext).ExecContext(ctx, args)
return s.stmt.(driver.StmtExecContext).ExecContext(ctx, args)
}

span, ctx := s.tracer.StartSpanFromContext(ctx, "sql/exec", zipkin.Kind(zipkinmodel.Client))
Expand All @@ -366,7 +408,7 @@ func (s zStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res d

setSpanDefaultTags(span, s.options.DefaultTags)

execContext := s.driver.(driver.StmtExecContext)
execContext := s.stmt.(driver.StmtExecContext)
res, err = execContext.ExecContext(ctx, args)
if err != nil {
return nil, err
Expand All @@ -378,13 +420,13 @@ func (s zStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res d
}
}

res, err = zResult{driver: res, tracer: s.tracer, ctx: ctx, options: s.options}, nil
res, err = zResult{result: res, tracer: s.tracer, ctx: ctx, options: s.options}, nil
return
}

func (s zStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) {
if zipkin.SpanFromContext(ctx) == nil && !s.options.AllowRootSpan {
return s.driver.(driver.StmtQueryContext).QueryContext(ctx, args)
return s.stmt.(driver.StmtQueryContext).QueryContext(ctx, args)
}

span, ctx := s.tracer.StartSpanFromContext(ctx, "sql/query", zipkin.Kind(zipkinmodel.Client))
Expand All @@ -405,38 +447,18 @@ func (s zStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows
}()

// we already tested driver to implement StmtQueryContext
queryContext := s.driver.(driver.StmtQueryContext)
queryContext := s.stmt.(driver.StmtQueryContext)
rows, err = queryContext.QueryContext(ctx, args)
if err != nil {
return nil, err
}

rows, err = zRows{driver: rows, ctx: ctx, options: s.options}, nil
return
}

// zRows implements driver.Rows.
type zRows struct {
driver driver.Rows
ctx context.Context
options TraceOptions
}

func (r zRows) Columns() []string {
return r.driver.Columns()
}

func (r zRows) Close() error {
return r.driver.Close()
}

func (r zRows) Next(dest []driver.Value) error {
return r.driver.Next(dest)
}

// zTx implemens driver.Tx
type zTx struct {
driver driver.Tx
tx driver.Tx
ctx context.Context
tracer *zipkin.Tracer
options TraceOptions
Expand All @@ -451,7 +473,7 @@ func (t zTx) Commit() (err error) {
span.Finish()
}()
}
err = t.driver.Commit()
err = t.tx.Commit()
return
}

Expand All @@ -464,7 +486,7 @@ func (t zTx) Rollback() (err error) {
span.Finish()
}()
}
err = t.driver.Rollback()
err = t.tx.Rollback()
return
}

Expand Down
Loading

0 comments on commit 24b349c

Please sign in to comment.