diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go index 5b609afd5a..a0a8d232f8 100644 --- a/go/plugins/googlegenai/gemini.go +++ b/go/plugins/googlegenai/gemini.go @@ -140,11 +140,17 @@ func defineModel(g *genkit.Genkit, client *genai.Client, name string, info ai.Mo provider = vertexAIProvider } + var config any + config = &genai.GenerateContentConfig{} + if mi, found := supportedImagenModels[name]; found { + config = &genai.GenerateImagesConfig{} + info = mi + } meta := &ai.ModelInfo{ Label: info.Label, Supports: info.Supports, Versions: info.Versions, - ConfigSchema: configToMap(genai.GenerateContentConfig{}), + ConfigSchema: configToMap(config), } fn := func( @@ -152,7 +158,12 @@ func defineModel(g *genkit.Genkit, client *genai.Client, name string, info ai.Mo input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error, ) (*ai.ModelResponse, error) { - return generate(ctx, client, name, input, cb) + switch config.(type) { + case *genai.GenerateImagesConfig: + return generateImage(ctx, client, name, input, cb) + default: + return generate(ctx, client, name, input, cb) + } } // the gemini api doesn't support downloading media from http(s) if info.Supports.Media { diff --git a/go/plugins/googlegenai/imagen.go b/go/plugins/googlegenai/imagen.go new file mode 100644 index 0000000000..aabe7ec490 --- /dev/null +++ b/go/plugins/googlegenai/imagen.go @@ -0,0 +1,117 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package googlegenai + +import ( + "context" + "encoding/base64" + "fmt" + + "github.com/firebase/genkit/go/ai" + "google.golang.org/genai" +) + +// Media describes model capabilities for Gemini models with media and text +// input and image only output +var Media = ai.ModelSupports{ + Media: true, + Multiturn: false, + Tools: false, + ToolChoice: false, + SystemRole: false, +} + +// imagenConfigFromRequest translates an [*ai.ModelRequest] configuration to [*genai.GenerateImagesConfig] +func imagenConfigFromRequest(input *ai.ModelRequest) (*genai.GenerateImagesConfig, error) { + var result genai.GenerateImagesConfig + + switch config := input.Config.(type) { + case genai.GenerateImagesConfig: + result = config + case *genai.GenerateImagesConfig: + result = *config + case map[string]any: + if err := mapToStruct(config, &result); err != nil { + return nil, err + } + case nil: + // empty but valid config + default: + return nil, fmt.Errorf("unexpected config type: %T", input.Config) + } + + return &result, nil +} + +// translateImagenCandidates translates the image generation response to [*ai.ModelResponse] +func translateImagenCandidates(images []*genai.GeneratedImage) *ai.ModelResponse { + m := &ai.ModelResponse{} + m.FinishReason = ai.FinishReasonStop + + msg := &ai.Message{} + msg.Role = ai.RoleModel + + for _, img := range images { + msg.Content = append(msg.Content, ai.NewMediaPart(img.Image.MIMEType, "data:"+img.Image.MIMEType+";base64,"+base64.StdEncoding.EncodeToString(img.Image.ImageBytes))) + } + + m.Message = msg + return m +} + +// translateImagenResponse translates [*genai.GenerateImagesResponse] to an [*ai.ModelResponse] +func translateImagenResponse(resp *genai.GenerateImagesResponse) *ai.ModelResponse { + return translateImagenCandidates(resp.GeneratedImages) +} + +// generateImage requests a generate call to the specified imagen model with the +// provided configuration +func generateImage( + ctx context.Context, + client *genai.Client, + model string, + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, +) (*ai.ModelResponse, error) { + gic, err := imagenConfigFromRequest(input) + if err != nil { + return nil, err + } + + var userPrompt string + for _, m := range input.Messages { + if m.Role == ai.RoleUser { + userPrompt += m.Text() + } + } + if userPrompt == "" { + return nil, fmt.Errorf("error generating images: empty prompt detected") + } + + if cb != nil { + return nil, fmt.Errorf("streaming mode not supported for image generation") + } + + resp, err := client.Models.GenerateImages(ctx, model, userPrompt, gic) + if err != nil { + return nil, err + } + + r := translateImagenResponse(resp) + r.Request = input + return r, nil +} diff --git a/go/plugins/googlegenai/models.go b/go/plugins/googlegenai/models.go index 4876f50c0c..f701481618 100644 --- a/go/plugins/googlegenai/models.go +++ b/go/plugins/googlegenai/models.go @@ -32,6 +32,10 @@ const ( gemini25ProExp0325 = "gemini-2.5-pro-exp-03-25" gemini25ProPreview0325 = "gemini-2.5-pro-preview-03-25" gemini25ProPreview0506 = "gemini-2.5-pro-preview-05-06" + + imagen3Generate001 = "imagen-3.0-generate-001" + imagen3Generate002 = "imagen-3.0-generate-002" + imagen3FastGenerate001 = "imagen-3.0-fast-generate-001" ) var ( @@ -50,6 +54,10 @@ var ( gemini25ProExp0325, gemini25ProPreview0325, gemini25ProPreview0506, + + imagen3Generate001, + imagen3Generate002, + imagen3FastGenerate001, } googleAIModels = []string{ @@ -66,9 +74,11 @@ var ( gemini25ProExp0325, gemini25ProPreview0325, gemini25ProPreview0506, + + imagen3Generate002, } - // models with native image support generation + // Gemini models with native image support generation imageGenModels = []string{ gemini20FlashPrevImageGen, } @@ -175,6 +185,27 @@ var ( }, } + supportedImagenModels = map[string]ai.ModelInfo{ + imagen3Generate001: { + Label: "Imagen 3 Generate 001", + Versions: []string{}, + Supports: &Media, + Stage: ai.ModelStageStable, + }, + imagen3Generate002: { + Label: "Imagen 3 Generate 002", + Versions: []string{}, + Supports: &Media, + Stage: ai.ModelStageStable, + }, + imagen3FastGenerate001: { + Label: "Imagen 3 Fast Generate 001", + Versions: []string{}, + Supports: &Media, + Stage: ai.ModelStageStable, + }, + } + googleAIEmbedders = []string{ "text-embedding-004", "embedding-001", @@ -194,7 +225,7 @@ var ( // listModels returns a map of supported models and their capabilities // based on the detected backend func listModels(provider string) (map[string]ai.ModelInfo, error) { - names := []string{} + var names []string var prefix string switch provider { @@ -210,7 +241,13 @@ func listModels(provider string) (map[string]ai.ModelInfo, error) { models := make(map[string]ai.ModelInfo, 0) for _, n := range names { - m, ok := supportedGeminiModels[n] + var m ai.ModelInfo + var ok bool + if strings.HasPrefix(n, "image") { + m, ok = supportedImagenModels[n] + } else { + m, ok = supportedGeminiModels[n] + } if !ok { return nil, fmt.Errorf("model %s not found for provider %s", n, provider) } @@ -227,7 +264,7 @@ func listModels(provider string) (map[string]ai.ModelInfo, error) { // listEmbedders returns a list of supported embedders based on the // detected backend func listEmbedders(backend genai.Backend) ([]string, error) { - embedders := []string{} + var embedders []string switch backend { case genai.BackendGeminiAPI: @@ -242,9 +279,10 @@ func listEmbedders(backend genai.Backend) ([]string, error) { } // genaiModels collects all the available models in go-genai SDK -// TODO: add imagen and veo models +// TODO: add veo models type genaiModels struct { gemini []string + imagen []string embedders []string } @@ -253,6 +291,7 @@ type genaiModels struct { func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, error) { models := genaiModels{} allowedModels := []string{"gemini", "gemma"} + allowedImagenModels := []string{"imagen"} for item, err := range client.Models.All(ctx) { var name string @@ -283,7 +322,15 @@ func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, er continue } - // TODO: add imagen and veo models + found = slices.ContainsFunc(allowedImagenModels, func(s string) bool { + return strings.Contains(name, s) + }) + // filter out: Aqa, Text-bison, Chat, learnlm + if found { + models.imagen = append(models.imagen, name) + continue + } } + return models, nil } diff --git a/go/samples/imagen/main.go b/go/samples/imagen/main.go new file mode 100644 index 0000000000..93b0de290e --- /dev/null +++ b/go/samples/imagen/main.go @@ -0,0 +1,61 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "log" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + g, err := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.VertexAI{})) + if err != nil { + log.Fatal(err) + } + + genkit.DefineFlow(g, "image-generation", func(ctx context.Context, input string) ([]string, error) { + r, err := genkit.Generate(ctx, g, + ai.WithModelName("vertexai/imagen-3.0-generate-001"), + ai.WithPrompt("Generate an image of %s", input), + ai.WithConfig(&genai.GenerateImagesConfig{ + NumberOfImages: 2, + NegativePrompt: "night", + AspectRatio: "9:16", + SafetyFilterLevel: genai.SafetyFilterLevelBlockLowAndAbove, + PersonGeneration: genai.PersonGenerationAllowAll, + Language: genai.ImagePromptLanguageEn, + AddWatermark: true, + OutputMIMEType: "image/jpeg", + }), + ) + if err != nil { + log.Fatal(err) + } + + var images []string + for _, m := range r.Message.Content { + images = append(images, m.Text) + } + return images, nil + }) + + <-ctx.Done() +}