Skip to content

feat(go/plugins/googlegenai): add image-generation models #2903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
23 changes: 17 additions & 6 deletions go/plugins/googlegenai/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ type GeminiConfig struct {
ResponseModalities []Modality `json:"responseModalities,omitempty"`
}

// configFromRequest converts any supported config type to [GeminiConfig].
func configFromRequest(input *ai.ModelRequest) (*GeminiConfig, error) {
// geminiConfigFromRequest converts any supported config type to [GeminiConfig].
func geminiConfigFromRequest(input *ai.ModelRequest) (*GeminiConfig, error) {
var result GeminiConfig

switch config := input.Config.(type) {
Expand Down Expand Up @@ -236,19 +236,30 @@ func defineModel(g *genkit.Genkit, client *genai.Client, name string, info ai.Mo
provider = vertexAIProvider
}

var config any
config = &GeminiConfig{}
if mi, found := supportedImagenModels[name]; found {
config = &ImagenConfig{}
info = mi
}
meta := &ai.ModelInfo{
Label: info.Label,
Supports: info.Supports,
Versions: info.Versions,
ConfigSchema: configToMap(&GeminiConfig{}),
ConfigSchema: configToMap(config),
}

fn := func(
ctx context.Context,
input *ai.ModelRequest,
cb func(context.Context, *ai.ModelResponseChunk) error,
) (*ai.ModelResponse, error) {
return generate(ctx, client, name, input, cb)
switch config.(type) {
case *ImagenConfig:
return generateImage(ctx, client, name, input, cb)
default:
return generate(ctx, client, name, input, cb)
}
}
// the gemini api doesn't support downloading media from http(s)
if info.Supports.Media {
Expand Down Expand Up @@ -325,7 +336,7 @@ func generate(
cb func(context.Context, *ai.ModelResponseChunk) error,
) (*ai.ModelResponse, error) {
// Extract configuration to get the model version
config, err := configFromRequest(input)
config, err := geminiConfigFromRequest(input)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -455,7 +466,7 @@ func toGeminiRequest(input *ai.ModelRequest, cache *genai.CachedContent) (*genai
CandidateCount: 1,
}

c, err := configFromRequest(input)
c, err := geminiConfigFromRequest(input)
if err != nil {
return nil, err
}
Expand Down
218 changes: 218 additions & 0 deletions go/plugins/googlegenai/imagen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

package googlegenai

import (
"context"
"encoding/base64"
"fmt"

"github.com/firebase/genkit/go/ai"
"google.golang.org/genai"
)

var (
// BasicMedia describes model capabitities for image-only output Gemini models.
BasicMedia = ai.ModelSupports{
Media: false,
Multiturn: false,
Tools: false,
ToolChoice: false,
SystemRole: false,
}

// Media describes model capabilities for Gemini models with media and text
// input and image only output
Media = ai.ModelSupports{
Media: true,
Multiturn: false,
Tools: false,
ToolChoice: false,
SystemRole: false,
}
)

type PersonGeneration string

const (
// Disallow the inclusion of people or faces in images
DontAllowPerson PersonGeneration = "dont_allow"
// Allow generation of adults only
AllowAdultPerson PersonGeneration = "allow_adult"
// Allow generation of people of all ages
AllowAllPerson PersonGeneration = "allow_all"
)

// Enum that specifies the language of the text in the prompt.
type ImagePromptLanguage string

const (
ImagePromptLanguageAuto ImagePromptLanguage = "auto"
ImagePromptLanguageEn ImagePromptLanguage = "en"
ImagePromptLanguageJa ImagePromptLanguage = "ja"
ImagePromptLanguageKo ImagePromptLanguage = "ko"
ImagePromptLanguageHi ImagePromptLanguage = "hi"
)

// Imagen generation configuration
// VertexAI API default values: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api
// GeminiAPI: https://ai.google.dev/gemini-api/docs/imagen#imagen-model
type ImagenConfig struct {
// Number of images to generate. Defaults to 4
NumberOfImages int32 `json:"numberOfImages,omitempty"`
// Random seed generation
Seed *int32 `json:"seed,omitempty"`
// A description of what to discourage in the generated images
NegativePrompt string `json:"negativePrompt,omitempty"`
// Aspect ratio for the image. Defaults to 1:1
AspectRatio string `json:"aspectRatio,omitempty"`
// Allow generation of people by the model. Defaults to [AllowAdultPerson]
PersonGeneration PersonGeneration `json:"personGeneration,omitempty"`
// Language of the text in the prompt
Language string `json:"language,omitempty"`
// Filter level to safety filtering
SafetySetting HarmBlockThreshold `json:"safetySetting,omitempty"`
// Sets an invisible watermark to the generated images
// Only available in VertexAI API
AddWatermark bool `json:"addWatermark,omitempty"`
// Cloud Storage URI used to store the generated images.
OutputGCSURI string `json:"outputGcsUri,omitempty"`
// MIME type of the generated image.
OutputMIMEType string `json:"outputMimeType,omitempty"`
}

// imagenConfigFromRequest translates an [*ai.ModelRequest] configuration to [*ImagenConfig]
func imagenConfigFromRequest(input *ai.ModelRequest) (*ImagenConfig, error) {
var result ImagenConfig

switch config := input.Config.(type) {
case ImagenConfig:
result = config
case *ImagenConfig:
result = *config
case map[string]any:
if err := mapToStruct(config, &result); err != nil {
return nil, err
}
case nil:
// empty but valid config
default:
return nil, fmt.Errorf("unexpected config type: %T", input.Config)
}

return &result, nil
}

// toImageRequest translates an [*ai.ModelRequest] to a [*genai.GenerateImagesConfig]
func toImageRequest(input *ai.ModelRequest) (*genai.GenerateImagesConfig, error) {
config, err := imagenConfigFromRequest(input)
if err != nil {
return nil, err
}

gic := genai.GenerateImagesConfig{
AddWatermark: config.AddWatermark,
}
if config.NumberOfImages > 0 {
gic.NumberOfImages = config.NumberOfImages
}
if config.Seed != nil {
gic.Seed = config.Seed
}
if config.NegativePrompt != "" {
gic.NegativePrompt = config.NegativePrompt
}
if config.AspectRatio != "" {
gic.AspectRatio = config.AspectRatio
}
if config.PersonGeneration != "" {
gic.PersonGeneration = genai.PersonGeneration(config.PersonGeneration)
}
if config.Language != "" {
gic.Language = genai.ImagePromptLanguage(config.Language)
}
if config.SafetySetting != "" {
gic.SafetyFilterLevel = genai.SafetyFilterLevel(config.SafetySetting)
}
if config.OutputGCSURI != "" {
gic.OutputGCSURI = config.OutputGCSURI
}
if config.OutputMIMEType != "" {
gic.OutputMIMEType = config.OutputMIMEType
}

return &gic, nil
}

// translateImagenCandidates translates the image generation response to [*ai.ModelResponse]
func translateImagenCandidates(images []*genai.GeneratedImage) *ai.ModelResponse {
m := &ai.ModelResponse{}
m.FinishReason = ai.FinishReasonStop

msg := &ai.Message{}
msg.Role = ai.RoleModel

for _, img := range images {
msg.Content = append(msg.Content, ai.NewMediaPart(img.Image.MIMEType, "data:"+img.Image.MIMEType+";base64,"+base64.StdEncoding.EncodeToString(img.Image.ImageBytes)))
}

m.Message = msg
return m
}

// translateImagenResponse translates [*genai.GenerateImagesResponse] to an [*ai.ModelResponse]
func translateImagenResponse(resp *genai.GenerateImagesResponse) *ai.ModelResponse {
return translateImagenCandidates(resp.GeneratedImages)
}

// generateImage requests a generate call to the specified imagen model with the
// provided configuration
func generateImage(
ctx context.Context,
client *genai.Client,
model string,
input *ai.ModelRequest,
cb func(context.Context, *ai.ModelResponseChunk) error,
) (*ai.ModelResponse, error) {
gic, err := toImageRequest(input)
if err != nil {
return nil, err
}

var userPrompt string
for _, m := range input.Messages {
if m.Role == ai.RoleUser {
userPrompt += m.Text()
}
}
if userPrompt == "" {
return nil, fmt.Errorf("error generating images: empty prompt detected")
}

if cb != nil {
return nil, fmt.Errorf("streaming mode not supported for image generation")
}

resp, err := client.Models.GenerateImages(ctx, model, userPrompt, gic)
if err != nil {
return nil, err
}

r := translateImagenResponse(resp)
r.Request = input
return r, nil
}
54 changes: 50 additions & 4 deletions go/plugins/googlegenai/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package googlegenai

import (
"fmt"
"strings"

"github.com/firebase/genkit/go/ai"
"google.golang.org/genai"
Expand All @@ -27,6 +28,11 @@ const (
gemini25ProExp0325 = "gemini-2.5-pro-exp-03-25"
gemini25ProPreview0325 = "gemini-2.5-pro-preview-03-25"
gemini25ProPreview0506 = "gemini-2.5-pro-preview-05-06"

imagen2 = "imagegeneration@006"
imagen3Generate001 = "imagen-3.0-generate-001"
imagen3Generate002 = "imagen-3.0-generate-002"
imagen3FastGenerate001 = "imagen-3.0-fast-generate-001"
)

var (
Expand All @@ -44,6 +50,11 @@ var (
gemini25ProExp0325,
gemini25ProPreview0325,
gemini25ProPreview0506,

imagen2,
imagen3Generate001,
imagen3Generate002,
imagen3FastGenerate001,
}

googleAIModels = []string{
Expand All @@ -59,9 +70,11 @@ var (
gemini25ProExp0325,
gemini25ProPreview0325,
gemini25ProPreview0506,

imagen3Generate002,
}

// models with native image support generation
// Gemini models with native image support generation
imageGenModels = []string{
gemini20FlashExp,
}
Expand Down Expand Up @@ -162,6 +175,33 @@ var (
},
}

supportedImagenModels = map[string]ai.ModelInfo{
imagen2: {
Label: "Imagen 2",
Versions: []string{},
Supports: &Media,
Stage: ai.ModelStageStable,
},
imagen3Generate001: {
Label: "Imagen 3 Generate 001",
Versions: []string{},
Supports: &Media,
Stage: ai.ModelStageStable,
},
imagen3Generate002: {
Label: "Imagen 3 Generate 002",
Versions: []string{},
Supports: &Media,
Stage: ai.ModelStageStable,
},
imagen3FastGenerate001: {
Label: "Imagen 3 Fast Generate 001",
Versions: []string{},
Supports: &Media,
Stage: ai.ModelStageStable,
},
}

googleAIEmbedders = []string{
"text-embedding-004",
"embedding-001",
Expand All @@ -181,7 +221,7 @@ var (
// listModels returns a map of supported models and their capabilities
// based on the detected backend
func listModels(provider string) (map[string]ai.ModelInfo, error) {
names := []string{}
var names []string
var prefix string

switch provider {
Expand All @@ -197,7 +237,13 @@ func listModels(provider string) (map[string]ai.ModelInfo, error) {

models := make(map[string]ai.ModelInfo, 0)
for _, n := range names {
m, ok := supportedGeminiModels[n]
var m ai.ModelInfo
var ok bool
if strings.HasPrefix(n, "image") {
m, ok = supportedImagenModels[n]
} else {
m, ok = supportedGeminiModels[n]
}
if !ok {
return nil, fmt.Errorf("model %s not found for provider %s", n, provider)
}
Expand All @@ -213,7 +259,7 @@ func listModels(provider string) (map[string]ai.ModelInfo, error) {
// listEmbedders returns a list of supported embedders based on the
// detected backend
func listEmbedders(backend genai.Backend) ([]string, error) {
embedders := []string{}
var embedders []string

switch backend {
case genai.BackendGeminiAPI:
Expand Down
Loading
Loading