diff --git a/lex/gen.go b/lex/gen.go index af8dab70d..29447bb53 100644 --- a/lex/gen.go +++ b/lex/gen.go @@ -232,9 +232,16 @@ func writeMethods(typename string, ts *TypeSchema, w io.Writer) error { case "record": return nil case "query": - return ts.WriteRPC(w, typename) + return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename)) case "procedure": - return ts.WriteRPC(w, typename) + if ts.Input == nil || ts.Input.Schema == nil || ts.Input.Schema.Type == "object" { + return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename)) + } else if ts.Input.Schema.Type == "ref" { + inputname, _ := ts.namesFromRef(ts.Input.Schema.Ref) + return ts.WriteRPC(w, typename, inputname) + } else { + return fmt.Errorf("unhandled input type: %s", ts.Input.Schema.Type) + } case "object", "string": return nil case "subscription": diff --git a/lex/type_schema.go b/lex/type_schema.go index aeeb7389d..fcec3575a 100644 --- a/lex/type_schema.go +++ b/lex/type_schema.go @@ -50,7 +50,7 @@ type TypeSchema struct { Maximum any `json:"maximum"` } -func (s *TypeSchema) WriteRPC(w io.Writer, typename string) error { +func (s *TypeSchema) WriteRPC(w io.Writer, typename, inputname string) error { pf := printerf(w) fname := typename @@ -65,7 +65,7 @@ func (s *TypeSchema) WriteRPC(w io.Writer, typename string) error { case EncodingCBOR, EncodingCAR, EncodingANY, EncodingMP4: params = fmt.Sprintf("%s, input io.Reader", params) case EncodingJSON: - params = fmt.Sprintf("%s, input *%s_Input", params, fname) + params = fmt.Sprintf("%s, input *%s", params, inputname) default: return fmt.Errorf("unsupported input encoding (RPC input): %q", s.Input.Encoding)