Skip to content

Commit 096215e

Browse files
authored
Merge pull request #1026 from google/unsafe-dialect
Unsafe dialect
2 parents 6227d33 + 1f37faf commit 096215e

File tree

3 files changed

+190
-62
lines changed

3 files changed

+190
-62
lines changed

engine/src/conversion/codegen_rs/fun_codegen.rs

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use syn::{
2121
use super::{
2222
function_wrapper_rs::RustParamConversion,
2323
unqualify::{unqualify_params, unqualify_ret_type},
24-
ImplBlockDetails, RsCodegenResult, TraitImplBlockDetails, Use,
24+
ImplBlockDetails, MaybeUnsafeStmt, RsCodegenResult, TraitImplBlockDetails, Use,
2525
};
2626
use crate::{
2727
conversion::{
@@ -30,6 +30,7 @@ use crate::{
3030
TraitMethodDetails,
3131
},
3232
api::UnsafetyNeeded,
33+
codegen_rs::maybe_unsafes_to_tokens,
3334
},
3435
types::{Namespace, QualifiedName},
3536
};
@@ -253,24 +254,21 @@ impl<'a> FnGenerator<'a> {
253254
let mut local_variables = Vec::new();
254255
let mut arg_list = Vec::new();
255256
let mut ptr_arg_name = None;
256-
let wrap_unsafe_calls = self.should_wrap_unsafe_calls();
257257
let ret_type = Cow::Borrowed(ret_type);
258258
for pd in self.param_details {
259259
let wrapper_arg_name = if pd.self_type.is_some() && !avoid_self {
260260
parse_quote!(self)
261261
} else {
262262
pd.name.clone()
263263
};
264-
let rust_for_param = pd
265-
.conversion
266-
.rust_conversion(wrapper_arg_name.clone(), wrap_unsafe_calls);
264+
let rust_for_param = pd.conversion.rust_conversion(wrapper_arg_name.clone());
267265
let RustParamConversion {
268266
ty,
269267
conversion,
270-
local_variables: these_local_variables,
268+
local_variables: mut these_local_variables,
271269
} = rust_for_param;
272270
arg_list.push(conversion.clone());
273-
local_variables.extend(these_local_variables.into_iter());
271+
local_variables.append(&mut these_local_variables);
274272
if pd.is_placement_return_destination {
275273
ptr_arg_name = Some(conversion);
276274
} else {
@@ -280,7 +278,6 @@ impl<'a> FnGenerator<'a> {
280278
));
281279
}
282280
}
283-
let local_variables = quote! { #(#local_variables);* };
284281
if let Some(parameter_reordering) = &parameter_reordering {
285282
wrapper_params = Self::reorder_parameters(wrapper_params, parameter_reordering);
286283
}
@@ -293,34 +290,34 @@ impl<'a> FnGenerator<'a> {
293290
);
294291

295292
let cxxbridge_name = self.cxxbridge_name;
296-
let call_body = quote! {
297-
cxxbridge::#cxxbridge_name ( #(#arg_list),* )
298-
};
299-
let call_body = if let Some(ptr_arg_name) = ptr_arg_name {
293+
let call_body = MaybeUnsafeStmt::maybe_unsafe(
300294
quote! {
301-
autocxx::moveit::new::by_raw(move |#ptr_arg_name| {
302-
let #ptr_arg_name = #ptr_arg_name.get_unchecked_mut().as_mut_ptr();
303-
#call_body
295+
cxxbridge::#cxxbridge_name ( #(#arg_list),* )
296+
},
297+
!matches!(self.unsafety, UnsafetyNeeded::None),
298+
);
299+
let call_stmts = if let Some(ptr_arg_name) = ptr_arg_name {
300+
let mut closure_stmts = local_variables;
301+
closure_stmts.push(MaybeUnsafeStmt::binary(
302+
quote! { let #ptr_arg_name = unsafe { #ptr_arg_name.get_unchecked_mut().as_mut_ptr() };},
303+
quote! { let #ptr_arg_name = #ptr_arg_name.get_unchecked_mut().as_mut_ptr();},
304+
));
305+
closure_stmts.push(call_body);
306+
let closure_stmts = maybe_unsafes_to_tokens(closure_stmts, true);
307+
vec![MaybeUnsafeStmt::needs_unsafe(parse_quote! {
308+
autocxx::moveit::new::by_raw(move |#ptr_arg_name| {
309+
#closure_stmts
304310
})
305-
}
311+
})]
306312
} else {
307-
quote! {
308-
#call_body
309-
}
310-
};
311-
let call_body = if self.should_wrap_unsafe_calls() {
312-
quote! {
313-
unsafe {
314-
#call_body
315-
}
316-
}
317-
} else {
318-
call_body
319-
};
320-
let call_body = quote! {
321-
#local_variables
322-
#call_body
313+
let mut call_stmts = local_variables;
314+
call_stmts.push(call_body);
315+
call_stmts
323316
};
317+
318+
let context_is_unsafe = matches!(self.unsafety, UnsafetyNeeded::Always)
319+
|| self.always_unsafe_due_to_trait_definition;
320+
let call_body = maybe_unsafes_to_tokens(call_stmts, context_is_unsafe);
324321
(lifetime_tokens, wrapper_params, ret_type, call_body)
325322
}
326323

@@ -413,11 +410,6 @@ impl<'a> FnGenerator<'a> {
413410
})
414411
}
415412

416-
fn should_wrap_unsafe_calls(&self) -> bool {
417-
matches!(self.unsafety, UnsafetyNeeded::JustBridge)
418-
|| self.always_unsafe_due_to_trait_definition
419-
}
420-
421413
fn reorder_parameters(
422414
params: Punctuated<FnArg, Comma>,
423415
parameter_ordering: &[usize],

engine/src/conversion/codegen_rs/function_wrapper_rs.rs

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,27 @@ use crate::{
1616
use quote::quote;
1717
use syn::parse_quote;
1818

19+
use super::MaybeUnsafeStmt;
20+
1921
/// Output Rust snippets for how to deal with a given parameter.
2022
pub(super) struct RustParamConversion {
2123
pub(super) ty: Type,
22-
pub(super) local_variables: Option<TokenStream>,
24+
pub(super) local_variables: Vec<MaybeUnsafeStmt>,
2325
pub(super) conversion: TokenStream,
2426
}
2527

2628
impl TypeConversionPolicy {
2729
/// If returns `None` then this parameter should be omitted entirely.
28-
pub(super) fn rust_conversion(&self, var: Pat, wrap_in_unsafe: bool) -> RustParamConversion {
30+
pub(super) fn rust_conversion(&self, var: Pat) -> RustParamConversion {
2931
match self.rust_conversion {
3032
RustConversionType::None => RustParamConversion {
3133
ty: self.converted_rust_type(),
32-
local_variables: None,
34+
local_variables: Vec::new(),
3335
conversion: quote! { #var },
3436
},
3537
RustConversionType::FromStr => RustParamConversion {
3638
ty: parse_quote! { impl ToCppString },
37-
local_variables: None,
39+
local_variables: Vec::new(),
3840
conversion: quote! ( #var .into_cpp() ),
3941
},
4042
RustConversionType::ToBoxedUpHolder(ref sub) => {
@@ -45,7 +47,7 @@ impl TypeConversionPolicy {
4547
};
4648
RustParamConversion {
4749
ty,
48-
local_variables: None,
50+
local_variables: Vec::new(),
4951
conversion: quote! {
5052
Box::new(#holder_type(#var))
5153
},
@@ -61,7 +63,7 @@ impl TypeConversionPolicy {
6163
};
6264
RustParamConversion {
6365
ty,
64-
local_variables: None,
66+
local_variables: Vec::new(),
6567
conversion: quote! {
6668
#var.get_unchecked_mut().as_mut_ptr()
6769
},
@@ -77,7 +79,7 @@ impl TypeConversionPolicy {
7779
};
7880
RustParamConversion {
7981
ty,
80-
local_variables: None,
82+
local_variables: Vec::new(),
8183
conversion: quote! {
8284
{ let r: &mut _ = ::std::pin::Pin::into_inner_unchecked(#var.as_mut());
8385
r
@@ -93,7 +95,7 @@ impl TypeConversionPolicy {
9395
let ty = parse_quote! { &mut #ty };
9496
RustParamConversion {
9597
ty,
96-
local_variables: None,
98+
local_variables: Vec::new(),
9799
conversion: quote! {
98100
#var
99101
},
@@ -106,29 +108,28 @@ impl TypeConversionPolicy {
106108
panic!("Unexpected non-ident parameter name");
107109
};
108110
let space_var_name = make_ident(format!("{}_space", var_name));
109-
let call = quote! { #space_var_name.as_mut().populate(#var_name); };
110-
let call = if wrap_in_unsafe {
111-
quote! {
112-
unsafe {
113-
#call
114-
}
115-
}
116-
} else {
117-
call
118-
};
119111
let ty = &self.unwrapped_type;
120112
let ty = parse_quote! { impl autocxx::ValueParam<#ty> };
121113
// This is the usual trick to put something on the stack, then
122114
// immediately shadow the variable name so it can't be accessed or moved.
123115
RustParamConversion {
124116
ty,
125-
local_variables: Some(quote! {
126-
let mut #space_var_name = autocxx::ValueParamHandler::default();
127-
let mut #space_var_name = unsafe {
128-
std::pin::Pin::new_unchecked(&mut #space_var_name)
129-
};
130-
#call
131-
}),
117+
local_variables: vec![
118+
MaybeUnsafeStmt::new(
119+
quote! { let mut #space_var_name = autocxx::ValueParamHandler::default(); },
120+
),
121+
MaybeUnsafeStmt::binary(
122+
quote! { let mut #space_var_name =
123+
unsafe { std::pin::Pin::new_unchecked(&mut #space_var_name) };
124+
},
125+
quote! { let mut #space_var_name =
126+
std::pin::Pin::new_unchecked(&mut #space_var_name);
127+
},
128+
),
129+
MaybeUnsafeStmt::needs_unsafe(
130+
quote! { #space_var_name.as_mut().populate(#var_name); },
131+
),
132+
],
132133
conversion: quote! {
133134
#space_var_name.get_ptr()
134135
},

engine/src/conversion/codegen_rs/mod.rs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,3 +1206,138 @@ struct RsCodegenResult {
12061206
trait_impl_entry: Option<Box<TraitImplBlockDetails>>,
12071207
materializations: Vec<Use>,
12081208
}
1209+
1210+
/// An [`Item`] that always needs to be in an unsafe block.
1211+
#[derive(Clone)]
1212+
enum MaybeUnsafeStmt {
1213+
// This could almost be a syn::Stmt, but that doesn't quite work
1214+
// because the last stmt in a function is actually an expression
1215+
// thus lacking a semicolon.
1216+
Normal(TokenStream),
1217+
NeedsUnsafe(TokenStream),
1218+
Binary {
1219+
in_safe_context: TokenStream,
1220+
in_unsafe_context: TokenStream,
1221+
},
1222+
}
1223+
1224+
impl MaybeUnsafeStmt {
1225+
fn new(stmt: TokenStream) -> Self {
1226+
Self::Normal(stmt)
1227+
}
1228+
1229+
fn needs_unsafe(stmt: TokenStream) -> Self {
1230+
Self::NeedsUnsafe(stmt)
1231+
}
1232+
1233+
fn maybe_unsafe(stmt: TokenStream, needs_unsafe: bool) -> Self {
1234+
if needs_unsafe {
1235+
Self::NeedsUnsafe(stmt)
1236+
} else {
1237+
Self::Normal(stmt)
1238+
}
1239+
}
1240+
1241+
fn binary(in_safe_context: TokenStream, in_unsafe_context: TokenStream) -> Self {
1242+
Self::Binary {
1243+
in_safe_context,
1244+
in_unsafe_context,
1245+
}
1246+
}
1247+
}
1248+
1249+
fn maybe_unsafes_to_tokens(
1250+
items: Vec<MaybeUnsafeStmt>,
1251+
context_is_already_unsafe: bool,
1252+
) -> TokenStream {
1253+
if context_is_already_unsafe {
1254+
let items = items.into_iter().map(|item| match item {
1255+
MaybeUnsafeStmt::Normal(stmt)
1256+
| MaybeUnsafeStmt::NeedsUnsafe(stmt)
1257+
| MaybeUnsafeStmt::Binary {
1258+
in_unsafe_context: stmt,
1259+
..
1260+
} => stmt,
1261+
});
1262+
quote! {
1263+
#(#items)*
1264+
}
1265+
} else {
1266+
let mut currently_unsafe_list = None;
1267+
let mut output = Vec::new();
1268+
for item in items {
1269+
match item {
1270+
MaybeUnsafeStmt::NeedsUnsafe(stmt) => {
1271+
if currently_unsafe_list.is_none() {
1272+
currently_unsafe_list = Some(Vec::new());
1273+
}
1274+
currently_unsafe_list.as_mut().unwrap().push(stmt);
1275+
}
1276+
MaybeUnsafeStmt::Normal(stmt)
1277+
| MaybeUnsafeStmt::Binary {
1278+
in_safe_context: stmt,
1279+
..
1280+
} => {
1281+
if let Some(currently_unsafe_list) = currently_unsafe_list.take() {
1282+
output.push(quote! {
1283+
unsafe {
1284+
#(#currently_unsafe_list)*
1285+
}
1286+
})
1287+
}
1288+
output.push(stmt);
1289+
}
1290+
}
1291+
}
1292+
if let Some(currently_unsafe_list) = currently_unsafe_list.take() {
1293+
output.push(quote! {
1294+
unsafe {
1295+
#(#currently_unsafe_list)*
1296+
}
1297+
})
1298+
}
1299+
quote! {
1300+
#(#output)*
1301+
}
1302+
}
1303+
}
1304+
1305+
#[test]
1306+
fn test_maybe_unsafes_to_tokens() {
1307+
let items = vec![
1308+
MaybeUnsafeStmt::new(quote! { use A; }),
1309+
MaybeUnsafeStmt::new(quote! { use B; }),
1310+
MaybeUnsafeStmt::needs_unsafe(quote! { use C; }),
1311+
MaybeUnsafeStmt::needs_unsafe(quote! { use D; }),
1312+
MaybeUnsafeStmt::new(quote! { use E; }),
1313+
MaybeUnsafeStmt::needs_unsafe(quote! { use F; }),
1314+
];
1315+
assert_eq!(
1316+
maybe_unsafes_to_tokens(items.clone(), false).to_string(),
1317+
quote! {
1318+
use A;
1319+
use B;
1320+
unsafe {
1321+
use C;
1322+
use D;
1323+
}
1324+
use E;
1325+
unsafe {
1326+
use F;
1327+
}
1328+
}
1329+
.to_string()
1330+
);
1331+
assert_eq!(
1332+
maybe_unsafes_to_tokens(items, true).to_string(),
1333+
quote! {
1334+
use A;
1335+
use B;
1336+
use C;
1337+
use D;
1338+
use E;
1339+
use F;
1340+
}
1341+
.to_string()
1342+
);
1343+
}

0 commit comments

Comments
 (0)