Skip to content

Commit

Permalink
Adapters, reasoning.
Browse files Browse the repository at this point in the history
  • Loading branch information
ljuti committed Aug 2, 2023
1 parent 1c7d306 commit 4f2f7f0
Show file tree
Hide file tree
Showing 44 changed files with 1,296 additions and 49 deletions.
5 changes: 4 additions & 1 deletion lib/roseflow.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
require "roseflow/action"
require "roseflow/ai/model"
require "roseflow/ai/provider"
require "roseflow/embeddings/base"
require "roseflow/chat/dialogue"
require "roseflow/finite_machine"
require "roseflow/interaction"
require "roseflow/interaction/with_cli"
require "roseflow/interaction/with_documentation"
require "roseflow/interaction_context"
require "roseflow/interactions/ai/initialize_llm"
require "roseflow/vector_stores/base"
require "roseflow/registry"

module Roseflow
class Error < StandardError; end
Expand Down
21 changes: 21 additions & 0 deletions lib/roseflow/actions/ai/chat.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# frozen_string_literal: true

module Roseflow
module Actions
module AI
class Chat
extend Roseflow::Action

expects :messages
expects :model
expects :options, default: {}

promises :response

executed do |context|
context[:response] = context.model.chat(context.options.merge(messages: context.messages))
end
end
end
end
end
4 changes: 2 additions & 2 deletions lib/roseflow/actions/ai/resolve_model.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ class ResolveModel
promises :llm

executed do |context|
model = context.provider.models.find(context[:model])
model = Registry.get(:models).find(context[:model])

unless model
context.fail_and_return!("Model #{context[:model]} not found")
end

context[:llm] = Roseflow::AI::Model.new(name: model.name, provider: context.provider)
context[:llm] = model
end
end
end
Expand Down
6 changes: 1 addition & 5 deletions lib/roseflow/actions/ai/resolve_provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

require "roseflow/action"
require "roseflow/ai/provider"
require "roseflow/openai/config"

module Roseflow
module Actions
Expand All @@ -20,10 +19,7 @@ class ResolveProvider
private_class_method

def self.resolve_provider(provider)
case provider
when :openai
Roseflow::AI::Provider.new(name: :openai, credentials: Roseflow::OpenAI::Config.new)
end
Registry.get(:providers).find(provider)
end
end
end
Expand Down
55 changes: 49 additions & 6 deletions lib/roseflow/ai/model.rb
Original file line number Diff line number Diff line change
@@ -1,18 +1,61 @@
# frozen_string_literal: true

require "roseflow/model_repository"
require "roseflow/provider_repository"
require "roseflow/ai/models/instance_factory"
require "roseflow/ai/models/openai_adapter"
require "roseflow/ai/models/openrouter_adapter"

module Roseflow
module AI
class ModelInstanceNotFoundError < StandardError; end

class Model
attr_reader :name, :provider
attr_reader :name

def initialize(name:, provider:)
delegate :chat, to: :instance

def initialize(name:, provider: nil)
raise ArgumentError, "Name must be provided" if name.nil?
provider = resolve_provider(name, provider)
instance = create_adapted_instance(name, provider)
@name = name
@provider = provider
@model_ = provider.models.find(name)
@instance = instance
@_provider = provider
end

def provider
_provider.name
end

def call(operation, options, &block)
instance.call(operation, options, &block)
end

def operations
instance.operations
end

def self.load(name)
Registry.get(:models).find(name)
end

private

attr_reader :instance, :_provider

def resolve_provider(name, provider)
if provider.nil?
provider_name = Registry.get(:models).find(name).provider
Registry.get(:providers).find(provider_name)
else
return provider if provider.instance_of?(Provider)
Registry.get(:providers).find(provider)
end
end

def call(operation, input)
@model_.call(operation, input)
def create_adapted_instance(name, provider)
Models::InstanceFactory.create(name, provider)
end
end # Model
end # AI
Expand Down
23 changes: 23 additions & 0 deletions lib/roseflow/ai/model_interface.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# frozen_string_literal: true

module Roseflow
module AI
module ModelInterface
def call(**options)
raise NotImplementedError, "Model must implement #call"
end

def chat(**options)
raise NotImplementedError, "Model must implement #chat"
end

def completion(**options)
raise NotImplementedError, "Model must implement #completion"
end

def operations
raise NotImplementedError, "Model must implement #operations"
end
end
end
end
13 changes: 13 additions & 0 deletions lib/roseflow/ai/models/base_adapter.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# frozen_string_literal: true

