Skip to content

Commit 28a725b

Browse files
authored
Merge pull request #716 from ydb-platform/fix_token_rotation
Refactor auth token refresh logic
2 parents 40ac692 + f1712bc commit 28a725b

File tree

6 files changed

+388
-172
lines changed

6 files changed

+388
-172
lines changed

tests/aio/test_credentials.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import tempfile
66
import os
77
import json
8+
import asyncio
9+
from unittest.mock import patch, AsyncMock, MagicMock
810

911
import tests.auth.test_credentials
1012
import tests.oauth2_token_exchange
@@ -112,3 +114,152 @@ def serve(s):
112114
except Exception:
113115
os.remove(cfg_file_name)
114116
raise
117+
118+
119+
@pytest.mark.asyncio
120+
async def test_token_lazy_refresh():
121+
credentials = ServiceAccountCredentialsForTest(
122+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
123+
tests.auth.test_credentials.ACCESS_KEY_ID,
124+
tests.auth.test_credentials.PRIVATE_KEY,
125+
"localhost:0",
126+
)
127+
128+
credentials._tp.submit = MagicMock()
129+
130+
mock_response = {"access_token": "token_v1", "expires_in": 3600}
131+
credentials._make_token_request = AsyncMock(return_value=mock_response)
132+
133+
with patch("time.time") as mock_time:
134+
mock_time.return_value = 1000
135+
136+
token1 = await credentials.token()
137+
assert token1 == "token_v1"
138+
assert credentials._make_token_request.call_count == 1
139+
140+
token2 = await credentials.token()
141+
assert token2 == "token_v1"
142+
assert credentials._make_token_request.call_count == 1
143+
144+
mock_time.return_value = 1000 + 3600 - 30 + 1
145+
credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600}
146+
147+
token3 = await credentials.token()
148+
assert token3 == "token_v2"
149+
assert credentials._make_token_request.call_count == 2
150+
151+
152+
@pytest.mark.asyncio
153+
async def test_token_double_check_locking():
154+
credentials = ServiceAccountCredentialsForTest(
155+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
156+
tests.auth.test_credentials.ACCESS_KEY_ID,
157+
tests.auth.test_credentials.PRIVATE_KEY,
158+
"localhost:0",
159+
)
160+
161+
credentials._tp.submit = MagicMock()
162+
163+
call_count = 0
164+
165+
async def mock_make_request():
166+
nonlocal call_count
167+
call_count += 1
168+
await asyncio.sleep(0.01)
169+
return {"access_token": f"token_v{call_count}", "expires_in": 3600}
170+
171+
credentials._make_token_request = mock_make_request
172+
173+
with patch("time.time") as mock_time:
174+
mock_time.return_value = 1000
175+
176+
tasks = [credentials.token() for _ in range(10)]
177+
results = await asyncio.gather(*tasks)
178+
179+
assert len(set(results)) == 1
180+
assert call_count == 1
181+
182+
183+
@pytest.mark.asyncio
184+
async def test_token_expiration_calculation():
185+
credentials = ServiceAccountCredentialsForTest(
186+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
187+
tests.auth.test_credentials.ACCESS_KEY_ID,
188+
tests.auth.test_credentials.PRIVATE_KEY,
189+
"localhost:0",
190+
)
191+
192+
credentials._tp.submit = MagicMock()
193+
194+
with patch("time.time") as mock_time:
195+
mock_time.return_value = 1000
196+
197+
credentials._make_token_request = AsyncMock(return_value={"access_token": "token", "expires_in": 3600})
198+
199+
await credentials.token()
200+
201+
expected_expires = 1000 + 3600 - 30
202+
assert credentials._expires_in == expected_expires
203+
204+
205+
@pytest.mark.asyncio
206+
async def test_token_refresh_error_handling():
207+
credentials = ServiceAccountCredentialsForTest(
208+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
209+
tests.auth.test_credentials.ACCESS_KEY_ID,
210+
tests.auth.test_credentials.PRIVATE_KEY,
211+
"localhost:0",
212+
)
213+
214+
credentials._tp.submit = MagicMock()
215+
216+
credentials._make_token_request = AsyncMock(side_effect=Exception("Network error"))
217+
218+
with pytest.raises(Exception) as exc_info:
219+
await credentials.token()
220+
221+
assert "Network error" in str(exc_info.value)
222+
assert credentials.last_error == "Network error"
223+
224+
225+
@pytest.mark.asyncio
226+
async def test_hybrid_background_and_sync_refresh():
227+
credentials = ServiceAccountCredentialsForTest(
228+
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
229+
tests.auth.test_credentials.ACCESS_KEY_ID,
230+
tests.auth.test_credentials.PRIVATE_KEY,
231+
"localhost:0",
232+
)
233+
234+
call_count = 0
235+
background_calls = []
236+
237+
async def mock_make_request():
238+
nonlocal call_count
239+
call_count += 1
240+
return {"access_token": f"token_v{call_count}", "expires_in": 3600}
241+
242+
def mock_submit(callback):
243+
background_calls.append(callback)
244+
245+
credentials._make_token_request = mock_make_request
246+
credentials._tp.submit = mock_submit
247+
248+
with patch("time.time") as mock_time:
249+
mock_time.return_value = 1000
250+
251+
token1 = await credentials.token()
252+
assert token1 == "token_v1"
253+
assert call_count == 1
254+
assert len(background_calls) == 0
255+
256+
mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1
257+
token2 = await credentials.token()
258+
assert token2 == "token_v1"
259+
assert call_count == 1
260+
assert len(background_calls) == 1
261+
262+
mock_time.return_value = 1000 + 3600 - 30 + 1
263+
token3 = await credentials.token()
264+
assert token3 == "token_v2"
265+
assert call_count == 2

