Spaces:
Sleeping
Sleeping
# app.py | |
from flask import Flask, request, jsonify, render_template_string | |
from llama_cpp import Llama | |
from pydantic import BaseModel, ValidationError | |
from typing import List | |
import logging | |
app = Flask(__name__) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize the Llama model | |
llm = Llama.from_pretrained( | |
repo_id="bartowski/Marco-o1-GGUF", | |
filename="Marco-o1-Q3_K_M.gguf", | |
) | |
# Pydantic Models | |
class Message(BaseModel): | |
role: str | |
content: str | |
class ChatRequest(BaseModel): | |
messages: List[Message] | |
class ChatResponse(BaseModel): | |
response: str | |
# Route to serve the chat interface | |
def index(): | |
html_content = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset="UTF-8"> | |
<title>Llama Chat Interface</title> | |
<style> | |
body { font-family: Arial, sans-serif; background-color: #f4f4f4; padding: 20px; } | |
#chat-container { max-width: 600px; margin: auto; background: #fff; padding: 20px; border-radius: 5px; } | |
#messages { border: 1px solid #ccc; padding: 10px; height: 300px; overflow-y: scroll; } | |
.message { margin-bottom: 15px; } | |
.user { color: blue; } | |
.assistant { color: green; } | |
#input-form { display: flex; margin-top: 10px; } | |
#input-form input { flex: 1; padding: 10px; border: 1px solid #ccc; border-radius: 3px; } | |
#input-form button { padding: 10px; border: none; background: #28a745; color: #fff; cursor: pointer; border-radius: 3px; margin-left: 5px; } | |
#input-form button:hover { background: #218838; } | |
</style> | |
</head> | |
<body> | |
<div id="chat-container"> | |
<h2>Llama Chatbot</h2> | |
<div id="messages"></div> | |
<form id="input-form"> | |
<input type="text" id="user-input" placeholder="Type your message here..." required /> | |
<button type="submit">Send</button> | |
</form> | |
</div> | |
<!-- Babel CDN --> | |
<script src="https://unpkg.com/@babel/standalone/babel.min.js"></script> | |
<!-- Your JavaScript Code --> | |
<script type="text/babel"> | |
const chatContainer = document.getElementById('messages'); | |
const inputForm = document.getElementById('input-form'); | |
const userInput = document.getElementById('user-input'); | |
// Function to append messages to the chat container | |
function appendMessage(role, content) { | |
const messageDiv = document.createElement('div'); | |
messageDiv.classList.add('message'); | |
if (role === 'user') { | |
messageDiv.classList.add('user'); | |
messageDiv.innerHTML = '<strong>You:</strong> ' + content; | |
} else if (role === 'assistant') { | |
messageDiv.classList.add('assistant'); | |
messageDiv.innerHTML = '<strong>Assistant:</strong> ' + content; | |
} | |
chatContainer.appendChild(messageDiv); | |
chatContainer.scrollTop = chatContainer.scrollHeight; | |
} | |
// Handle form submission | |
inputForm.addEventListener('submit', async (e) => { | |
e.preventDefault(); | |
const message = userInput.value.trim(); | |
if (message === '') return; | |
appendMessage('user', message); | |
userInput.value = ''; | |
try { | |
const response = await fetch('/chat', { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json' | |
}, | |
body: JSON.stringify({ | |
messages: [ | |
{ | |
role: 'user', | |
content: message | |
} | |
] | |
}) | |
}); | |
if (!response.ok) { | |
const errorData = await response.json(); | |
appendMessage('assistant', 'Error: ' + (errorData.error || 'Unknown error')); | |
return; | |
} | |
const data = await response.json(); | |
appendMessage('assistant', data.response); | |
} catch (error) { | |
appendMessage('assistant', 'Error: ' + error.message); | |
} | |
}); | |
</script> | |
</body> | |
</html> | |
""" | |
return render_template_string(html_content) | |
# Chat API Endpoint | |
def chat(): | |
try: | |
# Parse and validate the JSON request using Pydantic | |
json_data = request.get_json() | |
if not json_data: | |
logger.warning("Invalid JSON payload received.") | |
return jsonify({'error': 'Invalid JSON payload'}), 400 | |
chat_request = ChatRequest(**json_data) | |
logger.info(f"Received messages: {chat_request.messages}") | |
# Convert Pydantic models to the format expected by Llama | |
messages = [message.dict() for message in chat_request.messages] | |
# Generate the chat completion | |
completion = llm.create_chat_completion(messages=messages) | |
logger.info(f"Generated completion: {completion}") | |
# Create the response using Pydantic | |
chat_response = ChatResponse(response=completion) | |
return jsonify(chat_response.dict()) | |
except ValidationError as ve: | |
# Handle validation errors from Pydantic | |
logger.error(f"Pydantic validation error: {ve.errors()}") | |
errors = [{"field": error['loc'][0], "message": error['msg']} for error in ve.errors()] | |
return jsonify({'error': errors}), 422 | |
except Exception as e: | |
# Handle unexpected errors | |
logger.error(f"Unexpected error: {str(e)}") | |
return jsonify({'error': str(e)}), 500 | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860, debug=True) | |