module Roseflow
module AI
module Models
class BaseAdapter
def initialize(model)
@model = model
end
end
end
end
end
22 changes: 22 additions & 0 deletions lib/roseflow/ai/models/instance_factory.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# frozen_string_literal: true

module Roseflow
module AI
module Models
class InstanceFactory
def self.create(name, provider)
provider_model = provider.models.find(name)
raise ModelInstanceNotFoundError, "Instance for #{name} not found" if provider_model.nil?

begin
provider_config = PROVIDER_GEMS.find { |gem, settings| settings[:name] == provider.name.to_sym }.last
adapter_class = Models.const_get(provider_config.fetch(:adapter_class))
adapter_class.new(provider_model)
rescue => exception
raise NotImplementedError, "Model adapter for provider '#{provider.name}' not implemented"
end
end
end
end
end
end
22 changes: 22 additions & 0 deletions lib/roseflow/ai/models/openai_adapter.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# frozen_string_literal: true

require "roseflow/ai/model_interface"
require "roseflow/ai/models/base_adapter"

module Roseflow
module AI
module Models
class OpenAIAdapter < BaseAdapter
include ModelInterface

def call(operation, options, &block)
@model.call(operation, options, &block)
end

def chat(options, &block)
@model.chat(options.delete(:messages), options, &block)
end
end
end
end
end
22 changes: 22 additions & 0 deletions lib/roseflow/ai/models/openrouter_adapter.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# frozen_string_literal: true

require "roseflow/ai/model_interface"
require "roseflow/ai/models/base_adapter"

module Roseflow
module AI
module Models
class OpenRouterAdapter < OpenAIAdapter
include ModelInterface

def call(operation, options, &block)
@model.call(operation, options, &block)
end

def chat(options, &block)
@model.chat(options.delete(:messages), options, &block)
end
end
end
end
end
24 changes: 13 additions & 11 deletions lib/roseflow/ai/provider.rb
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
# frozen_string_literal: true

require "roseflow/openai/provider"
require "roseflow/provider_repository"
require "roseflow/ai/providers/instance_factory"

module Roseflow
module AI
class Provider
def initialize(name:, credentials:)
attr_reader :name

def initialize(name:, config:)
instance = create_adapted_instance(config)
@name = name
@credentials = credentials
initialize_provider
@instance = instance
end

def models
@models ||= provider.models
@models ||= instance.models
end

private

attr_reader :name, :credentials, :provider
attr_reader :instance

def initialize_provider
case name
when :openai
@provider = Roseflow::OpenAI::Provider.new(credentials: credentials)
end
def create_adapted_instance(config)
Providers::InstanceFactory.create(config)
end
end # Provider

ProviderNotFoundError = ProviderRepository::ProviderNotFoundError
end # AI
end # Roseflow
23 changes: 23 additions & 0 deletions lib/roseflow/ai/provider_interface.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# frozen_string_literal: true

module Roseflow
module AI
module ProviderInterface
def call
raise NotImplementedError, "Provider must implement #call"
end

def chat
raise NotImplementedError, "Provider must implement #chat"
end

def completion
raise NotImplementedError, "Provider must implement #completion"
end

def models
raise NotImplementedError, "Provider must implement #models"
end
end
end
end
17 changes: 17 additions & 0 deletions lib/roseflow/ai/providers/base_adapter.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# frozen_string_literal: true

require "roseflow/ai/provider_interface"

module Roseflow
module AI
module Providers
class BaseAdapter
include ProviderInterface

def initialize(provider)
@provider = provider
end
end
end
end
end
22 changes: 22 additions & 0 deletions lib/roseflow/ai/providers/instance_factory.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# frozen_string_literal: true

require_relative "./openai_adapter"
require_relative "./openrouter_adapter"

module Roseflow
module AI
module Providers
class InstanceFactory
def self.create(config)
begin
adapter = Providers.const_get(config.fetch(:adapter_class))
klass = Object.const_get("#{config.fetch(:namespace)}::Provider")
adapter.new(klass.new)
rescue => exception
raise NotImplementedError, "Adapter for provider #{config.fetch(:name)} not implemented"
end
end
end
end
end
end
15 changes: 15 additions & 0 deletions lib/roseflow/ai/providers/openai_adapter.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# frozen_string_literal: true

require "roseflow/ai/providers/base_adapter"

module Roseflow
module AI
module Providers
class OpenAIAdapter < BaseAdapter
def models
@provider.models
end
end
end
end
end
Loading

0 comments on commit 4f2f7f0

Please sign in to comment.