From 56da067ca10d49a724d478ecfec991cfb4a52f25 Mon Sep 17 00:00:00 2001 From: Wesley <100464352+ologbonowiwi@users.noreply.github.com> Date: Tue, 16 Jan 2024 13:12:24 -0300 Subject: [PATCH] feat: throw error in all resolvers for missing arguments (#952) --- .../from_config/operators/graphql.rs | 6 +- src/blueprint/from_config/operators/grpc.rs | 3 +- src/blueprint/from_config/operators/http.rs | 128 +------------ src/blueprint/mod.rs | 1 + src/blueprint/validate.rs | 172 ++++++++++++++++++ ...-missing-argument-on-all-resolvers.graphql | 49 +++++ tests/graphql/test-graphqlsource.graphql | 2 + 7 files changed, 231 insertions(+), 130 deletions(-) create mode 100644 src/blueprint/validate.rs create mode 100644 tests/graphql/errors/test-missing-argument-on-all-resolvers.graphql diff --git a/src/blueprint/from_config/operators/graphql.rs b/src/blueprint/from_config/operators/graphql.rs index bd5735d410..04e0ce505e 100644 --- a/src/blueprint/from_config/operators/graphql.rs +++ b/src/blueprint/from_config/operators/graphql.rs @@ -33,12 +33,14 @@ pub fn update_graphql<'a>( operation_type: &'a GraphQLOperationType, ) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( - |(config, field, _, _), b_field| { + |(config, field, type_of, _), b_field| { let Some(graphql) = &field.graphql else { return Valid::succeed(b_field); }; - compile_graphql(config, operation_type, graphql).map(|resolver| b_field.resolver(Some(resolver))) + compile_graphql(config, operation_type, graphql) + .map(|resolver| b_field.resolver(Some(resolver))) + .and_then(|b_field| b_field.validate_field(type_of, config).map_to(b_field)) }, ) } diff --git a/src/blueprint/from_config/operators/grpc.rs b/src/blueprint/from_config/operators/grpc.rs index 01c824b8dc..19ecc656de 100644 --- a/src/blueprint/from_config/operators/grpc.rs +++ b/src/blueprint/from_config/operators/grpc.rs @@ -151,13 +151,14 @@ pub fn update_grpc<'a>( operation_type: &'a GraphQLOperationType, ) -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type, &'a str), FieldDefinition, String> { TryFold::<(&Config, &Field, &config::Type, &'a str), FieldDefinition, String>::new( - |(config, field, _type_of, _name), b_field| { + |(config, field, type_of, _name), b_field| { let Some(grpc) = &field.grpc else { return Valid::succeed(b_field); }; compile_grpc(CompileGrpc { config, operation_type, field, grpc, validate_with_schema: true }) .map(|resolver| b_field.resolver(Some(resolver))) + .and_then(|b_field| b_field.validate_field(type_of, config).map_to(b_field)) }, ) } diff --git a/src/blueprint/from_config/operators/http.rs b/src/blueprint/from_config/operators/http.rs index c6769d65bb..ee7935da3b 100644 --- a/src/blueprint/from_config/operators/http.rs +++ b/src/blueprint/from_config/operators/http.rs @@ -1,4 +1,3 @@ -use crate::blueprint::from_config::to_type; use crate::blueprint::*; use crate::config::group_by::GroupBy; use crate::config::{Config, Field}; @@ -9,131 +8,6 @@ use crate::try_fold::TryFold; use crate::valid::{Valid, ValidationError}; use crate::{config, helpers}; -struct MustachePartsValidator<'a> { - type_of: &'a config::Type, - config: &'a Config, - field: &'a FieldDefinition, -} - -impl<'a> MustachePartsValidator<'a> { - fn new(type_of: &'a config::Type, config: &'a Config, field: &'a FieldDefinition) -> Self { - Self { type_of, config, field } - } - - fn validate_type(&self, parts: &[String], is_query: bool) -> Result<(), String> { - let mut len = parts.len(); - let mut type_of = self.type_of; - for item in parts { - let field = type_of.fields.get(item).ok_or_else(|| { - format!( - "no value '{}' found", - parts[0..parts.len() - len + 1].join(".").as_str() - ) - })?; - let val_type = to_type(field, None); - - if !is_query && val_type.is_nullable() { - return Err(format!("value '{}' is a nullable type", item.as_str())); - } else if len == 1 && !is_scalar(val_type.name()) { - return Err(format!("value '{}' is not of a scalar type", item.as_str())); - } else if len == 1 { - break; - } - - type_of = self - .config - .find_type(&field.type_of) - .ok_or_else(|| format!("no type '{}' found", parts.join(".").as_str()))?; - - len -= 1; - } - - Ok(()) - } - - fn validate(&self, parts: &[String], is_query: bool) -> Valid<(), String> { - let config = self.config; - let args = &self.field.args; - - if parts.len() < 2 { - return Valid::fail("too few parts in template".to_string()); - } - - let head = parts[0].as_str(); - let tail = parts[1].as_str(); - - match head { - "value" => { - // all items on parts except the first one - let tail = &parts[1..]; - - if let Err(e) = self.validate_type(tail, is_query) { - return Valid::fail(e); - } - } - "args" => { - // XXX this is a linear search but it's cost is less than that of - // constructing a HashMap since we'd have 3-4 arguments at max in - // most cases - if let Some(arg) = args.iter().find(|arg| arg.name == tail) { - if let Type::ListType { .. } = arg.of_type { - return Valid::fail(format!("can't use list type '{tail}' here")); - } - - // we can use non-scalar types in args - - if !is_query && arg.default_value.is_none() && arg.of_type.is_nullable() { - return Valid::fail(format!("argument '{tail}' is a nullable type")); - } - } else { - return Valid::fail(format!("no argument '{tail}' found")); - } - } - "vars" => { - if config.server.vars.get(tail).is_none() { - return Valid::fail(format!("var '{tail}' is not set in the server config")); - } - } - "headers" | "env" => { - // "headers" and "env" refers to values known at runtime, which we can't - // validate here - } - _ => { - return Valid::fail(format!("unknown template directive '{head}'")); - } - } - - Valid::succeed(()) - } -} - -fn validate_field(type_of: &config::Type, config: &Config, field: &FieldDefinition) -> Valid<(), String> { - // XXX we could use `Mustache`'s `render` method with a mock - // struct implementing the `PathString` trait encapsulating `validation_map` - // but `render` simply falls back to the default value for a given - // type if it doesn't exist, so we wouldn't be able to get enough - // context from that method alone - // So we must duplicate some of that logic here :( - - let parts_validator = MustachePartsValidator::new(type_of, config, field); - - if let Some(Expression::Unsafe(Unsafe::Http { req_template, .. })) = &field.resolver { - Valid::from_iter(req_template.root_url.expression_segments(), |parts| { - parts_validator.validate(parts, false).trace("path") - }) - .and(Valid::from_iter(req_template.query.clone(), |query| { - let (_, mustache) = query; - - Valid::from_iter(mustache.expression_segments(), |parts| { - parts_validator.validate(parts, true).trace("query") - }) - })) - .unit() - } else { - Valid::succeed(()) - } -} - pub fn compile_http(config: &config::Config, field: &config::Field, http: &config::Http) -> Valid { Valid::<(), String>::fail("GroupBy is only supported for GET requests".to_string()) .when(|| !http.group_by.is_empty() && http.method != Method::GET) @@ -189,7 +63,7 @@ pub fn update_http<'a>() -> TryFold<'a, (&'a Config, &'a Field, &'a config::Type compile_http(config, field, http) .map(|resolver| b_field.resolver(Some(resolver))) - .and_then(|b_field| validate_field(type_of, config, &b_field).map_to(b_field)) + .and_then(|b_field| b_field.validate_field(type_of, config).map_to(b_field)) }, ) } diff --git a/src/blueprint/mod.rs b/src/blueprint/mod.rs index b00d620ee1..49883f5eeb 100644 --- a/src/blueprint/mod.rs +++ b/src/blueprint/mod.rs @@ -5,6 +5,7 @@ mod from_config; mod into_schema; mod operation; mod timeout; +mod validate; pub use blueprint::*; pub use const_utils::*; diff --git a/src/blueprint/validate.rs b/src/blueprint/validate.rs new file mode 100644 index 0000000000..3f4eb88024 --- /dev/null +++ b/src/blueprint/validate.rs @@ -0,0 +1,172 @@ +use super::{is_scalar, to_type, FieldDefinition, Type}; +use crate::config::{self, Config}; +use crate::lambda::{Expression, Unsafe}; +use crate::valid::Valid; + +struct MustachePartsValidator<'a> { + type_of: &'a config::Type, + config: &'a Config, + field: &'a FieldDefinition, +} + +impl<'a> MustachePartsValidator<'a> { + fn new(type_of: &'a config::Type, config: &'a Config, field: &'a FieldDefinition) -> Self { + Self { type_of, config, field } + } + + fn validate_type(&self, parts: &[String], is_query: bool) -> Result<(), String> { + let mut len = parts.len(); + let mut type_of = self.type_of; + for item in parts { + let field = type_of.fields.get(item).ok_or_else(|| { + format!( + "no value '{}' found", + parts[0..parts.len() - len + 1].join(".").as_str() + ) + })?; + let val_type = to_type(field, None); + + if !is_query && val_type.is_nullable() { + return Err(format!("value '{}' is a nullable type", item.as_str())); + } else if len == 1 && !is_scalar(val_type.name()) { + return Err(format!("value '{}' is not of a scalar type", item.as_str())); + } else if len == 1 { + break; + } + + type_of = self + .config + .find_type(&field.type_of) + .ok_or_else(|| format!("no type '{}' found", parts.join(".").as_str()))?; + + len -= 1; + } + + Ok(()) + } + + fn validate(&self, parts: &[String], is_query: bool) -> Valid<(), String> { + let config = self.config; + let args = &self.field.args; + + if parts.len() < 2 { + return Valid::fail("too few parts in template".to_string()); + } + + let head = parts[0].as_str(); + let tail = parts[1].as_str(); + + match head { + "value" => { + // all items on parts except the first one + let tail = &parts[1..]; + + if let Err(e) = self.validate_type(tail, is_query) { + return Valid::fail(e); + } + } + "args" => { + // XXX this is a linear search but it's cost is less than that of + // constructing a HashMap since we'd have 3-4 arguments at max in + // most cases + if let Some(arg) = args.iter().find(|arg| arg.name == tail) { + if let Type::ListType { .. } = arg.of_type { + return Valid::fail(format!("can't use list type '{tail}' here")); + } + + // we can use non-scalar types in args + if !is_query && arg.default_value.is_none() && arg.of_type.is_nullable() { + return Valid::fail(format!("argument '{tail}' is a nullable type")); + } + } else { + return Valid::fail(format!("no argument '{tail}' found")); + } + } + "vars" => { + if config.server.vars.get(tail).is_none() { + return Valid::fail(format!("var '{tail}' is not set in the server config")); + } + } + "headers" | "env" => { + // "headers" and "env" refers to values known at runtime, which we can't + // validate here + } + _ => { + return Valid::fail(format!("unknown template directive '{head}'")); + } + } + + Valid::succeed(()) + } +} + +impl FieldDefinition { + pub fn validate_field(&self, type_of: &config::Type, config: &Config) -> Valid<(), String> { + // XXX we could use `Mustache`'s `render` method with a mock + // struct implementing the `PathString` trait encapsulating `validation_map` + // but `render` simply falls back to the default value for a given + // type if it doesn't exist, so we wouldn't be able to get enough + // context from that method alone + // So we must duplicate some of that logic here :( + let parts_validator = MustachePartsValidator::new(type_of, config, self); + + match &self.resolver { + Some(Expression::Unsafe(Unsafe::Http { req_template, .. })) => { + Valid::from_iter(req_template.root_url.expression_segments(), |parts| { + parts_validator.validate(parts, false).trace("path") + }) + .and(Valid::from_iter(req_template.query.clone(), |query| { + let (_, mustache) = query; + + Valid::from_iter(mustache.expression_segments(), |parts| { + parts_validator.validate(parts, true).trace("query") + }) + })) + .unit() + } + Some(Expression::Unsafe(Unsafe::GraphQLEndpoint { req_template, .. })) => { + Valid::from_iter(req_template.headers.clone(), |(_, mustache)| { + Valid::from_iter(mustache.expression_segments(), |parts| { + parts_validator.validate(parts, true).trace("headers") + }) + }) + .and_then(|_| { + if let Some(args) = &req_template.operation_arguments { + Valid::from_iter(args, |(_, mustache)| { + Valid::from_iter(mustache.expression_segments(), |parts| { + parts_validator.validate(parts, true).trace("args") + }) + }) + } else { + Valid::succeed(Default::default()) + } + }) + .unit() + } + Some(Expression::Unsafe(Unsafe::Grpc { req_template, .. })) => { + Valid::from_iter(req_template.url.expression_segments(), |parts| { + parts_validator.validate(parts, false).trace("path") + }) + .and( + Valid::from_iter(req_template.headers.clone(), |(_, mustache)| { + Valid::from_iter(mustache.expression_segments(), |parts| { + parts_validator.validate(parts, true).trace("headers") + }) + }) + .unit(), + ) + .and_then(|_| { + if let Some(body) = &req_template.body { + Valid::from_iter(body.expression_segments(), |parts| { + parts_validator.validate(parts, true).trace("body") + }) + } else { + Valid::succeed(Default::default()) + } + }) + .unit() + } + _ => Valid::succeed(()), + } + } +} diff --git a/tests/graphql/errors/test-missing-argument-on-all-resolvers.graphql b/tests/graphql/errors/test-missing-argument-on-all-resolvers.graphql new file mode 100644 index 0000000000..f4be891df2 --- /dev/null +++ b/tests/graphql/errors/test-missing-argument-on-all-resolvers.graphql @@ -0,0 +1,49 @@ +#> server-sdl +schema @upstream(baseURL: "http://jsonplaceholder.typicode.com") { + query: Query +} + +type Post { + id: Int! +} + +type News { + body: String + id: Int + postImage: String + title: String +} + +type NewsData { + news: [News]! +} + +type Query { + postGraphQLArgs: Post @graphQL(name: "post", args: [{key: "id", value: "{{args.id}}"}]) + postGraphQLHeaders: Post @graphQL(name: "post", headers: [{key: "id", value: "{{args.id}}"}]) + postHttp: Post @http(path: "/posts/{{args.id}}") + newsGrpcHeaders: NewsData! + @grpc( + method: "GetAllNews" + protoPath: "src/grpc/tests/news.proto" + service: "NewsService" + headers: [{key: "id", value: "{{args.id}}"}] + ) + newsGrpcUrl: NewsData! + @grpc(method: "GetAllNews", protoPath: "src/grpc/tests/news.proto", service: "NewsService", baseURL: "{{args.url}}") + newsGrpcBody: NewsData! + @grpc(method: "GetAllNews", protoPath: "src/grpc/tests/news.proto", service: "NewsService", body: "{{args.id}}") +} + +type User { + id: Int + name: String +} + +#> client-sdl +type Failure @error(message: "no argument 'id' found", trace: ["Query", "newsGrpcBody", "@grpc", "body"]) +type Failure @error(message: "no argument 'id' found", trace: ["Query", "newsGrpcHeaders", "@grpc", "headers"]) +type Failure @error(message: "no argument 'url' found", trace: ["Query", "newsGrpcUrl", "@grpc", "path"]) +type Failure @error(message: "no argument 'id' found", trace: ["Query", "postGraphQLArgs", "@graphQL", "args"]) +type Failure @error(message: "no argument 'id' found", trace: ["Query", "postGraphQLHeaders", "@graphQL", "headers"]) +type Failure @error(message: "no argument 'id' found", trace: ["Query", "postHttp", "@http", "path"]) diff --git a/tests/graphql/test-graphqlsource.graphql b/tests/graphql/test-graphqlsource.graphql index 7b737601a7..ed0d6662ba 100644 --- a/tests/graphql/test-graphqlsource.graphql +++ b/tests/graphql/test-graphqlsource.graphql @@ -6,6 +6,7 @@ schema @server @upstream(baseURL: "http://localhost:8000/graphql") { type Post { id: Int! user: User @graphQL(args: [{key: "id", value: "{{value.userId}}"}], name: "user") + userId: Int! } type Query { @@ -21,6 +22,7 @@ type User { type Post { id: Int! user: User + userId: Int! } type Query {