Skip to content

Commit 4098004

Browse files
committed
add unit test for initialize
1 parent 1495995 commit 4098004

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for _initialize_client error handling."""
16+
17+
import httpx
18+
import pytest
19+
from mcp import McpError
20+
from mcp.types import ErrorData, JSONRPCError, JSONRPCResponse
21+
from mcp_proxy_for_aws.server import _initialize_client
22+
from unittest.mock import AsyncMock, Mock, patch
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_successful_initialization():
27+
"""Test successful client initialization."""
28+
mock_transport = Mock()
29+
mock_client = Mock()
30+
31+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
32+
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
33+
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)
34+
35+
async with _initialize_client(mock_transport) as client:
36+
assert client == mock_client
37+
38+
39+
@pytest.mark.asyncio
40+
async def test_http_error_with_jsonrpc_error(capsys):
41+
"""Test HTTPStatusError with JSONRPCError response."""
42+
mock_transport = Mock()
43+
error_data = ErrorData(code=-32600, message='Invalid Request')
44+
jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=1, error=error_data)
45+
46+
mock_response = Mock()
47+
mock_response.aread = AsyncMock(return_value=jsonrpc_error.model_dump_json().encode())
48+
49+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
50+
51+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
52+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
53+
54+
with pytest.raises(httpx.HTTPStatusError):
55+
async with _initialize_client(mock_transport):
56+
pass
57+
58+
captured = capsys.readouterr()
59+
assert 'Invalid Request' in captured.out
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_http_error_with_jsonrpc_response(capsys):
64+
"""Test HTTPStatusError with JSONRPCResponse."""
65+
mock_transport = Mock()
66+
jsonrpc_response = JSONRPCResponse(jsonrpc='2.0', id=1, result={'status': 'error'})
67+
68+
mock_response = Mock()
69+
mock_response.aread = AsyncMock(return_value=jsonrpc_response.model_dump_json().encode())
70+
71+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
72+
73+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
74+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
75+
76+
with pytest.raises(httpx.HTTPStatusError):
77+
async with _initialize_client(mock_transport):
78+
pass
79+
80+
captured = capsys.readouterr()
81+
assert '"result":{"status":"error"}' in captured.out
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_http_error_with_invalid_json():
86+
"""Test HTTPStatusError with invalid JSON response."""
87+
mock_transport = Mock()
88+
89+
mock_response = Mock()
90+
mock_response.aread = AsyncMock(return_value=b'invalid json')
91+
92+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
93+
94+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
95+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
96+
97+
with pytest.raises(httpx.HTTPStatusError):
98+
async with _initialize_client(mock_transport):
99+
pass
100+
101+
102+
@pytest.mark.asyncio
103+
async def test_http_error_with_non_jsonrpc_message():
104+
"""Test HTTPStatusError with non-JSONRPCError/Response message."""
105+
mock_transport = Mock()
106+
107+
mock_response = Mock()
108+
mock_response.aread = AsyncMock(return_value=b'{"jsonrpc":"2.0","method":"test"}')
109+
110+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
111+
112+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
113+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
114+
115+
with pytest.raises(httpx.HTTPStatusError):
116+
async with _initialize_client(mock_transport):
117+
pass
118+
119+
120+
@pytest.mark.asyncio
121+
async def test_http_error_response_read_failure():
122+
"""Test HTTPStatusError when response.aread() fails."""
123+
mock_transport = Mock()
124+
125+
mock_response = Mock()
126+
mock_response.aread = AsyncMock(side_effect=Exception('Read failed'))
127+
128+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
129+
130+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
131+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
132+
133+
with pytest.raises(httpx.HTTPStatusError):
134+
async with _initialize_client(mock_transport):
135+
pass
136+
137+
138+
@pytest.mark.asyncio
139+
async def test_generic_error_with_mcp_error_cause(capsys):
140+
"""Test generic exception with McpError as cause."""
141+
mock_transport = Mock()
142+
error_data = ErrorData(code=-32601, message='Method not found')
143+
mcp_error = McpError(error_data)
144+
generic_error = Exception('Wrapper error')
145+
generic_error.__cause__ = mcp_error
146+
147+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
148+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error)
149+
150+
with pytest.raises(Exception):
151+
async with _initialize_client(mock_transport):
152+
pass
153+
154+
captured = capsys.readouterr()
155+
assert 'Method not found' in captured.out
156+
assert '"code":-32601' in captured.out
157+
158+
159+
@pytest.mark.asyncio
160+
async def test_generic_error_without_mcp_error_cause(capsys):
161+
"""Test generic exception without McpError cause."""
162+
mock_transport = Mock()
163+
generic_error = Exception('Generic error')
164+
165+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
166+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error)
167+
168+
with pytest.raises(Exception):
169+
async with _initialize_client(mock_transport):
170+
pass
171+
172+
captured = capsys.readouterr()
173+
assert captured.out.strip() == ''

0 commit comments

Comments
 (0)