Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(grpc): handle union shared fields #2757

Merged
merged 11 commits into from
Aug 28, 2024
40 changes: 40 additions & 0 deletions src/core/blueprint/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,37 @@
pub fn get_mutation(&self) -> Option<&str> {
self.schema.mutation.as_deref()
}

pub fn is_type_implements(&self, type_name: &str, implements: &str) -> bool {
tusharmath marked this conversation as resolved.
Show resolved Hide resolved
if type_name == implements {
return true;
}

if !matches!(
self.map.get(implements),
Some((Definition::Interface(_), _))
) {
return false;
}

self.is_type_implements_rec(type_name, implements)
tusharmath marked this conversation as resolved.
Show resolved Hide resolved
}

fn is_type_implements_rec(&self, type_name: &str, implements: &str) -> bool {
tusharmath marked this conversation as resolved.
Show resolved Hide resolved
if type_name == implements {
return true;
}

if let Some((Definition::Object(obj), _)) = self.map.get(type_name) {
for interface in obj.implements.iter() {
if self.is_type_implements_rec(interface, implements) {
return true;
}

Check warning on line 94 in src/core/blueprint/index.rs

View check run for this annotation

Codecov / codecov/patch

src/core/blueprint/index.rs#L94

Added line #L94 was not covered by tests
}
}

Check warning on line 96 in src/core/blueprint/index.rs

View check run for this annotation

Codecov / codecov/patch

src/core/blueprint/index.rs#L96

Added line #L96 was not covered by tests

false

Check warning on line 98 in src/core/blueprint/index.rs

View check run for this annotation

Codecov / codecov/patch

src/core/blueprint/index.rs#L98

Added line #L98 was not covered by tests
}
}

impl From<&Blueprint> for Index {
Expand Down Expand Up @@ -232,4 +263,13 @@
index.schema.mutation = None;
assert_eq!(index.get_mutation(), None);
}

#[test]
fn test_is_type_implements() {
let index = setup();

assert!(index.is_type_implements("User", "Node"));
assert!(index.is_type_implements("Post", "Post"));
assert!(!index.is_type_implements("Node", "User"));
}
}
3 changes: 3 additions & 0 deletions src/core/config/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,9 @@ impl Config {
stack.extend(field.args.values().map(|arg| arg.type_of.clone()));
stack.push(field.type_of.clone());
}
for interface in typ.implements.iter() {
stack.push(interface.clone())
}
}
}

