Skip to content

Latest commit

 

History

History
367 lines (306 loc) · 9.04 KB

lms.livemd

File metadata and controls

367 lines (306 loc) · 9.04 KB

LMS

Mix.install([
  {:kino, "~> 0.9.1"},
  {:poison, "~> 4.0"},
  {:httpoison, "~> 1.8"}
])

Section

defmodule NestedUpdater do
  @doc ~S"""
  updates nested dict

  ## Examples

      iex> NestedUpdater.update(%{:a => %{:b => 1, :c => 2}}, %{:a => %{:b => 2}})
      %{:a => %{:b => 2, :c => 2}}
  """
  def update(old_map, new_map) do
    Enum.reduce(new_map, old_map, fn {new_key, new_val}, acc ->
      if Map.has_key?(acc, new_key) and is_map(new_val) do
        Map.update!(acc, new_key, &update(&1, new_val))
      else
        Map.put(acc, new_key, new_val)
      end
    end)
  end
end
defprotocol Utility do
  def type(value)
end

defimpl Utility, for: BitString do
  def type(_value), do: "string"
end

defimpl Utility, for: Integer do
  def type(_value), do: "integer"
end
defprotocol LMAPI do
  @type payload :: %{}
  @type completion_type :: :completion | :rellm
  @type api_response :: {:ok, String.t()} | {:error, String.t()}
  @spec get_completion_response(
          client :: Any.t(),
          payload :: payload(),
          completion_type :: completion_type()
        ) :: api_response()
  def get_completion_response(client, payload, completion_type)
end
defmodule LMServerClient do
  defstruct [:model, :url]

  def default_payload do
    %{
      "prompt" => "User: give at least three examples of fruits\n Assistant: ",
      "min_length" => 3,
      "max_new_tokens" => 20,
      "n" => 1,
      "max_length" => 512,
      "stop" => "string",
      "stream" => false,
      "sampling_parameters" => %{
        "temperature" => 1,
        "top_k" => 100,
        "top_p" => 0.9,
        "logit_bias" => %{},
        "presence_penalty" => 0,
        "frequency_penalty" => 0,
        "repetition_penalty" => 1,
        "typical_p" => 1
      },
      "stop_after_match" => true,
      "do_sample" => true,
      "return_full_text" => false
    }
  end

  def rellm_payload do
    NestedUpdater.update(default_payload(), %{"pattern" => "(.* ){5}"})
  end

  def get_default_payload(:rellm), do: rellm_payload()
  def get_default_payload(_), do: default_payload()
end

defimpl LMAPI, for: LMServerClient do
  require HTTPoison

  @default_args %{"prompt" => "A cat walks into a bar"}

  @type api_response :: {:ok, String.t()} | {:error, String.t()}
  @spec get_completion_response(
          client :: LMServerClient.t(),
          payload :: %{},
          completion_type :: :rellm | :completion
        ) :: api_response()
  def get_completion_response(
        %LMServerClient{model: model, url: url},
        payload \\ @default_args,
        completion_type
      ) do
    updated_payload = payload |> get_updated_payload(completion_type)
    {model, url, updated_payload} |> get_response(completion_type) |> get_response_text()
  end

  defp get_updated_payload(payload, completion_type) do
    NestedUpdater.update(LMServerClient.get_default_payload(completion_type), payload)
  end

  defp get_response({model, url, payload}, use_rellm) do
    request_body = payload |> Poison.encode!()
    options = [recv_timeout: 10000, timeout: 10000]
    endpoint = get_endpoint(use_rellm)

    HTTPoison.post(
      "#{url}/#{endpoint}",
      request_body,
      [{"Content-type", "application/json"}],
      options
    )
  end

  defp get_endpoint(:rellm), do: "rellm_generate"
  defp get_endpoint(_), do: "generate"

  defp get_response_text({:ok, response}) do
    body = Poison.decode!(response.body)
    [result | _] = body["texts"]
    {:ok, result}
  end

  defp get_response_text(other) do
    other
  end
end
defmodule OpenAIClient do
  defstruct [:openai_key, :model]

  def new(model, openai_key, openai_key_path \\ "~/home/.keys") do
    key =
      case openai_key do
        nil -> load_openai_key(openai_key_path)
        _ -> openai_key
      end

    %OpenAIClient{:model => model, :openai_key => key}
  end

  defp load_openai_key(path \\ "openai_medsi_key.txt") do
    {:ok, key} = File.read(Path.expand(path))
    String.trim(key)
  end

  def default_completion_payload(model) do
    %{
      "model" => model,
      "prompt" => "A cat walks into a bar",
      "temperature" => 0.9,
      "max_tokens" => 150,
      "top_p" => 1,
      "frequency_penalty" => 0.0,
      "presence_penalty" => 0.6,
      "stop" => [" Human:", " AI:"]
    }
  end

  def default_chat_payload(model) do
    %{
      "model" => model,
      "messages" => [%{"role" => "user", "content" => "A cat walks into a bar"}],
      "temperature" => 0.9,
      "max_tokens" => 15
    }
  end

  def get_default_payload(model) do
    cond do
      is_completion_model_name(model) -> default_completion_payload(model)
      true -> default_chat_payload(model)
    end
  end

  def is_completion_model_name(model) do
    !String.contains?(model, "gpt-3.5") && !String.contains?(model, "gpt-4")
  end
