Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from gradio.data_classes import FileData | |
from huggingface_hub import snapshot_download | |
from pathlib import Path | |
import base64 | |
import spaces | |
import os | |
from mistral_inference.transformer import Transformer | |
from mistral_inference.generate import generate | |
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, TextChunk, ImageURLChunk | |
from mistral_common.protocol.instruct.request import ChatCompletionRequest | |
models_path = Path.home().joinpath('pixtral', 'Pixtral') | |
models_path.mkdir(parents=True, exist_ok=True) | |
snapshot_download(repo_id="mistral-community/pixtral-12b-240910", | |
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], | |
local_dir=models_path) | |
tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json") | |
model = Transformer.from_folder(models_path) | |
def image_to_base64(image_path): | |
with open(image_path, 'rb') as img: | |
encoded_string = base64.b64encode(img.read()).decode('utf-8') | |
return f"data:image/jpeg;base64,{encoded_string}" | |
import requests | |
import base64 | |
def url_to_base64(image_url): | |
# Fetch the image from the URL | |
response = requests.get(image_url) | |
if response.status_code == 200: | |
# Encode image content to Base64 | |
base64_image = base64.b64encode(response.content).decode('utf-8') | |
return base64_image | |
else: | |
return '' | |
def run_inference(message, history): | |
## may work | |
messages = [] | |
images = [] | |
print('\n\nmessage ',message) | |
print('\n\nhistoery ',history) | |
for couple in history: | |
if type(couple[0]) is tuple: | |
images += couple[0] | |
elif couple[0][1]: | |
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(path)) for path in images]+[TextChunk(text=couple[0][1])])) | |
messages.append(AssistantMessage(content = couple[1])) | |
images = [] | |
## | |
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(file["path"])) for file in message["files"]]+[TextChunk(text=message["text"])])) | |
print('\n\nfinal messageds', messages) | |
completion_request = ChatCompletionRequest(messages=messages) | |
encoded = tokenizer.encode_chat_completion(completion_request) | |
images = encoded.images | |
tokens = encoded.tokens | |
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) | |
result = tokenizer.decode(out_tokens[0]) | |
return result | |
demo = gr.ChatInterface(fn=run_inference, title="Pixtral 12B", multimodal=True, description="A demo chat interface with Pixtral 12B, deployed using Mistral Inference.") | |
demo.queue().launch() |