Skip to content

Commit 309207a

Browse files
committed
Adds support for zero-copy FromSql derive
Similar in spirit and implementation to how `serde` supports borrowed deserialization lifetimes. Like `serde`, it uses `'de` as the lifetime of the trait (`FromoSql<'de>`) and it adds all borrowed liftimes as lifetime bounds on `'de`. Also like `serde` all top level references automatically have their lifetimes borrowed, but container types carrying references must explicity be borrowed with the `#[postgres(borrow)]` annotation.
1 parent c5ff8cf commit 309207a

File tree

8 files changed

+164
-27
lines changed

8 files changed

+164
-27
lines changed

postgres-derive-test/src/composites.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,3 +346,32 @@ fn generics() {
346346
},
347347
);
348348
}
349+
350+
#[test]
351+
fn struct_with_borrowed_fields() {
352+
#[derive(FromSql, ToSql, Debug, PartialEq)]
353+
#[postgres(name = "item")]
354+
struct Item<'a, 'b: 'a> {
355+
name: &'a str,
356+
data: &'b [u8],
357+
}
358+
359+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
360+
conn.batch_execute(
361+
"CREATE TYPE pg_temp.item AS (
362+
name TEXT,
363+
data BYTEA
364+
);",
365+
)
366+
.unwrap();
367+
368+
let item = Item {
369+
name: "foobar",
370+
data: b"12345",
371+
};
372+
373+
let row = conn.query_one("SELECT $1::item", &[&item]).unwrap();
374+
let result: Item<'_, '_> = row.get(0);
375+
assert_eq!(item.name, result.name);
376+
assert_eq!(item.data, result.data);
377+
}

postgres-derive-test/src/transparent.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,38 @@ fn round_trip() {
1616
UserId(123)
1717
);
1818
}
19+
20+
#[test]
21+
fn struct_with_reference() {
22+
#[derive(FromSql, ToSql, Debug, PartialEq)]
23+
#[postgres(transparent)]
24+
struct UserName<'a>(&'a str);
25+
26+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
27+
28+
let user_name = "tester";
29+
let row = conn
30+
.query_one("SELECT $1", &[&UserName(user_name)])
31+
.unwrap();
32+
let result: UserName<'_> = row.get(0);
33+
assert_eq!(user_name, result.0);
34+
}
35+
36+
#[test]
37+
fn nested_struct_with_reference() {
38+
#[derive(FromSql, ToSql, Debug, PartialEq)]
39+
#[postgres(transparent)]
40+
struct Inner<'a>(&'a str);
41+
42+
#[derive(FromSql, ToSql, Debug, PartialEq)]
43+
#[postgres(transparent)]
44+
struct UserName<'a>(#[postgres(borrow)] Inner<'a>);
45+
46+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
47+
48+
let user_name = "tester";
49+
let inner = Inner(user_name);
50+
let row = conn.query_one("SELECT $1", &[&UserName(inner)]).unwrap();
51+
let result: UserName<'_> = row.get(0);
52+
assert_eq!(user_name, result.0 .0);
53+
}

postgres-derive/src/accepts.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use quote::quote;
33
use std::iter;
44
use syn::Ident;
55

6-
use crate::composites::Field;
6+
use crate::composites::NamedField;
77
use crate::enums::Variant;
88

99
pub fn transparent_body(field: &syn::Field) -> TokenStream {
@@ -66,7 +66,7 @@ pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> Toke
6666
}
6767
}
6868