Expand Down
8 changes: 6 additions & 2 deletions src/core/generator/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl Context {

collect_types(
type_name.clone(),
base_type,
base_type.clone(),
&oneof_fields,
&mut union_types,
);
Expand All @@ -141,13 +141,17 @@ impl Context {
}

let mut union_ = Union::default();
let interface_name = format!("{type_name}__Interface");

for (type_name, ty) in union_types {
for (type_name, mut ty) in union_types {
ty.implements.insert(interface_name.clone());
union_.types.insert(type_name.clone());

self = self.insert_type(type_name, ty);
}

// base interface type
self.config.types.insert(interface_name, base_type);
self.config.unions.insert(type_name, union_);

self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ input oneof__Request__Var__Var1 {
usual: String
}

interface oneof__Request__Interface {
usual: String
}

interface oneof__Response__Interface {
usual: Int
}

union oneof__Request = oneof__Request__Var0__Var | oneof__Request__Var0__Var0 | oneof__Request__Var0__Var1 | oneof__Request__Var1__Var | oneof__Request__Var1__Var0 | oneof__Request__Var1__Var1 | oneof__Request__Var__Var | oneof__Request__Var__Var0 | oneof__Request__Var__Var1

union oneof__Response = oneof__Response__Var | oneof__Response__Var0 | oneof__Response__Var1 | oneof__Response__Var2
Expand All @@ -78,21 +86,21 @@ type oneof__Payload {
payload: String
}

type oneof__Response__Var {
type oneof__Response__Var implements oneof__Response__Interface {
usual: Int
}

type oneof__Response__Var0 {
type oneof__Response__Var0 implements oneof__Response__Interface {
payload: oneof__Payload!
usual: Int
}

type oneof__Response__Var1 {
type oneof__Response__Var1 implements oneof__Response__Interface {
command: oneof__Command!
usual: Int
}

type oneof__Response__Var2 {
type oneof__Response__Var2 implements oneof__Response__Interface {
response: String!
usual: Int
}
3 changes: 2 additions & 1 deletion src/core/ir/resolver_context_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ impl SelectionField {
field: &crate::core::jit::Field<Nested<ConstValue>, ConstValue>,
) -> SelectionField {
let name = field.output_name.to_string();
let type_name = field.type_of.name();
let selection_set = field
.nested_iter(field.type_of.name())
.nested_iter(|field| field.type_condition == type_name)
.map(Self::from_jit_field)
.collect();
let args = field
Expand Down
34 changes: 20 additions & 14 deletions src/core/jit/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,16 @@ where
// Check if the value is an array
if let Some(array) = value.as_array() {
join_all(array.iter().enumerate().map(|(index, value)| {
let type_name = value.get_type_name().unwrap_or(field.type_of.name());

join_all(field.nested_iter(type_name).map(|field| {
let ctx = ctx.with_value_and_field(value, field);
let data_path = data_path.clone().with_index(index);
async move { self.execute(&ctx, data_path).await }
}))
join_all(
self.request
.plan()
.field_nested_iter(field, value)
.map(|field| {
let ctx = ctx.with_value_and_field(value, field);
let data_path = data_path.clone().with_index(index);
async move { self.execute(&ctx, data_path).await }
}),
)
}))
.await;
}
Expand All @@ -111,13 +114,16 @@ where
// TODO: Validate if the value is an Object
// Has to be an Object, we don't do anything while executing if its a Scalar
else {
let type_name = value.get_type_name().unwrap_or(field.type_of.name());

join_all(field.nested_iter(type_name).map(|child| {
let ctx = ctx.with_value_and_field(value, child);
let data_path = data_path.clone();
async move { self.execute(&ctx, data_path).await }
}))
join_all(
self.request
.plan()
.field_nested_iter(field, value)
.map(|child| {
let ctx = ctx.with_value_and_field(value, child);
let data_path = data_path.clone();
async move { self.execute(&ctx, data_path).await }
}),
)
.await;
}

Expand Down
47 changes: 30 additions & 17 deletions src/core/jit/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize};
use super::Error;
use crate::core::blueprint::Index;
use crate::core::ir::model::IR;
use crate::core::ir::TypedValue;

#[derive(Debug, Deserialize, Clone)]
pub struct Variables<Value>(HashMap<String, Value>);
Expand Down Expand Up @@ -143,6 +144,15 @@ impl Variable {
}
}

impl<Extensions, Input> Field<Extensions, Input> {
pub fn value_type<'a, Output>(&'a self, value: &'a Output) -> &'a str
where
Output: TypedValue<'a>,
{
value.get_type_name().unwrap_or(self.type_of.name())
}
}

impl<Input> Field<Nested<Input>, Input> {
pub fn try_map<Output, Error>(
self,
Expand Down Expand Up @@ -215,27 +225,15 @@ impl<Input> Field<Flat, Input> {
}

impl<Input> Field<Nested<Input>, Input> {
/// iters over children fields that are
/// related to passed `type_name` either
/// as direct field of the queried type or
/// field from fragment on type `type_name`
/// iters over children fields that satisfies
/// passed filter_fn
pub fn nested_iter<'a>(
&'a self,
type_name: &'a str,
mut filter_fn: impl FnMut(&'a Field<Nested<Input>, Input>) -> bool + 'a,
) -> impl Iterator<Item = &Field<Nested<Input>, Input>> + 'a {
meskill marked this conversation as resolved.
Show resolved Hide resolved
self.extensions
.as_ref()
.map(move |nested| {
nested
.0
.iter()
// TODO: handle Interface and Union types here
// Right now only exact type name is used to check the set of fields
// but with Interfaces/Unions we need to check if that specific type
// is member of some Interface/Union and if so call the fragments for
// the related Interfaces/Unions
.filter(move |field| field.type_condition == type_name)
})
.map(move |nested| nested.0.iter().filter(move |&field| filter_fn(field)))
.into_iter()
.flatten()
}
Expand Down Expand Up @@ -332,7 +330,6 @@ pub struct OperationPlan<Input> {
flat: Vec<Field<Flat, Input>>,
operation_type: OperationType,
nested: Vec<Field<Nested<Input>, Input>>,

// TODO: drop index from here. Embed all the necessary information in each field of the plan.
pub index: Arc<Index>,
}
Expand Down Expand Up @@ -447,6 +444,22 @@ impl<Input> OperationPlan<Input> {
) -> bool {
self.index.validate_enum_value(field.type_of.name(), value)
}

pub fn field_nested_iter<'a, Output>(
meskill marked this conversation as resolved.
Show resolved Hide resolved
&'a self,
field: &'a Field<Nested<Input>, Input>,
value: &'a Output,
) -> impl Iterator<Item = &'a Field<Nested<Input>, Input>>
where
Output: TypedValue<'a>,
{
let value_type = field.value_type(value);

field.nested_iter(move |field| {
self.index
.is_type_implements(value_type, &field.type_condition)
})
}
}

#[derive(Clone, Debug)]
Expand Down
5 changes: 1 addition & 4 deletions src/core/jit/synth/synth.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::core::ir::TypedValue;
use crate::core::jit::model::{Field, Nested, OperationPlan, Variable, Variables};
use crate::core::jit::store::{Data, DataPath, Store};
use crate::core::jit::{Error, PathSegment, Positioned, ValidationError};
Expand Down Expand Up @@ -166,9 +165,7 @@ where
(_, Some(obj)) => {
let mut ans = Value::JsonObject::new();

let type_name = value.get_type_name().unwrap_or(node.type_of.name());

for child in node.nested_iter(type_name) {
for child in self.plan.field_nested_iter(node, value) {
// all checks for skip must occur in `iter_inner`
// and include be checked before calling `iter` or recursing.
let include = self.include(child);
Expand Down
12 changes: 8 additions & 4 deletions tests/core/snapshots/grpc-oneof.md_client.snap
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,25 @@ input oneof__Request__Var__Var1 {

union oneof__Response = oneof__Response__Var | oneof__Response__Var0 | oneof__Response__Var1 | oneof__Response__Var2

type oneof__Response__Var {
interface oneof__Response__Interface {
usual: Int
}

type oneof__Response__Var0 {
type oneof__Response__Var implements oneof__Response__Interface {
usual: Int
}

type oneof__Response__Var0 implements oneof__Response__Interface {
payload: oneof__Payload!
usual: Int
}

type oneof__Response__Var1 {
type oneof__Response__Var1 implements oneof__Response__Interface {
command: oneof__Command!
usual: Int
}

type oneof__Response__Var2 {
type oneof__Response__Var2 implements oneof__Response__Interface {
response: String!
usual: Int
}
Expand Down
12 changes: 8 additions & 4 deletions tests/core/snapshots/grpc-oneof.md_merged.snap
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ input oneof__Request__Var__Var1 {
usual: String
}

interface oneof__Response__Interface {
usual: Int
}

union oneof__Response = oneof__Response__Var | oneof__Response__Var0 | oneof__Response__Var1 | oneof__Response__Var2

type Query {
Expand Down Expand Up @@ -96,21 +100,21 @@ type oneof__Payload {
payload: String
}

type oneof__Response__Var {
type oneof__Response__Var implements oneof__Response__Interface {
usual: Int
}

type oneof__Response__Var0 {
type oneof__Response__Var0 implements oneof__Response__Interface {
payload: oneof__Payload!
usual: Int
}

type oneof__Response__Var1 {
type oneof__Response__Var1 implements oneof__Response__Interface {
command: oneof__Command!
usual: Int
}

type oneof__Response__Var2 {
type oneof__Response__Var2 implements oneof__Response__Interface {
response: String!
usual: Int
}
Loading
Loading