Skip to content

Commit

Permalink
Use serde-untagged to improve some untagged enum error messages (as…
Browse files Browse the repository at this point in the history
…tral-sh#7822)

## Summary

This is related to astral-sh#7817, but
doesn't close it.
  • Loading branch information
charliermarsh authored Sep 30, 2024
1 parent 67769a4 commit b6de417
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 44 deletions.
34 changes: 32 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ same-file = { version = "1.0.6" }
schemars = { version = "0.8.21", features = ["url"] }
seahash = { version = "4.1.0" }
serde = { version = "1.0.210", features = ["derive"] }
serde-untagged = { version = "0.1.6" }
serde_json = { version = "1.0.128" }
sha2 = { version = "0.10.8" }
smallvec = { version = "1.13.2" }
Expand Down
1 change: 1 addition & 0 deletions crates/pypi-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mailparse = { workspace = true }
regex = { workspace = true }
rkyv = { workspace = true }
serde = { workspace = true }
serde-untagged = { workspace = true }
thiserror = { workspace = true }
toml = { workspace = true }
toml_edit = { workspace = true }
Expand Down
43 changes: 27 additions & 16 deletions crates/pypi-types/src/simple_json.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::str::FromStr;

use jiff::Timestamp;
use serde::{Deserialize, Deserializer, Serialize};

use pep440_rs::{VersionSpecifiers, VersionSpecifiersParseError};
use serde::{Deserialize, Deserializer, Serialize};

use crate::lenient_requirement::LenientVersionSpecifiers;

Expand Down Expand Up @@ -71,13 +70,24 @@ where
))
}

#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
#[derive(Debug, Clone)]
pub enum CoreMetadata {
Bool(bool),
Hashes(Hashes),
}

impl<'de> Deserialize<'de> for CoreMetadata {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
serde_untagged::UntaggedEnumVisitor::new()
.bool(|bool| Ok(CoreMetadata::Bool(bool)))
.map(|map| map.deserialize().map(CoreMetadata::Hashes))
.deserialize(deserializer)
}
}

impl CoreMetadata {
pub fn is_available(&self) -> bool {
match self {
Expand All @@ -87,24 +97,25 @@ impl CoreMetadata {
}
}

#[derive(
Debug,
Clone,
PartialEq,
Eq,
Hash,
Deserialize,
rkyv::Archive,
rkyv::Deserialize,
rkyv::Serialize,
)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[rkyv(derive(Debug))]
#[serde(untagged)]
pub enum Yanked {
Bool(bool),
Reason(String),
}

impl<'de> Deserialize<'de> for Yanked {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
serde_untagged::UntaggedEnumVisitor::new()
.bool(|bool| Ok(Yanked::Bool(bool)))
.string(|string| Ok(Yanked::Reason(string.to_owned())))
.deserialize(deserializer)
}
}

impl Yanked {
pub fn is_yanked(&self) -> bool {
match self {
Expand Down
1 change: 1 addition & 0 deletions crates/uv-configuration/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ fs-err = { workspace = true }
rustc-hash = { workspace = true }
schemars = { workspace = true, optional = true }
serde = { workspace = true }
serde-untagged = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
Expand Down
40 changes: 20 additions & 20 deletions crates/uv-configuration/src/trusted_host.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use serde::{Deserialize, Deserializer};
use std::str::FromStr;

use url::Url;

/// A trusted host, which could be a host or a host-port pair.
Expand Down Expand Up @@ -33,28 +33,28 @@ impl TrustedHost {
}
}

#[derive(serde::Deserialize)]
#[serde(untagged)]
enum TrustHostWire {
String(String),
Struct {
scheme: Option<String>,
host: String,
port: Option<u16>,
},
}

impl<'de> serde::de::Deserialize<'de> for TrustedHost {
fn deserialize<D>(deserializer: D) -> Result<TrustedHost, D::Error>
impl<'de> Deserialize<'de> for TrustedHost {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
D: Deserializer<'de>,
{
let helper = TrustHostWire::deserialize(deserializer)?;

match helper {
TrustHostWire::String(s) => TrustedHost::from_str(&s).map_err(serde::de::Error::custom),
TrustHostWire::Struct { scheme, host, port } => Ok(TrustedHost { scheme, host, port }),
#[derive(Deserialize)]
struct Inner {
scheme: Option<String>,
host: String,
port: Option<u16>,
}

serde_untagged::UntaggedEnumVisitor::new()
.string(|string| TrustedHost::from_str(string).map_err(serde::de::Error::custom))
.map(|map| {
map.deserialize::<Inner>().map(|inner| TrustedHost {
scheme: inner.scheme,
host: inner.host,
port: inner.port,
})
})
.deserialize(deserializer)
}
}

Expand Down
29 changes: 26 additions & 3 deletions crates/uv-distribution/src/metadata/requires_dist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,29 @@ mod test {
"###);
}

#[tokio::test]
async fn wrong_type() {
let input = indoc! {r#"
[project]
name = "foo"
version = "0.0.0"
dependencies = [
"tqdm",
]
[tool.uv.sources]
tqdm = true
"#};

assert_snapshot!(format_err(input).await, @r###"
error: TOML parse error at line 8, column 8
|
8 | tqdm = true
| ^^^^
invalid type: boolean `true`, expected an array or map
"###);
}

#[tokio::test]
async fn too_many_git_specs() {
let input = indoc! {r#"
Expand Down Expand Up @@ -264,7 +287,7 @@ mod test {
|
8 | tqdm = { git = "https://github.com/tqdm/tqdm", ref = "baaaaaab" }
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
data did not match any variant of untagged enum SourcesWire
data did not match any variant of untagged enum Source
"###);
}
Expand All @@ -288,7 +311,7 @@ mod test {
|
8 | tqdm = { path = "tqdm", index = "torch" }
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
data did not match any variant of untagged enum SourcesWire
data did not match any variant of untagged enum Source
"###);
}
Expand Down Expand Up @@ -348,7 +371,7 @@ mod test {
|
8 | tqdm = { url = "§invalid#+#*Ä" }
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
data did not match any variant of untagged enum SourcesWire
data did not match any variant of untagged enum Source
"###);
}
Expand Down
1 change: 1 addition & 0 deletions crates/uv-workspace/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ rustc-hash = { workspace = true }
same-file = { workspace = true }
schemars = { workspace = true, optional = true }
serde = { workspace = true, features = ["derive"] }
serde-untagged = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
toml = { workspace = true }
Expand Down
17 changes: 14 additions & 3 deletions crates/uv-workspace/src/pyproject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,15 +444,26 @@ impl IntoIterator for Sources {
}
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "kebab-case", untagged)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema), schemars(untagged))]
#[allow(clippy::large_enum_variant)]
enum SourcesWire {
One(Source),
Many(Vec<Source>),
}

impl<'de> serde::de::Deserialize<'de> for SourcesWire {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
serde_untagged::UntaggedEnumVisitor::new()
.map(|map| map.deserialize().map(SourcesWire::One))
.seq(|seq| seq.deserialize().map(SourcesWire::Many))
.deserialize(deserializer)
}
}

impl TryFrom<SourcesWire> for Sources {
type Error = SourceError;

Expand Down

0 comments on commit b6de417

Please sign in to comment.