# app.py
from flask import Flask, request, jsonify, render_template_string
from llama_cpp import Llama
from pydantic import BaseModel, ValidationError
from typing import List
import logging
app = Flask(__name__)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize the Llama model
llm = Llama.from_pretrained(
repo_id="bartowski/Marco-o1-GGUF",
filename="Marco-o1-Q3_K_M.gguf",
)
# Pydantic Models
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
class ChatResponse(BaseModel):
response: str
# Route to serve the chat interface
@app.route('/')
def index():
html_content = """
Llama Chat Interface
"""
return render_template_string(html_content)
# Chat API Endpoint
@app.route('/chat', methods=['POST'])
def chat():
try:
# Parse and validate the JSON request using Pydantic
json_data = request.get_json()
if not json_data:
logger.warning("Invalid JSON payload received.")
return jsonify({'error': 'Invalid JSON payload'}), 400
chat_request = ChatRequest(**json_data)
logger.info(f"Received messages: {chat_request.messages}")
# Convert Pydantic models to the format expected by Llama
messages = [message.dict() for message in chat_request.messages]
# Generate the chat completion
completion = llm.create_chat_completion(messages=messages)
logger.info(f"Generated completion: {completion}")
# Create the response using Pydantic
chat_response = ChatResponse(response=completion)
return jsonify(chat_response.dict())
except ValidationError as ve:
# Handle validation errors from Pydantic
logger.error(f"Pydantic validation error: {ve.errors()}")
errors = [{"field": error['loc'][0], "message": error['msg']} for error in ve.errors()]
return jsonify({'error': errors}), 422
except Exception as e:
# Handle unexpected errors
logger.error(f"Unexpected error: {str(e)}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)