tests/auth/test_static_credentials.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
import ydb
3+
from unittest.mock import patch, MagicMock
34

45

56
USERNAME = "root"
@@ -45,3 +46,131 @@ def test_static_credentials_wrong_creds(endpoint, database):
4546
with pytest.raises(ydb.ConnectionFailure):
4647
with ydb.Driver(driver_config=driver_config) as driver:
4748
driver.wait(5, fail_fast=True)
49+
50+
51+
def test_token_lazy_refresh():
52+
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
53+
54+
credentials._tp.submit = MagicMock()
55+
56+
mock_response = {"access_token": "token_v1", "expires_in": 3600}
57+
credentials._make_token_request = MagicMock(return_value=mock_response)
58+
59+
with patch("time.time") as mock_time:
60+
mock_time.return_value = 1000
61+
62+
token1 = credentials.token
63+
assert token1 == "token_v1"
64+
assert credentials._make_token_request.call_count == 1
65+
66+
token2 = credentials.token
67+
assert token2 == "token_v1"
68+
assert credentials._make_token_request.call_count == 1
69+
70+
mock_time.return_value = 1000 + 3600 - 30 + 1
71+
credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600}
72+
73+
token3 = credentials.token
74+
assert token3 == "token_v2"
75+
assert credentials._make_token_request.call_count == 2
76+
77+
78+
def test_token_double_check_locking():
79+
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
80+
credentials._tp.submit = MagicMock()
81+
82+
call_count = 0
83+
84+
def mock_make_request():
85+
nonlocal call_count
86+
call_count += 1
87+
return {"access_token": f"token_v{call_count}", "expires_in": 3600}
88+
89+
credentials._make_token_request = mock_make_request
90+
91+
with patch("time.time") as mock_time:
92+
mock_time.return_value = 1000
93+
94+
import threading
95+
96+
results = []
97+
98+
def get_token():
99+
results.append(credentials.token)
100+
101+
threads = [threading.Thread(target=get_token) for _ in range(10)]
102+
for t in threads:
103+
t.start()
104+
for t in threads:
105+
t.join()
106+
107+
assert len(set(results)) == 1
108+
assert call_count == 1
109+
110+
111+
def test_token_expiration_calculation():
112+
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
113+
114+
credentials._tp.submit = MagicMock()
115+
116+
with patch("time.time") as mock_time:
117+
mock_time.return_value = 1000
118+
119+
credentials._make_token_request = MagicMock(return_value={"access_token": "token", "expires_in": 3600})
120+
121+
credentials.token
122+
123+
expected_expires = 1000 + 3600 - 30
124+
assert credentials._expires_in == expected_expires
125+
126+
127+
def test_token_refresh_error_handling():
128+
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
129+
credentials._tp.submit = MagicMock()
130+
credentials._make_token_request = MagicMock(side_effect=Exception("Network error"))
131+
132+
with patch("time.time") as mock_time:
133+
mock_time.return_value = 1000 + 3600
134+
135+
with pytest.raises(ydb.ConnectionError) as exc_info:
136+
credentials.token
137+
138+
assert "Network error" in str(exc_info.value)
139+
assert credentials.last_error == "Network error"
140+
141+
142+
def test_hybrid_background_and_sync_refresh():
143+
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
144+
145+
call_count = 0
146+
background_calls = []
147+
148+
def mock_make_request():
149+
nonlocal call_count
150+
call_count += 1
151+
return {"access_token": f"token_v{call_count}", "expires_in": 3600}
152+
153+
def mock_submit(callback):
154+
background_calls.append(callback)
155+
156+
credentials._make_token_request = mock_make_request
157+
credentials._tp.submit = mock_submit
158+
159+
with patch("time.time") as mock_time:
160+
mock_time.return_value = 1000
161+
162+
token1 = credentials.token
163+
assert token1 == "token_v1"
164+
assert call_count == 1
165+
assert len(background_calls) == 0
166+
167+
mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1
168+
token2 = credentials.token
169+
assert token2 == "token_v1"
170+
assert call_count == 1
171+
assert len(background_calls) == 1
172+
173+
mock_time.return_value = 1000 + 3600 - 30 + 1
174+
token3 = credentials.token
175+
assert token3 == "token_v2"
176+
assert call_count == 2

0 commit comments

Comments
 (0)