69-
pub fn composite_body(name: &str, trait_: &str, fields: &[Field]) -> TokenStream {
69+
pub fn composite_body(name: &str, trait_: &str, fields: &[NamedField]) -> TokenStream {
7070
let num_fields = fields.len();
7171
let trait_ = Ident::new(trait_, Span::call_site());
7272
let traits = iter::repeat(&trait_);

postgres-derive/src/composites.rs

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
use proc_macro2::Span;
2+
use std::collections::HashSet;
23
use syn::{
3-
punctuated::Punctuated, Error, GenericParam, Generics, Ident, Path, PathSegment, Type,
4-
TypeParamBound,
4+
punctuated::Punctuated, AngleBracketedGenericArguments, Error, GenericArgument, GenericParam,
5+
Generics, Ident, Lifetime, Path, PathArguments, PathSegment, Type, TypeParamBound,
56
};
67

78
use crate::{case::RenameRule, overrides::Overrides};
89

9-
pub struct Field {
10+
pub struct NamedField {
1011
pub name: String,
1112
pub ident: Ident,
1213
pub type_: Type,
14+
pub borrowed_lifetimes: HashSet<Lifetime>,
1315
}
1416

15-
impl Field {
16-
pub fn parse(raw: &syn::Field, rename_all: Option<RenameRule>) -> Result<Field, Error> {
17+
impl NamedField {
18+
pub fn parse(raw: &syn::Field, rename_all: Option<RenameRule>) -> Result<NamedField, Error> {
1719
let overrides = Overrides::extract(&raw.attrs, false)?;
1820
let ident = raw.ident.as_ref().unwrap().clone();
1921

20-
// field level name override takes precendence over container level rename_all override
22+
let borrowed_lifetimes = extract_borrowed_lifetimes(raw, &overrides);
23+
24+
// field level name override takes precedence over container level rename_all override
2125
let name = match overrides.name {
2226
Some(n) => n,
2327
None => {
@@ -31,14 +35,60 @@ impl Field {
3135
}
3236
};
3337

34-
Ok(Field {
38+
Ok(NamedField {
3539
name,
3640
ident,
3741
type_: raw.ty.clone(),
42+
borrowed_lifetimes,
3843
})
3944
}
4045
}
4146

47+
pub struct UnnamedField {
48+
pub borrowed_lifetimes: HashSet<Lifetime>,
49+
}
50+
51+
impl UnnamedField {
52+
pub fn parse(raw: &syn::Field) -> Result<UnnamedField, Error> {
53+
let overrides = Overrides::extract(&raw.attrs, false)?;
54+
let borrowed_lifetimes = extract_borrowed_lifetimes(raw, &overrides);
55+
Ok(UnnamedField { borrowed_lifetimes })
56+
}
57+
}
58+
59+
pub(crate) fn extract_borrowed_lifetimes(
60+
raw: &syn::Field,
61+
overrides: &Overrides,
62+
) -> HashSet<Lifetime> {
63+
let mut borrowed_lifetimes = HashSet::new();
64+
65+
// If the field is a reference, it's lifetime should be implicitly borrowed. Serde does
66+
// the same thing
67+
if let Type::Reference(ref_type) = &raw.ty {
68+
borrowed_lifetimes.insert(ref_type.lifetime.to_owned().unwrap());
69+
}
70+
71+
// Borrow all generic lifetimes of fields marked with #[postgres(borrow)]
72+
if overrides.borrows {
73+
if let Type::Path(type_path) = &raw.ty {
74+
for segment in &type_path.path.segments {
75+
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
76+
args, ..
77+
}) = &segment.arguments
78+
{
79+
let lifetimes = args.iter().filter_map(|a| match a {
80+
GenericArgument::Lifetime(lifetime) => Some(lifetime.to_owned()),
81+
_ => None,
82+
});
83+
borrowed_lifetimes.extend(lifetimes);
84+
}
85+
}
86+
}
87+
}
88+
89+
borrowed_lifetimes
90+
}
91+
4292
pub(crate) fn append_generic_bound(mut generics: Generics, bound: &TypeParamBound) -> Generics {
4393
for param in &mut generics.params {
4494
if let GenericParam::Type(param) = param {

postgres-derive/src/fromsql.rs

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use std::collections::{BTreeSet, HashSet};
12
use proc_macro2::{Span, TokenStream};
23
use quote::{format_ident, quote};
34
use std::iter;
5+
use std::iter::FromIterator;
46
use syn::{
57
punctuated::Punctuated, token, AngleBracketedGenericArguments, Data, DataStruct, DeriveInput,
68
Error, Fields, GenericArgument, GenericParam, Generics, Ident, Lifetime, PathArguments,
@@ -9,8 +11,8 @@ use syn::{
911
use syn::{LifetimeParam, TraitBound, TraitBoundModifier, TypeParamBound};
1012

1113
use crate::accepts;
12-
use crate::composites::Field;
1314
use crate::composites::{append_generic_bound, new_derive_path};
15+
use crate::composites::{NamedField, UnnamedField};
1416
use crate::enums::Variant;
1517
use crate::overrides::Overrides;
1618

@@ -29,16 +31,18 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
2931
.clone()
3032
.unwrap_or_else(|| input.ident.to_string());
3133

32-
let (accepts_body, to_sql_body) = if overrides.transparent {
34+
let (accepts_body, to_sql_body, borrowed_lifetimes) = if overrides.transparent {
3335
match input.data {
3436
Data::Struct(DataStruct {
3537
fields: Fields::Unnamed(ref fields),
3638
..
3739
}) if fields.unnamed.len() == 1 => {
3840
let field = fields.unnamed.first().unwrap();
41+
let parsed_field = UnnamedField::parse(field)?;
3942
(
4043
accepts::transparent_body(field),
4144
transparent_body(&input.ident, field),
45+
parsed_field.borrowed_lifetimes,
4246
)
4347
}
4448
_ => {
@@ -59,6 +63,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
5963
(
6064
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
6165
enum_body(&input.ident, &variants),
66+
HashSet::new(),
6267
)
6368
}
6469
_ => {
@@ -79,16 +84,19 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
7984
(
8085
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
8186
enum_body(&input.ident, &variants),
87+
HashSet::new(),
8288
)
8389
}
8490
Data::Struct(DataStruct {
8591
fields: Fields::Unnamed(ref fields),
8692
..
8793
}) if fields.unnamed.len() == 1 => {
8894
let field = fields.unnamed.first().unwrap();
95+
let parsed_field = UnnamedField::parse(field)?;
8996
(
9097
domain_accepts_body(&name, field),
9198
domain_body(&input.ident, field),
99+
parsed_field.borrowed_lifetimes,
92100
)
93101
}
94102
Data::Struct(DataStruct {
@@ -98,11 +106,16 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
98106
let fields = fields
99107
.named
100108
.iter()
101-
.map(|field| Field::parse(field, overrides.rename_all))
109+
.map(|field| NamedField::parse(field, overrides.rename_all))
102110
.collect::<Result<Vec<_>, _>>()?;
111+
let borrowed_lifetimes: HashSet<_> = fields
112+
.iter()
113+
.flat_map(|f| f.borrowed_lifetimes.to_owned())
114+
.collect();
103115
(
104116
accepts::composite_body(&name, "FromSql", &fields),
105117
composite_body(&input.ident, &fields),
118+
borrowed_lifetimes
106119
)
107120
}
108121
_ => {
@@ -115,7 +128,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
115128
};
116129

117130
let ident = &input.ident;
118-
let (generics, lifetime) = build_generics(&input.generics);
131+
let (generics, lifetime) = build_generics(&input.generics, borrowed_lifetimes);
119132
let (impl_generics, _, _) = generics.split_for_impl();
120133
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
121134
let out = quote! {
@@ -183,7 +196,7 @@ fn domain_body(ident: &Ident, field: &syn::Field) -> TokenStream {
183196
}
184197
}
185198

186-
fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
199+
fn composite_body(ident: &Ident, fields: &[NamedField]) -> TokenStream {
187200
let temp_vars = &fields
188201
.iter()
189202
.map(|f| format_ident!("__{}", f.ident))
@@ -233,16 +246,15 @@ fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
233246
}
234247
}
235248

236-
fn build_generics(source: &Generics) -> (Generics, Lifetime) {
237-
// don't worry about lifetime name collisions, it doesn't make sense to derive FromSql on a struct with a lifetime
238-
let lifetime = Lifetime::new("'a", Span::call_site());
239-
249+
fn build_generics(source: &Generics, borrowed_lifetimes: HashSet<Lifetime>) -> (Generics, Lifetime) {
250+
// This is the same parent lifetime name serde uses
251+
let lifetime = Lifetime::new("'de", Span::call_site());
252+
// Sort lifetimes for deterministic code-gen
253+
let sorted_lifetimes = BTreeSet::from_iter(borrowed_lifetimes);
254+
let mut lifetime_param = LifetimeParam::new(lifetime.to_owned());
255+
lifetime_param.bounds.extend(sorted_lifetimes);
240256
let mut out = append_generic_bound(source.to_owned(), &new_fromsql_bound(&lifetime));
241-
out.params.insert(
242-
0,
243-
GenericParam::Lifetime(LifetimeParam::new(lifetime.to_owned())),
244-
);
245-
257+
out.params.insert(0, GenericParam::Lifetime(lifetime_param));
246258
(out, lifetime)
247259
}
248260

postgres-derive/src/overrides.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub struct Overrides {
88
pub rename_all: Option<RenameRule>,
99
pub transparent: bool,
1010
pub allow_mismatch: bool,
11+
pub borrows: bool,
1112
}
1213

1314
impl Overrides {
@@ -17,6 +18,7 @@ impl Overrides {
1718
rename_all: None,
1819
transparent: false,
1920
allow_mismatch: false,
21+
borrows: false,
2022
};
2123

2224
for attr in attrs {
@@ -92,6 +94,14 @@ impl Overrides {
9294
));
9395
}
9496
overrides.allow_mismatch = true;
97+
} else if path.is_ident("borrow") {
98+
if container_attr {
99+
return Err(Error::new_spanned(
100+
path,
101+
"#[postgres(borrow)] is a field attribute",
102+
));
103+
}
104+
overrides.borrows = true;
95105
} else {
96106
return Err(Error::new_spanned(path, "unknown override"));
97107
}

postgres-derive/src/tosql.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use syn::{
77
};
88

99
use crate::accepts;
10-
use crate::composites::Field;
10+
use crate::composites::NamedField;
1111
use crate::composites::{append_generic_bound, new_derive_path};
1212
use crate::enums::Variant;
1313
use crate::overrides::Overrides;
@@ -92,7 +92,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
9292
let fields = fields
9393
.named
9494
.iter()
95-
.map(|field| Field::parse(field, overrides.rename_all))
95+
.map(|field| NamedField::parse(field, overrides.rename_all))
9696
.collect::<Result<Vec<_>, _>>()?;
9797
(
9898
accepts::composite_body(&name, "ToSql", &fields),
@@ -168,7 +168,7 @@ fn domain_body() -> TokenStream {
168168
}
169169
}
170170

171-
fn composite_body(fields: &[Field]) -> TokenStream {
171+
fn composite_body(fields: &[NamedField]) -> TokenStream {
172172
let field_names = fields.iter().map(|f| &f.name);
173173
let field_idents = fields.iter().map(|f| &f.ident);
174174

postgres-types/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ with-time-0_3 = ["time-03"]
3131
bytes = "1.0"
3232
fallible-iterator = "0.2"
3333
postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" }
34-
postgres-derive = { version = "0.4.5", optional = true, path = "../postgres-derive" }
34+
#postgres-derive = { version = "0.4.5", optional = true, path = "../postgres-derive" }
35+
postgres-derive = { optional = true, path = "../postgres-derive" }
3536

3637
array-init = { version = "2", optional = true }
3738
bit-vec-06 = { version = "0.6", package = "bit-vec", optional = true }

0 commit comments

Comments
 (0)