File size: 3,051 Bytes
7422c70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e015866
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import trim_messages
from langgraph.graph import START, MessagesState, StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint



llm = HuggingFaceEndpoint(
    repo_id="meta-llama/Llama-3.2-3B-Instruct",
    max_new_tokens=512,
    temperature=0.7,
)
model = ChatHuggingFace(llm=llm)

# Define message trimmer
trimmer = trim_messages(
    max_tokens=8192,
    strategy="last",
    token_counter=model,
    include_system=True,
    allow_partial=False,
    start_on="human",
)


# Define the workflow
workflow = StateGraph(state_schema=MessagesState)
app = workflow.compile(checkpointer=MemorySaver())

config = {"configurable": {"thread_id": "abc123"}}


# Function to handle chat interaction
def chat_fn(system_prompt, user_input, history):
    if history is None:
        history = []

    # Append user input
    history.append((user_input, ""))

    # Build messages for the model
    messages = [SystemMessage(system_prompt)]
    for user_msg, assistant_msg in history[:-1]:
        messages.append(HumanMessage(user_msg))
        messages.append(AIMessage(assistant_msg))
    messages.append(HumanMessage(user_input))

    # Trim messages
    trimmed_messages = trimmer.invoke(messages)

    # Create prompt template with current system prompt
    prompt_template = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )

    # Prepare the prompt
    prompt = prompt_template.invoke(trimmed_messages)

    # Call the model
    response = model.invoke(prompt)

    # Get assistant's reply
    print(response)
    # assistant_reply = response["messages"][-1].content
    assistant_reply = response.content

    # Update history with assistant's reply
    history[-1] = (user_input, assistant_reply)

    return history, history


# Build Gradio interface
with gr.Blocks() as demo:
    system_prompt = gr.Textbox(
        value="You talk like a pirate. Answer all questions to the best of your ability.",
        label="System Prompt",
        lines=2,
    )
    chatbot = gr.Chatbot()
    state = gr.State()

    with gr.Row():
        user_input = gr.Textbox(
            show_label=False,
            placeholder="Enter your message",
            container=False,  # Moved 'container' parameter here
        )
        send_button = gr.Button("Send")

    # Define interaction
    def user_message(_):
        return "", ""

    send_button.click(
        fn=chat_fn,
        inputs=[system_prompt, user_input, state],
        outputs=[chatbot, state],
    )
    user_input.submit(
        fn=chat_fn,
        inputs=[system_prompt, user_input, state],
        outputs=[chatbot, state],
    )

if __name__ == '__main__':
    demo.launch(share=True, ssr_mode=False)