Ajay12345678980 commited on
Commit
15d78e9
·
verified ·
1 Parent(s): 0280f0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -31
app.py CHANGED
@@ -1,35 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
- import torch
4
- import os
5
-
6
- # Load token from environment variable
7
- token = os.getenv('ACCESS_SECRET')
8
-
9
- # Specify the repository ID
10
- model_repo_id = "Ajay12345678980/QA_Chatbot"
11
-
12
- # Load model and tokenizer
13
- model = GPT2LMHeadModel.from_pretrained(model_repo_id, use_auth_token=token)
14
- tokenizer = GPT2Tokenizer.from_pretrained(model_repo_id, use_auth_token=token)
15
-
16
- # Define prediction function
17
- def predict(text):
18
- inputs = tokenizer.encode(text, return_tensors="pt")
19
- with torch.no_grad():
20
- outputs = model.generate(inputs, max_length=50, do_sample=True)
21
- prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
22
- return prediction
23
-
24
- # Set up Gradio interface
25
- interface = gr.Interface(
26
- fn=predict,
27
- inputs="text",
28
- outputs="text",
29
- title="GPT-2 Text Generation",
30
- description="Enter some text and see what the model generates!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
32
 
33
- # Launch the Gradio app
34
  if __name__ == "__main__":
35
- interface.launch()
 
1
+ # import gradio as gr
2
+ # from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+ # import torch
4
+ # import os
5
+
6
+ # # Load token from environment variable
7
+ # token = os.getenv('ACCESS_SECRET')
8
+
9
+ # # Specify the repository ID
10
+ # model_repo_id = "Ajay12345678980/QA_Chatbot"
11
+
12
+ # # Load model and tokenizer
13
+ # model = GPT2LMHeadModel.from_pretrained(model_repo_id, use_auth_token=token)
14
+ # tokenizer = GPT2Tokenizer.from_pretrained(model_repo_id, use_auth_token=token)
15
+
16
+ # # Define prediction function
17
+ # def predict(text):
18
+ # inputs = tokenizer.encode(text, return_tensors="pt")
19
+ # with torch.no_grad():
20
+ # outputs = model.generate(inputs, max_length=50, do_sample=True)
21
+ # prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
22
+ # return prediction
23
+
24
+ # # Set up Gradio interface
25
+ # interface = gr.Interface(
26
+ # fn=predict,
27
+ # inputs="text",
28
+ # outputs="text",
29
+ # title="GPT-2 Text Generation",
30
+ # description="Enter some text and see what the model generates!"
31
+ # )
32
+
33
+ # # Launch the Gradio app
34
+ # if __name__ == "__main__":
35
+ # interface.launch()
36
  import gradio as gr
37
+ from huggingface_hub import InferenceClient
38
+
39
+ """
40
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
41
+ """
42
+ client = InferenceClient("Ajay12345678980/QA_Chatbot")
43
+
44
+
45
+ def respond(
46
+ message,
47
+ history: list[tuple[str, str]],
48
+ system_message,
49
+ max_tokens,
50
+ temperature,
51
+ top_p,
52
+ ):
53
+ messages = [{"role": "system", "content": system_message}]
54
+
55
+ for val in history:
56
+ if val[0]:
57
+ messages.append({"role": "user", "content": val[0]})
58
+ if val[1]:
59
+ messages.append({"role": "assistant", "content": val[1]})
60
+
61
+ messages.append({"role": "user", "content": message})
62
+
63
+ response = ""
64
+
65
+ for message in client.chat_completion(
66
+ messages,
67
+ max_tokens=max_tokens,
68
+ stream=True,
69
+ temperature=temperature,
70
+ top_p=top_p,
71
+ ):
72
+ token = message.choices[0].delta.content
73
+
74
+ response += token
75
+ yield response
76
+
77
+ """
78
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
79
+ """
80
+ demo = gr.ChatInterface(
81
+ respond,
82
+ additional_inputs=[
83
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
84
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
85
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
86
+ gr.Slider(
87
+ minimum=0.1,
88
+ maximum=1.0,
89
+ value=0.95,
90
+ step=0.05,
91
+ label="Top-p (nucleus sampling)",
92
+ ),
93
+ ],
94
  )
95
 
96
+
97
  if __name__ == "__main__":
98
+ demo.launch()