Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(grpc): handle union shared fields #2757

Merged
merged 11 commits into from
Aug 28, 2024
21 changes: 21 additions & 0 deletions src/core/blueprint/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ impl Index {
pub fn get_mutation(&self) -> Option<&str> {
self.schema.mutation.as_deref()
}

pub fn is_type_implements(&self, type_name: &str, type_or_interface: &str) -> bool {
if type_name == type_or_interface {
return true;
}

if let Some((Definition::Object(obj), _)) = self.map.get(type_name) {
obj.implements.contains(type_or_interface)
} else {
false
}
}
}

impl From<&Blueprint> for Index {
Expand Down Expand Up @@ -232,4 +244,13 @@ mod test {
index.schema.mutation = None;
assert_eq!(index.get_mutation(), None);
}

#[test]
fn test_is_type_implements() {
let index = setup();

assert!(index.is_type_implements("User", "Node"));
assert!(index.is_type_implements("Post", "Post"));
assert!(!index.is_type_implements("Node", "User"));
}
}
3 changes: 3 additions & 0 deletions src/core/config/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,9 @@ impl Config {
stack.extend(field.args.values().map(|arg| arg.type_of.clone()));
stack.push(field.type_of.clone());
}
for interface in typ.implements.iter() {
stack.push(interface.clone())
}
}
}

