Skip to content

Commit d93f4dd

Browse files
committed
move manage session into transport
1 parent 4b0f561 commit d93f4dd

File tree

4 files changed

+49
-46
lines changed

4 files changed

+49
-46
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,7 @@ def __init__(
5757
client_headers: Headers to include in each request sent through this client.
5858
"""
5959

60-
# If no aiohttp.ClientSession is provided, make our own
61-
manage_session = False
62-
if session is None:
63-
manage_session = True
64-
session = ClientSession()
65-
self.__transport = ToolboxTransport(url, session, manage_session)
60+
self.__transport = ToolboxTransport(url, session)
6661
self.__client_headers = client_headers if client_headers is not None else {}
6762

6863
def __parse_tool(

packages/toolbox-core/src/toolbox_core/toolbox_transport.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,16 @@
2323
class ToolboxTransport(ITransport):
2424
"""Transport for the native Toolbox protocol."""
2525

26-
def __init__(self, base_url: str, session: ClientSession, manage_session: bool):
26+
def __init__(self, base_url: str, session: Optional[ClientSession]):
2727
self.__base_url = base_url
28-
self.__session = session
29-
self.__manage_session = manage_session
28+
29+
# If no aiohttp.ClientSession is provided, make our own
30+
self.__manage_session = False
31+
if session is not None:
32+
self.__session = session
33+
else:
34+
self.__manage_session = True
35+
self.__session = ClientSession()
3036

3137
@property
3238
def base_url(self) -> str:

packages/toolbox-core/tests/test_tool.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def toolbox_tool(
112112
sample_tool_description: str,
113113
) -> ToolboxTool:
114114
"""Fixture for a ToolboxTool instance with common test setup."""
115-
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)
115+
transport = ToolboxTransport(TEST_BASE_URL, http_session)
116116
return ToolboxTool(
117117
transport=transport,
118118
name=TEST_TOOL_NAME,
@@ -231,7 +231,7 @@ async def test_tool_creation_callable_and_run(
231231

232232
with aioresponses() as m:
233233
m.post(invoke_url, status=200, payload=mock_server_response_body)
234-
transport = ToolboxTransport(base_url, http_session, False)
234+
transport = ToolboxTransport(base_url, http_session)
235235

236236
tool_instance = ToolboxTool(
237237
transport=transport,
@@ -277,7 +277,7 @@ async def test_tool_run_with_pydantic_validation_error(
277277

278278
with aioresponses() as m:
279279
m.post(invoke_url, status=200, payload={"result": "Should not be called"})
280-
transport = ToolboxTransport(base_url, http_session, False)
280+
transport = ToolboxTransport(base_url, http_session)
281281

282282
tool_instance = ToolboxTool(
283283
transport=transport,
@@ -368,7 +368,7 @@ def test_tool_init_basic(http_session, sample_tool_params, sample_tool_descripti
368368
"""Tests basic tool initialization without headers or auth."""
369369
with catch_warnings(record=True) as record:
370370
simplefilter("always")
371-
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
371+
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)
372372

373373
tool_instance = ToolboxTool(
374374
transport=transport,
@@ -398,7 +398,7 @@ def test_tool_init_with_client_headers(
398398
http_session, sample_tool_params, sample_tool_description, static_client_header
399399
):
400400
"""Tests tool initialization *with* client headers."""
401-
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
401+
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)
402402
tool_instance = ToolboxTool(
403403
transport=transport,
404404
name=TEST_TOOL_NAME,
@@ -422,7 +422,7 @@ def test_tool_init_header_auth_conflict(
422422
):
423423
"""Tests ValueError on init if client header conflicts with auth token."""
424424
conflicting_client_header = {auth_header_key: "some-client-value"}
425-
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
425+
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)
426426

427427
with pytest.raises(
428428
ValueError, match=f"Client header\\(s\\) `{auth_header_key}` already registered"
@@ -449,7 +449,7 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header(
449449
Tests ValueError when add_auth_token_getters introduces an auth service
450450
whose token name conflicts with an existing client header.
451451
"""
452-
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
452+
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)
453453
tool_instance = ToolboxTool(
454454
transport=transport,
455455
name="tool_with_client_header",
@@ -485,7 +485,7 @@ def test_add_auth_token_getters_unused_token(
485485
Tests ValueError when add_auth_token_getters is called with a getter for
486486
an unused authentication service.
487487
"""
488-
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
488+
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)
489489
tool_instance = ToolboxTool(
490490
transport=transport,
491491
name=TEST_TOOL_NAME,
@@ -514,7 +514,7 @@ def test_add_auth_token_getter_unused_token(
514514
Tests ValueError when add_auth_token_getters is called with a getter for
515515
an unused authentication service.
516516
"""
517-
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
517+
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)
518518
tool_instance = ToolboxTool(
519519
transport=transport,
520520
name=TEST_TOOL_NAME,
@@ -673,7 +673,7 @@ def test_tool_init_http_warning_when_sensitive_info_over_http(
673673
"Sending ID token over HTTP. User data may be exposed. "
674674
"Use HTTPS for secure communication."
675675
)
676-
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)
676+
transport = ToolboxTransport(TEST_BASE_URL, http_session)
677677
init_kwargs = {
678678
"transport": transport,
679679
"name": "http_warning_tool",
@@ -704,7 +704,7 @@ def test_tool_init_no_http_warning_if_https(
704704
"""
705705
with catch_warnings(record=True) as record:
706706
simplefilter("always")
707-
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
707+
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)
708708

709709
ToolboxTool(
710710
transport=transport,
@@ -733,7 +733,7 @@ def test_tool_init_no_http_warning_if_no_sensitive_info_on_http(
733733
"""
734734
with catch_warnings(record=True) as record:
735735
simplefilter("always")
736-
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)
736+
transport = ToolboxTransport(TEST_BASE_URL, http_session)
737737

738738
ToolboxTool(
739739
transport=transport,

packages/toolbox-core/tests/test_toolbox_transport.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import AsyncGenerator, Union
15+
from typing import AsyncGenerator, Optional, Union
1616
from unittest.mock import AsyncMock
1717

1818
import pytest
@@ -58,7 +58,7 @@ def mock_manifest_dict() -> dict:
5858
@pytest.mark.asyncio
5959
async def test_base_url_property(http_session: ClientSession):
6060
"""Tests that the base_url property returns the correct URL."""
61-
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)
61+
transport = ToolboxTransport(TEST_BASE_URL, http_session)
6262
assert transport.base_url == TEST_BASE_URL
6363

6464

@@ -67,7 +67,7 @@ async def test_tool_get_success(http_session: ClientSession, mock_manifest_dict:
6767
"""Tests a successful tool_get call."""
6868
url = f"{TEST_BASE_URL}/api/tool/{TEST_TOOL_NAME}"
6969
headers = {"X-Test-Header": "value"}
70-
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)
70+
transport = ToolboxTransport(TEST_BASE_URL, http_session)
7171

7272
with aioresponses() as m:
7373
m.get(url, status=200, payload=mock_manifest_dict)
@@ -84,7 +84,7 @@ async def test_tool_get_success(http_session: ClientSession, mock_manifest_dict:
8484
async def test_tool_get_failure(http_session: ClientSession):
8585
"""Tests a failing tool_get call and ensures it raises RuntimeError."""
8686
url = f"{TEST_BASE_URL}/api/tool/{TEST_TOOL_NAME}"
87-
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)
87+
transport = ToolboxTransport(TEST_BASE_URL, http_session)
8888

8989
with aioresponses() as m:
9090
m.get(url, status=500, body="Internal Server Error")
@@ -111,7 +111,7 @@ async def test_tools_list_success(
111111
):
112112
"""Tests successful tools_list calls with and without a toolset name."""
113113
url = f"{TEST_BASE_URL}{expected_path}"
114-
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)
114+
transport = ToolboxTransport(TEST_BASE_URL, http_session)
115115

116116
with aioresponses() as m:
117117
m.get(url, status=200, payload=mock_manifest_dict)
@@ -129,7 +129,7 @@ async def test_tool_invoke_success(http_session: ClientSession):
129129
args = {"param1": "value1"}
130130
headers = {"Authorization": "Bearer token"}
131131
response_payload = {"result": "success"}
132-
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)
132+
transport = ToolboxTransport(TEST_BASE_URL, http_session)
133133

