@@ -2,10 +2,13 @@ package v1alpha1
22
33import (
44 "context"
5+ "encoding/json"
56 "fmt"
67 "regexp"
8+ "strings"
79 "text/template"
810
11+ "github.com/santhosh-tekuri/jsonschema/v6"
912 "k8s.io/apimachinery/pkg/runtime"
1013 ctrl "sigs.k8s.io/controller-runtime"
1114 "sigs.k8s.io/controller-runtime/pkg/webhook"
@@ -100,22 +103,29 @@ func (r *VirtualMCPCompositeToolDefinition) Validate() error {
100103
101104// validateParameters validates the parameter schema
102105func (r * VirtualMCPCompositeToolDefinition ) validateParameters () error {
106+ // Valid JSON Schema primitive types
107+ // Reference: https://json-schema.org/understanding-json-schema/reference/type.html
108+ validTypes := map [string ]bool {
109+ "string" : true ,
110+ "integer" : true ,
111+ "number" : true ,
112+ "boolean" : true ,
113+ "array" : true ,
114+ "object" : true ,
115+ "null" : true , // null is a valid JSON Schema type
116+ }
117+
103118 for paramName , param := range r .Spec .Parameters {
104119 if param .Type == "" {
105120 return fmt .Errorf ("spec.parameters[%s].type is required" , paramName )
106121 }
107122
108- // Validate parameter type
109- validTypes := map [string ]bool {
110- "string" : true ,
111- "integer" : true ,
112- "number" : true ,
113- "boolean" : true ,
114- "array" : true ,
115- "object" : true ,
116- }
123+ // Validate that the type is a valid JSON Schema type
117124 if ! validTypes [param .Type ] {
118- return fmt .Errorf ("spec.parameters[%s].type must be one of: string, integer, number, boolean, array, object" , paramName )
125+ return fmt .Errorf (
126+ "spec.parameters[%s].type must be a valid JSON Schema type (string, integer, number, boolean, array, object, null), got: %s" ,
127+ paramName , param .Type ,
128+ )
119129 }
120130 }
121131
@@ -232,6 +242,13 @@ func (*VirtualMCPCompositeToolDefinition) validateStepTemplates(index int, step
232242 }
233243 }
234244
245+ // Validate JSON Schema for elicitation steps
246+ if step .Schema != nil {
247+ if err := validateJSONSchema (step .Schema .Raw ); err != nil {
248+ return fmt .Errorf ("spec.steps[%d].schema: invalid JSON Schema: %v" , index , err )
249+ }
250+ }
251+
235252 return nil
236253}
237254
@@ -321,6 +338,69 @@ func validateTemplate(tmpl string) error {
321338 return nil
322339}
323340
341+ // validateJSONSchema validates that the provided bytes contain a valid JSON Schema
342+ func validateJSONSchema (schemaBytes []byte ) error {
343+ if len (schemaBytes ) == 0 {
344+ return nil // Empty schema is allowed
345+ }
346+
347+ // Parse the schema JSON
348+ var schemaDoc interface {}
349+ if err := json .Unmarshal (schemaBytes , & schemaDoc ); err != nil {
350+ return fmt .Errorf ("failed to parse JSON: %v" , err )
351+ }
352+
353+ // Compile the schema to validate it's a valid JSON Schema
354+ compiler := jsonschema .NewCompiler ()
355+ schemaID := "schema://validation"
356+ if err := compiler .AddResource (schemaID , schemaDoc ); err != nil {
357+ return formatJSONSchemaError (err )
358+ }
359+
360+ if _ , err := compiler .Compile (schemaID ); err != nil {
361+ return formatJSONSchemaError (err )
362+ }
363+
364+ return nil
365+ }
366+
367+ // formatJSONSchemaError formats JSON Schema validation errors for better readability
368+ func formatJSONSchemaError (err error ) error {
369+ if validationErr , ok := err .(* jsonschema.ValidationError ); ok {
370+ var errorMessages []string
371+ collectJSONSchemaErrors (validationErr , & errorMessages )
372+ if len (errorMessages ) > 0 {
373+ return fmt .Errorf ("%s" , strings .Join (errorMessages , "; " ))
374+ }
375+ }
376+ return err
377+ }
378+
379+ // collectJSONSchemaErrors recursively collects all validation error messages
380+ func collectJSONSchemaErrors (err * jsonschema.ValidationError , messages * []string ) {
381+ if err == nil {
382+ return
383+ }
384+
385+ // If this error has causes, recurse into them
386+ if len (err .Causes ) > 0 {
387+ for _ , cause := range err .Causes {
388+ collectJSONSchemaErrors (cause , messages )
389+ }
390+ return
391+ }
392+
393+ // This is a leaf error - format it
394+ output := err .BasicOutput ()
395+ if output != nil && output .Error != nil {
396+ errorMsg := output .Error .String ()
397+ if output .InstanceLocation != "" {
398+ errorMsg = fmt .Sprintf ("%s at '%s'" , errorMsg , output .InstanceLocation )
399+ }
400+ * messages = append (* messages , errorMsg )
401+ }
402+ }
403+
324404// validateDuration validates duration format (e.g., "30s", "5m", "1h")
325405func validateDuration (duration string ) error {
326406 // Pattern: one or more segments of number + unit (ms, s, m, h)
0 commit comments