@@ -4,7 +4,7 @@ use std::str::Chars;
44
55use pyo3:: intern;
66use pyo3:: prelude:: * ;
7- use pyo3:: types:: { PyDict , PyList } ;
7+ use pyo3:: types:: { PyDict , PyList , PyType } ;
88
99use ahash:: AHashSet ;
1010use url:: { ParseError , SyntaxViolation , Url } ;
@@ -26,6 +26,7 @@ type AllowedSchemas = Option<(AHashSet<String>, String)>;
2626#[ derive( Debug , Clone ) ]
2727pub struct UrlValidator {
2828 strict : bool ,
29+ cls : Option < Py < PyType > > ,
2930 max_length : Option < usize > ,
3031 allowed_schemes : AllowedSchemas ,
3132 host_required : bool ,
@@ -47,6 +48,7 @@ impl BuildValidator for UrlValidator {
4748
4849 Ok ( Self {
4950 strict : is_strict ( schema, config) ?,
51+ cls : schema. get_as ( intern ! ( schema. py( ) , "cls" ) ) ?,
5052 max_length : schema. get_as ( intern ! ( schema. py( ) , "max_length" ) ) ?,
5153 host_required : schema. get_as ( intern ! ( schema. py( ) , "host_required" ) ) ?. unwrap_or ( false ) ,
5254 default_host : schema. get_as ( intern ! ( schema. py( ) , "default_host" ) ) ?,
@@ -59,7 +61,7 @@ impl BuildValidator for UrlValidator {
5961 }
6062}
6163
62- impl_py_gc_traverse ! ( UrlValidator { } ) ;
64+ impl_py_gc_traverse ! ( UrlValidator { cls } ) ;
6365
6466impl Validator for UrlValidator {
6567 fn validate < ' py > (
@@ -93,7 +95,31 @@ impl Validator for UrlValidator {
9395 Ok ( ( ) ) => {
9496 // Lax rather than strict to preserve V2.4 semantic that str wins over url in union
9597 state. floor_exactness ( Exactness :: Lax ) ;
96- Ok ( either_url. into_py ( py) )
98+
99+ if let Some ( url_subclass) = & self . cls {
100+ // TODO: we do an extra build for a subclass here, we should avoid this
101+ // in v2.11 for perf reasons, but this is a worthwhile patch for now
102+ // given that we want isinstance to work properly for subclasses of Url
103+ let py_url = match either_url {
104+ EitherUrl :: Py ( py_url) => py_url. get ( ) . clone ( ) ,
105+ EitherUrl :: Rust ( rust_url) => PyUrl :: new ( rust_url) ,
106+ } ;
107+
108+ let py_url = PyUrl :: build (
109+ url_subclass. bind ( py) ,
110+ py_url. scheme ( ) ,
111+ py_url. host ( ) ,
112+ py_url. username ( ) ,
113+ py_url. password ( ) ,
114+ py_url. port ( ) ,
115+ py_url. path ( ) . filter ( |path| * path != "/" ) ,
116+ py_url. query ( ) ,
117+ py_url. fragment ( ) ,
118+ ) ?;
119+ Ok ( py_url. into_py ( py) )
120+ } else {
121+ Ok ( either_url. into_py ( py) )
122+ }
97123 }
98124 Err ( error_type) => Err ( ValError :: new ( error_type, input) ) ,
99125 }
@@ -186,6 +212,7 @@ impl CopyFromPyUrl for EitherUrl<'_> {
186212#[ derive( Debug , Clone ) ]
187213pub struct MultiHostUrlValidator {
188214 strict : bool ,
215+ cls : Option < Py < PyType > > ,
189216 max_length : Option < usize > ,
190217 allowed_schemes : AllowedSchemas ,
191218 host_required : bool ,
@@ -213,6 +240,7 @@ impl BuildValidator for MultiHostUrlValidator {
213240 }
214241 Ok ( Self {
215242 strict : is_strict ( schema, config) ?,
243+ cls : schema. get_as ( intern ! ( schema. py( ) , "cls" ) ) ?,
216244 max_length : schema. get_as ( intern ! ( schema. py( ) , "max_length" ) ) ?,
217245 allowed_schemes,
218246 host_required : schema. get_as ( intern ! ( schema. py( ) , "host_required" ) ) ?. unwrap_or ( false ) ,
@@ -225,7 +253,7 @@ impl BuildValidator for MultiHostUrlValidator {
225253 }
226254}
227255
228- impl_py_gc_traverse ! ( MultiHostUrlValidator { } ) ;
256+ impl_py_gc_traverse ! ( MultiHostUrlValidator { cls } ) ;
229257
230258impl Validator for MultiHostUrlValidator {
231259 fn validate < ' py > (
@@ -258,7 +286,38 @@ impl Validator for MultiHostUrlValidator {
258286 Ok ( ( ) ) => {
259287 // Lax rather than strict to preserve V2.4 semantic that str wins over url in union
260288 state. floor_exactness ( Exactness :: Lax ) ;
261- Ok ( multi_url. into_py ( py) )
289+
290+ if let Some ( url_subclass) = & self . cls {
291+ // TODO: we do an extra build for a subclass here, we should avoid this
292+ // in v2.11 for perf reasons, but this is a worthwhile patch for now
293+ // given that we want isinstance to work properly for subclasses of Url
294+ let py_url = match multi_url {
295+ EitherMultiHostUrl :: Py ( py_url) => py_url. get ( ) . clone ( ) ,
296+ EitherMultiHostUrl :: Rust ( rust_url) => rust_url,
297+ } ;
298+
299+ let hosts = py_url
300+ . hosts ( py) ?
301+ . into_iter ( )
302+ . map ( |host| host. extract ( ) . expect ( "host should be a valid UrlHostParts" ) )
303+ . collect ( ) ;
304+
305+ let py_url = PyMultiHostUrl :: build (
306+ url_subclass. bind ( py) ,
307+ py_url. scheme ( ) ,
308+ Some ( hosts) ,
309+ py_url. path ( ) . filter ( |path| * path != "/" ) ,
310+ py_url. query ( ) ,
311+ py_url. fragment ( ) ,
312+ None ,
313+ None ,
314+ None ,
315+ None ,
316+ ) ?;
317+ Ok ( py_url. into_py ( py) )
318+ } else {
319+ Ok ( multi_url. into_py ( py) )
320+ }
262321 }
263322 Err ( error_type) => Err ( ValError :: new ( error_type, input) ) ,
264323 }
0 commit comments