diff --git a/crates/rustc_codegen_spirv-types/src/compile_result.rs b/crates/rustc_codegen_spirv-types/src/compile_result.rs index 7a19591532..c8465729f6 100644 --- a/crates/rustc_codegen_spirv-types/src/compile_result.rs +++ b/crates/rustc_codegen_spirv-types/src/compile_result.rs @@ -1,42 +1,83 @@ use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; +use std::convert::Infallible; use std::fmt::Write; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; + +pub type ModuleResult = GenericModuleResult; #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] -pub enum ModuleResult { - SingleModule(PathBuf), - MultiModule(BTreeMap), +pub enum GenericModuleResult { + SingleModule(T), + MultiModule(BTreeMap), } -impl ModuleResult { - pub fn unwrap_single(&self) -> &Path { +impl GenericModuleResult { + pub fn unwrap_single(&self) -> &T { match self { - ModuleResult::SingleModule(result) => result, - ModuleResult::MultiModule(_) => { + GenericModuleResult::SingleModule(result) => result, + GenericModuleResult::MultiModule(_) => { panic!("called `ModuleResult::unwrap_single()` on a `MultiModule` result") } } } - pub fn unwrap_multi(&self) -> &BTreeMap { + pub fn unwrap_multi(&self) -> &BTreeMap { match self { - ModuleResult::MultiModule(result) => result, - ModuleResult::SingleModule(_) => { + GenericModuleResult::MultiModule(result) => result, + GenericModuleResult::SingleModule(_) => { panic!("called `ModuleResult::unwrap_multi()` on a `SingleModule` result") } } } } +pub type CompileResult = GenericCompileResult; + #[derive(Debug, Serialize, Deserialize)] -pub struct CompileResult { +pub struct GenericCompileResult { pub entry_points: Vec, - pub module: ModuleResult, + pub module: GenericModuleResult, } -impl CompileResult { +impl GenericCompileResult { + pub fn try_map( + &self, + mut map_entry_point: impl FnMut(&String) -> Result, + mut map_module: impl FnMut(&T) -> Result, + ) -> Result, E> { + Ok(match &self.module { + GenericModuleResult::SingleModule(t) => GenericCompileResult { + entry_points: self + .entry_points + .iter() + .map(map_entry_point) + .collect::>()?, + module: GenericModuleResult::SingleModule(map_module(t)?), + }, + GenericModuleResult::MultiModule(map) => { + let new_map: BTreeMap = map + .iter() + .map(|(entry_point, t)| Ok((map_entry_point(entry_point)?, map_module(t)?))) + .collect::>()?; + GenericCompileResult { + entry_points: new_map.keys().cloned().collect(), + module: GenericModuleResult::MultiModule(new_map), + } + } + }) + } + + pub fn map( + &self, + mut map_entry_point: impl FnMut(&String) -> String, + mut map_module: impl FnMut(&T) -> R, + ) -> GenericCompileResult { + self.try_map::<_, Infallible>(|e| Ok(map_entry_point(e)), |e| Ok(map_module(e))) + .unwrap() + } + pub fn codegen_entry_point_strings(&self) -> String { let trie = Trie::create_from(self.entry_points.iter().map(|x| x as &str)); let mut builder = String::new(); @@ -110,9 +151,6 @@ impl<'a> Trie<'a> { } } -#[allow(non_upper_case_globals)] -pub const a: &str = "x::a"; - #[cfg(test)] mod test { use super::*; diff --git a/crates/spirv-builder/src/lib.rs b/crates/spirv-builder/src/lib.rs index 80cc8ad246..cbeaa12c21 100644 --- a/crates/spirv-builder/src/lib.rs +++ b/crates/spirv-builder/src/lib.rs @@ -88,8 +88,7 @@ use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; use thiserror::Error; -pub use rustc_codegen_spirv_types::Capability; -pub use rustc_codegen_spirv_types::{CompileResult, ModuleResult}; +pub use rustc_codegen_spirv_types::*; #[cfg(feature = "include-target-specs")] pub use rustc_codegen_spirv_target_specs::TARGET_SPEC_DIR_PATH;