Spaces:
Sleeping
Sleeping
elapt1c
commited on
Commit
•
dfb5da6
1
Parent(s):
9160705
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1 |
from typing import List, Tuple, Dict, Generator
|
2 |
-
from transformers import
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
|
6 |
-
# Load
|
7 |
-
|
8 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
model.to(device)
|
|
|
13 |
|
14 |
def create_history_messages(history: List[Tuple[str, str]]) -> List[dict]:
|
15 |
history_messages = [{"role": "user", "content": m[0]} for m in history]
|
@@ -34,7 +38,7 @@ def create_formatted_history(history_messages: List[dict]) -> List[Tuple[str, st
|
|
34 |
user_messages = []
|
35 |
assistant_messages = []
|
36 |
|
37 |
-
#
|
38 |
if user_messages:
|
39 |
formatted_history.append(("".join(user_messages), None))
|
40 |
elif assistant_messages:
|
@@ -49,26 +53,39 @@ def chat(message: str, state: List[Dict[str, str]]) -> Generator[Tuple[List[Tupl
|
|
49 |
history_messages.append({"role": "system", "content": "A helpful assistant."})
|
50 |
|
51 |
history_messages.append({"role": "user", "content": message})
|
52 |
-
# We have no content for the assistant's response yet but we will update this:
|
53 |
history_messages.append({"role": "assistant", "content": ""})
|
54 |
|
55 |
-
#
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
history_messages[-1]["content"] = response_message
|
64 |
|
65 |
formatted_history = create_formatted_history(history_messages)
|
66 |
yield formatted_history, history_messages
|
67 |
|
68 |
-
chatbot = gr.Chatbot(label="Chat")
|
69 |
iface = gr.Interface(
|
70 |
fn=chat,
|
71 |
-
inputs=[gr.Textbox(placeholder="Hello! How are you?
|
72 |
outputs=[chatbot, "state"],
|
73 |
allow_flagging="never",
|
74 |
)
|
|
|
1 |
from typing import List, Tuple, Dict, Generator
|
2 |
+
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
|
6 |
+
# Load the GPT-2 tokenizer
|
7 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
|
|
|
8 |
|
9 |
+
# Load the saved GPT-2 model from the local checkpoint
|
10 |
+
model_path = "DuckyPolice/ElapticAI-1a" # Adjust to your specific model path if needed
|
11 |
+
model = GPT2LMHeadModel.from_pretrained(model_path)
|
12 |
+
|
13 |
+
# Move model to appropriate device (GPU if available, otherwise CPU)
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
model.to(device)
|
16 |
+
model.eval()
|
17 |
|
18 |
def create_history_messages(history: List[Tuple[str, str]]) -> List[dict]:
|
19 |
history_messages = [{"role": "user", "content": m[0]} for m in history]
|
|
|
38 |
user_messages = []
|
39 |
assistant_messages = []
|
40 |
|
41 |
+
# Append any remaining messages
|
42 |
if user_messages:
|
43 |
formatted_history.append(("".join(user_messages), None))
|
44 |
elif assistant_messages:
|
|
|
53 |
history_messages.append({"role": "system", "content": "A helpful assistant."})
|
54 |
|
55 |
history_messages.append({"role": "user", "content": message})
|
|
|
56 |
history_messages.append({"role": "assistant", "content": ""})
|
57 |
|
58 |
+
# Tokenize user input and prepare input tensor
|
59 |
+
input_ids = tokenizer.encode(message, return_tensors='pt').to(device)
|
60 |
+
|
61 |
+
if input_ids.size(-1) == 0:
|
62 |
+
response_message = "Input was empty after tokenization. Please try again."
|
63 |
+
else:
|
64 |
+
# Generate tokens one by one
|
65 |
+
with torch.no_grad():
|
66 |
+
for _ in range(50): # Limit generation to 50 tokens
|
67 |
+
outputs = model(input_ids)
|
68 |
+
next_token_logits = outputs.logits[:, -1, :]
|
69 |
+
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
70 |
+
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
|
71 |
+
|
72 |
+
# Decode and append the latest token
|
73 |
+
decoded_token = tokenizer.decode(next_token_id)
|
74 |
+
history_messages[-1]["content"] += decoded_token
|
75 |
|
76 |
+
# Stop if the model generates the end-of-sequence token
|
77 |
+
if next_token_id.item() == tokenizer.eos_token_id:
|
78 |
+
break
|
79 |
|
80 |
+
response_message = history_messages[-1]["content"]
|
|
|
81 |
|
82 |
formatted_history = create_formatted_history(history_messages)
|
83 |
yield formatted_history, history_messages
|
84 |
|
85 |
+
chatbot = gr.Chatbot(label="Chat")
|
86 |
iface = gr.Interface(
|
87 |
fn=chat,
|
88 |
+
inputs=[gr.Textbox(placeholder="Hello! How are you?", label="Message"), "state"],
|
89 |
outputs=[chatbot, "state"],
|
90 |
allow_flagging="never",
|
91 |
)
|