Mahavaury2 commited on
Commit
308a509
·
verified ·
1 Parent(s): 6f01f8a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ # MODEL REPO
6
+ MODEL_NAME = "mistralai/Mistral-7B-v0.1"
7
+
8
+ # Load tokenizer
9
+ print("Loading tokenizer...")
10
+ tokenizer = AutoTokenizer.from_pretrained(
11
+ MODEL_NAME,
12
+ trust_remote_code=True
13
+ )
14
+
15
+ # Load model in 4-bit on CPU
16
+ # (Even though we set device_map="auto", on a free Space there's no GPU, so it stays on CPU.)
17
+ print("Loading model in 4-bit...")
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ MODEL_NAME,
20
+ torch_dtype=torch.float16,
21
+ device_map="auto", # auto-detect available devices
22
+ load_in_4bit=True, # bitsandbytes for quantization
23
+ trust_remote_code=True # Mistral uses custom code
24
+ )
25
+
26
+ model.eval()
27
+
28
+ def chat_mistral(prompt):
29
+ """
30
+ Generates a response from Mistral 7B given a user prompt.
31
+ """
32
+ # Tokenize
33
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
34
+
35
+ # Generate
36
+ with torch.no_grad():
37
+ outputs = model.generate(
38
+ **inputs,
39
+ max_new_tokens=128, # limit output length to avoid OOM
40
+ temperature=0.7,
41
+ repetition_penalty=1.1
42
+ )
43
+
44
+ # Decode
45
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+ return response
47
+
48
+ # Create a Gradio interface
49
+ demo = gr.Interface(
50
+ fn=chat_mistral,
51
+ inputs=gr.Textbox(lines=3, label="Your Prompt"),
52
+ outputs=gr.Textbox(label="Mistral 7B Response"),
53
+ title="Mistral 7B (4-bit) Chat",
54
+ description=(
55
+ "A minimal Mistral-7B demo running on free CPU. "
56
+ "Inference will be slow and might run out of memory. "
57
+ "Use short prompts!"
58
+ )
59
+ )
60
+
61
+ # Launch the Gradio app
62
+ if __name__ == "__main__":
63
+ demo.launch()