elapt1c commited on
Commit
dfb5da6
1 Parent(s): 9160705

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -16
app.py CHANGED
@@ -1,15 +1,19 @@
1
  from typing import List, Tuple, Dict, Generator
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import gradio as gr
5
 
6
- # Load your safetensors model and tokenizer
7
- model_name = "DuckyPolice/ElapticAI-1a"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
10
 
 
 
 
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model.to(device)
 
13
 
14
  def create_history_messages(history: List[Tuple[str, str]]) -> List[dict]:
15
  history_messages = [{"role": "user", "content": m[0]} for m in history]
@@ -34,7 +38,7 @@ def create_formatted_history(history_messages: List[dict]) -> List[Tuple[str, st
34
  user_messages = []
35
  assistant_messages = []
36
 
37
- # append any remaining messages
38
  if user_messages:
39
  formatted_history.append(("".join(user_messages), None))
40
  elif assistant_messages:
@@ -49,26 +53,39 @@ def chat(message: str, state: List[Dict[str, str]]) -> Generator[Tuple[List[Tupl
49
  history_messages.append({"role": "system", "content": "A helpful assistant."})
50
 
51
  history_messages.append({"role": "user", "content": message})
52
- # We have no content for the assistant's response yet but we will update this:
53
  history_messages.append({"role": "assistant", "content": ""})
54
 
55
- # Prepare input for the model
56
- inputs = tokenizer(message, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Generate response from model
59
- response_ids = model.generate(inputs['input_ids'], max_length=200)
60
- response_message = tokenizer.decode(response_ids[0], skip_special_tokens=True)
61
 
62
- # Update the assistant's response in our model
63
- history_messages[-1]["content"] = response_message
64
 
65
  formatted_history = create_formatted_history(history_messages)
66
  yield formatted_history, history_messages
67
 
68
- chatbot = gr.Chatbot(label="Chat").style(color_map=("yellow", "purple"))
69
  iface = gr.Interface(
70
  fn=chat,
71
- inputs=[gr.Textbox(placeholder="Hello! How are you? etc.", label="Message"), "state"],
72
  outputs=[chatbot, "state"],
73
  allow_flagging="never",
74
  )
 
1
  from typing import List, Tuple, Dict, Generator
2
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast
3
  import torch
4
  import gradio as gr
5
 
6
+ # Load the GPT-2 tokenizer
7
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
 
 
8
 
9
+ # Load the saved GPT-2 model from the local checkpoint
10
+ model_path = "DuckyPolice/ElapticAI-1a" # Adjust to your specific model path if needed
11
+ model = GPT2LMHeadModel.from_pretrained(model_path)
12
+
13
+ # Move model to appropriate device (GPU if available, otherwise CPU)
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  model.to(device)
16
+ model.eval()
17
 
18
  def create_history_messages(history: List[Tuple[str, str]]) -> List[dict]:
19
  history_messages = [{"role": "user", "content": m[0]} for m in history]
 
38
  user_messages = []
39
  assistant_messages = []
40
 
41
+ # Append any remaining messages
42
  if user_messages:
43
  formatted_history.append(("".join(user_messages), None))
44
  elif assistant_messages:
 
53
  history_messages.append({"role": "system", "content": "A helpful assistant."})
54
 
55
  history_messages.append({"role": "user", "content": message})
 
56
  history_messages.append({"role": "assistant", "content": ""})
57
 
58
+ # Tokenize user input and prepare input tensor
59
+ input_ids = tokenizer.encode(message, return_tensors='pt').to(device)
60
+
61
+ if input_ids.size(-1) == 0:
62
+ response_message = "Input was empty after tokenization. Please try again."
63
+ else:
64
+ # Generate tokens one by one
65
+ with torch.no_grad():
66
+ for _ in range(50): # Limit generation to 50 tokens
67
+ outputs = model(input_ids)
68
+ next_token_logits = outputs.logits[:, -1, :]
69
+ next_token_id = torch.argmax(next_token_logits, dim=-1)
70
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
71
+
72
+ # Decode and append the latest token
73
+ decoded_token = tokenizer.decode(next_token_id)
74
+ history_messages[-1]["content"] += decoded_token
75
 
76
+ # Stop if the model generates the end-of-sequence token
77
+ if next_token_id.item() == tokenizer.eos_token_id:
78
+ break
79
 
80
+ response_message = history_messages[-1]["content"]
 
81
 
82
  formatted_history = create_formatted_history(history_messages)
83
  yield formatted_history, history_messages
84
 
85
+ chatbot = gr.Chatbot(label="Chat")
86
  iface = gr.Interface(
87
  fn=chat,
88
+ inputs=[gr.Textbox(placeholder="Hello! How are you?", label="Message"), "state"],
89
  outputs=[chatbot, "state"],
90
  allow_flagging="never",
91
  )