From 89b95182ad8ea99b6f30b18c3ab44ce7622f8729 Mon Sep 17 00:00:00 2001 From: ihan211 Date: Tue, 15 Apr 2025 15:57:49 -0700 Subject: [PATCH 1/6] feat(go/genkit): add DefineSchema and Dotprompt reference support --- go/core/schema.go | 114 +++++++ go/genkit/genkit.go | 20 ++ go/genkit/schema.go | 118 ++++++++ go/genkit/schema_test.go | 92 ++++++ go/go.mod | 6 +- go/go.sum | 6 - go/internal/registry/schema.go | 240 +++++++++++++++ go/samples/schema/main.go | 284 ++++++++++++++++++ .../_schema_ProductSchema.partial.prompt | 40 +++ .../schema/prompts/product_generator.prompt | 9 + 10 files changed, 918 insertions(+), 11 deletions(-) create mode 100644 go/core/schema.go create mode 100644 go/genkit/schema.go create mode 100644 go/genkit/schema_test.go create mode 100644 go/internal/registry/schema.go create mode 100644 go/samples/schema/main.go create mode 100644 go/samples/schema/prompts/_schema_ProductSchema.partial.prompt create mode 100644 go/samples/schema/prompts/product_generator.prompt diff --git a/go/core/schema.go b/go/core/schema.go new file mode 100644 index 0000000000..8a499699ad --- /dev/null +++ b/go/core/schema.go @@ -0,0 +1,114 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +// Package core provides core functionality for the genkit framework. + +package core + +import ( + "fmt" + "sync" +) + +// Schema represents a schema definition that can be of any type. +type Schema any + +// SchemaType is the type identifier for schemas in the registry. +const SchemaType = "schema" + +// schemaRegistry maintains registry of schemas. +var ( + schemasMu sync.RWMutex + schemas = make(map[string]any) + schemaLookups []func(string) any + // Keep track of schemas to register with Dotprompt + pendingSchemas = make(map[string]Schema) +) + +// RegisterSchema registers a schema with the given name. +// This is intended to be called by higher-level packages like ai. +func RegisterSchema(name string, schema any) { + if name == "" { + panic("core.RegisterSchema: schema name cannot be empty") + } + + if schema == nil { + panic("core.RegisterSchema: schema definition cannot be nil") + } + + schemasMu.Lock() + defer schemasMu.Unlock() + + if _, exists := schemas[name]; exists { + panic(fmt.Sprintf("core.RegisterSchema: schema with name %q already exists", name)) + } + + schemas[name] = schema +} + +// LookupSchema looks up a schema by name. +// It first checks the local registry, and if not found, +// it calls each registered lookup function until one returns a non-nil result. +func LookupSchema(name string) any { + schemasMu.RLock() + defer schemasMu.RUnlock() + + // First check local registry + if schema, ok := schemas[name]; ok { + return schema + } + + // Then try lookup functions + for _, lookup := range schemaLookups { + if schema := lookup(name); schema != nil { + return schema + } + } + + return nil +} + +// RegisterSchemaLookup registers a function that can look up schemas by name. +// This allows different packages to provide schemas while maintaining a +// unified lookup mechanism. +func RegisterSchemaLookup(lookup func(string) any) { + schemasMu.Lock() + defer schemasMu.Unlock() + + schemaLookups = append(schemaLookups, lookup) +} + +// GetSchemas returns a copy of all registered schemas. +func GetSchemas() map[string]any { + schemasMu.RLock() + defer schemasMu.RUnlock() + + result := make(map[string]any, len(schemas)) + for name, schema := range schemas { + result[name] = schema + } + + return result +} + +// ClearSchemas removes all registered schemas. +// This is primarily for testing purposes. +func ClearSchemas() { + schemasMu.Lock() + defer schemasMu.Unlock() + + schemas = make(map[string]any) + schemaLookups = nil +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 2473b0dd68..d4b1be0b8a 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -210,6 +210,11 @@ func Init(ctx context.Context, opts ...GenkitOption) (*Genkit, error) { g := &Genkit{reg: r} + // Register schemas with Dotprompt before loading plugins or prompt files + if err := registerPendingSchemas(r); err != nil { + return nil, fmt.Errorf("genkit.Init: error registering schemas: %w", err) + } + for _, plugin := range gOpts.Plugins { if err := plugin.Init(ctx, g); err != nil { return nil, fmt.Errorf("genkit.Init: plugin %T initialization failed: %w", plugin, err) @@ -245,6 +250,21 @@ func Init(ctx context.Context, opts ...GenkitOption) (*Genkit, error) { return g, nil } +// Internal function called during Init to register pending schemas +func registerPendingSchemas(reg *registry.Registry) error { + schemasMu.Lock() + defer schemasMu.Unlock() + + for name, schema := range pendingSchemas { + if err := reg.RegisterSchemaWithDotprompt(name, schema); err != nil { + return fmt.Errorf("failed to register schema %s: %w", name, err) + } + } + // Clear pending schemas + pendingSchemas = make(map[string]Schema) + return nil +} + // DefineFlow defines a non-streaming flow, registers it as a [core.Action] of type Flow, // and returns a [core.Flow] runner. // The provided function `fn` takes an input of type `In` and returns an output of type `Out`. diff --git a/go/genkit/schema.go b/go/genkit/schema.go new file mode 100644 index 0000000000..bb6521bb14 --- /dev/null +++ b/go/genkit/schema.go @@ -0,0 +1,118 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package genkit + +import ( + "fmt" + "sync" + + "github.com/firebase/genkit/go/core" + "github.com/google/dotprompt/go/dotprompt" + "github.com/invopop/jsonschema" +) + +// Schema is an alias for core.Schema to maintain compatibility with existing type definitions +type Schema = core.Schema + +// schemasMu and pendingSchemas are maintained for backward compatibility +var ( + schemasMu sync.RWMutex + pendingSchemas = make(map[string]Schema) +) + +// DefineSchema registers a schema that can be referenced by name in genkit. +// This allows schemas to be defined once and used across the AI generation pipeline. +// +// Example usage: +// +// type Person struct { +// Name string `json:"name"` +// Age int `json:"age"` +// } +// +// personSchema := genkit.DefineSchema("Person", Person{}) +func DefineSchema(name string, schema Schema) Schema { + if name == "" { + panic("genkit.DefineSchema: schema name cannot be empty") + } + + if schema == nil { + panic("genkit.DefineSchema: schema cannot be nil") + } + + // Register with core registry + core.RegisterSchema(name, schema) + + // Also track for Dotprompt integration + schemasMu.Lock() + defer schemasMu.Unlock() + pendingSchemas[name] = schema + + return schema +} + +// LookupSchema retrieves a registered schema by name. +// It returns nil and false if no schema exists with that name. +func LookupSchema(name string) (Schema, bool) { + schema := core.LookupSchema(name) + return schema, schema != nil +} + +// GetSchema retrieves a registered schema by name. +// It returns an error if no schema exists with that name. +func GetSchema(name string) (Schema, error) { + schema, exists := LookupSchema(name) + if !exists { + return nil, fmt.Errorf("genkit: schema '%s' not found", name) + } + return schema, nil +} + +// registerSchemaResolver registers a schema resolver with Dotprompt to handle schema lookups +func registerSchemaResolver(dp *dotprompt.Dotprompt) { + // Create a schema resolver that can look up schemas from the Genkit registry + schemaResolver := func(name string) any { + schema, exists := LookupSchema(name) + if !exists { + fmt.Printf("Schema '%s' not found in registry\n", name) + return nil + } + + // Convert the schema to a JSON schema + reflector := jsonschema.Reflector{} + jsonSchema := reflector.Reflect(schema) + return jsonSchema + } + + // Register the resolver with Dotprompt + dp.RegisterExternalSchemaLookup(schemaResolver) +} + +// RegisterGlobalSchemaResolver exports the schema lookup capabilities for use in other packages +func RegisterGlobalSchemaResolver(dp *dotprompt.Dotprompt) { + dp.RegisterExternalSchemaLookup(func(name string) any { + schema, exists := LookupSchema(name) + if !exists { + return nil + } + + // Convert the schema to a JSON schema + reflector := jsonschema.Reflector{} + jsonSchema := reflector.Reflect(schema) + return jsonSchema + }) +} diff --git a/go/genkit/schema_test.go b/go/genkit/schema_test.go new file mode 100644 index 0000000000..9fbdde7ac0 --- /dev/null +++ b/go/genkit/schema_test.go @@ -0,0 +1,92 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package genkit + +import ( + "testing" +) + +type TestStruct struct { + Name string + Age int +} + +func TestDefineAndLookupSchema(t *testing.T) { + schemaName := "TestStruct" + testSchema := TestStruct{Name: "Alice", Age: 30} + + // Define the schema + DefineSchema(schemaName, testSchema) + + // Lookup the schema + schema, found := LookupSchema(schemaName) + if !found { + t.Fatalf("Expected schema '%s' to be found", schemaName) + } + + // Assert the type + typedSchema, ok := schema.(TestStruct) + if !ok { + t.Fatalf("Expected schema to be of type TestStruct") + } + + if typedSchema.Name != "Alice" || typedSchema.Age != 30 { + t.Errorf("Unexpected schema contents: %+v", typedSchema) + } +} + +func TestGetSchemaSuccess(t *testing.T) { + schemaName := "GetStruct" + testSchema := TestStruct{Name: "Bob", Age: 25} + DefineSchema(schemaName, testSchema) + + schema, err := GetSchema(schemaName) + if err != nil { + t.Fatalf("Expected schema '%s' to be retrieved without error", schemaName) + } + + typedSchema := schema.(TestStruct) + if typedSchema.Name != "Bob" || typedSchema.Age != 25 { + t.Errorf("Unexpected schema contents: %+v", typedSchema) + } +} + +func TestGetSchemaNotFound(t *testing.T) { + _, err := GetSchema("NonExistentSchema") + if err == nil { + t.Fatal("Expected error when retrieving a non-existent schema") + } +} + +func TestDefineSchemaPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("Expected panic for empty schema name") + } + }() + DefineSchema("", TestStruct{}) +} + +func TestDefineSchemaNilPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("Expected panic for nil schema") + } + }() + var nilSchema Schema + DefineSchema("NilSchema", nilSchema) +} diff --git a/go/go.mod b/go/go.mod index 2df94ef6b5..5dbced3b15 100644 --- a/go/go.mod +++ b/go/go.mod @@ -23,6 +23,7 @@ require ( github.com/pgvector/pgvector-go v0.3.0 github.com/weaviate/weaviate v1.26.0-rc.1 github.com/weaviate/weaviate-go-client/v4 v4.15.0 + github.com/wk8/go-ordered-map/v2 v2.1.8 github.com/xeipuuv/gojsonschema v1.2.0 go.opentelemetry.io/otel v1.29.0 go.opentelemetry.io/otel/metric v1.29.0 @@ -54,7 +55,6 @@ require ( github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect - github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.21.2 // indirect @@ -80,17 +80,13 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect go.mongodb.org/mongo-driver v1.14.0 // indirect go.opencensus.io v0.24.0 // indirect - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/net v0.37.0 // indirect golang.org/x/oauth2 v0.23.0 // indirect - golang.org/x/sync v0.12.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect golang.org/x/time v0.6.0 // indirect diff --git a/go/go.sum b/go/go.sum index 48450c251f..d1f5e20556 100644 --- a/go/go.sum +++ b/go/go.sum @@ -60,8 +60,6 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= -github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -300,10 +298,6 @@ go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= diff --git a/go/internal/registry/schema.go b/go/internal/registry/schema.go new file mode 100644 index 0000000000..9f04690fa4 --- /dev/null +++ b/go/internal/registry/schema.go @@ -0,0 +1,240 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import ( + "fmt" + "reflect" + "strings" + + "github.com/google/dotprompt/go/dotprompt" + "github.com/invopop/jsonschema" + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +// DefineSchema registers a Go struct as a schema with the given name. +func (r *Registry) DefineSchema(name string, structType any) error { + jsonSchema, err := convertStructToJsonSchema(structType) + if err != nil { + return err + } + + if r.Dotprompt == nil { + r.Dotprompt = dotprompt.NewDotprompt(&dotprompt.DotpromptOptions{ + Schemas: map[string]*jsonschema.Schema{}, + }) + } + + r.Dotprompt.DefineSchema(name, jsonSchema) + r.RegisterValue("schema/"+name, structType) + fmt.Printf("Registered schema '%s' with registry and Dotprompt\n", name) + return nil +} + +// RegisterSchemaWithDotprompt registers a schema with the Dotprompt instance +func (r *Registry) RegisterSchemaWithDotprompt(name string, schema any) error { + if r.Dotprompt == nil { + r.Dotprompt = dotprompt.NewDotprompt(&dotprompt.DotpromptOptions{ + Schemas: map[string]*jsonschema.Schema{}, + }) + } + + jsonSchema, err := convertStructToJsonSchema(schema) + if err != nil { + return err + } + + r.Dotprompt.DefineSchema(name, jsonSchema) + r.RegisterValue("schema/"+name, schema) + r.setupSchemaLookupFunction() + + return nil +} + +// setupSchemaLookupFunction registers the external schema lookup function with Dotprompt +// This function bridges between Dotprompt's schema resolution and the registry's values +func (r *Registry) setupSchemaLookupFunction() { + if r.Dotprompt == nil { + fmt.Println("Warning: No Dotprompt instance to set up schema lookup for") + return + } + + fmt.Println("Registering external schema lookup function with Dotprompt") + r.Dotprompt.RegisterExternalSchemaLookup(func(schemaName string) any { + fmt.Printf("External schema lookup for '%s'\n", schemaName) + + schemaValue := r.LookupValue("schema/" + schemaName) + if schemaValue != nil { + fmt.Printf("Found schema '%s' in registry values\n", schemaName) + return schemaValue + } else { + fmt.Printf("Schema '%s' not found in registry values\n", schemaName) + } + return nil + }) +} + +// DumpRegistrySchemas prints all schemas stored in the registry +func (r *Registry) DumpRegistrySchemas() { + fmt.Println("=== Registry Schemas ===") + + for k, v := range r.values { + if strings.HasPrefix(k, "schema/") { + schemaName := strings.TrimPrefix(k, "schema/") + fmt.Printf("Schema: %s, Type: %T\n", schemaName, v) + } + } + + fmt.Println("=======================") +} + +// convertStructToJsonSchema converts a Go struct to a JSON schema +func convertStructToJsonSchema(structType any) (*jsonschema.Schema, error) { + fmt.Printf("Converting schema of type %T to JSON Schema\n", structType) + + t := reflect.TypeOf(structType) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("expected struct type, got %s", t.Kind()) + } + + schema := &jsonschema.Schema{ + Type: "object", + Properties: orderedmap.New[string, *jsonschema.Schema](), + Required: []string{}, + } + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + if field.PkgPath != "" { + continue + } + + jsonTag := field.Tag.Get("json") + parts := strings.Split(jsonTag, ",") + propName := parts[0] + if propName == "" { + propName = field.Name + } + + if propName == "-" { + continue + } + + isRequired := true + for _, opt := range parts[1:] { + if opt == "omitempty" { + isRequired = false + break + } + } + + if isRequired { + schema.Required = append(schema.Required, propName) + } + + description := field.Tag.Get("description") + + fieldSchema := fieldToSchema(field.Type, description) + schema.Properties.Set(propName, fieldSchema) + } + + return schema, nil +} + +// fieldToSchema converts a field type to a JSON Schema. +func fieldToSchema(t reflect.Type, description string) *jsonschema.Schema { + schema := &jsonschema.Schema{} + + if description != "" { + schema.Description = description + } + + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.String: + schema.Type = "string" + case reflect.Bool: + schema.Type = "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + schema.Type = "integer" + case reflect.Float32, reflect.Float64: + schema.Type = "number" + case reflect.Slice, reflect.Array: + schema.Type = "array" + itemSchema := fieldToSchema(t.Elem(), "") + schema.Items = itemSchema + case reflect.Map: + schema.Type = "object" + if t.Key().Kind() == reflect.String { + valueSchema := fieldToSchema(t.Elem(), "") + schema.AdditionalProperties = valueSchema + } + case reflect.Struct: + schema.Type = "object" + schema.Properties = orderedmap.New[string, *jsonschema.Schema]() + schema.Required = []string{} + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + if field.PkgPath != "" { + continue + } + + jsonTag := field.Tag.Get("json") + parts := strings.Split(jsonTag, ",") + propName := parts[0] + if propName == "" { + propName = field.Name + } + + if propName == "-" { + continue + } + + isRequired := true + for _, opt := range parts[1:] { + if opt == "omitempty" { + isRequired = false + break + } + } + + if isRequired { + schema.Required = append(schema.Required, propName) + } + + fieldDescription := field.Tag.Get("description") + + fieldSchema := fieldToSchema(field.Type, fieldDescription) + schema.Properties.Set(propName, fieldSchema) + } + default: + schema.Type = "string" + } + + return schema +} diff --git a/go/samples/schema/main.go b/go/samples/schema/main.go new file mode 100644 index 0000000000..fe127e142e --- /dev/null +++ b/go/samples/schema/main.go @@ -0,0 +1,284 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +/* +Product Generator using Genkit and Dotprompt + +This application demonstrates a structured product generation system that uses: +- Genkit: A framework for managing AI model interactions and prompts +- Dotprompt: A library for working with structured prompts and JSON schemas +- JSON Schema: For defining the structure of generated product data + +The program: +1. Defines a ProductSchema struct for structured product data +2. Creates a mock AI model plugin that returns predefined product data +3. Generates and saves JSON schema files in a prompts directory +4. Creates a prompt template that takes a theme as input and outputs a product +5. Initializes Dotprompt with schema resolution capabilities +6. Executes the prompt with an "eco-friendly" theme +7. Parses the structured response and displays the generated product + +The mock implementation simulates what would happen with a real AI model +by returning different products based on detected themes in the input. +This provides a testable framework for structured AI outputs conforming +to the defined schema. +*/ + +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/google/dotprompt/go/dotprompt" + "github.com/invopop/jsonschema" +) + +// ProductSchema defines our product output structure +// This schema will be used for structured outputs from AI models +type ProductSchema struct { + Name string `json:"name"` + Description string `json:"description"` + Price float64 `json:"price"` + Category string `json:"category"` + InStock bool `json:"inStock"` +} + +// MockPlugin implements the genkit.Plugin interface +// It provides a custom model implementation for testing purposes +type MockPlugin struct{} + +// Name returns the unique identifier for the plugin +func (p *MockPlugin) Name() string { + return "mock" +} + +// Init initializes the plugin with the Genkit instance +// It registers a mock model that returns predefined product data +func (p *MockPlugin) Init(ctx context.Context, g *genkit.Genkit) error { + genkit.DefineModel(g, "mock", "product-model", + &ai.ModelInfo{ + Label: "Mock Product Model", + Supports: &ai.ModelSupports{}, + }, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + product := ProductSchema{ + Name: "Eco-Friendly Bamboo Cutting Board", + Description: "A sustainable cutting board made from 100% bamboo. Features a juice groove and handle.", + Price: 29.99, + Category: "Kitchen Accessories", + InStock: true, + } + + jsonBytes, err := json.Marshal(product) + if err != nil { + return nil, err + } + + resp := &ai.ModelResponse{ + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewTextPart(string(jsonBytes))}, + }, + FinishReason: ai.FinishReasonStop, + } + + return resp, nil + }) + + return nil +} + +func main() { + ctx := context.Background() + + cwd, _ := os.Getwd() + promptDir := filepath.Join(cwd, "prompts") + + if _, err := os.Stat(promptDir); os.IsNotExist(err) { + if err := os.MkdirAll(promptDir, 0755); err != nil { + log.Fatalf("Failed to create prompt directory: %v", err) + } + } + + schemaFilePath := filepath.Join(promptDir, "_schema_ProductSchema.partial.prompt") + + reflector := jsonschema.Reflector{} + schema := reflector.Reflect(ProductSchema{}) + + // Structure the schema according to what Dotprompt expects + schemaWrapper := struct { + Schema string `json:"$schema"` + Ref string `json:"$ref"` + Definitions map[string]*jsonschema.Schema `json:"$defs"` + }{ + Schema: "https://json-schema.org/draft/2020-12/schema", + Ref: "#/$defs/ProductSchema", + Definitions: map[string]*jsonschema.Schema{ + "ProductSchema": schema, + }, + } + + schemaJSON, err := json.MarshalIndent(schemaWrapper, "", " ") + if err != nil { + log.Fatalf("Failed to marshal schema: %v", err) + } + + if err := os.WriteFile(schemaFilePath, schemaJSON, 0644); err != nil { + log.Fatalf("Failed to write schema file: %v", err) + } + + // Create prompt file with schema reference + promptFilePath := filepath.Join(promptDir, "product_generator.prompt") + promptContent := "---\n" + + "input:\n" + + " schema:\n" + + " theme: string\n" + + "output:\n" + + " schema: ProductSchema\n" + + "---\n" + + "Generate a product that fits the {{theme}} theme.\n" + + "Make sure to provide a detailed description and appropriate pricing." + + if err := os.WriteFile(promptFilePath, []byte(promptContent), 0644); err != nil { + log.Fatalf("Failed to write prompt file: %v", err) + } + + // Testing with dotprompt directly + dp := dotprompt.NewDotprompt(&dotprompt.DotpromptOptions{ + Schemas: map[string]*jsonschema.Schema{}, + }) + + // Register external schema lookup function + dp.RegisterExternalSchemaLookup(func(schemaName string) any { + if schemaName == "ProductSchema" { + return schema + } + return nil + }) + + metadata := map[string]any{ + "output": map[string]any{ + "schema": "ProductSchema", + }, + } + + if err = dp.ResolveSchemaReferences(metadata); err != nil { + log.Fatalf("Schema resolution failed: %v", err) + } + + // Define our schema with Genkit + genkit.DefineSchema("ProductSchema", ProductSchema{}) + + // Initialize Genkit with our prompt directory + g, err := genkit.Init(ctx, + genkit.WithPromptDir(promptDir), + genkit.WithDefaultModel("mock/default-model")) + if err != nil { + log.Fatalf("Failed to initialize Genkit: %v", err) + } + + // Define a mock model to respond to prompts + genkit.DefineModel(g, "mock", "default-model", + &ai.ModelInfo{ + Label: "Mock Default Model", + Supports: &ai.ModelSupports{}, + }, + func(ctx context.Context, req *ai.ModelRequest, cb ai.ModelStreamCallback) (*ai.ModelResponse, error) { + // Extract theme from the request to customize the response + theme := "generic" + if len(req.Messages) > 0 { + lastMsg := req.Messages[len(req.Messages)-1] + if lastMsg.Role == ai.RoleUser { + for _, part := range lastMsg.Content { + if part.IsText() && strings.Contains(part.Text, "eco-friendly") { + theme = "eco-friendly" + } + } + } + } + + // Generate appropriate product based on theme + var product ProductSchema + if theme == "eco-friendly" { + product = ProductSchema{ + Name: "Eco-Friendly Bamboo Cutting Board", + Description: "A sustainable cutting board made from 100% bamboo. Features a juice groove and handle.", + Price: 29.99, + Category: "Kitchen Accessories", + InStock: true, + } + } else { + product = ProductSchema{ + Name: "Classic Stainless Steel Water Bottle", + Description: "Durable 24oz water bottle with vacuum insulation. Keeps drinks cold for 24 hours.", + Price: 24.99, + Category: "Drinkware", + InStock: true, + } + } + + jsonBytes, err := json.Marshal(product) + if err != nil { + return nil, err + } + + resp := &ai.ModelResponse{ + Message: &ai.Message{ + Role: ai.RoleModel, + Content: []*ai.Part{ai.NewTextPart(string(jsonBytes))}, + }, + FinishReason: ai.FinishReasonStop, + } + + return resp, nil + }) + + // Look up and execute the prompt + productPrompt := genkit.LookupPrompt(g, "local", "product_generator") + if productPrompt == nil { + log.Fatalf("Prompt 'product_generator' not found") + } + + input := map[string]any{ + "theme": "eco-friendly kitchen gadgets", + } + + resp, err := productPrompt.Execute(ctx, ai.WithInput(input)) + if err != nil { + log.Fatalf("Failed to execute prompt: %v", err) + } + + // Parse the structured response into our Go struct + var product ProductSchema + if err := resp.Output(&product); err != nil { + log.Fatalf("Failed to parse response: %v", err) + } + + fmt.Println("\nGenerated Product:") + fmt.Printf("Name: %s\n", product.Name) + fmt.Printf("Description: %s\n", product.Description) + fmt.Printf("Price: $%.2f\n", product.Price) + fmt.Printf("Category: %s\n", product.Category) + fmt.Printf("In Stock: %v\n", product.InStock) +} diff --git a/go/samples/schema/prompts/_schema_ProductSchema.partial.prompt b/go/samples/schema/prompts/_schema_ProductSchema.partial.prompt new file mode 100644 index 0000000000..bd3a4cdc31 --- /dev/null +++ b/go/samples/schema/prompts/_schema_ProductSchema.partial.prompt @@ -0,0 +1,40 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/ProductSchema", + "$defs": { + "ProductSchema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/ProductSchema", + "$defs": { + "ProductSchema": { + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "price": { + "type": "number" + }, + "category": { + "type": "string" + }, + "inStock": { + "type": "boolean" + } + }, + "additionalProperties": false, + "type": "object", + "required": [ + "name", + "description", + "price", + "category", + "inStock" + ] + } + } + } + } +} \ No newline at end of file diff --git a/go/samples/schema/prompts/product_generator.prompt b/go/samples/schema/prompts/product_generator.prompt new file mode 100644 index 0000000000..2c1a127ff4 --- /dev/null +++ b/go/samples/schema/prompts/product_generator.prompt @@ -0,0 +1,9 @@ +--- +input: + schema: + theme: string +output: + schema: ProductSchema +--- +Generate a product that fits the {{theme}} theme. +Make sure to provide a detailed description and appropriate pricing. \ No newline at end of file From d75d77ee894d521a06d3fbdb03880dc441e52afa Mon Sep 17 00:00:00 2001 From: Iven Han Date: Mon, 21 Apr 2025 12:00:41 -0700 Subject: [PATCH 2/6] feat(go/genkit): Refactor schema handling to resolve import cycles and improve code organization - Move pendingSchemas management to core package to avoid validation duplication - Add RegisterGlobalSchemaResolver that follows DRY by calling registerSchemaResolver - Create recursive helper function for field-to-schema conversion to support nested structs - Fix import cycles between genkit, core, and registry packages - Standardize schema path construction with consistent separators - Add functions to retrieve and clear pending schemas - Update schema registration during initialization process - Fix LookupPrompt function signature to match usage This refactoring addresses code review feedback by: 1. Eliminating duplicate validation between genkit.DefineSchema and core.RegisterSchema 2. Following DRY principle for schema resolver registration 3. Improving support for nested struct types with recursive field conversion 4. Creating a cleaner architecture with proper separation of concerns --- go/core/schema.go | 37 ++++++++++++++++++++++++++++---- go/genkit/genkit.go | 9 ++++---- go/go.mod | 4 ++++ go/go.sum | 6 ++++++ go/internal/registry/registry.go | 1 + go/internal/registry/schema.go | 20 +++++++---------- go/samples/schema/main.go | 2 +- 7 files changed, 58 insertions(+), 21 deletions(-) diff --git a/go/core/schema.go b/go/core/schema.go index 8a499699ad..c2d94b0572 100644 --- a/go/core/schema.go +++ b/go/core/schema.go @@ -25,9 +25,6 @@ import ( // Schema represents a schema definition that can be of any type. type Schema any -// SchemaType is the type identifier for schemas in the registry. -const SchemaType = "schema" - // schemaRegistry maintains registry of schemas. var ( schemasMu sync.RWMutex @@ -39,7 +36,10 @@ var ( // RegisterSchema registers a schema with the given name. // This is intended to be called by higher-level packages like ai. -func RegisterSchema(name string, schema any) { +// It validates that the name is not empty and the schema is not nil, +// then registers the schema in the core schemas map. +// Returns the schema for convenience in chaining operations. +func RegisterSchema(name string, schema any) Schema { if name == "" { panic("core.RegisterSchema: schema name cannot be empty") } @@ -56,6 +56,9 @@ func RegisterSchema(name string, schema any) { } schemas[name] = schema + pendingSchemas[name] = schema + + return schema } // LookupSchema looks up a schema by name. @@ -110,5 +113,31 @@ func ClearSchemas() { defer schemasMu.Unlock() schemas = make(map[string]any) + pendingSchemas = make(map[string]Schema) + schemaLookups = nil +} + +// GetPendingSchemas returns a copy of pending schemas that need to be +// registered with Dotprompt. +func GetPendingSchemas() map[string]Schema { + schemasMu.RLock() + defer schemasMu.RUnlock() + + result := make(map[string]Schema, len(pendingSchemas)) + for name, schema := range pendingSchemas { + result[name] = schema + } + + return result +} + +// ClearPendingSchemas clears the pending schemas map. +// This is called after the schemas have been registered with Dotprompt. +func ClearPendingSchemas() { + schemasMu.Lock() + defer schemasMu.Unlock() + + schemas = make(map[string]any) + pendingSchemas = make(map[string]Schema) schemaLookups = nil } diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index d4b1be0b8a..c5eca209b5 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -252,16 +252,17 @@ func Init(ctx context.Context, opts ...GenkitOption) (*Genkit, error) { // Internal function called during Init to register pending schemas func registerPendingSchemas(reg *registry.Registry) error { - schemasMu.Lock() - defer schemasMu.Unlock() + // Get pending schemas from core + pendingSchemas := core.GetPendingSchemas() for name, schema := range pendingSchemas { if err := reg.RegisterSchemaWithDotprompt(name, schema); err != nil { return fmt.Errorf("failed to register schema %s: %w", name, err) } } - // Clear pending schemas - pendingSchemas = make(map[string]Schema) + + // Clear pending schemas after registration + core.ClearPendingSchemas() return nil } diff --git a/go/go.mod b/go/go.mod index 5dbced3b15..4b8278f304 100644 --- a/go/go.mod +++ b/go/go.mod @@ -55,6 +55,7 @@ require ( github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.21.2 // indirect @@ -84,9 +85,12 @@ require ( github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect go.mongodb.org/mongo-driver v1.14.0 // indirect go.opencensus.io v0.24.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/net v0.37.0 // indirect golang.org/x/oauth2 v0.23.0 // indirect + golang.org/x/sync v0.12.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect golang.org/x/time v0.6.0 // indirect diff --git a/go/go.sum b/go/go.sum index d1f5e20556..48450c251f 100644 --- a/go/go.sum +++ b/go/go.sum @@ -60,6 +60,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -298,6 +300,10 @@ go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= diff --git a/go/internal/registry/registry.go b/go/internal/registry/registry.go index 8cf0ef80c8..2ba4622039 100644 --- a/go/internal/registry/registry.go +++ b/go/internal/registry/registry.go @@ -36,6 +36,7 @@ import ( const ( DefaultModelKey = "genkit/defaultModel" PromptDirKey = "genkit/promptDir" + SchemaType = "schema" ) type Registry struct { diff --git a/go/internal/registry/schema.go b/go/internal/registry/schema.go index 9f04690fa4..c2c9281a0f 100644 --- a/go/internal/registry/schema.go +++ b/go/internal/registry/schema.go @@ -40,12 +40,13 @@ func (r *Registry) DefineSchema(name string, structType any) error { } r.Dotprompt.DefineSchema(name, jsonSchema) - r.RegisterValue("schema/"+name, structType) + r.RegisterValue(SchemaType+"/"+name, structType) fmt.Printf("Registered schema '%s' with registry and Dotprompt\n", name) return nil } // RegisterSchemaWithDotprompt registers a schema with the Dotprompt instance +// This is used during Init to register schemas that were defined before the registry was created. func (r *Registry) RegisterSchemaWithDotprompt(name string, schema any) error { if r.Dotprompt == nil { r.Dotprompt = dotprompt.NewDotprompt(&dotprompt.DotpromptOptions{ @@ -59,7 +60,9 @@ func (r *Registry) RegisterSchemaWithDotprompt(name string, schema any) error { } r.Dotprompt.DefineSchema(name, jsonSchema) - r.RegisterValue("schema/"+name, schema) + r.RegisterValue(SchemaType+"/"+name, schema) + + // Set up schema lookup if not already done r.setupSchemaLookupFunction() return nil @@ -69,20 +72,13 @@ func (r *Registry) RegisterSchemaWithDotprompt(name string, schema any) error { // This function bridges between Dotprompt's schema resolution and the registry's values func (r *Registry) setupSchemaLookupFunction() { if r.Dotprompt == nil { - fmt.Println("Warning: No Dotprompt instance to set up schema lookup for") return } - fmt.Println("Registering external schema lookup function with Dotprompt") r.Dotprompt.RegisterExternalSchemaLookup(func(schemaName string) any { - fmt.Printf("External schema lookup for '%s'\n", schemaName) - - schemaValue := r.LookupValue("schema/" + schemaName) + schemaValue := r.LookupValue(SchemaType + "/" + schemaName) if schemaValue != nil { - fmt.Printf("Found schema '%s' in registry values\n", schemaName) return schemaValue - } else { - fmt.Printf("Schema '%s' not found in registry values\n", schemaName) } return nil }) @@ -93,8 +89,8 @@ func (r *Registry) DumpRegistrySchemas() { fmt.Println("=== Registry Schemas ===") for k, v := range r.values { - if strings.HasPrefix(k, "schema/") { - schemaName := strings.TrimPrefix(k, "schema/") + if strings.HasPrefix(k, SchemaType+"/") { + schemaName := strings.TrimPrefix(k, SchemaType+"/") fmt.Printf("Schema: %s, Type: %T\n", schemaName, v) } } diff --git a/go/samples/schema/main.go b/go/samples/schema/main.go index fe127e142e..78525ed19c 100644 --- a/go/samples/schema/main.go +++ b/go/samples/schema/main.go @@ -255,7 +255,7 @@ func main() { }) // Look up and execute the prompt - productPrompt := genkit.LookupPrompt(g, "local", "product_generator") + productPrompt := genkit.LookupPrompt(g, "product_generator") if productPrompt == nil { log.Fatalf("Prompt 'product_generator' not found") } From e18f609c23f0db49f0cddb7615d19cdaa73e87b1 Mon Sep 17 00:00:00 2001 From: Iven Han Date: Mon, 21 Apr 2025 16:47:20 -0700 Subject: [PATCH 3/6] feat(go/genkit): Fix dependency conflict with dotprompt by adding replace directive --- go/go.mod | 2 ++ go/go.sum | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/go/go.mod b/go/go.mod index 4b8278f304..a54110a0a5 100644 --- a/go/go.mod +++ b/go/go.mod @@ -101,3 +101,5 @@ require ( google.golang.org/grpc v1.66.2 // indirect google.golang.org/protobuf v1.34.2 // indirect ) + +replace github.com/google/dotprompt/go => github.com/ihan211/dotprompt/go v0.0.0-20250421000000-71f278572f8dd5a73b72814f93c774cddb5f9562 diff --git a/go/go.sum b/go/go.sum index 48450c251f..7d342a630a 100644 --- a/go/go.sum +++ b/go/go.sum @@ -145,8 +145,6 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/dotprompt/go v0.0.0-20250415074656-072d95deb01d h1:ChKKjq8F7GcNKViCCB/vRoU6joR7IDsZgu1I4wg6RjQ= -github.com/google/dotprompt/go v0.0.0-20250415074656-072d95deb01d/go.mod h1:dnIk+MSMnipm9uZyPIgptq7I39aDxyjBiaev/OG0W0Y= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -170,6 +168,8 @@ github.com/googleapis/gax-go/v2 v2.13.0 h1:yitjD5f7jQHhyDsnhKEBU52NdvvdSeGzlAnDP github.com/googleapis/gax-go/v2 v2.13.0/go.mod h1:Z/fvTZXF8/uw7Xu5GuslPw+bplx6SS338j1Is2S+B7A= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/ihan211/dotprompt/go v0.0.0-20250421213448-303a393ed401 h1:MldUWpgA5fV4SdzfM/N9RYGlz7lB2Hb37FGlKhU5btk= +github.com/ihan211/dotprompt/go v0.0.0-20250421213448-303a393ed401/go.mod h1:bJnw7xUiojzhNtAYYnp+I6wMI03WOPhdYWVahrwvk6g= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= From 50d7d6e9411e96e6ea2203df789cf2fa20febce9 Mon Sep 17 00:00:00 2001 From: Iven Han Date: Tue, 22 Apr 2025 13:59:32 -0700 Subject: [PATCH 4/6] Update dotprompt dependency to use add-schema-support branch --- go/go.mod | 4 ++-- go/go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go/go.mod b/go/go.mod index a54110a0a5..ae1d8a1824 100644 --- a/go/go.mod +++ b/go/go.mod @@ -53,6 +53,7 @@ require ( github.com/PuerkitoBio/purell v1.1.1 // indirect github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect + github.com/aymerick/raymond v2.0.2+incompatible // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -77,7 +78,6 @@ require ( github.com/gorilla/websocket v1.5.3 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/mailru/easyjson v0.9.0 // indirect - github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -102,4 +102,4 @@ require ( google.golang.org/protobuf v1.34.2 // indirect ) -replace github.com/google/dotprompt/go => github.com/ihan211/dotprompt/go v0.0.0-20250421000000-71f278572f8dd5a73b72814f93c774cddb5f9562 +replace github.com/google/dotprompt/go => github.com/google/dotprompt/go v0.0.0-20250422204256-6029fef7a2fd diff --git a/go/go.sum b/go/go.sum index 7d342a630a..e6efdac1d5 100644 --- a/go/go.sum +++ b/go/go.sum @@ -43,6 +43,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aymerick/raymond v2.0.2+incompatible h1:VEp3GpgdAnv9B2GFyTvqgcKvY+mfKMjPOA3SbKLtnU0= +github.com/aymerick/raymond v2.0.2+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/blues/jsonata-go v1.5.4 h1:XCsXaVVMrt4lcpKeJw6mNJHqQpWU751cnHdCFUq3xd8= @@ -145,6 +147,8 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/dotprompt/go v0.0.0-20250422204256-6029fef7a2fd h1:LmVYfpTt3dDDYoBqziibAZf2lfMOcOf5MfkFDyoDrPg= +github.com/google/dotprompt/go v0.0.0-20250422204256-6029fef7a2fd/go.mod h1:wVZXOPYuasZIfPu6UQvYxODdVUR2nIligI4SWs47GVs= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -168,8 +172,6 @@ github.com/googleapis/gax-go/v2 v2.13.0 h1:yitjD5f7jQHhyDsnhKEBU52NdvvdSeGzlAnDP github.com/googleapis/gax-go/v2 v2.13.0/go.mod h1:Z/fvTZXF8/uw7Xu5GuslPw+bplx6SS338j1Is2S+B7A= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/ihan211/dotprompt/go v0.0.0-20250421213448-303a393ed401 h1:MldUWpgA5fV4SdzfM/N9RYGlz7lB2Hb37FGlKhU5btk= -github.com/ihan211/dotprompt/go v0.0.0-20250421213448-303a393ed401/go.mod h1:bJnw7xUiojzhNtAYYnp+I6wMI03WOPhdYWVahrwvk6g= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= @@ -213,8 +215,6 @@ github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4 github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= -github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a h1:v2cBA3xWKv2cIOVhnzX/gNgkNXqiHfUgJtA3r61Hf7A= -github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a/go.mod h1:Y6ghKH+ZijXn5d9E7qGGZBmjitx7iitZdQiIW97EpTU= github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= From 0aefcf3b2117cfb84cceb6cc61125b3462299d67 Mon Sep 17 00:00:00 2001 From: Iven Han Date: Tue, 22 Apr 2025 16:55:13 -0700 Subject: [PATCH 5/6] feat(go/genkit): Refactor schema registry: rename Get prefixed functions and improve naming convention --- go/core/schema.go | 9 ++++----- go/genkit/genkit.go | 2 +- go/genkit/schema.go | 5 ++--- go/genkit/schema_test.go | 8 ++++---- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/go/core/schema.go b/go/core/schema.go index c2d94b0572..86fda44a95 100644 --- a/go/core/schema.go +++ b/go/core/schema.go @@ -25,7 +25,6 @@ import ( // Schema represents a schema definition that can be of any type. type Schema any -// schemaRegistry maintains registry of schemas. var ( schemasMu sync.RWMutex schemas = make(map[string]any) @@ -93,8 +92,8 @@ func RegisterSchemaLookup(lookup func(string) any) { schemaLookups = append(schemaLookups, lookup) } -// GetSchemas returns a copy of all registered schemas. -func GetSchemas() map[string]any { +// Schemas returns a copy of all registered schemas. +func Schemas() map[string]any { schemasMu.RLock() defer schemasMu.RUnlock() @@ -117,9 +116,9 @@ func ClearSchemas() { schemaLookups = nil } -// GetPendingSchemas returns a copy of pending schemas that need to be +// PendingSchemas returns a copy of pending schemas that need to be // registered with Dotprompt. -func GetPendingSchemas() map[string]Schema { +func PendingSchemas() map[string]Schema { schemasMu.RLock() defer schemasMu.RUnlock() diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index c5eca209b5..81c174fc70 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -253,7 +253,7 @@ func Init(ctx context.Context, opts ...GenkitOption) (*Genkit, error) { // Internal function called during Init to register pending schemas func registerPendingSchemas(reg *registry.Registry) error { // Get pending schemas from core - pendingSchemas := core.GetPendingSchemas() + pendingSchemas := core.PendingSchemas() for name, schema := range pendingSchemas { if err := reg.RegisterSchemaWithDotprompt(name, schema); err != nil { diff --git a/go/genkit/schema.go b/go/genkit/schema.go index bb6521bb14..f6e7f37899 100644 --- a/go/genkit/schema.go +++ b/go/genkit/schema.go @@ -72,9 +72,9 @@ func LookupSchema(name string) (Schema, bool) { return schema, schema != nil } -// GetSchema retrieves a registered schema by name. +// FindSchema retrieves a registered schema by name. // It returns an error if no schema exists with that name. -func GetSchema(name string) (Schema, error) { +func FindSchema(name string) (Schema, error) { schema, exists := LookupSchema(name) if !exists { return nil, fmt.Errorf("genkit: schema '%s' not found", name) @@ -92,7 +92,6 @@ func registerSchemaResolver(dp *dotprompt.Dotprompt) { return nil } - // Convert the schema to a JSON schema reflector := jsonschema.Reflector{} jsonSchema := reflector.Reflect(schema) return jsonSchema diff --git a/go/genkit/schema_test.go b/go/genkit/schema_test.go index 9fbdde7ac0..28241640b1 100644 --- a/go/genkit/schema_test.go +++ b/go/genkit/schema_test.go @@ -49,12 +49,12 @@ func TestDefineAndLookupSchema(t *testing.T) { } } -func TestGetSchemaSuccess(t *testing.T) { +func TestSchemaSuccess(t *testing.T) { schemaName := "GetStruct" testSchema := TestStruct{Name: "Bob", Age: 25} DefineSchema(schemaName, testSchema) - schema, err := GetSchema(schemaName) + schema, err := FindSchema(schemaName) if err != nil { t.Fatalf("Expected schema '%s' to be retrieved without error", schemaName) } @@ -65,8 +65,8 @@ func TestGetSchemaSuccess(t *testing.T) { } } -func TestGetSchemaNotFound(t *testing.T) { - _, err := GetSchema("NonExistentSchema") +func TestSchemaNotFound(t *testing.T) { + _, err := FindSchema("NonExistentSchema") if err == nil { t.Fatal("Expected error when retrieving a non-existent schema") } From f7f435ae42773144d63beeb9c24c52ff09b316f1 Mon Sep 17 00:00:00 2001 From: Iven Han Date: Fri, 25 Apr 2025 11:52:56 -0700 Subject: [PATCH 6/6] fix(go/genkit): address issues raised by @hugoaguirre - Move ClearSchemas to schema_test.go as clearSchemasForTest() - Replace panic with error returns in RegisterSchema and DefineSchema - Improve logging using structured slog instead of fmt.Printf - Remove debug prints and unused functions - Add proper error handling for schema not found cases --- go/core/schema.go | 22 +-- go/core/schema_test.go | 282 +++++++++++++++++++++++++++++++++ go/genkit/genkit.go | 2 - go/genkit/schema.go | 15 +- go/genkit/schema_test.go | 38 ++--- go/internal/registry/schema.go | 16 -- 6 files changed, 313 insertions(+), 62 deletions(-) create mode 100644 go/core/schema_test.go diff --git a/go/core/schema.go b/go/core/schema.go index 86fda44a95..09c9789533 100644 --- a/go/core/schema.go +++ b/go/core/schema.go @@ -14,7 +14,6 @@ // // SPDX-License-Identifier: Apache-2.0 // Package core provides core functionality for the genkit framework. - package core import ( @@ -38,26 +37,26 @@ var ( // It validates that the name is not empty and the schema is not nil, // then registers the schema in the core schemas map. // Returns the schema for convenience in chaining operations. -func RegisterSchema(name string, schema any) Schema { +func RegisterSchema(name string, schema any) (Schema, error) { if name == "" { - panic("core.RegisterSchema: schema name cannot be empty") + return nil, fmt.Errorf("core.RegisterSchema: schema name cannot be empty") } if schema == nil { - panic("core.RegisterSchema: schema definition cannot be nil") + return nil, fmt.Errorf("core.RegisterSchema: schema definition cannot be nil") } schemasMu.Lock() defer schemasMu.Unlock() if _, exists := schemas[name]; exists { - panic(fmt.Sprintf("core.RegisterSchema: schema with name %q already exists", name)) + return nil, fmt.Errorf("core.RegisterSchema: schema with name %q already exists", name) } schemas[name] = schema pendingSchemas[name] = schema - return schema + return schema, nil } // LookupSchema looks up a schema by name. @@ -105,17 +104,6 @@ func Schemas() map[string]any { return result } -// ClearSchemas removes all registered schemas. -// This is primarily for testing purposes. -func ClearSchemas() { - schemasMu.Lock() - defer schemasMu.Unlock() - - schemas = make(map[string]any) - pendingSchemas = make(map[string]Schema) - schemaLookups = nil -} - // PendingSchemas returns a copy of pending schemas that need to be // registered with Dotprompt. func PendingSchemas() map[string]Schema { diff --git a/go/core/schema_test.go b/go/core/schema_test.go new file mode 100644 index 0000000000..932ea775c5 --- /dev/null +++ b/go/core/schema_test.go @@ -0,0 +1,282 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package core + +import ( + "fmt" + "reflect" + "sync" + "testing" +) + +// clearSchemasForTest removes all registered schemas. +// This is exclusively for testing purposes. +func clearSchemasForTest() { + schemasMu.Lock() + defer schemasMu.Unlock() + + schemas = make(map[string]any) + pendingSchemas = make(map[string]Schema) + schemaLookups = nil +} + +// TestRegisterSchema tests schema registration functionality +func TestRegisterSchema(t *testing.T) { + clearSchemasForTest() + t.Cleanup(clearSchemasForTest) + + t.Run("RegisterValidSchema", func(t *testing.T) { + schema := map[string]interface{}{"type": "object"} + result, err := RegisterSchema("test", schema) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if result == nil { + t.Fatal("Expected RegisterSchema to return the schema, got nil") + } + + retrieved := LookupSchema("test") + if retrieved == nil { + t.Fatal("Failed to retrieve registered schema") + } + + if !reflect.DeepEqual(retrieved, schema) { + t.Fatalf("Retrieved schema doesn't match registered schema. Got %v, want %v", retrieved, schema) + } + }) + + t.Run("RegisterDuplicateName", func(t *testing.T) { + clearSchemasForTest() + _, err := RegisterSchema("duplicate", "first") + if err != nil { + t.Fatalf("Unexpected error registering first schema: %v", err) + } + + _, err = RegisterSchema("duplicate", "second") + if err == nil { + t.Fatal("Expected error when registering duplicate schema name, but no error occurred") + } + expectedErrMsg := `core.RegisterSchema: schema with name "duplicate" already exists` + if err.Error() != expectedErrMsg { + t.Fatalf("Expected error message %q, got %q", expectedErrMsg, err.Error()) + } + }) + + t.Run("RegisterEmptyName", func(t *testing.T) { + _, err := RegisterSchema("", "schema") + if err == nil { + t.Fatal("Expected error when registering schema with empty name, but no error occurred") + } + expectedErrMsg := "core.RegisterSchema: schema name cannot be empty" + if err.Error() != expectedErrMsg { + t.Fatalf("Expected error message %q, got %q", expectedErrMsg, err.Error()) + } + }) + + t.Run("RegisterNilSchema", func(t *testing.T) { + _, err := RegisterSchema("nil_schema", nil) + if err == nil { + t.Fatal("Expected error when registering nil schema, but no error occurred") + } + expectedErrMsg := "core.RegisterSchema: schema definition cannot be nil" + if err.Error() != expectedErrMsg { + t.Fatalf("Expected error message %q, got %q", expectedErrMsg, err.Error()) + } + }) +} + +// TestLookupSchema tests schema lookup functionality +func TestLookupSchema(t *testing.T) { + clearSchemasForTest() + t.Cleanup(clearSchemasForTest) + + t.Run("LookupExistingSchema", func(t *testing.T) { + expectedSchema := "test_schema" + _, err := RegisterSchema("existing", expectedSchema) + if err != nil { + t.Fatalf("Failed to register schema: %v", err) + } + + result := LookupSchema("existing") + if result != expectedSchema { + t.Fatalf("Expected schema %v, got %v", expectedSchema, result) + } + }) + + t.Run("LookupNonExistentSchema", func(t *testing.T) { + result := LookupSchema("nonexistent") + if result != nil { + t.Fatalf("Expected nil for non-existent schema, got %v", result) + } + }) + + t.Run("LookupViaCustomFunction", func(t *testing.T) { + expectedSchema := "custom_schema" + RegisterSchemaLookup(func(name string) any { + if name == "custom" { + return expectedSchema + } + return nil + }) + + result := LookupSchema("custom") + if result != expectedSchema { + t.Fatalf("Expected schema %v from custom lookup, got %v", expectedSchema, result) + } + }) + + t.Run("PreferLocalRegistryOverLookup", func(t *testing.T) { + localSchema := "local_schema" + _, err := RegisterSchema("preference_test", localSchema) + if err != nil { + t.Fatalf("Failed to register schema: %v", err) + } + + lookupSchema := "lookup_schema" + RegisterSchemaLookup(func(name string) any { + if name == "preference_test" { + return lookupSchema + } + return nil + }) + + result := LookupSchema("preference_test") + if result != localSchema { + t.Fatalf("Expected local schema %v to be preferred, got %v", localSchema, result) + } + }) +} + +// TestPendingSchemas tests handling of pending schemas +func TestPendingSchemas(t *testing.T) { + clearSchemasForTest() + t.Cleanup(clearSchemasForTest) + + t.Run("GetPendingSchemas", func(t *testing.T) { + _, err := RegisterSchema("pending1", "test1") + if err != nil { + t.Fatalf("Failed to register first schema: %v", err) + } + + _, err = RegisterSchema("pending2", "test2") + if err != nil { + t.Fatalf("Failed to register second schema: %v", err) + } + + pending := PendingSchemas() + if len(pending) != 2 { + t.Fatalf("Expected 2 pending schemas, got %d", len(pending)) + } + + if pending["pending1"] != "test1" || pending["pending2"] != "test2" { + t.Fatal("Pending schemas don't match expected values") + } + }) + + t.Run("ClearPendingSchemas", func(t *testing.T) { + _, err := RegisterSchema("pending3", "test3") + if err != nil { + t.Fatalf("Failed to register schema: %v", err) + } + + ClearPendingSchemas() + + pending := PendingSchemas() + if len(pending) != 0 { + t.Fatalf("Expected 0 pending schemas after clearing, got %d", len(pending)) + } + }) +} + +// TestSchemas tests the Schemas function that returns all registered schemas +func TestSchemas(t *testing.T) { + clearSchemasForTest() + t.Cleanup(clearSchemasForTest) + + _, err := RegisterSchema("schema1", "value1") + if err != nil { + t.Fatalf("Failed to register first schema: %v", err) + } + + _, err = RegisterSchema("schema2", "value2") + if err != nil { + t.Fatalf("Failed to register second schema: %v", err) + } + + schemasMap := Schemas() + if len(schemasMap) != 2 { + t.Fatalf("Expected 2 schemas, got %d", len(schemasMap)) + } + + if schemasMap["schema1"] != "value1" || schemasMap["schema2"] != "value2" { + t.Fatal("Retrieved schemas don't match expected values") + } + + schemasMap["schema3"] = "value3" + + internalSchemas := Schemas() + if len(internalSchemas) != 2 { + t.Fatalf("Expected internal schemas count to remain 2, got %d", len(internalSchemas)) + } + + if _, exists := internalSchemas["schema3"]; exists { + t.Fatal("Modifying returned schemas map should not affect internal state") + } +} + +// TestConcurrentAccess tests thread safety of schema operations +func TestConcurrentAccess(t *testing.T) { + clearSchemasForTest() + t.Cleanup(clearSchemasForTest) + + const numGoroutines = 10 + const schemasPerGoroutine = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(routineID int) { + defer wg.Done() + + for j := 0; j < schemasPerGoroutine; j++ { + name := fmt.Sprintf("schema_r%d_s%d", routineID, j) + _, err := RegisterSchema(name, j) + if err != nil { + t.Errorf("Unexpected error registering schema %s: %v", name, err) + } + } + + for j := 0; j < schemasPerGoroutine; j++ { + name := fmt.Sprintf("schema_r%d_s%d", routineID, j) + value := LookupSchema(name) + if value != j { + t.Errorf("Expected schema value %d for %s, got %v", j, name, value) + } + } + }(i) + } + + wg.Wait() + + schemasMap := Schemas() + expectedCount := numGoroutines * schemasPerGoroutine + if len(schemasMap) != expectedCount { + t.Fatalf("Expected %d total schemas, got %d", expectedCount, len(schemasMap)) + } +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 81c174fc70..7f858c1627 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -252,7 +252,6 @@ func Init(ctx context.Context, opts ...GenkitOption) (*Genkit, error) { // Internal function called during Init to register pending schemas func registerPendingSchemas(reg *registry.Registry) error { - // Get pending schemas from core pendingSchemas := core.PendingSchemas() for name, schema := range pendingSchemas { @@ -261,7 +260,6 @@ func registerPendingSchemas(reg *registry.Registry) error { } } - // Clear pending schemas after registration core.ClearPendingSchemas() return nil } diff --git a/go/genkit/schema.go b/go/genkit/schema.go index f6e7f37899..8ddf9d687f 100644 --- a/go/genkit/schema.go +++ b/go/genkit/schema.go @@ -18,6 +18,7 @@ package genkit import ( "fmt" + "log/slog" "sync" "github.com/firebase/genkit/go/core" @@ -45,24 +46,22 @@ var ( // } // // personSchema := genkit.DefineSchema("Person", Person{}) -func DefineSchema(name string, schema Schema) Schema { +func DefineSchema(name string, schema Schema) (Schema, error) { if name == "" { - panic("genkit.DefineSchema: schema name cannot be empty") + return nil, fmt.Errorf("genkit.DefineSchema: schema name cannot be empty") } if schema == nil { - panic("genkit.DefineSchema: schema cannot be nil") + return nil, fmt.Errorf("genkit.DefineSchema: schema cannot be nil") } - // Register with core registry core.RegisterSchema(name, schema) - // Also track for Dotprompt integration schemasMu.Lock() defer schemasMu.Unlock() pendingSchemas[name] = schema - return schema + return schema, nil } // LookupSchema retrieves a registered schema by name. @@ -88,7 +87,7 @@ func registerSchemaResolver(dp *dotprompt.Dotprompt) { schemaResolver := func(name string) any { schema, exists := LookupSchema(name) if !exists { - fmt.Printf("Schema '%s' not found in registry\n", name) + slog.Error("schema not found in registry", "name", name) return nil } @@ -97,7 +96,6 @@ func registerSchemaResolver(dp *dotprompt.Dotprompt) { return jsonSchema } - // Register the resolver with Dotprompt dp.RegisterExternalSchemaLookup(schemaResolver) } @@ -109,7 +107,6 @@ func RegisterGlobalSchemaResolver(dp *dotprompt.Dotprompt) { return nil } - // Convert the schema to a JSON schema reflector := jsonschema.Reflector{} jsonSchema := reflector.Reflect(schema) return jsonSchema diff --git a/go/genkit/schema_test.go b/go/genkit/schema_test.go index 28241640b1..15fa4cc514 100644 --- a/go/genkit/schema_test.go +++ b/go/genkit/schema_test.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -13,7 +13,6 @@ // limitations under the License. // // SPDX-License-Identifier: Apache-2.0 - package genkit import ( @@ -30,7 +29,10 @@ func TestDefineAndLookupSchema(t *testing.T) { testSchema := TestStruct{Name: "Alice", Age: 30} // Define the schema - DefineSchema(schemaName, testSchema) + schema, err := DefineSchema(schemaName, testSchema) + if err != nil { + t.Fatalf("Unexpected error defining schema: %v", err) + } // Lookup the schema schema, found := LookupSchema(schemaName) @@ -52,7 +54,11 @@ func TestDefineAndLookupSchema(t *testing.T) { func TestSchemaSuccess(t *testing.T) { schemaName := "GetStruct" testSchema := TestStruct{Name: "Bob", Age: 25} - DefineSchema(schemaName, testSchema) + + _, err := DefineSchema(schemaName, testSchema) + if err != nil { + t.Fatalf("Unexpected error defining schema: %v", err) + } schema, err := FindSchema(schemaName) if err != nil { @@ -72,21 +78,17 @@ func TestSchemaNotFound(t *testing.T) { } } -func TestDefineSchemaPanics(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Fatal("Expected panic for empty schema name") - } - }() - DefineSchema("", TestStruct{}) +func TestDefineSchemaEmptyName(t *testing.T) { + _, err := DefineSchema("", TestStruct{}) + if err == nil { + t.Fatal("Expected error for empty schema name") + } } -func TestDefineSchemaNilPanics(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Fatal("Expected panic for nil schema") - } - }() +func TestDefineSchemaNil(t *testing.T) { var nilSchema Schema - DefineSchema("NilSchema", nilSchema) + _, err := DefineSchema("NilSchema", nilSchema) + if err == nil { + t.Fatal("Expected error for nil schema") + } } diff --git a/go/internal/registry/schema.go b/go/internal/registry/schema.go index c2c9281a0f..a24da9fde7 100644 --- a/go/internal/registry/schema.go +++ b/go/internal/registry/schema.go @@ -84,24 +84,8 @@ func (r *Registry) setupSchemaLookupFunction() { }) } -// DumpRegistrySchemas prints all schemas stored in the registry -func (r *Registry) DumpRegistrySchemas() { - fmt.Println("=== Registry Schemas ===") - - for k, v := range r.values { - if strings.HasPrefix(k, SchemaType+"/") { - schemaName := strings.TrimPrefix(k, SchemaType+"/") - fmt.Printf("Schema: %s, Type: %T\n", schemaName, v) - } - } - - fmt.Println("=======================") -} - // convertStructToJsonSchema converts a Go struct to a JSON schema func convertStructToJsonSchema(structType any) (*jsonschema.Schema, error) { - fmt.Printf("Converting schema of type %T to JSON Schema\n", structType) - t := reflect.TypeOf(structType) if t.Kind() == reflect.Ptr { t = t.Elem()