134134
with aioresponses() as m:
135135
m.post(url, status=200, payload=response_payload)
@@ -144,7 +144,7 @@ async def test_tool_invoke_failure(http_session: ClientSession):
144144
"""Tests a failing tool_invoke call where the server returns an error payload."""
145145
url = f"{TEST_BASE_URL}/api/tool/{TEST_TOOL_NAME}/invoke"
146146
response_payload = {"error": "Invalid arguments"}
147-
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)
147+
transport = ToolboxTransport(TEST_BASE_URL, http_session)
148148

149149
with aioresponses() as m:
150150
m.post(url, status=400, payload=response_payload)
@@ -155,25 +155,27 @@ async def test_tool_invoke_failure(http_session: ClientSession):
155155

156156

157157
@pytest.mark.asyncio
158-
@pytest.mark.parametrize(
159-
"manage_session, is_closed, should_call_close",
160-
[
161-
(True, False, True),
162-
(False, False, False),
163-
(True, True, False),
164-
],
165-
)
166-
async def test_close_behavior(
167-
manage_session: bool, is_closed: bool, should_call_close: bool
168-
):
169-
"""Tests the close method under different conditions."""
158+
async def test_close_does_not_close_unmanaged_session():
159+
"""
160+
Tests that close() does NOT affect a session that was provided externally
161+
(i.e., an unmanaged session).
162+
"""
170163
mock_session = AsyncMock(spec=ClientSession)
171-
mock_session.closed = is_closed
172-
transport = ToolboxTransport(TEST_BASE_URL, mock_session, manage_session)
164+
mock_session.closed = False
173165

166+
transport = ToolboxTransport(TEST_BASE_URL, mock_session)
174167
await transport.close()
168+
mock_session.close.assert_not_called()
175169

176-
if should_call_close:
177-
mock_session.close.assert_awaited_once()
178-
else:
179-
mock_session.close.assert_not_awaited()
170+
171+
@pytest.mark.asyncio
172+
async def test_close_closes_managed_session():
173+
"""
174+
Tests that close() successfully closes a session that was created and
175+
managed internally by the transport.
176+
"""
177+
transport = ToolboxTransport(TEST_BASE_URL, session=None)
178+
179+
await transport.close()
180+
internal_session = transport._ToolboxTransport__session
181+
assert internal_session.closed is True

0 commit comments

Comments
 (0)