Skip to content

Commit a34f084

Browse files
committed
[WIP] Refactor Provider Options
1 parent dce86b9 commit a34f084

18 files changed

+635
-146
lines changed

lib/active_agent/generation_provider/_base_provider.rb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,18 @@ class GenerationProviderError < StandardError; end
3636
attr_reader :client, :config, :prompt, :response, :access_token, :model_name
3737

3838
def initialize(config)
39-
@config = config
40-
@prompt = nil
41-
@response = nil
39+
@config = config
40+
@prompt = nil
41+
@response = nil
4242
@model_name = config["model"] if config
4343
end
4444

4545
def generate(prompt)
4646
raise NotImplementedError, "Subclasses must implement the 'generate' method"
4747
end
4848

49+
# Optional embedding support - override in providers that support it
4950
def embed(prompt)
50-
# Optional embedding support - override in providers that support it
5151
raise NotImplementedError, "#{self.class.name} does not support embeddings"
5252
end
5353

lib/active_agent/generation_provider/anthropic_provider.rb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ class AnthropicProvider < BaseProvider
1111
include MessageFormatting
1212
include ToolManagement
1313

14+
attr_reader :access_token
15+
1416
def initialize(config)
1517
super
1618
@access_token ||= config["api_key"] || config["access_token"] || Anthropic.configuration.access_token || ENV["ANTHROPIC_ACCESS_TOKEN"]
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# frozen_string_literal: true
2+
3+
module ActiveAgent
4+
module GenerationProvider
5+
module Common
6+
class Options
7+
include ActiveModel::Model
8+
include ActiveModel::Attributes
9+
10+
def self.delegate_attributes(*attributes, to:)
11+
attributes.each do |attribute|
12+
define_method(attribute) do
13+
public_send(to)&.public_send(attribute)
14+
end
15+
16+
define_method("#{attribute}=") do |value|
17+
if public_send(to).nil?
18+
public_send("#{to}=", {})
19+
end
20+
21+
public_send(to).public_send("#{attribute}=", value)
22+
end
23+
end
24+
end
25+
26+
def deep_compact(hash = nil, **kwargs)
27+
(hash || kwargs).each_with_object({}) do |(key, value), result|
28+
compacted_value = case value
29+
when Hash
30+
deep_compacted = deep_compact(value)
31+
deep_compacted unless deep_compacted.empty?
32+
when Array
33+
compacted_array = value.map { |v| v.is_a?(Hash) ? deep_compact(v) : v }.compact
34+
compacted_array unless compacted_array.empty?
35+
else
36+
value
37+
end
38+
39+
result[key] = compacted_value unless compacted_value.nil?
40+
end
41+
end
42+
end
43+
end
44+
end
45+
end
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# frozen_string_literal: true
2+
3+
require_relative "../open_ai/options"
4+
5+
module ActiveAgent
6+
module GenerationProvider
7+
module Ollama
8+
class Options < ActiveAgent::GenerationProvider::OpenAI::Options
9+
# Client Options
10+
attribute :uri_base, :string, default: "http://localhost:11434"
11+
12+
private
13+
14+
def resolve_access_token(settings)
15+
settings["api_key"] ||
16+
settings["access_token"] ||
17+
ENV["OLLAMA_API_KEY"] ||
18+
ENV["OLLAMA_ACCESS_TOKEN"]
19+
end
20+
21+
# Not Used as Part of Ollama
22+
def resolve_organization_id(settings) = nil
23+
def resolve_admin_token(settings) = nil
24+
end
25+
end
26+
end
27+
end

lib/active_agent/generation_provider/ollama_provider.rb

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,16 @@
33
require_gem!(:openai, __FILE__)
44

55
require_relative "open_ai_provider"
6+
require_relative "ollama/options"
67

