Spestly commited on
Commit
e1807a7
Β·
verified Β·
1 Parent(s): 1c501c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -31
app.py CHANGED
@@ -4,55 +4,161 @@ from huggingface_hub import login
4
  import torch
5
  import os
6
 
 
7
  HF_TOKEN = os.getenv("HF_TOKEN")
8
  login(token=HF_TOKEN)
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- model_name = "Spestly/Atlas-Pro-1.5B-Preview"
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True)
14
 
 
 
 
 
15
 
16
- model.eval()
17
-
18
- def generate_response(message, history):
19
- instruction = (
20
- "You are an LLM called Atlas. You are finetuned by Aayan Mishra. You are NOT trained by Anthropic. "
21
- "You are a Qwen 2.5 fine-tune. Your purpose is the help the user accomplish their request to the best of your abilities. "
22
- "Below is an instruction that describes a task. Answer it clearly and concisely.\n\n"
23
- f"### Instruction:\n{message}\n\n### Response:"
24
  )
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  inputs = tokenizer(instruction, return_tensors="pt")
27
 
 
28
  with torch.no_grad():
29
  outputs = model.generate(
30
  **inputs,
31
- max_new_tokens=1000,
32
  num_return_sequences=1,
33
- temperature=0.7,
34
- top_p=0.9,
35
  do_sample=True
36
  )
37
 
 
38
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
  response = response.split("### Response:")[-1].strip()
40
-
41
  return response
42
 
43
- iface = gr.ChatInterface(
44
- generate_response,
45
- chatbot=gr.Chatbot(height=600, type="messages"),
46
- textbox=gr.Textbox(placeholder="Type your message here...", container=False, scale=7),
47
- title="🦁 Atlas-Pro",
48
- description="Chat with Alas-Pro",
49
- theme="soft",
50
- examples=[
51
- "Can you give me a good salsa recipe?",
52
- "Write an engaging two-line horror story.",
53
- "What is the capital of Australia?",
54
- ],
55
- type="messages"
56
- )
57
-
58
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import torch
5
  import os
6
 
7
+ # Hugging Face token login
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
  login(token=HF_TOKEN)
10
 
11
+ # Define models
12
+ MODELS = {
13
+ "atlas-flash-1215": {
14
+ "name": "🦁 Atlas-Flash 1215",
15
+ "sizes": {
16
+ "1.5B": "Spestly/Atlas-Flash-1.5B-Preview",
17
+ },
18
+ "emoji": "🦁",
19
+ "experimental": True,
20
+ "is_vision": False,
21
+ "system_prompt_env": "ATLAS_FLASH_1215",
22
+ },
23
+ "atlas-pro-0403": {
24
+ "name": "πŸ† Atlas-Pro 0403",
25
+ "sizes": {
26
+ "1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
27
+ },
28
+ "emoji": "πŸ†",
29
+ "experimental": True,
30
+ "is_vision": False,
31
+ "system_prompt_env": "ATLAS_PRO_0403",
32
+ },
33
 
34
+ }
 
 
35
 
36
+ # Load default model
37
+ default_model_key = "atlas-pro-0403"
38
+ default_size = "1.5B"
39
+ default_model = MODELS[default_model_key]["sizes"][default_size]
40
 
41
+ def load_model(model_name):
42
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ model_name,
45
+ torch_dtype=torch.float32,
46
+ low_cpu_mem_usage=True
 
 
47
  )
48
+ model.eval()
49
+ return tokenizer, model
50
+
51
+ tokenizer, model = load_model(default_model)
52
+
53
+ # Generate response function
54
+ def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens):
55
+ global tokenizer, model
56
+ # Load the selected model
57
+ selected_model = MODELS[model_key]["sizes"][model_size]
58
+ if selected_model != default_model:
59
+ tokenizer, model = load_model(selected_model)
60
 
61
+ # Get the system prompt from the environment
62
+ system_prompt_env = MODELS[model_key]["system_prompt_env"]
63
+ system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.")
64
+
65
+ # Construct instruction
66
+ if MODELS[model_key]["is_vision"]:
67
+ # If a vision model, include the image information
68
+ image_info = "An image has been provided as input."
69
+ instruction = (
70
+ f"{system_prompt}\n\n"
71
+ f"### Instruction:\n{message}\n{image_info}\n\n### Response:"
72
+ )
73
+ else:
74
+ # For non-vision models
75
+ instruction = (
76
+ f"{system_prompt}\n\n"
77
+ f"### Instruction:\n{message}\n\n### Response:"
78
+ )
79
+
80
+ # Tokenize input
81
  inputs = tokenizer(instruction, return_tensors="pt")
82
 
83
+ # Generate response
84
  with torch.no_grad():
85
  outputs = model.generate(
86
  **inputs,
87
+ max_new_tokens=max_new_tokens,
88
  num_return_sequences=1,
89
+ temperature=temperature,
90
+ top_p=top_p,
91
  do_sample=True
92
  )
93
 
94
+ # Decode response
95
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
  response = response.split("### Response:")[-1].strip()
 
97
  return response
98
 
99
+ # User interface
100
+ def create_interface():
101
+ # Define input components
102
+ message_input = gr.Textbox(label="Message", placeholder="Type your message here...")
103
+ model_key_selector = gr.Dropdown(
104
+ label="Model",
105
+ choices=list(MODELS.keys()),
106
+ value=default_model_key
107
+ )
108
+ model_size_selector = gr.Dropdown(
109
+ label="Model Size",
110
+ choices=list(MODELS[default_model_key]["sizes"].keys()),
111
+ value=default_size
112
+ )
113
+ temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)
114
+ top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
115
+ max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50)
116
+ image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False)
117
+
118
+ # Function to toggle visibility of image input
119
+ def toggle_image_input(model_key):
120
+ return MODELS[model_key]["is_vision"]
121
+
122
+ # Output components
123
+ chat_output = gr.Chatbot(label="Chatbot")
124
+
125
+ # Function to process inputs and generate output
126
+ def process_inputs(message, image, model_key, model_size, temperature, top_p, max_new_tokens, history=[]):
127
+ response = generate_response(
128
+ message=message,
129
+ image=image,
130
+ history=history,
131
+ model_key=model_key,
132
+ model_size=model_size,
133
+ temperature=temperature,
134
+ top_p=top_p,
135
+ max_new_tokens=max_new_tokens
136
+ )
137
+ history.append((message, response))
138
+ return history
139
+
140
+ # Interface layout
141
+ iface = gr.Interface(
142
+ fn=process_inputs,
143
+ inputs=[
144
+ message_input,
145
+ image_input,
146
+ model_key_selector,
147
+ model_size_selector,
148
+ temperature_slider,
149
+ top_p_slider,
150
+ max_tokens_slider
151
+ ],
152
+ outputs=chat_output,
153
+ title="🌟 Atlas-Pro/Flash/Vision Interface",
154
+ description="Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Pro (Comming Soon!). Upload images for vision models!",
155
+ theme="soft",
156
+ live=True
157
+ )
158
+
159
+ # Add event to toggle image input visibility
160
+ iface.input_components[1].set_visibility(toggle_image_input(model_key_selector.value))
161
+ return iface
162
+
163
+ # Launch the app
164
+ create_interface().launch()