Skip to content

Commit c31fe51

Browse files
authored
Support RichComparison, hash and deepcopy for Url and MultiHostUrl (#558)
1 parent ae4cb28 commit c31fe51

File tree

3 files changed

+133
-3
lines changed

3 files changed

+133
-3
lines changed

pydantic_core/_pydantic_core.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ if sys.version_info < (3, 11):
1515
else:
1616
from typing import Literal, NotRequired, TypeAlias
1717

18+
from _typeshed import SupportsAllComparisons
19+
1820
__all__ = (
1921
'__version__',
2022
'build_profile',
@@ -126,7 +128,7 @@ def to_jsonable_python(
126128
fallback: 'Callable[[Any], Any] | None' = None,
127129
) -> Any: ...
128130

129-
class Url:
131+
class Url(SupportsAllComparisons):
130132
@property
131133
def scheme(self) -> str: ...
132134
@property
@@ -156,7 +158,7 @@ class MultiHostHost(TypedDict):
156158
query: 'str | None'
157159
fragment: 'str | None'
158160

159-
class MultiHostUrl:
161+
class MultiHostUrl(SupportsAllComparisons):
160162
@property
161163
def scheme(self) -> str: ...
162164
@property

src/url.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
use std::collections::hash_map::DefaultHasher;
2+
use std::hash::{Hash, Hasher};
3+
14
use idna::punycode::decode_to_string;
25
use pyo3::once_cell::GILOnceCell;
36
use pyo3::prelude::*;
7+
use pyo3::pyclass::CompareOp;
48
use pyo3::types::PyDict;
59
use url::Url;
610

@@ -116,6 +120,31 @@ impl PyUrl {
116120
pub fn __repr__(&self) -> String {
117121
format!("Url('{}')", self.lib_url)
118122
}
123+
124+
fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
125+
match op {
126+
CompareOp::Lt => Ok(self.lib_url < other.lib_url),
127+
CompareOp::Le => Ok(self.lib_url <= other.lib_url),
128+
CompareOp::Eq => Ok(self.lib_url == other.lib_url),
129+
CompareOp::Ne => Ok(self.lib_url != other.lib_url),
130+
CompareOp::Gt => Ok(self.lib_url > other.lib_url),
131+
CompareOp::Ge => Ok(self.lib_url >= other.lib_url),
132+
}
133+
}
134+
135+
fn __hash__(&self) -> u64 {
136+
let mut s = DefaultHasher::new();
137+
self.lib_url.to_string().hash(&mut s);
138+
s.finish()
139+
}
140+
141+
fn __bool__(&self) -> bool {
142+
true // an empty string is not a valid URL
143+
}
144+
145+
pub fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> Py<PyAny> {
146+
self.clone().into_py(py)
147+
}
119148
}
120149

121150
#[pyclass(name = "MultiHostUrl", module = "pydantic_core._pydantic_core")]
@@ -250,6 +279,32 @@ impl PyMultiHostUrl {
250279
pub fn __repr__(&self) -> String {
251280
format!("Url('{}')", self.__str__())
252281
}
282+
283+
fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
284+
match op {
285+
CompareOp::Lt => Ok(self.unicode_string() < other.unicode_string()),
286+
CompareOp::Le => Ok(self.unicode_string() <= other.unicode_string()),
287+
CompareOp::Eq => Ok(self.unicode_string() == other.unicode_string()),
288+
CompareOp::Ne => Ok(self.unicode_string() != other.unicode_string()),
289+
CompareOp::Gt => Ok(self.unicode_string() > other.unicode_string()),
290+
CompareOp::Ge => Ok(self.unicode_string() >= other.unicode_string()),
291+
}
292+
}
293+
294+
fn __hash__(&self) -> u64 {
295+
let mut s = DefaultHasher::new();
296+
self.ref_url.clone().into_url().to_string().hash(&mut s);
297+
self.extra_urls.hash(&mut s);
298+
s.finish()
299+
}
300+
301+
fn __bool__(&self) -> bool {
302+
true // an empty string is not a valid URL
303+
}
304+
305+
pub fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> Py<PyAny> {
306+
self.clone().into_py(py)
307+
}
253308
}
254309

255310
fn host_to_dict<'a>(py: Python<'a>, lib_url: &Url) -> PyResult<&'a PyDict> {

tests/validators/test_url.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
2-
from typing import Optional, Union
2+
from copy import deepcopy
3+
from typing import Dict, Optional, Union
34