end

defimpl LMAPI, for: OpenAIClient do
  require HTTPoison

  def get_completion_response(
        client,
        payload,
        completion_type
      ) do
    url = get_url(client.model)
    default_payload = OpenAIClient.get_default_payload(client.model)
    updated_payload = {client.model, default_payload} |> get_payload(payload)

    {client, updated_payload}
    |> get_api_response(url)
    |> get_response_body()
    |> get_text_result()
  end

  defp get_chat_payload(default_chat_payload, payload) do
    {prompt, payload_without_prompt} = payload |> Map.pop("prompt")

    prompt_payload =
      case prompt do
        nil ->
          %{}

        _ ->
          %{
            "messages" => [%{"role" => "user", "content" => prompt}]
          }
      end

    default_chat_payload
    |> NestedUpdater.update(payload_without_prompt)
    |> NestedUpdater.update(prompt_payload)
  end

  defp get_payload({model, default_payload}, payload) do
    cond do
      OpenAIClient.is_completion_model_name(model) ->
        default_payload |> NestedUpdater.update(payload)

      true ->
        default_payload |> get_chat_payload(payload)
    end
  end

  defp get_url(model) do
    if OpenAIClient.is_completion_model_name(model) do
      "https://api.openai.com/v1/completions"
    else
      "https://api.openai.com/v1/chat/completions"
    end
  end

  defp get_api_response({client, payload}, url) do
    headers = [
      {"Content-Type", "application/json"},
      {"Authorization", "Bearer #{client.openai_key}"}
    ]

    request_body = payload |> Poison.encode!()
    options = [recv_timeout: 10000, timeout: 10000]
    HTTPoison.post(url, request_body, headers, options)
  end

  defp get_response_body({flag, %HTTPoison.Response{body: encoded_body}}),
    do: {flag, encoded_body |> Poison.decode!()}

  defp get_response_body({flag, encoded_body}), do: {flag, encoded_body}

  defp get_text_result({:ok, results}) do
    text =
      case results do
        %{"choices" => [%{"message" => %{"content" => text}} | _]} -> text
        %{"choices" => [%{"text" => text} | _]} -> text
      end

    {:ok, text}
  end

  defp get_text_result({:error, other}) do
    {:error, other}
  end
end
client = %LMServerClient{model: "falcon40b-instruct", url: "http://localhost:8765"}
LMAPI.get_completion_response(client, %{"prompt" => "A cat walks into a bar"}, :completion)
client = OpenAIClient.new("gpt-3.5-turbo", nil, "/home/kuba/.keys/openai_key.txt")
:ok
LMAPI.get_completion_response(client, %{"prompt" => "A cat walks into a bar"}, :completion)
defmodule PromptHistoryInterface do
  def new(client) do
    {inputs, chat_form, frame} = make_widgets()
    chat_form |> Kino.render()
    frame |> Kino.render()
    client |> wire_widgets({chat_form, frame})
  end

  def make_widgets() do
    inputs = [
      prompt: Kino.Input.text("Prompt"),
      pattern: Kino.Input.text("rellm pattern")
    ]

    frame = Kino.Frame.new()
    chat_form = Kino.Control.form(inputs, submit: "Send", reset_on_submit: [:prompt, :pattern])
    {inputs, chat_form, frame}
  end

  def wire_widgets(client, {chat_form, frame}) do
    Kino.listen(chat_form, fn %{data: %{prompt: prompt, pattern: pattern}, origin: origin} ->
      if prompt != "" do
        add_markdown_to_frame(frame, "**user**: #{prompt}")

        cond do
          pattern != "" -> add_markdown_to_frame(frame, "***pattern***: #{pattern}")
          true -> nil
        end

        add_markdown_to_frame(frame, "generating...")
        {payload, call_type} = make_client_args(prompt, pattern)
        {:ok, generated_text} = LMClient.get_completion_response(client, payload, call_type)
        add_markdown_to_frame(frame, "**bot**: #{generated_text}")
      else
        content = Kino.Markdown.new("_ERROR! You need a name and message to submit..._")
        Kino.Frame.append(frame, content, to: origin)
      end
    end)
  end

  def add_markdown_to_frame(frame, markdown_str) do
    Kino.Frame.append(frame, Kino.Markdown.new(markdown_str))
  end

  defp make_client_args(prompt, pattern) do
    payload = %{"prompt" => prompt}

    cond do
      pattern == "" -> {payload, :completion}
      true -> {payload, :rellm}
    end
  end
end
PromptHistoryInterface.new(client)