VenkateshRoshan commited on
Commit
6823dec
·
1 Parent(s): 6dab482

app and dockerfile for hf added

Browse files
Files changed (2) hide show
  1. app_hf.py +200 -0
  2. dockerfile_hf +15 -0
app_hf.py CHANGED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import psutil
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import gradio as gr
6
+ import os
7
+ import tarfile
8
+ from typing import List, Tuple
9
+ import boto3
10
+ import logging
11
+
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class CustomerSupportBot:
17
+ def __init__(self, model_path="models/customer_support_gpt"):
18
+ """
19
+ Initialize the customer support bot with the fine-tuned model.
20
+
21
+ Args:
22
+ model_path (str): Path to the saved model and tokenizer
23
+ """
24
+ self.process = psutil.Process(os.getpid())
25
+ self.model_path = model_path
26
+ self.model_file_path = os.path.join(self.model_path, "model.tar.gz")
27
+ self.s3 = boto3.client("s3")
28
+ self.model_key = "models/model.tar.gz"
29
+ self.bucket_name = "customer-support-gpt"
30
+
31
+ # Download and load the model
32
+ self.download_and_load_model()
33
+
34
+ def download_and_load_model(self):
35
+ # Check if the model directory exists
36
+ if not os.path.exists(self.model_path):
37
+ os.makedirs(self.model_path)
38
+
39
+ # Download model.tar.gz from S3 if not already downloaded
40
+ if not os.path.exists(self.model_file_path):
41
+ print("Downloading model from S3...")
42
+ self.s3.download_file(self.bucket_name, self.model_key, self.model_file_path)
43
+ print("Download complete. Extracting model files...")
44
+
45
+ # Extract the model files
46
+ with tarfile.open(self.model_file_path, "r:gz") as tar:
47
+ tar.extractall(self.model_path)
48
+
49
+ # Load the model and tokenizer from extracted files
50
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
51
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_path)
52
+ print("Model and tokenizer loaded successfully.")
53
+
54
+ # Move model to GPU if available
55
+ self.device = "cpu" #"cuda" if torch.cuda.is_available() else "cpu"
56
+ self.model = self.model.to(self.device)
57
+
58
+ print(f'Model loaded on device: {self.device}')
59
+
60
+ def generate_response(self, message: str, max_length=100, temperature=0.7) -> str:
61
+ try:
62
+ input_text = f"Instruction: {message}\nResponse:"
63
+
64
+ # Tokenize input text
65
+ inputs = self.tokenizer(input_text, return_tensors="pt").to(self.device)
66
+
67
+ # Generate response using the model
68
+ with torch.no_grad():
69
+ outputs = self.model.generate(
70
+ **inputs,
71
+ max_length=max_length,
72
+ temperature=temperature,
73
+ num_return_sequences=1,
74
+ pad_token_id=self.tokenizer.pad_token_id,
75
+ eos_token_id=self.tokenizer.eos_token_id,
76
+ do_sample=True,
77
+ top_p=0.95,
78
+ top_k=50
79
+ )
80
+
81
+ # Decode and format the response
82
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
83
+ response = response.split("Response:")[-1].strip()
84
+ return response
85
+ except Exception as e:
86
+ return f"An error occurred: {str(e)}"
87
+
88
+ def monitor_resources(self) -> dict:
89
+ usage = {
90
+ "CPU (%)": self.process.cpu_percent(interval=1),
91
+ "RAM (GB)": self.process.memory_info().rss / (1024 ** 3)
92
+ }
93
+ return usage
94
+
95
+
96
+ def create_chat_interface():
97
+ bot = CustomerSupportBot(model_path="/app/models")
98
+
99
+ def predict(message: str, history: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[str, str]]]:
100
+ if not message:
101
+ return "", history
102
+
103
+ bot_response = bot.generate_response(message)
104
+
105
+ # Log resource usage
106
+ usage = bot.monitor_resources()
107
+ print("Resource Usage:", usage)
108
+
109
+ history.append((message, bot_response))
110
+ return "", history
111
+
112
+ # Create the Gradio interface with custom CSS
113
+ with gr.Blocks(css="""
114
+ .message-box {
115
+ margin-bottom: 10px;
116
+ }
117
+ .button-row {
118
+ display: flex;
119
+ gap: 10px;
120
+ margin-top: 10px;
121
+ }
122
+ """) as interface:
123
+ gr.Markdown("# Customer Support Chatbot")
124
+ gr.Markdown("Welcome! How can I assist you today?")
125
+
126
+ chatbot = gr.Chatbot(
127
+ label="Chat History",
128
+ height=500,
129
+ elem_classes="message-box",
130
+ # type="messages"
131
+ )
132
+
133
+ with gr.Row():
134
+ msg = gr.Textbox(
135
+ label="Your Message",
136
+ placeholder="Type your message here...",
137
+ lines=2,
138
+ elem_classes="message-box"
139
+ )
140
+
141
+ with gr.Row(elem_classes="button-row"):
142
+ submit = gr.Button("Send Message", variant="primary")
143
+ clear = gr.ClearButton([msg, chatbot], value="Clear Chat")
144
+
145
+ # Add example queries in a separate row
146
+ with gr.Row():
147
+ gr.Examples(
148
+ examples=[
149
+ "How do I reset my password?",
150
+ "What are your shipping policies?",
151
+ "I want to return a product.",
152
+ "How can I track my order?",
153
+ "What payment methods do you accept?"
154
+ ],
155
+ inputs=msg,
156
+ label="Example Questions"
157
+ )
158
+
159
+ # Set up event handlers
160
+ submit_click = submit.click(
161
+ predict,
162
+ inputs=[msg, chatbot],
163
+ outputs=[msg, chatbot]
164
+ )
165
+
166
+ msg.submit(
167
+ predict,
168
+ inputs=[msg, chatbot],
169
+ outputs=[msg, chatbot]
170
+ )
171
+
172
+ # Add keyboard shortcut for submit
173
+ msg.change(lambda x: gr.update(interactive=bool(x.strip())), inputs=[msg], outputs=[submit])
174
+
175
+ print("Interface created successfully.")
176
+
177
+ # call the initial query function
178
+ # run a query first how are you and predict the output
179
+ print(predict("How are you", []))
180
+
181
+ # run a command which checks the resource usage
182
+ print(f'Bot Resource Usage : {bot.monitor_resources()}')
183
+
184
+ # show full system usage
185
+ print(f'CPU Percentage : {psutil.cpu_percent()}')
186
+ print(f'RAM Usage : {psutil.virtual_memory()}')
187
+ print(f'Swap Memory : {psutil.swap_memory()}')
188
+
189
+ return interface
190
+
191
+ if __name__ == "__main__":
192
+ demo = create_chat_interface()
193
+ print("Starting Gradio server...")
194
+ demo.launch(
195
+ share=True,
196
+ server_name="0.0.0.0",
197
+ server_port=7860, # Changed to 7860 for Gradio
198
+ debug=True,
199
+ inline=False
200
+ )
dockerfile_hf ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY app.py /app/app_hf.py
6
+ COPY src/ /app/src/
7
+
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir --upgrade pip
10
+ RUN pip install --no-cache-dir torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ EXPOSE 7860
14
+
15
+ CMD ["python", "app_hf.py"]