diff --git a/neo4j/internal/bolt/chunker.go b/neo4j/internal/bolt/chunker.go index b4d42f9b..1907f250 100644 --- a/neo4j/internal/bolt/chunker.go +++ b/neo4j/internal/bolt/chunker.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -22,6 +22,7 @@ package bolt import ( "context" "encoding/binary" + "errors" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" rio "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" "io" @@ -118,7 +119,7 @@ func processWriteError(err error, ctx context.Context) error { Err: err, } } - if err == context.Canceled { + if errors.Is(err, context.Canceled) { return &errorutil.ConnectionWriteCanceled{ Err: err, } diff --git a/neo4j/internal/bolt/dechunker.go b/neo4j/internal/bolt/dechunker.go index 86589d0f..4117fd34 100644 --- a/neo4j/internal/bolt/dechunker.go +++ b/neo4j/internal/bolt/dechunker.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package bolt @@ -22,6 +22,7 @@ package bolt import ( "context" "encoding/binary" + "errors" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" rio "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing" "net" @@ -94,7 +95,7 @@ func processReadError(err error, ctx context.Context, readTimeout time.Duration) Err: err, } } - if err == context.Canceled { + if errors.Is(err, context.Canceled) { return &errorutil.ConnectionReadCanceled{ Err: err, } diff --git a/neo4j/internal/errorutil/bolt.go b/neo4j/internal/errorutil/bolt.go index d31bd665..5826a2f5 100644 --- a/neo4j/internal/errorutil/bolt.go +++ b/neo4j/internal/errorutil/bolt.go @@ -21,6 +21,7 @@ package errorutil import ( "context" + "errors" "fmt" idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" "strings" @@ -81,7 +82,7 @@ type timeout interface { } func IsTimeoutError(err error) bool { - if err == context.DeadlineExceeded { + if errors.Is(err, context.DeadlineExceeded) { return true } timeoutErr, ok := err.(timeout) @@ -89,25 +90,27 @@ func IsTimeoutError(err error) bool { } func IsFatalDuringDiscovery(err error) bool { - if _, ok := err.(*idb.FeatureNotSupportedError); ok { + var featureNotSupportedError *idb.FeatureNotSupportedError + if errors.As(err, &featureNotSupportedError) { return true } - if err, ok := err.(*idb.Neo4jError); ok { - if err.Code == "Neo.ClientError.Database.DatabaseNotFound" || - err.Code == "Neo.ClientError.Transaction.InvalidBookmark" || - err.Code == "Neo.ClientError.Transaction.InvalidBookmarkMixture" || - err.Code == "Neo.ClientError.Statement.TypeError" || - err.Code == "Neo.ClientError.Statement.ArgumentError" || - err.Code == "Neo.ClientError.Request.Invalid" { + var neo4jErr *idb.Neo4jError + if errors.As(err, &neo4jErr) { + if neo4jErr.Code == "Neo.ClientError.Database.DatabaseNotFound" || + neo4jErr.Code == "Neo.ClientError.Transaction.InvalidBookmark" || + neo4jErr.Code == "Neo.ClientError.Transaction.InvalidBookmarkMixture" || + neo4jErr.Code == "Neo.ClientError.Statement.TypeError" || + neo4jErr.Code == "Neo.ClientError.Statement.ArgumentError" || + neo4jErr.Code == "Neo.ClientError.Request.Invalid" { return true } - if strings.HasPrefix(err.Code, "Neo.ClientError.Security.") && - err.Code != "Neo.ClientError.Security.AuthorizationExpired" { + if strings.HasPrefix(neo4jErr.Code, "Neo.ClientError.Security.") && + neo4jErr.Code != "Neo.ClientError.Security.AuthorizationExpired" { return true } } - if err == context.DeadlineExceeded || - err == context.Canceled { + if errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) { return true } return false diff --git a/neo4j/internal/pool/pool.go b/neo4j/internal/pool/pool.go index eaa685f7..5de9b818 100644 --- a/neo4j/internal/pool/pool.go +++ b/neo4j/internal/pool/pool.go @@ -204,12 +204,15 @@ serverLoop: return nil, nil } -func (p *Pool) Borrow(ctx context.Context, getServerNames func() []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) { +func (p *Pool) Borrow(ctx context.Context, getServerNames func(context.Context) ([]string, error), wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) { for { if p.closed { return nil, &errorutil.PoolClosed{} } - serverNames := getServerNames() + serverNames, err := getServerNames(ctx) + if err != nil { + return nil, err + } if len(serverNames) == 0 { return nil, &errorutil.PoolOutOfServers{} } @@ -221,8 +224,6 @@ func (p *Pool) Borrow(ctx context.Context, getServerNames func() []string, wait return penalties[i].penalty < penalties[j].penalty }) - var err error - var conn idb.Connection for _, s := range penalties { conn, err = p.tryBorrow(ctx, s.name, boltLogger, idlenessThreshold, auth) diff --git a/neo4j/internal/pool/pool_test.go b/neo4j/internal/pool/pool_test.go index 9bea4863..6dc845e5 100644 --- a/neo4j/internal/pool/pool_test.go +++ b/neo4j/internal/pool/pool_test.go @@ -673,9 +673,9 @@ func deadConnectionAfterForceReset(name string, idleness time.Time) *ConnFake { return result } -func getServers(servers []string) func() []string { - return func() []string { - return servers +func getServers(servers []string) func(context.Context) ([]string, error) { + return func(context.Context) ([]string, error) { + return servers, nil } } diff --git a/neo4j/internal/router/no_test.go b/neo4j/internal/router/no_test.go index 587957a6..57fb8024 100644 --- a/neo4j/internal/router/no_test.go +++ b/neo4j/internal/router/no_test.go @@ -27,14 +27,17 @@ import ( ) type poolFake struct { - borrow func(names []string, cancel context.CancelFunc, logger log.BoltLogger) (db.Connection, error) + borrow func(ctx context.Context, names []string, cancel context.CancelFunc, logger log.BoltLogger) (db.Connection, error) returned []db.Connection cancel context.CancelFunc } -func (p *poolFake) Borrow(_ context.Context, getServers func() []string, _ bool, logger log.BoltLogger, _ time.Duration, _ *db.ReAuthToken) (db.Connection, error) { - servers := getServers() - return p.borrow(servers, p.cancel, logger) +func (p *poolFake) Borrow(ctx context.Context, getServers func(context.Context) ([]string, error), _ bool, logger log.BoltLogger, _ time.Duration, _ *db.ReAuthToken) (db.Connection, error) { + servers, err := getServers(ctx) + if err != nil { + return nil, err + } + return p.borrow(ctx, servers, p.cancel, logger) } func (p *poolFake) Return(_ context.Context, c db.Connection) { diff --git a/neo4j/internal/router/readtable.go b/neo4j/internal/router/readtable.go index 15f12907..ae243dce 100644 --- a/neo4j/internal/router/readtable.go +++ b/neo4j/internal/router/readtable.go @@ -49,10 +49,6 @@ func readTable( for _, router := range routers { var conn db.Connection if conn, err = connectionPool.Borrow(ctx, getStaticServer(router), true, boltLogger, pool.DefaultLivenessCheckThreshold, auth); err != nil { - // Check if failed due to context timing out - if ctx.Err() != nil { - return nil, wrapError(router, ctx.Err()) - } if errorutil.IsFatalDuringDiscovery(err) { return nil, err } @@ -75,8 +71,8 @@ func readTable( return nil, err } -func getStaticServer(server string) func() []string { - return func() []string { - return []string{server} +func getStaticServer(server string) func(context.Context) ([]string, error) { + return func(context.Context) ([]string, error) { + return []string{server}, nil } } diff --git a/neo4j/internal/router/readtable_test.go b/neo4j/internal/router/readtable_test.go index 48e63c4d..d15a1f72 100644 --- a/neo4j/internal/router/readtable_test.go +++ b/neo4j/internal/router/readtable_test.go @@ -8,13 +8,13 @@ * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package router @@ -27,6 +27,7 @@ import ( "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil" "github.com/neo4j/neo4j-go-driver/v5/neo4j/log" "testing" + "time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/db" "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/testutil" @@ -71,6 +72,18 @@ func TestReadTableTable(ot *testing.T) { } } + assertCancelledError := func(t *testing.T, err error) { + if !errors.Is(err, context.Canceled) { + t.Errorf("Error should be %T but was %T", context.Canceled, err) + } + } + + assertDeadlineExceededError := func(t *testing.T, err error) { + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Error should be %T but was %T", context.DeadlineExceeded, err) + } + } + cases := []struct { name string routers []string @@ -93,8 +106,9 @@ func TestReadTableTable(ot *testing.T) { assert: assertNoTable, assertErr: assertRoutingTableError, pool: &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, - _ log.BoltLogger) (idb.Connection, error) { + borrow: func( + ctx context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (idb.Connection, error) { return nil, errors.New("borrow fail") }, }, @@ -106,8 +120,9 @@ func TestReadTableTable(ot *testing.T) { assert: assertNoTable, assertErr: assertNeo4jError, pool: &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, - _ log.BoltLogger) (idb.Connection, error) { + borrow: func( + ctx context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (idb.Connection, error) { return nil, &db.Neo4jError{Code: "Neo.ClientError.Security.Unauthorized"} }, }, @@ -118,8 +133,9 @@ func TestReadTableTable(ot *testing.T) { routers: standardRouters, assert: assertTable, pool: &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, - _ log.BoltLogger) (idb.Connection, error) { + borrow: func( + ctx context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (idb.Connection, error) { return &testutil.ConnFake{Table: &idb.RoutingTable{}}, nil }, }, @@ -130,8 +146,9 @@ func TestReadTableTable(ot *testing.T) { routers: standardRouters, assert: assertTable, pool: &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, - _ log.BoltLogger) (idb.Connection, error) { + borrow: func( + ctx context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (idb.Connection, error) { if names[0] == "router2" { return &testutil.ConnFake{Table: &idb. RoutingTable{}}, nil @@ -147,8 +164,9 @@ func TestReadTableTable(ot *testing.T) { assert: assertNoTable, assertErr: assertRoutingTableError, pool: &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, - _ log.BoltLogger) (idb.Connection, error) { + borrow: func( + ctx context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (idb.Connection, error) { return &testutil.ConnFake{Err: errors.New("GetRoutingTable fail")}, nil }, }, @@ -158,17 +176,38 @@ func TestReadTableTable(ot *testing.T) { name: "Cancel context", routers: standardRouters, pool: &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, - _ log.BoltLogger) (idb.Connection, error) { + borrow: func( + ctx context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (idb.Connection, error) { if names[0] == "router2" { panic("Should not be called") } cancel() - return nil, errors.New("cancelled") + return nil, ctx.Err() }, }, assert: assertNoTable, - assertErr: assertRoutingTableError, + assertErr: assertCancelledError, + numReturns: 0, + }, + { + name: "Deadline exceeded context", + routers: standardRouters, + pool: &poolFake{ + borrow: func( + ctx context.Context, names []string, _ context.CancelFunc, _ log.BoltLogger, + ) (idb.Connection, error) { + if names[0] == "router2" { + panic("Should not be called") + } + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(-1*time.Second)) + err := ctx.Err() + cancel() + return nil, err + }, + }, + assert: assertNoTable, + assertErr: assertDeadlineExceededError, numReturns: 0, }, } diff --git a/neo4j/internal/router/router.go b/neo4j/internal/router/router.go index 8d7177ba..d75d3628 100644 --- a/neo4j/internal/router/router.go +++ b/neo4j/internal/router/router.go @@ -60,7 +60,7 @@ type Pool interface { // If all connections are busy and the pool is full, calls to Borrow may wait for a connection to become idle // If a connection has been idle for longer than idlenessThreshold, it will be reset // to check if it's still alive. - Borrow(ctx context.Context, getServers func() []string, wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) + Borrow(ctx context.Context, getServers func(context.Context) ([]string, error), wait bool, boltLogger log.BoltLogger, idlenessThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) Return(ctx context.Context, c idb.Connection) } @@ -156,7 +156,7 @@ func (r *Router) getOrUpdateTable(ctx context.Context, bookmarksFn func(context. defer unlock.Do(r.dbRoutersMut.Unlock) for { dbRouter := r.dbRouters[database] - if table := r.getTableLocked(dbRouter); table != nil { + if table := r.getTableIfFreshLocked(dbRouter); table != nil { return table, nil } waiters, ok := r.updating[database] @@ -190,7 +190,19 @@ func (r *Router) getOrUpdateTable(ctx context.Context, bookmarksFn func(context. } } +// Keep using the routing table for up to 30 seconds after it expires. +// This gives the driver time to hopefully get a new routing table in the meantime. +const ttlBuffer = (int64)(30 * time.Second) + func (r *Router) getTableLocked(dbRouter *databaseRouter) *idb.RoutingTable { + now := (*r.now)() + if dbRouter != nil && now.Unix() < dbRouter.dueUnix+ttlBuffer { + return dbRouter.table + } + return nil +} + +func (r *Router) getTableIfFreshLocked(dbRouter *databaseRouter) *idb.RoutingTable { now := (*r.now)() if dbRouter != nil && now.Unix() < dbRouter.dueUnix { return dbRouter.table diff --git a/neo4j/internal/router/router_test.go b/neo4j/internal/router/router_test.go index 98eda90d..f431dff1 100644 --- a/neo4j/internal/router/router_test.go +++ b/neo4j/internal/router/router_test.go @@ -44,7 +44,9 @@ func TestMultithreading(t *testing.T) { num := 0 table := &db.RoutingTable{Readers: []string{"rd1", "rd2"}, Writers: []string{"wr"}, TimeToLive: 1} pool := &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, _ log.BoltLogger) (db.Connection, error) { + borrow: func( + _ context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (db.Connection, error) { num++ return &testutil.ConnFake{Table: table}, nil }, @@ -100,7 +102,9 @@ func TestRespectsTimeToLiveAndInvalidate(t *testing.T) { numfetch := 0 table := &db.RoutingTable{TimeToLive: 1, Readers: []string{"router1"}} pool := &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, _ log.BoltLogger) (db.Connection, error) { + borrow: func( + _ context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (db.Connection, error) { numfetch++ return &testutil.ConnFake{Table: table}, nil }, @@ -157,7 +161,9 @@ func TestUsesRootRouterWhenPreviousRoutersFails(t *testing.T) { }} var err error pool := &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, _ log.BoltLogger) (db.Connection, error) { + borrow: func( + _ context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (db.Connection, error) { borrows = append(borrows, names) return conn, err }, @@ -189,7 +195,9 @@ func TestUsesRootRouterWhenPreviousRoutersFails(t *testing.T) { // rootRouter requestedOther := false requestedRoot := false - pool.borrow = func(names []string, cancel context.CancelFunc, _ log.BoltLogger) (db.Connection, error) { + pool.borrow = func( + _ context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (db.Connection, error) { if !requestedOther { if names[0] != "otherRouter" { t.Errorf("Expected request for otherRouter") @@ -223,7 +231,9 @@ func TestUsesRootRouterWhenPreviousRoutersFails(t *testing.T) { func TestUseGetRoutersHookWhenInitialRouterFails(t *testing.T) { var tried []string pool := &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, _ log.BoltLogger) (db.Connection, error) { + borrow: func( + _ context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (db.Connection, error) { tried = append(tried, names...) return nil, errors.New("fail") }, @@ -250,7 +260,9 @@ func TestWritersFailAfterNRetries(t *testing.T) { numfetch := 0 tableNoWriters := &db.RoutingTable{TimeToLive: 1, Routers: []string{"rt1", "rt2"}, Readers: []string{"rd1"}} pool := &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, _ log.BoltLogger) (db.Connection, error) { + borrow: func( + _ context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (db.Connection, error) { // Return no writers first time and writers the second time numfetch++ return &testutil.ConnFake{Table: tableNoWriters}, nil @@ -285,7 +297,9 @@ func TestWritersRetriesWhenNoWriters(t *testing.T) { tableNoWriters := &db.RoutingTable{TimeToLive: 1, Routers: []string{"rt1", "rt2"}, Readers: []string{"rd1"}} tableWriters := &db.RoutingTable{TimeToLive: 1, Routers: []string{"rt1", "rt2"}, Readers: []string{"rd1"}, Writers: []string{"wr1"}} pool := &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, _ log.BoltLogger) (db.Connection, error) { + borrow: func( + _ context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (db.Connection, error) { // Return no writers first time and writers the second time numfetch++ if numfetch == 1 { @@ -324,7 +338,9 @@ func TestReadersRetriesWhenNoReaders(t *testing.T) { tableNoReaders := &db.RoutingTable{TimeToLive: 1, Routers: []string{"rt1", "rt2"}, Writers: []string{"wd1"}} tableReaders := &db.RoutingTable{TimeToLive: 1, Routers: []string{"rt1", "rt2"}, Writers: []string{"wd1"}, Readers: []string{"wr1"}} pool := &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, _ log.BoltLogger) (db.Connection, error) { + borrow: func( + _ context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (db.Connection, error) { // Return no readers first time and readers the second time numfetch++ if numfetch == 1 { @@ -363,7 +379,9 @@ func TestReadersRetriesWhenNoReaders(t *testing.T) { func TestCleanUp(t *testing.T) { table := &db.RoutingTable{TimeToLive: 1, Readers: []string{"router1"}} pool := &poolFake{ - borrow: func(names []string, cancel context.CancelFunc, _ log.BoltLogger) (db.Connection, error) { + borrow: func( + _ context.Context, names []string, cancel context.CancelFunc, _ log.BoltLogger, + ) (db.Connection, error) { return &testutil.ConnFake{Table: table}, nil }, } diff --git a/neo4j/internal/testutil/poolfake.go b/neo4j/internal/testutil/poolfake.go index 0c569ca4..ba6eb216 100644 --- a/neo4j/internal/testutil/poolfake.go +++ b/neo4j/internal/testutil/poolfake.go @@ -34,13 +34,17 @@ type PoolFake struct { BorrowHook func() (db.Connection, error) } -func (p *PoolFake) Borrow(context.Context, func() []string, bool, log.BoltLogger, time.Duration, *db.ReAuthToken) (db.Connection, error) { +func (p *PoolFake) Borrow(ctx context.Context, getServerNames func(context.Context) ([]string, error), _ bool, _ log.BoltLogger, _ time.Duration, _ *db.ReAuthToken) (db.Connection, error) { if p.BorrowHook != nil && (p.BorrowConn != nil || p.BorrowErr != nil) { panic("either use the hook or the desired return values, but not both") } if p.BorrowHook != nil { return p.BorrowHook() } + _, err := getServerNames(ctx) + if err != nil { + return nil, err + } return p.BorrowConn, p.BorrowErr } diff --git a/neo4j/session_with_context.go b/neo4j/session_with_context.go index a7784407..41d1d548 100644 --- a/neo4j/session_with_context.go +++ b/neo4j/session_with_context.go @@ -191,7 +191,7 @@ const FetchDefault = 0 // Connection pool as seen by the session. type sessionPool interface { - Borrow(ctx context.Context, getServerNames func() []string, wait bool, boltLogger log.BoltLogger, livenessCheckThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) + Borrow(ctx context.Context, getServerNames func(context.Context) ([]string, error), wait bool, boltLogger log.BoltLogger, livenessCheckThreshold time.Duration, auth *idb.ReAuthToken) (idb.Connection, error) Return(ctx context.Context, c idb.Connection) CleanUp(ctx context.Context) Now() time.Time @@ -505,20 +505,21 @@ func (s *sessionWithContext) executeTransactionFunction( return true, x } -func (s *sessionWithContext) getOrUpdateServers(ctx context.Context, mode idb.AccessMode) ([]string, error) { - if mode == idb.ReadMode { - return s.router.GetOrUpdateReaders(ctx, s.getBookmarks, s.config.DatabaseName, s.auth, s.config.BoltLogger) - } else { - return s.router.GetOrUpdateWriters(ctx, s.getBookmarks, s.config.DatabaseName, s.auth, s.config.BoltLogger) - } -} - -func (s *sessionWithContext) getServers(mode idb.AccessMode) func() []string { - return func() []string { +func (s *sessionWithContext) getServers(mode idb.AccessMode) func(ctx context.Context) ([]string, error) { + update := true + return func(ctx context.Context) ([]string, error) { + if update { + update = false + if mode == idb.ReadMode { + return s.router.GetOrUpdateReaders(ctx, s.getBookmarks, s.config.DatabaseName, s.auth, s.config.BoltLogger) + } else { + return s.router.GetOrUpdateWriters(ctx, s.getBookmarks, s.config.DatabaseName, s.auth, s.config.BoltLogger) + } + } if mode == idb.ReadMode { - return s.router.Readers(s.config.DatabaseName) + return s.router.Readers(s.config.DatabaseName), nil } else { - return s.router.Writers(s.config.DatabaseName) + return s.router.Writers(s.config.DatabaseName), nil } } } @@ -540,10 +541,6 @@ func (s *sessionWithContext) getConnection(ctx context.Context, mode idb.AccessM if err := s.resolveHomeDatabase(ctx); err != nil { return nil, errorutil.WrapError(err) } - _, err := s.getOrUpdateServers(ctx, mode) - if err != nil { - return nil, errorutil.WrapError(err) - } conn, err := s.pool.Borrow( ctx, @@ -693,10 +690,6 @@ func (s *sessionWithContext) getServerInfo(ctx context.Context) (ServerInfo, err if err := s.resolveHomeDatabase(ctx); err != nil { return nil, errorutil.WrapError(err) } - _, err := s.getOrUpdateServers(ctx, idb.ReadMode) - if err != nil { - return nil, errorutil.WrapError(err) - } conn, err := s.pool.Borrow( ctx, s.getServers(idb.ReadMode), @@ -716,10 +709,6 @@ func (s *sessionWithContext) getServerInfo(ctx context.Context) (ServerInfo, err } func (s *sessionWithContext) verifyAuthentication(ctx context.Context) error { - _, err := s.getOrUpdateServers(ctx, idb.ReadMode) - if err != nil { - return errorutil.WrapError(err) - } conn, err := s.pool.Borrow( ctx, s.getServers(idb.ReadMode),