diff --git a/podcastfy/content_generator.py b/podcastfy/content_generator.py index f3cfd91..d688151 100644 --- a/podcastfy/content_generator.py +++ b/podcastfy/content_generator.py @@ -14,7 +14,7 @@ from langchain_community.chat_models import ChatLiteLLM from langchain_google_genai import ChatGoogleGenerativeAI from langchain_community.llms.llamafile import Llamafile -from langchain_core.prompts import ChatPromptTemplate +from langchain_core.prompts import ChatPromptTemplate, PromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain import hub from podcastfy.utils.config_conversation import load_conversation_config @@ -503,6 +503,7 @@ def clean(self, # Then apply additional long-form specific cleaning return self._clean_transcript_response(standard_clean, config) + def _clean_transcript_response(self, transcript: str, config: Dict[str, Any]) -> str: """ Clean transcript using a two-step process with LLM-based cleaning. @@ -522,7 +523,28 @@ def _clean_transcript_response(self, transcript: str, config: Dict[str, Any]) -> """ logger.debug("Starting transcript cleaning process") - final_transcript = self._fix_alternating_tags(transcript) + # Run rewriting chain + llm = self.llm + rewrite_prompt = PromptTemplate( + input_variables=["transcript"], + template=config.get("rewrite_prompt_template", "Clean and improve this podcast transcript by deduping any repeated sections and improving conversational flow. Just output the improved conversation in the same format and nothing else. Do not add or omit any information.: \n\n{transcript}") + ) + logger.debug("Executing rewriting chain") + rewrite_chain = rewrite_prompt | llm | StrOutputParser() + + try: + rewritten_response = rewrite_chain.invoke({"transcript": transcript}) + if not rewritten_response: + logger.warning("Rewriting chain returned empty response") + # Fall back to original + rewritten_response = transcript + logger.debug("Successfully rewrote transcript") + logger.debug("Successfully rewrote transcript, BEFORE = ", transcript, "AFTER = ", rewritten_response) + except Exception as e: + logger.error(f"Error in rewriting chain: {str(e)}") + rewritten_response = transcript # Fall back to original + + final_transcript = self._fix_alternating_tags(rewritten_response) logger.debug("Completed transcript cleaning process")