78
module ActiveAgent
89
module GenerationProvider
910
class OllamaProvider < OpenAIProvider
10-
def initialize(config)
11-
@config = config
12-
@access_token ||= config["api_key"] || config["access_token"] || ENV["OLLAMA_API_KEY"] || ENV["OLLAMA_ACCESS_TOKEN"]
13-
@model_name = config["model"]
14-
@host = config["host"] || "http://localhost:11434"
15-
@api_version = config["api_version"] || "v1"
16-
@client = OpenAI::Client.new(uri_base: @host, access_token: @access_token, log_errors: Rails.env.development?, api_version: @api_version)
17-
end
1811

1912
protected
2013

14+
def namespace = Ollama
15+
2116
def format_error_message(error)
2217
# Check for various connection-related errors
2318
connection_errors = [
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# frozen_string_literal: true
2+
3+
require_relative "../common/options"
4+
5+
module ActiveAgent
6+
module GenerationProvider
7+
module OpenAI
8+
class Options < Common::Options
9+
# Client Options
10+
attribute :access_token, :string
11+
attribute :uri_base, :string
12+
attribute :request_timeout, :integer
13+
attribute :organization_id, :string
14+
attribute :admin_token, :string
15+
16+
# Prompting Options
17+
attribute :model, :string, default: "gpt-4o-mini"
18+
attribute :temperature, :float, default: 0.7
19+
attribute :stream, :boolean, default: false
20+
21+
validates :model, presence: true
22+
validates :temperature, numericality: { greater_than_or_equal_to: 0, less_than_or_equal_to: 2 }, allow_nil: true
23+
validates :access_token, presence: true, if: :require_access_token?
24+
25+
# Backwards Compatibility
26+
alias_attribute :host, :uri_base
27+
alias_attribute :api_key, :access_token
28+
alias_attribute :model_name, :model
29+
30+
# Initialize from a hash (settings) with fallback to environment variables and OpenAI gem configuration
31+
def initialize(**settings)
32+
settings = settings.deep_symbolize_keys if settings.respond_to?(:deep_symbolize_keys)
33+
34+
super(**deep_compact(settings.except(:default_url_options).merge(
35+
access_token: settings[:access_token] || resolve_access_token(settings),
36+
organization_id: settings[:organization_id] || resolve_organization_id(settings),
37+
admin_token: settings[:admin_token] || resolve_admin_token(settings),
38+
)))
39+
end
40+
41+
# Returns a hash suitable for OpenAI::Client initialization
42+
def client_options
43+
deep_compact(
44+
access_token:,
45+
uri_base:,
46+
organization_id:,
47+
extra_headers: client_options_extra_headers,
48+
log_errors: true
49+
)
50+
end
51+
52+
# Returns parameters for chat completion requests
53+
def chat_parameters
54+
deep_compact(
55+
model:,
56+
temperature:
57+
)
58+
end
59+
60+
# Convert to hash for compatibility with existing code
61+
def to_h
62+
deep_compact(
63+
"host" => host,
64+
"api_key" => api_key,
65+
"access_token" => access_token,
66+
"organization_id" => organization_id,
67+
"admin_token" => admin_token,
68+
"model" => model,
69+
"temperature" => temperature,
70+
"stream" => stream
71+
)
72+
end
73+
74+
alias_method :to_hash, :to_h
75+
76+
protected
77+
78+
def client_options_extra_headers = nil
79+
80+
private
81+
82+
def resolve_access_token(settings)
83+
settings[:api_key] ||
84+
openai_settings_access_token ||
85+
ENV["OPENAI_ACCESS_TOKEN"]
86+
end
87+
88+
def resolve_organization_id(settings)
89+
openai_settings_organization_id ||
90+
ENV["OPENAI_ORGANIZATION_ID"]
91+
end
92+
93+
def resolve_admin_token(settings)
94+
openai_settings_admin_token ||
95+
ENV["OPENAI_ADMIN_TOKEN"]
96+
end
97+
98+
def openai_settings_access_token
99+
return nil unless defined?(::OpenAI)
100+
::OpenAI.configuration.access_token
101+
rescue
102+
nil
103+
end
104+
105+
def openai_settings_organization_id
106+
return nil unless defined?(::OpenAI)
107+
::OpenAI.configuration.organization_id
108+
rescue
109+
nil
110+
end
111+
112+
def openai_settings_admin_token
113+
return nil unless defined?(::OpenAI)
114+
::OpenAI.configuration.admin_token
115+
rescue
116+
nil
117+
end
118+
119+
# Only require access token if no other authentication method is available
120+
def require_access_token?
121+
resolved_access_token.blank?
122+
end
123+
end
124+
end
125+
end
126+
end

lib/active_agent/generation_provider/open_ai_provider.rb

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
require_relative "_base_provider"
2+
require_relative "open_ai/options"
23

34
require_gem!(:openai, __FILE__)
45

@@ -9,21 +10,20 @@ class OpenAIProvider < BaseProvider
910
include MessageFormatting
1011
include ToolManagement
1112

12-
def initialize(config)
13+
attr_reader :options
14+
15+
def initialize(options)
1316
super
14-
@host = config["host"] || nil
15-
@access_token ||= config["api_key"] || config["access_token"] || OpenAI.configuration.access_token || ENV["OPENAI_ACCESS_TOKEN"]
16-
@organization_id = config["organization_id"] || OpenAI.configuration.organization_id || ENV["OPENAI_ORGANIZATION_ID"]
17-
@admin_token = config["admin_token"] || OpenAI.configuration.admin_token || ENV["OPENAI_ADMIN_TOKEN"]
18-
@client = OpenAI::Client.new(
19-
access_token: @access_token,
20-
uri_base: @host,
21-
organization_id: @organization_id,
22-
admin_token: @admin_token,
23-
log_errors: Rails.env.development?
24-
)
17+
@options = namespace::Options.new(**options.except("service"))
18+
end
2519

26-
@model_name = config["model"] || "gpt-4o-mini"
20+
# @return [OpenAI::Client]
21+
def client(prompt_options = nil)
22+
if prompt_options
23+
::OpenAI::Client.new(namespace::Options.new(prompt_options).client_options)
24+
else
25+
@client ||= ::OpenAI::Client.new(@options.client_options)
26+
end
2727
end
2828

2929
def generate(prompt)
@@ -48,6 +48,8 @@ def embed(prompt)
4848

4949
protected
5050

51+
def namespace = OpenAI
52+
5153
# Override from StreamProcessing module
5254
def process_stream_chunk(chunk, message, agent_stream)
5355
new_content = chunk.dig("choices", 0, "delta", "content")
@@ -167,17 +169,19 @@ def handle_message(message_json)
167169

168170
# handle_actions is now provided by ToolManagement module
169171

172+
# @todo prompt_parameters client options overriding
170173
def chat_prompt(parameters: prompt_parameters)
171-
if prompt.options[:stream] || config["stream"]
174+
if prompt.options[:stream] || client.stream
172175
parameters[:stream] = provider_stream
173176
@streaming_request_params = parameters
174177
end
175-
chat_response(@client.chat(parameters: parameters), parameters)
178+
chat_response(client.chat(parameters: parameters), parameters)
176179
end
177180

181+
# @todo prompt_parameters client options overriding
178182
def responses_prompt(parameters: responses_parameters)
179183
# parameters[:stream] = provider_stream if prompt.options[:stream] || config["stream"]
180-
responses_response(@client.responses.create(parameters: parameters), parameters)
184+
responses_response(client.responses.create(parameters: parameters), parameters)
181185
end
182186

183187
def responses_parameters(model: @prompt.options[:model] || @model_name, messages: @prompt.messages, temperature: @prompt.options[:temperature] || @config["temperature"] || 0.7, tools: @prompt.actions, structured_output: @prompt.output_schema)
@@ -258,9 +262,10 @@ def embeddings_response(response, request_params = nil)
258262
)
259263
end
260264

265+
# @todo prompt_parameters client options overriding
261266
def embeddings_prompt(parameters:)
262267
params = embeddings_parameters
263-
embeddings_response(@client.embeddings(parameters: params), params)
268+
embeddings_response(client.embeddings(parameters: params), params)
264269
end
265270
end
266271
end

0 commit comments

Comments
 (0)