Expand Down
8 changes: 6 additions & 2 deletions src/core/generator/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl Context {

collect_types(
type_name.clone(),
base_type,
base_type.clone(),
&oneof_fields,
&mut union_types,
);
Expand All @@ -141,13 +141,17 @@ impl Context {
}

let mut union_ = Union::default();
let interface_name = format!("{type_name}__Interface");

for (type_name, ty) in union_types {
for (type_name, mut ty) in union_types {
ty.implements.insert(interface_name.clone());
union_.types.insert(type_name.clone());

self = self.insert_type(type_name, ty);
}

// base interface type
self.config.types.insert(interface_name, base_type);
self.config.unions.insert(type_name, union_);

self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ input oneof__Request__Var__Var1 {
usual: String
}

interface oneof__Request__Interface {
usual: String
}

interface oneof__Response__Interface {
usual: Int
}

union oneof__Request = oneof__Request__Var0__Var | oneof__Request__Var0__Var0 | oneof__Request__Var0__Var1 | oneof__Request__Var1__Var | oneof__Request__Var1__Var0 | oneof__Request__Var1__Var1 | oneof__Request__Var__Var | oneof__Request__Var__Var0 | oneof__Request__Var__Var1

union oneof__Response = oneof__Response__Var | oneof__Response__Var0 | oneof__Response__Var1 | oneof__Response__Var2
Expand All @@ -78,21 +86,21 @@ type oneof__Payload {
payload: String
}

type oneof__Response__Var {
type oneof__Response__Var implements oneof__Response__Interface {
usual: Int
}

type oneof__Response__Var0 {
type oneof__Response__Var0 implements oneof__Response__Interface {
payload: oneof__Payload!
usual: Int
}

type oneof__Response__Var1 {
type oneof__Response__Var1 implements oneof__Response__Interface {
command: oneof__Command!
usual: Int
}

type oneof__Response__Var2 {
type oneof__Response__Var2 implements oneof__Response__Interface {
response: String!
usual: Int
}
3 changes: 2 additions & 1 deletion src/core/ir/resolver_context_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ impl SelectionField {
field: &crate::core::jit::Field<Nested<ConstValue>, ConstValue>,
) -> SelectionField {
let name = field.output_name.to_string();
let type_name = field.type_of.name();
let selection_set = field
.nested_iter(field.type_of.name())
.iter_only(|field| field.type_condition == type_name)
.map(Self::from_jit_field)
.collect();
let args = field
Expand Down
34 changes: 20 additions & 14 deletions src/core/jit/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,16 @@ where
// Check if the value is an array
if let Some(array) = value.as_array() {
join_all(array.iter().enumerate().map(|(index, value)| {
let type_name = value.get_type_name().unwrap_or(field.type_of.name());

join_all(field.nested_iter(type_name).map(|field| {
let ctx = ctx.with_value_and_field(value, field);
let data_path = data_path.clone().with_index(index);
async move { self.execute(&ctx, data_path).await }
}))
join_all(
self.request
.plan()
.field_iter_only(field, value)
.map(|field| {
let ctx = ctx.with_value_and_field(value, field);
let data_path = data_path.clone().with_index(index);
async move { self.execute(&ctx, data_path).await }
}),
)
}))
.await;
}
Expand All @@ -111,13 +114,16 @@ where
// TODO: Validate if the value is an Object
// Has to be an Object, we don't do anything while executing if its a Scalar
else {
let type_name = value.get_type_name().unwrap_or(field.type_of.name());

join_all(field.nested_iter(type_name).map(|child| {
let ctx = ctx.with_value_and_field(value, child);
let data_path = data_path.clone();
async move { self.execute(&ctx, data_path).await }
}))
join_all(
self.request
.plan()
.field_iter_only(field, value)
.map(|child| {
let ctx = ctx.with_value_and_field(value, child);
let data_path = data_path.clone();
async move { self.execute(&ctx, data_path).await }
}),
)
.await;
}

Expand Down
61 changes: 43 additions & 18 deletions src/core/jit/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize};
use super::Error;
use crate::core::blueprint::Index;
use crate::core::ir::model::IR;
use crate::core::ir::TypedValue;
use crate::core::json::JsonLike;

#[derive(Debug, Deserialize, Clone)]
Expand Down Expand Up @@ -65,6 +66,14 @@ impl<Extensions, Input> Field<Extensions, Input> {

skip == include
}

/// Returns the __typename of the value related to this field
pub fn value_type<'a, Output>(&'a self, value: &'a Output) -> &'a str
where
Output: TypedValue<'a>,
{
value.get_type_name().unwrap_or(self.type_of.name())
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -234,27 +243,15 @@ impl<Input> Field<Flat, Input> {
}

impl<Input> Field<Nested<Input>, Input> {
/// iters over children fields that are
/// related to passed `type_name` either
/// as direct field of the queried type or
/// field from fragment on type `type_name`
pub fn nested_iter<'a>(
/// iters over children fields that satisfies
/// passed filter_fn
pub fn iter_only<'a>(
&'a self,
type_name: &'a str,
mut filter_fn: impl FnMut(&'a Field<Nested<Input>, Input>) -> bool + 'a,
) -> impl Iterator<Item = &Field<Nested<Input>, Input>> + 'a {
self.extensions
.as_ref()
.map(move |nested| {
nested
.0
.iter()
// TODO: handle Interface and Union types here
// Right now only exact type name is used to check the set of fields
// but with Interfaces/Unions we need to check if that specific type
// is member of some Interface/Union and if so call the fragments for
// the related Interfaces/Unions
.filter(move |field| field.type_condition == type_name)
})
.map(move |nested| nested.0.iter().filter(move |&field| filter_fn(field)))
.into_iter()
.flatten()
}
Expand Down Expand Up @@ -351,7 +348,6 @@ pub struct OperationPlan<Input> {
flat: Vec<Field<Flat, Input>>,
operation_type: OperationType,
nested: Vec<Field<Nested<Input>, Input>>,

// TODO: drop index from here. Embed all the necessary information in each field of the plan.
pub index: Arc<Index>,
}
Expand Down Expand Up @@ -409,30 +405,37 @@ impl<Input> OperationPlan<Input> {
Self { flat: fields, nested, operation_type, index }
}

/// Returns a graphQL operation type
pub fn operation_type(&self) -> OperationType {
self.operation_type
}

/// Check if current graphQL operation is query
pub fn is_query(&self) -> bool {
self.operation_type == OperationType::Query
}

/// Returns a nested [Field] representation
pub fn as_nested(&self) -> &[Field<Nested<Input>, Input>] {
&self.nested
}

/// Returns an owned version of [Field] representation
pub fn into_nested(self) -> Vec<Field<Nested<Input>, Input>> {
self.nested
}

/// Returns a flat [Field] representation
pub fn as_parent(&self) -> &[Field<Flat, Input>] {
&self.flat
}

/// Search for a field with a specified [FieldId]
pub fn find_field(&self, id: FieldId) -> Option<&Field<Flat, Input>> {
self.flat.iter().find(|field| field.id == id)
}

/// Search for a field by specified path of nested fields
pub fn find_field_path<S: AsRef<str>>(&self, path: &[S]) -> Option<&Field<Flat, Input>> {
match path.split_first() {
None => None,
Expand All @@ -447,25 +450,47 @@ impl<Input> OperationPlan<Input> {
}
}

/// Returns number of fields in plan
pub fn size(&self) -> usize {
self.flat.len()
}

/// Check if the field is of scalar type
pub fn field_is_scalar<Extensions>(&self, field: &Field<Extensions, Input>) -> bool {
self.index.type_is_scalar(field.type_of.name())
}

/// Check if the field is of enum type
pub fn field_is_enum<Extensions>(&self, field: &Field<Extensions, Input>) -> bool {
self.index.type_is_enum(field.type_of.name())
}

/// Validate the value against enum variants of the field
pub fn field_validate_enum_value<Extensions>(
&self,
field: &Field<Extensions, Input>,
value: &str,
) -> bool {
self.index.validate_enum_value(field.type_of.name(), value)
}

/// Iterate over nested fields that are related to the __typename of the
/// value
pub fn field_iter_only<'a, Output>(
&'a self,
field: &'a Field<Nested<Input>, Input>,
value: &'a Output,
) -> impl Iterator<Item = &'a Field<Nested<Input>, Input>>
where
Output: TypedValue<'a>,
{
let value_type = field.value_type(value);

field.iter_only(move |field| {
self.index
.is_type_implements(value_type, &field.type_condition)
})
}
}

#[derive(Clone, Debug)]
Expand Down
5 changes: 1 addition & 4 deletions src/core/jit/synth/synth.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::core::ir::TypedValue;
use crate::core::jit::model::{Field, Nested, OperationPlan, Variables};
use crate::core::jit::store::{Data, DataPath, Store};
use crate::core::jit::{Error, PathSegment, Positioned, ValidationError};
Expand Down Expand Up @@ -164,9 +163,7 @@ where
(_, Some(obj)) => {
let mut ans = Value::JsonObject::new();

let type_name = value.get_type_name().unwrap_or(node.type_of.name());

for child in node.nested_iter(type_name) {
for child in self.plan.field_iter_only(node, value) {
// all checks for skip must occur in `iter_inner`
// and include be checked before calling `iter` or recursing.
let include = self.include(child);
Expand Down
12 changes: 8 additions & 4 deletions tests/core/snapshots/grpc-oneof.md_client.snap
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,25 @@ input oneof__Request__Var__Var1 {

union oneof__Response = oneof__Response__Var | oneof__Response__Var0 | oneof__Response__Var1 | oneof__Response__Var2

type oneof__Response__Var {
interface oneof__Response__Interface {
usual: Int
}

type oneof__Response__Var0 {
type oneof__Response__Var implements oneof__Response__Interface {
usual: Int
}

type oneof__Response__Var0 implements oneof__Response__Interface {
payload: oneof__Payload!
usual: Int
}

type oneof__Response__Var1 {
type oneof__Response__Var1 implements oneof__Response__Interface {
command: oneof__Command!
usual: Int
}

type oneof__Response__Var2 {
type oneof__Response__Var2 implements oneof__Response__Interface {
response: String!
usual: Int
}
Expand Down
Loading
Loading