from langchain.prompts import ChatPromptTemplate from langchain.schema.runnable import RunnableLambda from operator import itemgetter from langchain.output_parsers import PydanticOutputParser from .output_parser import SongDescriptions from langchain.llms.base import LLM class LLMChain: def __init__(self, llm_model: LLM) -> None: self.llm_model = llm_model self.parser = PydanticOutputParser(pydantic_object=SongDescriptions) self.full_chain = self._create_llm_chain() def _get_output_format(self, _): return self.parser.get_format_instructions() def _create_llm_chain(self): prompt_response = ChatPromptTemplate.from_messages([ ("system", "You are an AI assistant, helping the user to turn a music playlist text description into four separate song descriptions that are probably contained in the playlist. Try to be specific with descriptions. Make sure all 4 song descriptions are similar.\n"), ("system", "{format_instructions}\n"), ("human", "Playlist description: {description}.\n"), # ("human", "Song descriptions:"), ]) # prompt = PromptTemplate( # template="You are an AI assistant, helping the user to turn a music playlist text description into three separate generic song descriptions that are probably contained in the playlist.\n{format_instructions}\n{description}\n", # input_variables=["description"], # partial_variables={"format_instructions": self.parser.get_format_instructions()}, # ) full_chain = ( { "format_instructions": RunnableLambda(self._get_output_format), "description": itemgetter("description"), } | prompt_response | self.llm_model ) return full_chain def process_user_description(self, user_input): output = self.full_chain.invoke( { "description": user_input } ).replace("\\", '') return self.parser.parse(output)