Ashrafb commited on
Commit
71b54be
1 Parent(s): 007094d

Rename app.py to main.py

Browse files
Files changed (2) hide show
  1. app.py +0 -105
  2. main.py +60 -0
app.py DELETED
@@ -1,105 +0,0 @@
1
- from huggingface_hub import InferenceClient
2
- import gradio as gr
3
- import random
4
-
5
- API_URL = "https://api-inference.huggingface.co/models/"
6
-
7
- client = InferenceClient(
8
- "mistralai/Mistral-7B-Instruct-v0.1"
9
- )
10
-
11
- def format_prompt(message, history):
12
- prompt = "<s>"
13
- for user_prompt, bot_response in history:
14
- prompt += f"[INST] {user_prompt} [/INST]"
15
- prompt += f" {bot_response}</s> "
16
- prompt += f"[INST] {message} [/INST]"
17
- return prompt
18
-
19
- def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
20
-
21
- temperature = float(temperature)
22
- if temperature < 1e-2:
23
- temperature = 1e-2
24
- top_p = float(top_p)
25
-
26
- generate_kwargs = dict(
27
- temperature=temperature,
28
- max_new_tokens=max_new_tokens,
29
- top_p=top_p,
30
- repetition_penalty=repetition_penalty,
31
- do_sample=True,
32
- seed=random.randint(0, 10**7),
33
- )
34
-
35
- formatted_prompt = format_prompt(prompt, history)
36
-
37
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
38
- output = ""
39
-
40
- for response in stream:
41
- output += response.token.text
42
- yield output
43
- return output
44
-
45
- chatbot = gr.Chatbot(label='Chatbot',show_share_button=False)
46
-
47
- additional_inputs=[
48
- gr.Slider(
49
- label="Temperature",
50
- value=0.9,
51
- minimum=0.0,
52
- maximum=1.0,
53
- step=0.05,
54
- interactive=True,
55
- info="Higher values produce more diverse outputs",
56
- ),
57
- gr.Slider(
58
- label="Max new tokens",
59
- value=512,
60
- minimum=64,
61
- maximum=1024,
62
- step=64,
63
- interactive=True,
64
- info="The maximum numbers of new tokens",
65
- ),
66
- gr.Slider(
67
- label="Top-p (nucleus sampling)",
68
- value=0.90,
69
- minimum=0.0,
70
- maximum=1,
71
- step=0.05,
72
- interactive=True,
73
- info="Higher values sample more low-probability tokens",
74
- ),
75
- gr.Slider(
76
- label="Repetition penalty",
77
- value=1.2,
78
- minimum=1.0,
79
- maximum=2.0,
80
- step=0.05,
81
- interactive=True,
82
- info="Penalize repeated tokens",
83
- )
84
- ]
85
-
86
- customCSS = """
87
- #component-7 { # this is the default element ID of the chat component
88
- height: 1000px; # adjust the height as needed
89
- flex-grow: 1;
90
- }
91
- footer{display:none !important;}
92
- .gr-share {display:none !important;} /* Hide the share button */
93
-
94
- """
95
-
96
- with gr.Blocks(title="<span style='color: crimson ;'>Aiconvert.online</span>", css=customCSS, theme=gr.themes.Base()) as demo:
97
- gr.ChatInterface(
98
- generate,
99
- chatbot = chatbot,
100
- additional_inputs=additional_inputs,
101
- title="<span style='color: crimson ;'>Aiconvert.online</span>",
102
-
103
- )
104
-
105
- demo.queue().launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Form
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.templating import Jinja2Templates
5
+ from huggingface_hub import InferenceClient
6
+ import random
7
+
8
+ API_URL = "https://api-inference.huggingface.co/models/"
9
+
10
+ client = InferenceClient(
11
+ "mistralai/Mistral-7B-Instruct-v0.1"
12
+ )
13
+
14
+ app = FastAPI()
15
+
16
+ app.mount("/static", StaticFiles(directory="static"), name="static")
17
+
18
+ templates = Jinja2Templates(directory="templates")
19
+
20
+ def format_prompt(message, history):
21
+ prompt = "<s>"
22
+ for user_prompt, bot_response in history:
23
+ prompt += f"[INST] {user_prompt} [/INST]"
24
+ prompt += f" {bot_response}</s> "
25
+ prompt += f"[INST] {message} [/INST]"
26
+ return prompt
27
+
28
+ def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
29
+
30
+ temperature = float(temperature)
31
+ if temperature < 1e-2:
32
+ temperature = 1e-2
33
+ top_p = float(top_p)
34
+
35
+ generate_kwargs = dict(
36
+ temperature=temperature,
37
+ max_new_tokens=max_new_tokens,
38
+ top_p=top_p,
39
+ repetition_penalty=repetition_penalty,
40
+ do_sample=True,
41
+ seed=random.randint(0, 10**7),
42
+ )
43
+
44
+ formatted_prompt = format_prompt(prompt, history)
45
+
46
+ output = ""
47
+ for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False):
48
+ output += response.token.text
49
+ return output
50
+
51
+ @app.get("/", response_class=HTMLResponse)
52
+ async def home(request: Request):
53
+ return templates.TemplateResponse("index.html", {"request": request})
54
+
55
+ @app.post("/generate/")
56
+ async def generate_chat(request: Request, prompt: str = Form(...), history: str = Form(...), temperature: float = Form(0.9), max_new_tokens: int = Form(512), top_p: float = Form(0.95), repetition_penalty: float = Form(1.0)):
57
+ history = eval(history) # Convert history string back to list
58
+ response = generate(prompt, history, temperature, max_new_tokens, top_p, repetition_penalty)
59
+ return {"response": response}
60
+