|
|
|
|
|
|
|
"""
|
|
High-Performance Chat Interface for LM Studio
|
|
|
|
This script creates a robust and efficient chat interface using Gradio,
|
|
facilitating seamless interactions with the LM Studio API. It leverages
|
|
GPU capabilities for accelerated processing and adheres to best practices
|
|
in modern Python programming. Comprehensive logging and error handling
|
|
ensure reliability and ease of maintenance.
|
|
|
|
Author: Your Name
|
|
Date: YYYY-MM-DD
|
|
"""
|
|
|
|
import gradio as gr
|
|
import httpx
|
|
import logging
|
|
import json
|
|
import os
|
|
import numpy as np
|
|
import torch
|
|
import asyncio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
BASE_URL = os.getenv("LMSTUDIO_API_BASE_URL", "http://localhost:1234/v1")
|
|
|
|
|
|
USE_GPU = torch.cuda.is_available()
|
|
DEVICE = torch.device("cuda" if USE_GPU else "cpu")
|
|
logger.info(f"Using device: {DEVICE}")
|
|
|
|
|
|
MODEL_MAX_TOKENS = 32768
|
|
AVERAGE_CHARS_PER_TOKEN = 4
|
|
BUFFER_TOKENS = 2000
|
|
MIN_OUTPUT_TOKENS = 1000
|
|
|
|
|
|
MAX_EMBEDDINGS = 100
|
|
|
|
|
|
HTTPX_TIMEOUT = 300
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_max_tokens(message, model_max_tokens=MODEL_MAX_TOKENS,
|
|
buffer=BUFFER_TOKENS, avg_chars_per_token=AVERAGE_CHARS_PER_TOKEN,
|
|
min_tokens=MIN_OUTPUT_TOKENS):
|
|
"""
|
|
Calculate the maximum number of tokens for the output based on the input message length.
|
|
|
|
Args:
|
|
message (str): The input message from the user.
|
|
model_max_tokens (int): The total token capacity of the model.
|
|
buffer (int): Reserved tokens for system prompts and overhead.
|
|
avg_chars_per_token (int): Approximate number of characters per token.
|
|
min_tokens (int): Minimum number of tokens to ensure a meaningful response.
|
|
|
|
Returns:
|
|
int: The calculated maximum tokens for the output.
|
|
"""
|
|
input_length = len(message)
|
|
input_tokens = input_length / avg_chars_per_token
|
|
max_tokens = model_max_tokens - int(input_tokens) - buffer
|
|
calculated_max = max(max_tokens, min_tokens)
|
|
logger.debug(f"Input length (chars): {input_length}, "
|
|
f"Estimated input tokens: {input_tokens}, "
|
|
f"Max tokens for output: {calculated_max}")
|
|
return calculated_max
|
|
|
|
async def get_embeddings(text):
|
|
"""
|
|
Retrieve embeddings for the given text from the LM Studio API.
|
|
|
|
Args:
|
|
text (str): The input text to generate embeddings for.
|
|
|
|
Returns:
|
|
list or None: The embedding vector as a list if successful, else None.
|
|
"""
|
|
url = f"{BASE_URL}/embeddings"
|
|
payload = {"model": "nomad_embed_text_v1_5_Q8_0", "input": text}
|
|
logger.info(f"Requesting embeddings for input: {text[:100]}...")
|
|
async with httpx.AsyncClient(timeout=HTTPX_TIMEOUT) as client:
|
|
try:
|
|
response = await client.post(
|
|
url,
|
|
json=payload,
|
|
headers={
|
|
"Content-Type": "application/json"
|
|
}
|
|
)
|
|
logger.info(f"Embeddings response status code: {response.status_code}")
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
logger.debug(f"Embeddings response data: {data}")
|
|
if "data" in data and len(data["data"]) > 0:
|
|
embedding = np.array(data["data"][0]["embedding"])
|
|
if USE_GPU:
|
|
embedding = torch.tensor(embedding, device=DEVICE).tolist()
|
|
return embedding
|
|
else:
|
|
logger.error("Invalid response structure for embeddings.")
|
|
return None
|
|
except httpx.RequestError as e:
|
|
logger.error(f"Failed to retrieve embeddings: {e}")
|
|
return None
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"HTTP error while retrieving embeddings: {e}")
|
|
return None
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"JSON decode error: {e}")
|
|
return None
|
|
|
|
def calculate_similarity(vec1, vec2):
|
|
"""
|
|
Calculate the cosine similarity between two vectors using GPU acceleration.
|
|
|
|
Args:
|
|
vec1 (list or torch.Tensor): The first embedding vector.
|
|
vec2 (list or torch.Tensor): The second embedding vector.
|
|
|
|
Returns:
|
|
float: The cosine similarity score.
|
|
"""
|
|
if vec1 is None or vec2 is None:
|
|
logger.warning("One or both vectors for similarity calculation are None.")
|
|
return 0.0
|
|
logger.debug("Calculating similarity between vectors.")
|
|
vec1_tensor = torch.tensor(vec1, device=DEVICE) if not isinstance(vec1, torch.Tensor) else vec1.to(DEVICE)
|
|
vec2_tensor = torch.tensor(vec2, device=DEVICE) if not isinstance(vec2, torch.Tensor) else vec2.to(DEVICE)
|
|
similarity = torch.nn.functional.cosine_similarity(vec1_tensor.unsqueeze(0), vec2_tensor.unsqueeze(0)).item()
|
|
logger.debug(f"Calculated similarity: {similarity}")
|
|
return similarity
|
|
|
|
|
|
|
|
|
|
|
|
async def chat_with_lmstudio(messages, max_tokens):
|
|
"""
|
|
Handle chat completions with the LM Studio API using streaming.
|
|
|
|
Args:
|
|
messages (list): A list of message dictionaries following OpenAI's format.
|
|
max_tokens (int): The maximum number of tokens to generate in the response.
|
|
|
|
Yields:
|
|
str: Chunks of the generated response.
|
|
"""
|
|
url = f"{BASE_URL}/chat/completions"
|
|
payload = {
|
|
"model": "Qwen2.5-Coder-32B-Instruct",
|
|
"messages": messages,
|
|
"temperature": 0.7,
|
|
"max_tokens": max_tokens,
|
|
"stream": True,
|
|
}
|
|
logger.info(f"Sending request to chat/completions with max_tokens: {max_tokens}")
|
|
async with httpx.AsyncClient(timeout=HTTPX_TIMEOUT) as client:
|
|
try:
|
|
async with client.stream("POST", url, json=payload, headers={"Content-Type": "application/json"}) as response:
|
|
logger.info(f"chat/completions response status code: {response.status_code}")
|
|
response.raise_for_status()
|
|
async for line in response.aiter_lines():
|
|
if line:
|
|
try:
|
|
decoded_line = line.strip()
|
|
if decoded_line.startswith("data: "):
|
|
data = json.loads(decoded_line[6:])
|
|
logger.debug(f"Received chunk: {data}")
|
|
content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
|
yield content
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"JSON decode error: {e}")
|
|
except httpx.RequestError as e:
|
|
logger.error(f"LM Studio chat/completions request failed: {e}")
|
|
yield "An error occurred while generating a response."
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"HTTP error during chat/completions: {e}")
|
|
yield "An HTTP error occurred while generating a response."
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_chat_interface():
|
|
"""
|
|
Create and launch the Gradio Blocks interface for the chat application.
|
|
"""
|
|
with gr.Blocks() as interface:
|
|
gr.Markdown("# 🚀 High-Performance Chat Interface for LM Studio")
|
|
|
|
|
|
chatbot = gr.Chatbot(label="Conversation", type="messages")
|
|
|
|
|
|
user_input = gr.Textbox(
|
|
label="Your Message",
|
|
placeholder="Type your message here...",
|
|
lines=2,
|
|
interactive=True
|
|
)
|
|
|
|
|
|
file_input = gr.File(
|
|
label="Upload Context File (.txt)",
|
|
type="binary",
|
|
interactive=True
|
|
)
|
|
|
|
|
|
context_display = gr.Textbox(
|
|
label="Relevant Context",
|
|
interactive=False
|
|
)
|
|
|
|
|
|
embeddings_state = gr.State({"embeddings": [], "messages_history": []})
|
|
|
|
async def chat_handler(message, file, state):
|
|
"""
|
|
Handle user input, process embeddings, retrieve context, and generate responses.
|
|
|
|
Args:
|
|
message (str): The user's input message.
|
|
file (UploadedFile): The uploaded context file.
|
|
state (dict): The current state containing embeddings and message history.
|
|
|
|
Yields:
|
|
list: Updated chatbot messages, new state, and context display text.
|
|
"""
|
|
embeddings = state.get("embeddings", [])
|
|
messages_history = state.get("messages_history", [])
|
|
|
|
|
|
|
|
|
|
if file:
|
|
try:
|
|
file_content = file.read().decode("utf-8")
|
|
message += f"\n[File Content]:\n{file_content}"
|
|
logger.info("Successfully processed uploaded file.")
|
|
except Exception as e:
|
|
error_msg = f"Error reading file: {e}"
|
|
logger.error(error_msg)
|
|
yield [error_msg, state, ""]
|
|
return
|
|
|
|
|
|
|
|
|
|
user_embedding = await get_embeddings(message)
|
|
if user_embedding is not None:
|
|
embeddings.append(user_embedding)
|
|
messages_history.append({"role": "user", "content": message})
|
|
logger.info("Embeddings generated and appended to state.")
|
|
else:
|
|
error_msg = "Failed to generate embeddings."
|
|
logger.error(error_msg)
|
|
yield [error_msg, state, ""]
|
|
return
|
|
|
|
|
|
if len(embeddings) > MAX_EMBEDDINGS:
|
|
embeddings = embeddings[-MAX_EMBEDDINGS:]
|
|
messages_history = messages_history[-MAX_EMBEDDINGS:]
|
|
|
|
|
|
|
|
|
|
history = [{"role": "user", "content": message}]
|
|
context_text = ""
|
|
if len(embeddings) > 1:
|
|
similarities = [
|
|
(calculate_similarity(user_embedding, emb), idx)
|
|
for idx, emb in enumerate(embeddings[:-1])
|
|
]
|
|
similarities.sort(reverse=True, key=lambda x: x[0])
|
|
top_context = similarities[:3]
|
|
for similarity, idx in top_context:
|
|
context_message = messages_history[idx]
|
|
history.insert(0, {"role": "system", "content": context_message["content"]})
|
|
context_text += f"Context: {context_message['content'][:100]}...\n"
|
|
logger.info("Relevant context retrieved based on similarity.")
|
|
|
|
|
|
|
|
|
|
max_tokens = calculate_max_tokens(message)
|
|
logger.info(f"Calculated max_tokens for output: {max_tokens}")
|
|
|
|
|
|
|
|
|
|
response = ""
|
|
try:
|
|
async for chunk in chat_with_lmstudio(history, max_tokens):
|
|
response += chunk
|
|
|
|
if not isinstance(response, str):
|
|
response = str(response)
|
|
|
|
if not response.strip():
|
|
response = "Sorry, I couldn't process your request."
|
|
|
|
|
|
updated_chat = chatbot.value.copy()
|
|
updated_chat.append({"role": "user", "content": message})
|
|
updated_chat.append({"role": "assistant", "content": response})
|
|
logger.debug(f"Updated Chat: {updated_chat}")
|
|
yield [
|
|
updated_chat,
|
|
{"embeddings": embeddings, "messages_history": messages_history},
|
|
context_text
|
|
]
|
|
logger.info("Response generation completed.")
|
|
except Exception as e:
|
|
error_msg = f"An error occurred while generating a response: {e}"
|
|
logger.error(error_msg)
|
|
yield [error_msg, state, ""]
|
|
return
|
|
|
|
|
|
|
|
|
|
messages_history.append({"role": "assistant", "content": response})
|
|
new_state = {"embeddings": embeddings, "messages_history": messages_history}
|
|
updated_chat = chatbot.value.copy()
|
|
updated_chat.append({"role": "user", "content": message})
|
|
updated_chat.append({"role": "assistant", "content": response})
|
|
|
|
|
|
try:
|
|
logger.debug(f"Final Updated Chat: {updated_chat}")
|
|
yield [
|
|
updated_chat,
|
|
new_state,
|
|
context_text
|
|
]
|
|
except Exception as e:
|
|
error_msg = f"Error updating chatbot: {e}"
|
|
logger.error(error_msg)
|
|
yield ["An error occurred while updating the chat.", state, ""]
|
|
|
|
|
|
|
|
|
|
send_button = gr.Button("Send")
|
|
send_button.click(
|
|
chat_handler,
|
|
inputs=[user_input, file_input, embeddings_state],
|
|
outputs=[chatbot, embeddings_state, context_display],
|
|
show_progress=True
|
|
)
|
|
|
|
|
|
|
|
|
|
interface.launch(share=True, server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(gradio_chat_interface())
|
|
|