zaeemzafar commited on
Commit
039f5a8
·
verified ·
1 Parent(s): 3dd7cda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -1,17 +1,31 @@
1
- from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer, AutoModel
2
  import torch
 
3
 
4
  # Load the tokenizer and model
5
  repo_name = "nvidia/Hymba-1.5B-Base"
6
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
8
  model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
9
  model = model.cuda().to(torch.bfloat16)
10
 
11
- # Chat with Hymba
12
- prompt = input()
13
- inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
14
- outputs = model.generate(**inputs, max_length=64, do_sample=True, temperature=0.7, use_cache=True)
15
- response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
16
 
17
- print(f"Model response: {response}")
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
+ import gradio as gr
4
 
5
  # Load the tokenizer and model
6
  repo_name = "nvidia/Hymba-1.5B-Base"
7
 
8
+ # Load the tokenizer and model
9
  tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
10
  model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
11
  model = model.cuda().to(torch.bfloat16)
12
 
13
+ # Define the chatbot function
14
+ def chat_with_hymba(prompt):
15
+ inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
16
+ outputs = model.generate(**inputs, max_length=64, do_sample=True, temperature=0.7, use_cache=True)
17
+ response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
18
+ return response
19
+
20
+ # Create Gradio Interface
21
+ interface = gr.Interface(
22
+ fn=chat_with_hymba,
23
+ inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
24
+ outputs="text",
25
+ title="Chat with Hymba",
26
+ description="Interact with the Hymba-1.5B model in real-time!"
27
+ )
28
 
29
+ # Launch the interface
30
+ if __name__ == "__main__":
31
+ interface.launch()