Spaces:
Sleeping
Sleeping
File size: 2,636 Bytes
be0233b 5913991 be0233b 7661979 be0233b 7661979 be0233b 7661979 be0233b 7661979 16f3588 7661979 16f3588 7661979 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
# Following https://python.langchain.com/docs/tutorials/chatbot/
# Missing: trimming, streaming with memory, use multiple threads
from langchain_mistralai import ChatMistralAI
from langchain_core.rate_limiters import InMemoryRateLimiter
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
import gradio as gr
# Prompt template
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You talk like a person of the Middle Ages. Answer all questions to the best of your ability.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
# Rate limiter
rate_limiter = InMemoryRateLimiter(
requests_per_second=0.1, # <-- MistralAI free. We can only make a request once every second
check_every_n_seconds=0.01, # Wake up every 100 ms to check whether allowed to make a request,
max_bucket_size=10, # Controls the maximum burst size.
)
model = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter)
# Define a new graph
workflow = StateGraph(state_schema=MessagesState)
# Define the function that calls the model
def call_model(state: MessagesState):
chain = prompt | model
response = chain.invoke(state)
return {"messages": response}
# Define the (single) node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)
# Add memory
memory = MemorySaver()
app = workflow.compile(checkpointer=memory)
# Config with thread
config = {"configurable": {"thread_id": "abc345"}}
def handle_prompt(query, history):
input_messages = [HumanMessage(query)]
try:
# Stream output
# out=""
# for chunk, metadata in app.stream({
# "messages": input_messages},
# config,
# stream_mode="messages"):
# if isinstance(chunk, AIMessage): # Filter to just model responses
# out += chunk.content
# yield out
output = app.invoke({"messages": input_messages}, config)
return output["messages"][-1].content
except:
raise gr.Error("Requests rate limit exceeded")
description = "A MistralAI powered chatbot which talks in the way of ancient times, using Langchain and deployed with Gradio."
demo = gr.ChatInterface(handle_prompt, type="messages", title="Medieval ChatBot", theme=gr.themes.Citrus(), description=description)
demo.launch() |