Skip to content

Commit

Permalink
Improve codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
zshipko committed Aug 2, 2022
1 parent 08c773e commit 4e72040
Show file tree
Hide file tree
Showing 12 changed files with 107 additions and 62 deletions.
5 changes: 5 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# 0.1.1

- Fix codegen for entries that return no value
- Fix codegen for structs and enums

# 0.1.0

- Initial release
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "futhark-bindgen"
version = "0.1.0"
version = "0.1.1"
edition = "2021"
authors = ["Zach Shipko <[email protected]>"]
license = "ISC"
Expand Down
11 changes: 11 additions & 0 deletions src/generate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ pub(crate) fn first_uppercase(s: &str) -> String {
s
}

pub(crate) fn convert_struct_name(s: &str) -> &str {
s.strip_prefix("struct")
.unwrap()
.strip_suffix('*')
.unwrap()
.strip_prefix(|x: char| x.is_ascii_whitespace())
.unwrap()
.strip_suffix(|x: char| x.is_ascii_whitespace())
.unwrap()
}

/*pub(crate) fn first_lowercase(s: &str) -> String {
let mut s = s.to_string();
if let Some(r) = s.get_mut(0..1) {
Expand Down
67 changes: 48 additions & 19 deletions src/generate/ocaml.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io::Write;

use crate::generate::first_uppercase;
use crate::generate::{convert_struct_name, first_uppercase};
use crate::*;

pub struct OCaml {
Expand All @@ -15,9 +15,9 @@ const OCAML_CTYPES_MAP: &[(&str, &str)] = &[
("i16", "int16_t"),
("u16", "uint16_t"),
("i32", "int32_t"),
("u32", "uint32_t"),
("u32", "int32_t"),
("i64", "int64_t"),
("u64", "uint64_t"),
("u64", "int64_t"),
("f16", ""),
("f32", "float"),
("f64", "double"),
Expand Down Expand Up @@ -178,7 +178,7 @@ impl Generate for OCaml {
self.ctypes_map.insert(name.clone(), ocaml_name.clone());
let elem_ptr = format!("ptr {ctypes_elemtype}");
generated_foreign_functions.push(format!(
" let {ocaml_name} = typedef (ptr void) \"array_{elemtype}_{rank}d\""
" let {ocaml_name} = typedef (ptr void) \"{ocaml_name}\""
));
let mut new_args = vec!["context", &elem_ptr];
new_args.resize(rank as usize + 2, "int64_t");
Expand Down Expand Up @@ -208,13 +208,26 @@ impl Generate for OCaml {
));
}
manifest::Type::Opaque(ty) => {
generated_foreign_functions
.push(format!(" let {name} = typedef (ptr void) \"{name}\""));
let futhark_name = convert_struct_name(&ty.ctype);
let mut ocaml_name = futhark_name
.strip_prefix("futhark_opaque_")
.unwrap()
.to_string();
if ocaml_name.chars().next().unwrap().is_numeric() || name.contains(' ') {
ocaml_name = format!("type_{ocaml_name}");
}

self.typemap
.insert(name.clone(), format!("{}.t", first_uppercase(&ocaml_name)));
self.ctypes_map.insert(name.to_string(), ocaml_name.clone());
generated_foreign_functions.push(format!(
" let {ocaml_name} = typedef (ptr void) \"{futhark_name}\""
));

let free_fn = &ty.ops.free;
generated_foreign_functions.push(format!(
" {}",
self.foreign_function(free_fn, "int", vec!["context", name])
self.foreign_function(free_fn, "int", vec!["context", &ocaml_name])
));

let record = match &ty.record {
Expand All @@ -223,7 +236,7 @@ impl Generate for OCaml {
};

let new_fn = &record.new;
let mut args = vec!["context".to_string(), format!("ptr {name}")];
let mut args = vec!["context".to_string(), format!("ptr {ocaml_name}")];
for f in record.fields.iter() {
let cty = self
.ctypes_map
Expand All @@ -247,7 +260,7 @@ impl Generate for OCaml {
self.foreign_function(
&f.project,
"int",
vec!["context", &format!("ptr {cty}"), name]
vec!["context", &format!("ptr {cty}"), &ocaml_name]
)
));
}
Expand Down Expand Up @@ -349,9 +362,17 @@ impl Generate for OCaml {
)?;
}
manifest::Type::Opaque(ty) => {
let module_name = first_uppercase(name);
let futhark_name = convert_struct_name(&ty.ctype);
let mut ocaml_name = futhark_name
.strip_prefix("futhark_opaque_")
.unwrap()
.to_string();
if ocaml_name.chars().next().unwrap().is_numeric() || name.contains(' ') {
ocaml_name = format!("type_{ocaml_name}");
}
let module_name = first_uppercase(&ocaml_name);
self.typemap
.insert(name.clone(), format!("{module_name}.t"));
.insert(ocaml_name.clone(), format!("{module_name}.t"));

let free_fn = &ty.ops.free;

Expand All @@ -362,7 +383,7 @@ impl Generate for OCaml {
config.output_file,
include_str!("templates/ocaml/opaque.ml"),
free_fn = free_fn,
name = name,
name = ocaml_name,
)?;
writeln!(mli_file, include_str!("templates/ocaml/opaque.mli"),)?;

Expand Down Expand Up @@ -414,12 +435,6 @@ impl Generate for OCaml {
let name = &f.name;
let project = &f.project;

let s = if type_is_array(&t) {
format!("Bindings.{t}")
} else {
t.clone()
};

let out = if type_is_opaque(&t) {
let call = t.replace(".t", ".of_ptr");
format!("{call} t.opaque_ctx !@out")
Expand All @@ -436,6 +451,14 @@ impl Generate for OCaml {
t.to_string()
};

let s = if type_is_array(&t) {
format!("Bindings.{t}")
} else if !t.ends_with(".t") {
self.get_ctype(&f.r#type)
} else {
t
};

writeln!(
config.output_file,
include_str!("templates/ocaml/record_project.ml"),
Expand Down Expand Up @@ -554,12 +577,18 @@ impl Generate for OCaml {
call_args = call_args.join(" "),
out_return = out_return.join(", ")
)?;

let return_type = if return_type.is_empty() {
"unit".to_string()
} else {
return_type.join(", ")
};
writeln!(
mli_file,
include_str!("templates/ocaml/entry.mli"),
name = name,
arg_types = arg_types.join(" -> "),
return_type = return_type.join(", "),
return_type = return_type,
)?;
}

Expand Down
68 changes: 32 additions & 36 deletions src/generate/rust.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::generate::first_uppercase;
use crate::generate::{convert_struct_name, first_uppercase};
use crate::*;
use std::io::Write;

Expand Down Expand Up @@ -42,7 +42,7 @@ impl Rust {
let elemtype = a.elemtype.to_str();
let rank = a.rank;

let futhark_type = format!("futhark_{elemtype}_{rank}d");
let futhark_type = convert_struct_name(&a.ctype).to_string();
let rust_type = format!("Array{}D{rank}", elemtype.to_ascii_uppercase(),);
let info = ArrayInfo {
futhark_type,
Expand Down Expand Up @@ -96,27 +96,31 @@ impl Rust {
name: &str,
ty: &manifest::OpaqueType,
config: &mut Config,
) -> Result<String, Error> {
let futhark_type = format!("futhark_opaque_{name}");
let rust_type = first_uppercase(name);
) -> Result<(String, String), Error> {
let futhark_type = convert_struct_name(&ty.ctype).to_string();
let mut rust_type = first_uppercase(futhark_type.strip_prefix("futhark_opaque_").unwrap());
if rust_type.chars().next().unwrap().is_numeric() || name.contains(' ') {
rust_type = format!("Type{}", rust_type);
}

writeln!(
config.output_file,
include_str!("templates/rust/opaque.rs"),
futhark_type = futhark_type,
rust_type = rust_type,
name = name,
free_fn = ty.ops.free,
)?;

let record = match &ty.record {
Some(r) => r,
None => return Ok(rust_type),
None => return Ok((futhark_type, rust_type)),
};

let mut new_call_args = vec![];
let mut new_params = vec![];
let mut new_extern_params = vec![];
for field in record.fields.iter() {
// Build new function
let a = Self::get_type(&self.typemap, &field.r#type);
let t = Self::get_type(&self.typemap, &a);

Expand All @@ -138,24 +142,8 @@ impl Rust {
}

new_params.push(format!("field{}: {u}", field.name));
}

writeln!(
config.output_file,
include_str!("templates/rust/record.rs"),
rust_type = rust_type,
futhark_type = futhark_type,
new_fn = record.new,
new_params = new_params.join(", "),
new_call_args = new_call_args.join(", "),
new_extern_params = new_extern_params.join(", "),
name = name,
)?;

// Implement get functions
for field in record.fields.iter() {
let a = Self::get_type(&self.typemap, &field.r#type);
let t = Self::get_type(&self.typemap, &a);
// Implement get function

// If the output type is an array or opaque type then we need to wrap the return value
let (output, futhark_field_type) = if type_is_opaque(&a) || type_is_array(&t) {
Expand All @@ -176,12 +164,22 @@ impl Rust {
field_name = field.name,
futhark_field_type = futhark_field_type,
rust_field_type = t,
name = name,
output = output
)?;
}

Ok(rust_type)
writeln!(
config.output_file,
include_str!("templates/rust/record.rs"),
rust_type = rust_type,
futhark_type = futhark_type,
new_fn = record.new,
new_params = new_params.join(", "),
new_call_args = new_call_args.join(", "),
new_extern_params = new_extern_params.join(", "),
)?;

Ok((futhark_type, rust_type))
}

fn generate_entry_function(
Expand Down Expand Up @@ -254,13 +252,13 @@ impl Rust {
}
}

let (entry_return_type, entry_return) = if entry.outputs.len() <= 1 {
(return_type.join(", "), entry_return.join(", "))
} else {
(
let (entry_return_type, entry_return) = match entry.outputs.len() {
0 => ("()".to_string(), "()".to_string()),
1 => (return_type.join(", "), entry_return.join(", ")),
_ => (
format!("({})", return_type.join(", ")),
format!("({})", entry_return.join(", ")),
)
),
};

writeln!(
Expand Down Expand Up @@ -336,11 +334,9 @@ impl Generate for Rust {
self.typemap.insert(info.futhark_type, info.rust_type);
}
manifest::Type::Opaque(ty) => {
let rust_type = self.generate_opaque_type(name, ty, config)?;
self.typemap
.insert(name.clone(), format!("futhark_opaque_{name}"));
self.typemap
.insert(format!("futhark_opaque_{name}"), rust_type);
let (futhark_type, rust_type) = self.generate_opaque_type(name, ty, config)?;
self.typemap.insert(name.clone(), futhark_type.clone());
self.typemap.insert(futhark_type, rust_type);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/generate/templates/ocaml/record.mli
Original file line number Diff line number Diff line change
@@ -1 +1 @@
val v: Context.t -> {new_arg_types} -> t
val v: Context.t -> {new_arg_types} -> t
1 change: 1 addition & 0 deletions src/generate/templates/rust/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ impl Context {{
}};
if rc != 0 {{ return Err(Error::Code(rc)); }}

#[allow(unused_unsafe)]
unsafe {{
Ok({entry_return})
}}
Expand Down
4 changes: 2 additions & 2 deletions src/generate/templates/rust/opaque.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ struct {futhark_type} {{
}}

extern "C" {{
fn futhark_free_opaque_{name}(
fn {free_fn}(
_: *mut futhark_context,
_: *mut {futhark_type}
) -> std::os::raw::c_int;
Expand All @@ -26,7 +26,7 @@ impl<'a> {rust_type}<'a> {{
impl<'a> Drop for {rust_type}<'a> {{
fn drop(&mut self) {{
unsafe {{
futhark_free_opaque_{name}(self.ctx.context, self.data);
{free_fn}(self.ctx.context, self.data);
}}
}}
}}
2 changes: 1 addition & 1 deletion src/generate/templates/rust/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ impl<'a> {rust_type}<'a> {{
pub fn new(ctx: &'a Context, {new_params}) -> std::result::Result<Self, Error> {{
unsafe {{
let mut out = std::ptr::null_mut();
let rc = futhark_new_opaque_{name}(ctx.context, &mut out, {new_call_args});
let rc = {new_fn}(ctx.context, &mut out, {new_call_args});
if rc != 0 {{ return Err(Error::Code(rc)); }}
ctx.auto_sync();
Ok(Self {{ data: out, ctx }})
Expand Down
2 changes: 1 addition & 1 deletion src/generate/templates/rust/record_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ impl<'a> {rust_type}<'a> {{
pub fn get_{field_name}(&self) -> Result<{rust_field_type}, Error> {{
let mut out = std::mem::MaybeUninit::zeroed();
let rc = unsafe {{
futhark_project_opaque_{name}_{field_name}(
{project_fn}(
self.ctx.context,
out.as_mut_ptr(),
self.data
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Backend {
match self {
Backend::CUDA => &["cuda", "cudart", "nvrtc", "m"],
Backend::OpenCL => &["OpenCL", "m"],
Backend::Multicore => &["pthread", "m"],
Backend::Multicore | Backend::ISPC => &["pthread", "m"],
_ => &[],
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/library.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@ impl Library {
.flag("-pthread")
.flag("-lm")
.flag("-std=c99")
.extra_warnings(false)
.warnings(false)
.compile(&name);
} else {
cc::Build::new()
.flag("-std=c99")
.flag("-Wno-unused-parameter")
.file(&self.c_file)
.extra_warnings(false)
.warnings(false)
.compile(&name);
}
Expand Down

0 comments on commit 4e72040

Please sign in to comment.