Skip to content

Commit

Permalink
#115 integrates with litellm
Browse files Browse the repository at this point in the history
  • Loading branch information
souzatharsis committed Nov 6, 2024
1 parent d8562bf commit 61d52d5
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 18 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ This sample collection is also [available at audio.com](https://audio.com/thatup

- Generate conversational content from multiple sources and formats (images, websites, YouTube, and PDFs).
- Customize transcript and audio generation (e.g., style, language, structure, length).
- Create podcasts from pre-existing or edited transcripts.
- Leverage cloud-based and local LLMs for transcript generation (increased privacy and control).
- Integrate with advanced text-to-speech models (OpenAI, Google,ElevenLabs, and Microsoft Edge).
- Generate transcripts using 100+ LLM models (OpenAI, Anthropic, Google etc).
- Leverage local LLMs for transcript generation for increased privacy and control.
- Integrate with advanced text-to-speech models (OpenAI, Google, ElevenLabs, and Microsoft Edge).
- Provide multi-language support for global content creation.
- Integrate seamlessly with CLI and Python packages for automated workflows.

Expand Down
43 changes: 42 additions & 1 deletion podcastfy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"- Multilingual Support\n",
" - French (fr)\n",
" - Portugue (pt-br)\n",
"- Local LLM Support"
"- Custom LLM Support"
]
},
{
Expand Down Expand Up @@ -818,6 +818,47 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom LLM Support\n",
"\n",
"Podcastfy offers a range of LLM models for generating transcripts including OpenAI, Anthropic, Google as well as local LLM models.\n",
"\n",
"### Cloud-based LLMs\n",
"\n",
"To select a particular cloud-based LLM model, users can pass the `llm_model_name` and `api_key_label` parameters to the `generate_podcast` function.\n",
"\n",
"For example, to use OpenAI's `gpt-4-turbo` model, users can pass `llm_model_name=\"gpt-4-turbo\"` and `api_key_label=\"OPENAI_API_KEY\"`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Test generating a podcast with a custom LLM model.\"\"\"\n",
"urls = [\"https://en.wikipedia.org/wiki/Artificial_intelligence\"]\n",
"\n",
"audio_file = generate_podcast(\n",
" urls=urls,\n",
" tts_model=\"edge\",\n",
" llm_model_name=\"gpt-4-turbo\",\n",
" api_key_label=\"OPENAI_API_KEY\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Remember to have the correct API key label and value in your environment variables (`.env` file)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Local LLM Support\n",
"\n",
"We enable serving local LLMs with llamafile. In the API, Local LLM support is available through the `is_local` parameter. If `is_local=True', then a local (llamafile) LLM model is used to generate the podcast transcript. Llamafiles of LLM models can be found on [HuggingFace today offering 156+ models](https://huggingface.co/models?library=llamafile).\n",
"\n",
"All you need to do is:\n",
Expand Down
32 changes: 26 additions & 6 deletions podcastfy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
from podcastfy.content_generator import ContentGenerator
from podcastfy.text_to_speech import TextToSpeech
from podcastfy.utils.config import Config, load_config
from podcastfy.utils.config_conversation import (
ConversationConfig,
load_conversation_config,
)
from podcastfy.utils.config_conversation import load_conversation_config
from podcastfy.utils.logger import setup_logger
from typing import List, Optional, Dict, Any
import copy
Expand All @@ -40,6 +37,8 @@ def process_content(
image_paths: Optional[List[str]] = None,
is_local: bool = False,
text: Optional[str] = None,
model_name: Optional[str] = None,
api_key_label: Optional[str] = None,
):
"""
Process URLs, a transcript file, image paths, or raw text to generate a podcast or transcript.
Expand Down Expand Up @@ -90,6 +89,8 @@ def process_content(
image_file_paths=image_paths or [],
output_filepath=transcript_filepath,
is_local=is_local,
model_name=model_name,
api_key_label=api_key_label,
)

if generate_audio:
Expand All @@ -98,8 +99,8 @@ def process_content(
api_key = getattr(config, f"{tts_model.upper()}_API_KEY")

text_to_speech = TextToSpeech(
api_key=api_key,
model=tts_model,
api_key=api_key,
conversation_config=conv_config.to_dict(),
)

Expand Down Expand Up @@ -155,6 +156,12 @@ def main(
text: str = typer.Option(
None, "--text", "-txt", help="Raw text input to be processed"
),
llm_model_name: str = typer.Option(
None, "--llm-model-name", "-m", help="LLM model name for transcript generation"
),
api_key_label: str = typer.Option(
None, "--api-key-label", "-k", help="Environment variable name for LLMAPI key"
),
):
"""
Generate a podcast or transcript from a list of URLs, a file containing URLs, a transcript file, image files, or raw text.
Expand Down Expand Up @@ -185,6 +192,8 @@ def main(
config=config,
is_local=is_local,
text=text,
model_name=llm_model_name,
api_key_label=api_key_label,
)
else:
urls_list = urls or []
Expand All @@ -205,6 +214,8 @@ def main(
image_paths=image_paths,
is_local=is_local,
text=text,
model_name=llm_model_name,
api_key_label=api_key_label,
)

if transcript_only:
Expand Down Expand Up @@ -234,6 +245,8 @@ def generate_podcast(
image_paths: Optional[List[str]] = None,
is_local: bool = False,
text: Optional[str] = None,
llm_model_name: Optional[str] = None,
api_key_label: Optional[str] = None,
) -> Optional[str]:
"""
Generate a podcast or transcript from a list of URLs, a file containing URLs, a transcript file, or image files.
Expand All @@ -242,13 +255,15 @@ def generate_podcast(
urls (Optional[List[str]]): List of URLs to process.
url_file (Optional[str]): Path to a file containing URLs, one per line.
transcript_file (Optional[str]): Path to a transcript file.
tts_model (Optional[str]): TTS model to use ('openai' [default], 'elevenlabs' or 'edge').
tts_model (Optional[str]): TTS model to use ('openai' [default], 'elevenlabs', 'edge', or 'gemini').
transcript_only (bool): Generate only a transcript without audio. Defaults to False.
config (Optional[Dict[str, Any]]): User-provided configuration dictionary.
conversation_config (Optional[Dict[str, Any]]): User-provided conversation configuration dictionary.
image_paths (Optional[List[str]]): List of image file paths to process.
is_local (bool): Whether to use a local LLM. Defaults to False.
text (Optional[str]): Raw text input to be processed.
llm_model_name (Optional[str]): LLM model name for content generation.
api_key_label (Optional[str]): Environment variable name for LLM API key.
Returns:
Optional[str]: Path to the final podcast audio file, or None if only generating a transcript.
Expand All @@ -272,6 +287,7 @@ def generate_podcast(
raise ValueError(
"Config must be either a dictionary or a Config object"
)

if not conversation_config:
conversation_config = load_conversation_config().to_dict()

Expand All @@ -292,6 +308,8 @@ def generate_podcast(
conversation_config=conversation_config,
is_local=is_local,
text=text,
model_name=llm_model_name,
api_key_label=api_key_label,
)
else:
urls_list = urls or []
Expand All @@ -313,6 +331,8 @@ def generate_podcast(
image_paths=image_paths,
is_local=is_local,
text=text,
model_name=llm_model_name,
api_key_label=api_key_label,
)

except Exception as e:
Expand Down
26 changes: 18 additions & 8 deletions podcastfy/content_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Optional, Dict, Any, List
import re

from langchain_community.chat_models import ChatLiteLLM
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.llms.llamafile import Llamafile
from langchain_core.prompts import ChatPromptTemplate
Expand All @@ -30,6 +31,7 @@ def __init__(
temperature: float,
max_output_tokens: int,
model_name: str,
api_key_label: str = "OPENAI_API_KEY",
):
"""
Initialize the LLMBackend.
Expand All @@ -48,12 +50,16 @@ def __init__(

if is_local:
self.llm = Llamafile()
else:
elif "gemini" in self.model_name.lower(): #keeping original gemini as a special case while we build confidence on LiteLLM
self.llm = ChatGoogleGenerativeAI(
model=model_name,
temperature=temperature,
max_output_tokens=max_output_tokens,
)
else: # user should set api_key_label from input
self.llm = ChatLiteLLM(model=self.model_name,
temperature=temperature,
api_key=os.environ[api_key_label])


class ContentGenerator:
Expand Down Expand Up @@ -217,6 +223,8 @@ def generate_qa_content(
image_file_paths: List[str] = [],
output_filepath: Optional[str] = None,
is_local: bool = False,
model_name: str = None,
api_key_label: str = "OPENAI_API_KEY"
) -> str:
"""
Generate Q&A content based on input texts.
Expand All @@ -234,19 +242,21 @@ def generate_qa_content(
Exception: If there's an error in generating content.
"""
try:
if not model_name:
model_name = self.content_generator_config.get(
"gemini_model", "gemini-1.5-pro-latest"
)
if is_local:
model_name = "User provided local model"

llmbackend = LLMBackend(
is_local=is_local,
temperature=self.config_conversation.get("creativity", 0),
max_output_tokens=self.content_generator_config.get(
"max_output_tokens", 8192
),
model_name=(
self.content_generator_config.get(
"gemini_model", "gemini-1.5-pro-latest"
)
if not is_local
else "User provided model"
),
model_name=model_name,
api_key_label=api_key_label
)

num_images = 0 if is_local else len(image_file_paths)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ ffmpeg = "^1.4"
pytest = "^8.3.3"
pytest-xdist = "^3.6.1"
google-cloud-texttospeech = "^2.21.0"
litellm = "^1.52.0"


[tool.poetry.group.dev.dependencies]
Expand Down
62 changes: 62 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,5 +261,67 @@ def test_no_input_provided():
assert "No input provided" in result.stdout


def test_generate_podcast_with_custom_llm():
"""Test generating a podcast with a custom LLM model using CLI."""
result = runner.invoke(
app,
[
"--url", MOCK_URLS[0],
"--tts-model", "edge",
"--llm-model-name", "gpt-4-turbo",
"--api-key-label", "OPENAI_API_KEY"
]
)

assert result.exit_code == 0
assert "Podcast generated successfully using edge TTS model" in result.stdout

# Extract and verify the audio file
audio_path = result.stdout.split(": ")[-1].strip()
assert os.path.exists(audio_path)
assert audio_path.endswith(".mp3")
assert os.path.getsize(audio_path) > 1024 # Check if larger than 1KB

# Clean up
os.remove(audio_path)

def test_generate_transcript_only_with_custom_llm():
"""Test generating only a transcript with a custom LLM model using CLI."""
result = runner.invoke(
app,
[
"--url", MOCK_URLS[0],
"--transcript-only",
"--llm-model-name", "gpt-4-turbo",
"--api-key-label", "OPENAI_API_KEY"
]
)

assert result.exit_code == 0
assert "Transcript generated successfully" in result.stdout

# Extract and verify the transcript file
transcript_path = result.stdout.split(": ")[-1].strip()
assert os.path.exists(transcript_path)
assert transcript_path.endswith(".txt")

# Verify transcript content
with open(transcript_path, "r") as f:
content = f.read()
assert content != ""
assert isinstance(content, str)
assert "<Person1>" in content
assert "<Person2>" in content
assert len(content.split("<Person1>")) > 1 # At least one question
assert len(content.split("<Person2>")) > 1 # At least one answer

# Verify content is substantial
min_length = 500 # Minimum expected length in characters
assert len(content) > min_length, \
f"Content length ({len(content)}) is less than minimum expected ({min_length})"

# Clean up
os.remove(transcript_path)

if __name__ == "__main__":
pytest.main()
19 changes: 19 additions & 0 deletions tests/test_genai_podcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ def test_generate_qa_content_from_raw_text(self):
self.assertNotEqual(result, "")
self.assertIsInstance(result, str)

def test_generate_qa_content_with_custom_model(self):
"""Test generating Q&A content with a custom model and API key."""
content_generator = ContentGenerator(
self.api_key,
conversation_config=sample_conversation_config()
)
input_text = "United States of America"

# Test with OpenAI model
result = content_generator.generate_qa_content(
input_text,
model_name="gpt-4-turbo",
api_key_label="OPENAI_API_KEY"
)

self.assertIsNotNone(result)
self.assertNotEqual(result, "")
self.assertIsInstance(result, str)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 61d52d5

Please sign in to comment.