Mix.install([
{:kino, "~> 0.9.1"},
{:poison, "~> 4.0"},
{:httpoison, "~> 1.8"}
])
defmodule NestedUpdater do
@doc ~S"""
updates nested dict
## Examples
iex> NestedUpdater.update(%{:a => %{:b => 1, :c => 2}}, %{:a => %{:b => 2}})
%{:a => %{:b => 2, :c => 2}}
"""
def update(old_map, new_map) do
Enum.reduce(new_map, old_map, fn {new_key, new_val}, acc ->
if Map.has_key?(acc, new_key) and is_map(new_val) do
Map.update!(acc, new_key, &update(&1, new_val))
else
Map.put(acc, new_key, new_val)
end
end)
end
end
defprotocol Utility do
def type(value)
end
defimpl Utility, for: BitString do
def type(_value), do: "string"
end
defimpl Utility, for: Integer do
def type(_value), do: "integer"
end
defprotocol LMAPI do
@type payload :: %{}
@type completion_type :: :completion | :rellm
@type api_response :: {:ok, String.t()} | {:error, String.t()}
@spec get_completion_response(
client :: Any.t(),
payload :: payload(),
completion_type :: completion_type()
) :: api_response()
def get_completion_response(client, payload, completion_type)
end
defmodule LMServerClient do
defstruct [:model, :url]
def default_payload do
%{
"prompt" => "User: give at least three examples of fruits\n Assistant: ",
"min_length" => 3,
"max_new_tokens" => 20,
"n" => 1,
"max_length" => 512,
"stop" => "string",
"stream" => false,
"sampling_parameters" => %{
"temperature" => 1,
"top_k" => 100,
"top_p" => 0.9,
"logit_bias" => %{},
"presence_penalty" => 0,
"frequency_penalty" => 0,
"repetition_penalty" => 1,
"typical_p" => 1
},
"stop_after_match" => true,
"do_sample" => true,
"return_full_text" => false
}
end
def rellm_payload do
NestedUpdater.update(default_payload(), %{"pattern" => "(.* ){5}"})
end
def get_default_payload(:rellm), do: rellm_payload()
def get_default_payload(_), do: default_payload()
end
defimpl LMAPI, for: LMServerClient do
require HTTPoison
@default_args %{"prompt" => "A cat walks into a bar"}
@type api_response :: {:ok, String.t()} | {:error, String.t()}
@spec get_completion_response(
client :: LMServerClient.t(),
payload :: %{},
completion_type :: :rellm | :completion
) :: api_response()
def get_completion_response(
%LMServerClient{model: model, url: url},
payload \\ @default_args,
completion_type
) do
updated_payload = payload |> get_updated_payload(completion_type)
{model, url, updated_payload} |> get_response(completion_type) |> get_response_text()
end
defp get_updated_payload(payload, completion_type) do
NestedUpdater.update(LMServerClient.get_default_payload(completion_type), payload)
end
defp get_response({model, url, payload}, use_rellm) do
request_body = payload |> Poison.encode!()
options = [recv_timeout: 10000, timeout: 10000]
endpoint = get_endpoint(use_rellm)
HTTPoison.post(
"#{url}/#{endpoint}",
request_body,
[{"Content-type", "application/json"}],
options
)
end
defp get_endpoint(:rellm), do: "rellm_generate"
defp get_endpoint(_), do: "generate"
defp get_response_text({:ok, response}) do
body = Poison.decode!(response.body)
[result | _] = body["texts"]
{:ok, result}
end
defp get_response_text(other) do
other
end
end
defmodule OpenAIClient do
defstruct [:openai_key, :model]
def new(model, openai_key, openai_key_path \\ "~/home/.keys") do
key =
case openai_key do
nil -> load_openai_key(openai_key_path)
_ -> openai_key
end
%OpenAIClient{:model => model, :openai_key => key}
end
defp load_openai_key(path \\ "openai_medsi_key.txt") do
{:ok, key} = File.read(Path.expand(path))
String.trim(key)
end
def default_completion_payload(model) do
%{
"model" => model,
"prompt" => "A cat walks into a bar",
"temperature" => 0.9,
"max_tokens" => 150,
"top_p" => 1,
"frequency_penalty" => 0.0,
"presence_penalty" => 0.6,
"stop" => [" Human:", " AI:"]
}
end
def default_chat_payload(model) do
%{
"model" => model,
"messages" => [%{"role" => "user", "content" => "A cat walks into a bar"}],
"temperature" => 0.9,
"max_tokens" => 15
}
end
def get_default_payload(model) do
cond do
is_completion_model_name(model) -> default_completion_payload(model)
true -> default_chat_payload(model)
end
end
def is_completion_model_name(model) do
!String.contains?(model, "gpt-3.5") && !String.contains?(model, "gpt-4")
end
end
defimpl LMAPI, for: OpenAIClient do
require HTTPoison
def get_completion_response(
client,
payload,
completion_type
) do
url = get_url(client.model)
default_payload = OpenAIClient.get_default_payload(client.model)
updated_payload = {client.model, default_payload} |> get_payload(payload)
{client, updated_payload}
|> get_api_response(url)
|> get_response_body()
|> get_text_result()
end
defp get_chat_payload(default_chat_payload, payload) do
{prompt, payload_without_prompt} = payload |> Map.pop("prompt")
prompt_payload =
case prompt do
nil ->
%{}
_ ->
%{
"messages" => [%{"role" => "user", "content" => prompt}]
}
end
default_chat_payload
|> NestedUpdater.update(payload_without_prompt)
|> NestedUpdater.update(prompt_payload)
end
defp get_payload({model, default_payload}, payload) do
cond do
OpenAIClient.is_completion_model_name(model) ->
default_payload |> NestedUpdater.update(payload)
true ->
default_payload |> get_chat_payload(payload)
end
end
defp get_url(model) do
if OpenAIClient.is_completion_model_name(model) do
"https://api.openai.com/v1/completions"
else
"https://api.openai.com/v1/chat/completions"
end
end
defp get_api_response({client, payload}, url) do
headers = [
{"Content-Type", "application/json"},
{"Authorization", "Bearer #{client.openai_key}"}
]
request_body = payload |> Poison.encode!()
options = [recv_timeout: 10000, timeout: 10000]
HTTPoison.post(url, request_body, headers, options)
end
defp get_response_body({flag, %HTTPoison.Response{body: encoded_body}}),
do: {flag, encoded_body |> Poison.decode!()}
defp get_response_body({flag, encoded_body}), do: {flag, encoded_body}
defp get_text_result({:ok, results}) do
text =
case results do
%{"choices" => [%{"message" => %{"content" => text}} | _]} -> text
%{"choices" => [%{"text" => text} | _]} -> text
end
{:ok, text}
end
defp get_text_result({:error, other}) do
{:error, other}
end
end
client = %LMServerClient{model: "falcon40b-instruct", url: "http://localhost:8765"}
LMAPI.get_completion_response(client, %{"prompt" => "A cat walks into a bar"}, :completion)
client = OpenAIClient.new("gpt-3.5-turbo", nil, "/home/kuba/.keys/openai_key.txt")
:ok
LMAPI.get_completion_response(client, %{"prompt" => "A cat walks into a bar"}, :completion)
defmodule PromptHistoryInterface do
def new(client) do
{inputs, chat_form, frame} = make_widgets()
chat_form |> Kino.render()
frame |> Kino.render()
client |> wire_widgets({chat_form, frame})
end
def make_widgets() do
inputs = [
prompt: Kino.Input.text("Prompt"),
pattern: Kino.Input.text("rellm pattern")
]
frame = Kino.Frame.new()
chat_form = Kino.Control.form(inputs, submit: "Send", reset_on_submit: [:prompt, :pattern])
{inputs, chat_form, frame}
end
def wire_widgets(client, {chat_form, frame}) do
Kino.listen(chat_form, fn %{data: %{prompt: prompt, pattern: pattern}, origin: origin} ->
if prompt != "" do
add_markdown_to_frame(frame, "**user**: #{prompt}")
cond do
pattern != "" -> add_markdown_to_frame(frame, "***pattern***: #{pattern}")
true -> nil
end
add_markdown_to_frame(frame, "generating...")
{payload, call_type} = make_client_args(prompt, pattern)
{:ok, generated_text} = LMClient.get_completion_response(client, payload, call_type)
add_markdown_to_frame(frame, "**bot**: #{generated_text}")
else
content = Kino.Markdown.new("_ERROR! You need a name and message to submit..._")
Kino.Frame.append(frame, content, to: origin)
end
end)
end
def add_markdown_to_frame(frame, markdown_str) do
Kino.Frame.append(frame, Kino.Markdown.new(markdown_str))
end
defp make_client_args(prompt, pattern) do
payload = %{"prompt" => prompt}
cond do
pattern == "" -> {payload, :completion}
true -> {payload, :rellm}
end
end
end
PromptHistoryInterface.new(client)