45
import pytest
56
from dirty_equals import HasRepr, IsInstance
@@ -1140,3 +1141,75 @@ def test_url_vulnerabilities(url_validator, url, expected):
11401141
else:
11411142
output_parts[key] = getattr(output_url, key)
11421143
assert output_parts == expected
1144+
1145+
1146+
def test_multi_host_url_comparison() -> None:
1147+
assert MultiHostUrl('http://example.com,www.example.com') == MultiHostUrl('http://example.com,www.example.com')
1148+
assert MultiHostUrl('http://example.com,www.example.com') == MultiHostUrl('http://example.com,www.example.com/')
1149+
assert MultiHostUrl('http://example.com,www.example.com') != MultiHostUrl('http://example.com,www.example.com/123')
1150+
assert MultiHostUrl('http://example.com,www.example.com/123') > MultiHostUrl('http://example.com,www.example.com')
1151+
assert MultiHostUrl('http://example.com,www.example.com/123') >= MultiHostUrl('http://example.com,www.example.com')
1152+
assert MultiHostUrl('http://example.com,www.example.com') >= MultiHostUrl('http://example.com,www.example.com')
1153+
assert MultiHostUrl('http://example.com,www.example.com') < MultiHostUrl('http://example.com,www.example.com/123')
1154+
assert MultiHostUrl('http://example.com,www.example.com') <= MultiHostUrl('http://example.com,www.example.com/123')
1155+
assert MultiHostUrl('http://example.com,www.example.com') <= MultiHostUrl('http://example.com')
1156+
1157+
1158+
def test_multi_host_url_bool() -> None:
1159+
assert bool(MultiHostUrl('http://example.com,www.example.com')) is True
1160+
1161+
1162+
def test_multi_host_url_hash() -> None:
1163+
data: Dict[MultiHostUrl, int] = {}
1164+
1165+
data[MultiHostUrl('http://example.com,www.example.com')] = 1
1166+
assert data == {MultiHostUrl('http://example.com,www.example.com/'): 1}
1167+
1168+
data[MultiHostUrl('http://example.com,www.example.com/123')] = 2
1169+
assert data == {
1170+
MultiHostUrl('http://example.com,www.example.com/'): 1,
1171+
MultiHostUrl('http://example.com,www.example.com/123'): 2,
1172+
}
1173+
1174+
data[MultiHostUrl('http://example.com,www.example.com')] = 3
1175+
assert data == {
1176+
MultiHostUrl('http://example.com,www.example.com/'): 3,
1177+
MultiHostUrl('http://example.com,www.example.com/123'): 2,
1178+
}
1179+
1180+
1181+
def test_multi_host_url_deepcopy() -> None:
1182+
assert deepcopy(MultiHostUrl('http://example.com')) == MultiHostUrl('http://example.com/')
1183+
1184+
1185+
def test_url_comparison() -> None:
1186+
assert Url('http://example.com') == Url('http://example.com')
1187+
assert Url('http://example.com') == Url('http://example.com/')
1188+
assert Url('http://example.com') != Url('http://example.com/123')
1189+
assert Url('http://example.com/123') > Url('http://example.com')
1190+
assert Url('http://example.com/123') >= Url('http://example.com')
1191+
assert Url('http://example.com') >= Url('http://example.com')
1192+
assert Url('http://example.com') < Url('http://example.com/123')
1193+
assert Url('http://example.com') <= Url('http://example.com/123')
1194+
assert Url('http://example.com') <= Url('http://example.com')
1195+
1196+
1197+
def test_url_bool() -> None:
1198+
assert bool(Url('http://example.com')) is True
1199+
1200+
1201+
def test_url_hash() -> None:
1202+
data: Dict[Url, int] = {}
1203+
1204+
data[Url('http://example.com')] = 1
1205+
assert data == {Url('http://example.com/'): 1}
1206+
1207+
data[Url('http://example.com/123')] = 2
1208+
assert data == {Url('http://example.com/'): 1, Url('http://example.com/123'): 2}
1209+
1210+
data[Url('http://example.com')] = 3
1211+
assert data == {Url('http://example.com/'): 3, Url('http://example.com/123'): 2}
1212+
1213+
1214+
def test_url_deepcopy() -> None:
1215+
assert deepcopy(Url('http://example.com')) == Url('http://example.com/')

0 commit comments

Comments
 (0)