Skip to content

Commit acb55d8

Browse files
committed
wip refactor entraid
1 parent 0c4f8fb commit acb55d8

File tree

12 files changed

+314
-219
lines changed

12 files changed

+314
-219
lines changed

internal/auth/conn_reauth_credentials_listener.go

Lines changed: 0 additions & 124 deletions
This file was deleted.

internal/auth/cred_listeners.go

Lines changed: 0 additions & 41 deletions
This file was deleted.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package streaming
2+
3+
import (
4+
"github.com/redis/go-redis/v9/auth"
5+
"github.com/redis/go-redis/v9/internal/pool"
6+
)
7+
8+
// ConnReAuthCredentialsListener is a struct that implements the CredentialsListener interface.
9+
// It is used to re-authenticate the credentials when they are updated.
10+
// It holds reference to the connection to re-authenticate and will pass it to the reAuth and onErr callbacks.
11+
// It contains:
12+
// - reAuth: a function that takes the new credentials and returns an error if any.
13+
// - onErr: a function that takes an error and handles it.
14+
// - conn: the connection to re-authenticate.
15+
type ConnReAuthCredentialsListener struct {
16+
// reAuth is called when the credentials are updated.
17+
reAuth func(conn *pool.Conn, credentials auth.Credentials) error
18+
// onErr is called when an error occurs.
19+
onErr func(conn *pool.Conn, err error)
20+
// conn is the connection to re-authenticate.
21+
conn *pool.Conn
22+
23+
manager *Manager
24+
}
25+
26+
// OnNext is called when the credentials are updated.
27+
// It calls the reAuth function with the new credentials.
28+
// If the reAuth function returns an error, it calls the onErr function with the error.
29+
func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
30+
if c.conn.IsClosed() {
31+
return
32+
}
33+
34+
if c.reAuth == nil {
35+
return
36+
}
37+
38+
c.manager.MarkForReAuth(c.conn, func(err error) {
39+
if err != nil {
40+
c.OnError(err)
41+
return
42+
}
43+
err = c.reAuth(c.conn, credentials)
44+
if err != nil {
45+
c.OnError(err)
46+
return
47+
}
48+
})
49+
50+
}
51+
52+
// OnError is called when an error occurs.
53+
// It can be called from both the credentials provider and the reAuth function.
54+
func (c *ConnReAuthCredentialsListener) OnError(err error) {
55+
if c.onErr == nil {
56+
return
57+
}
58+
59+
c.onErr(c.conn, err)
60+
}
61+
62+
// newConnReAuthCredentialsListener creates a new ConnReAuthCredentialsListener.
63+
// Implements the auth.CredentialsListener interface.
64+
func newConnReAuthCredentialsListener(
65+
conn *pool.Conn,
66+
reAuth func(conn *pool.Conn, credentials auth.Credentials) error,
67+
onErr func(conn *pool.Conn, err error),
68+
) *ConnReAuthCredentialsListener {
69+
return &ConnReAuthCredentialsListener{
70+
conn: conn,
71+
reAuth: reAuth,
72+
onErr: onErr,
73+
}
74+
}
75+
76+
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
77+
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package streaming
2+
3+
import (
4+
"sync"
5+
6+
"github.com/redis/go-redis/v9/auth"
7+
)
8+
9+
type CredentialsListeners struct {
10+
// connid -> listener
11+
listeners map[uint64]auth.CredentialsListener
12+
lock sync.RWMutex
13+
}
14+
15+
func NewCredentialsListeners() *CredentialsListeners {
16+
return &CredentialsListeners{
17+
listeners: make(map[uint64]auth.CredentialsListener),
18+
}
19+
}
20+
21+
func (c *CredentialsListeners) Add(connID uint64, listener auth.CredentialsListener) {
22+
c.lock.Lock()
23+
defer c.lock.Unlock()
24+
if c.listeners == nil {
25+
c.listeners = make(map[uint64]auth.CredentialsListener)
26+
}
27+
c.listeners[connID] = listener
28+
}
29+
30+
func (c *CredentialsListeners) Get(connID uint64) (auth.CredentialsListener, bool) {
31+
c.lock.RLock()
32+
defer c.lock.RUnlock()
33+
if len(c.listeners) == 0 {
34+
return nil, false
35+
}
36+
listener, ok := c.listeners[connID]
37+
return listener, ok
38+
}
39+
40+
func (c *CredentialsListeners) Remove(connID uint64) {
41+
c.lock.Lock()
42+
defer c.lock.Unlock()
43+
delete(c.listeners, connID)
44+
}

internal/auth/streaming/manager.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package streaming
2+
3+
import (
4+
"time"
5+
6+
"github.com/redis/go-redis/v9/auth"
7+
"github.com/redis/go-redis/v9/internal/pool"
8+
)
9+
10+
type Manager struct {
11+
credentialsListeners *CredentialsListeners
12+
pool pool.Pooler
13+
poolHookRef *ReAuthPoolHook
14+
}
15+
16+
func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager {
17+
return &Manager{
18+
pool: pl,
19+
poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout),
20+
credentialsListeners: NewCredentialsListeners(),
21+
}
22+
}
23+
24+
func (m *Manager) PoolHook() pool.PoolHook {
25+
return m.poolHookRef
26+
}
27+
28+
func (m *Manager) Listener(
29+
poolCn *pool.Conn,
30+
reAuth func(*pool.Conn, auth.Credentials) error,
31+
onErr func(*pool.Conn, error),
32+
) auth.CredentialsListener {
33+
connID := poolCn.GetID()
34+
listener, ok := m.credentialsListeners.Get(connID)
35+
if !ok {
36+
newCredListener := newConnReAuthCredentialsListener(
37+
poolCn,
38+
reAuth,
39+
onErr,
40+
)
41+
newCredListener.manager = m
42+
m.credentialsListeners.Add(connID, newCredListener)
43+
}
44+
return listener
45+
}
46+
47+
func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
48+
connID := poolCn.GetID()
49+
m.poolHookRef.MarkForReAuth(connID, reAuthFn)
50+
}

0 commit comments

Comments
 (0)