Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
novitae committed Oct 4, 2024
1 parent 303da69 commit 81cf97f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __init__(

if response_class is None:
response_class = Response
elif issubclass(response_class, Response) is False:
elif not issubclass(response_class, Response):
raise TypeError( "`response_class` must be a subclass of `curl_cffi.requests.models.Response`"
f" not of type `{response_class}`" )
self.response_class = response_class
Expand Down
13 changes: 11 additions & 2 deletions examples/custom_response_class.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from curl_cffi import requests
from curl_cffi.curl import Curl, CurlInfo
from typing import cast

class CustomResponse(requests.Response):
def __init__(self, curl: Curl | None = None, request: requests.Request | None = None):
super().__init__(curl, request)
self.local_port = cast(int, curl.getinfo(CurlInfo.LOCAL_PORT))
self.connect_time = cast(float, curl.getinfo(CurlInfo.CONNECT_TIME))

@property
def status(self):
return self.status_code

def custom_method():
def custom_method(self):
return "this is a custom method"

session = requests.Session(response_class=CustomResponse)
response: CustomResponse = session.get("http://example.com")
print(response.status)
print(f"{response.status=}")
print(response.custom_method())
print(f"{response.local_port=}")
print(f"{response.connect_time=}")
8 changes: 3 additions & 5 deletions tests/integration/test_response_class.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from curl_cffi import requests

def test_default_response():
Expand All @@ -20,8 +21,5 @@ def test_custom_response():
class WrongTypeResponse: pass

def test_wrong_type_custom_response():
try:
requests.Session(response_class=WrongTypeResponse)
assert False, "session was created without raising issue for wrong response class type"
except TypeError:
print("Wrong response class type detected")
with pytest.raises(TypeError):
requests.Session(response_class=WrongTypeResponse)

0 comments on commit 81cf97f

Please sign in to comment.