gokaygokay commited on
Commit
4b1a870
·
verified ·
1 Parent(s): 5d5d47d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import json
3
+ import subprocess
4
+ from llama_cpp import Llama
5
+ from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
6
+ from llama_cpp_agent.providers import LlamaCppPythonProvider
7
+ from llama_cpp_agent.chat_history import BasicChatHistory
8
+ from llama_cpp_agent.chat_history.messages import Roles
9
+ import gradio as gr
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ hf_hub_download(
13
+ repo_id="bartowski/gemma-2-9b-it-GGUF",
14
+ filename="gemma-2-9b-it-Q5_K_M.gguf",
15
+ local_dir="./models"
16
+ )
17
+
18
+ def get_messages_formatter_type(model_name):
19
+ if "gemma" in model_name:
20
+ return MessagesFormatterType.GEMMA_2
21
+ else:
22
+ raise ValueError(f"Unsupported model: {model_name}")
23
+
24
+ @spaces.GPU(duration=120)
25
+ def respond(
26
+ message,
27
+ history: list[tuple[str, str]],
28
+ model,
29
+ system_message,
30
+ max_tokens,
31
+ temperature,
32
+ top_p,
33
+ top_k,
34
+ repeat_penalty,
35
+ ):
36
+ chat_template = get_messages_formatter_type(model)
37
+
38
+ llm = Llama(
39
+ model_path=f"models/{model}",
40
+ flash_attn=True,
41
+ n_gpu_layers=81,
42
+ n_batch=1024,
43
+ n_ctx=8192,
44
+ )
45
+ provider = LlamaCppPythonProvider(llm)
46
+
47
+ agent = LlamaCppAgent(
48
+ provider,
49
+ system_prompt=f"{system_message}",
50
+ predefined_messages_formatter_type=chat_template,
51
+ debug_output=True
52
+ )
53
+
54
+ settings = provider.get_provider_default_settings()
55
+ settings.temperature = temperature
56
+ settings.top_k = top_k
57
+ settings.top_p = top_p
58
+ settings.max_tokens = max_tokens
59
+ settings.repeat_penalty = repeat_penalty
60
+ settings.stream = True
61
+
62
+ messages = BasicChatHistory()
63
+
64
+ for msn in history:
65
+ user = {
66
+ 'role': Roles.user,
67
+ 'content': msn[0]
68
+ }
69
+ assistant = {
70
+ 'role': Roles.assistant,
71
+ 'content': msn[1]
72
+ }
73
+ messages.add_message(user)
74
+ messages.add_message(assistant)
75
+
76
+ stream = agent.get_chat_response(
77
+ message,
78
+ llm_sampling_settings=settings,
79
+ chat_history=messages,
80
+ returns_streaming_generator=True,
81
+ print_output=False
82
+ )
83
+
84
+ outputs = ""
85
+ for output in stream:
86
+ outputs += output
87
+ yield outputs
88
+
89
+ demo = gr.ChatInterface(
90
+ respond,
91
+ additional_inputs=[
92
+ gr.Dropdown([
93
+ 'gemma-2-9b-it-Q5_K_M.gguf'
94
+ ],
95
+ value="gemma-2-9b-it-Q5_K_M.gguf",
96
+ label="Model"
97
+ ),
98
+ gr.Textbox(value="You are a helpful assistant.", label="System message"),
99
+ gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
100
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
101
+ gr.Slider(
102
+ minimum=0.1,
103
+ maximum=1.0,
104
+ value=0.95,
105
+ step=0.05,
106
+ label="Top-p",
107
+ ),
108
+ gr.Slider(
109
+ minimum=0,
110
+ maximum=100,
111
+ value=40,
112
+ step=1,
113
+ label="Top-k",
114
+ ),
115
+ gr.Slider(
116
+ minimum=0.0,
117
+ maximum=2.0,
118
+ value=1.1,
119
+ step=0.1,
120
+ label="Repetition penalty",
121
+ ),
122
+ ],
123
+ retry_btn="Retry",
124
+ undo_btn="Undo",
125
+ clear_btn="Clear",
126
+ submit_btn="Send",
127
+ description="Llama-cpp-agent: Chat with gemma-2-9b-it model",
128
+ chatbot=gr.Chatbot(
129
+ scale=1,
130
+ likeable=False,
131
+ show_copy_button=True
132
+ )
133
+ )
134
+
135
+ if __name__ == "__main__":
136
+ demo.launch()