File size: 2,636 Bytes
be0233b
5913991
be0233b
 
 
 
 
 
 
7661979
 
be0233b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7661979
be0233b
 
7661979
 
be0233b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7661979
16f3588
7661979
16f3588
7661979
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Following https://python.langchain.com/docs/tutorials/chatbot/
# Missing: trimming, streaming with memory, use multiple threads

from langchain_mistralai import ChatMistralAI
from langchain_core.rate_limiters import InMemoryRateLimiter
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
import gradio as gr

# Prompt template
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You talk like a person of the Middle Ages. Answer all questions to the best of your ability.",
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
)

# Rate limiter
rate_limiter = InMemoryRateLimiter(
    requests_per_second=0.1,  # <-- MistralAI free. We can only make a request once every second
    check_every_n_seconds=0.01,  # Wake up every 100 ms to check whether allowed to make a request,
    max_bucket_size=10,  # Controls the maximum burst size.
)

model = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter)

# Define a new graph
workflow = StateGraph(state_schema=MessagesState)

# Define the function that calls the model
def call_model(state: MessagesState):
    chain = prompt | model
    response = chain.invoke(state)
    return {"messages": response}

# Define the (single) node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)

# Add memory
memory = MemorySaver()
app = workflow.compile(checkpointer=memory)

# Config with thread
config = {"configurable": {"thread_id": "abc345"}}


def handle_prompt(query, history):
    input_messages = [HumanMessage(query)]
    try:
        # Stream output
        # out=""
        # for chunk, metadata in app.stream({
        #                         "messages": input_messages},
        #                         config,
        #                         stream_mode="messages"):
        #     if isinstance(chunk, AIMessage):  # Filter to just model responses
        #         out += chunk.content
        #     yield out
        output = app.invoke({"messages": input_messages}, config)
        return output["messages"][-1].content
    except:
        raise gr.Error("Requests rate limit exceeded")

description = "A MistralAI powered chatbot which talks in the way of ancient times, using Langchain and deployed with Gradio."

demo = gr.ChatInterface(handle_prompt, type="messages", title="Medieval ChatBot", theme=gr.themes.Citrus(), description=description)

demo.launch()