Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,503 Bytes
88d793f b164762 88d793f de8c49a deacbef de8c49a deacbef de8c49a 8625475 de8c49a 7a6ac2d b164762 8625475 3508dbc 8625475 a7583a6 8625475 b3b665b b164762 b3b665b 8625475 b0c49af 8625475 88d793f 78844be 88d793f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
from gradio.data_classes import FileData
from huggingface_hub import snapshot_download
from pathlib import Path
import base64
import spaces
import os
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, TextChunk, ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
models_path = Path.home().joinpath('pixtral', 'Pixtral')
models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistral-community/pixtral-12b-240910",
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
local_dir=models_path)
tokenizer = MistralTokenizer.from_file(f"{models_path}/tekken.json")
model = Transformer.from_folder(models_path)
def image_to_base64(image_path):
with open(image_path, 'rb') as img:
encoded_string = base64.b64encode(img.read()).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_string}"
import requests
import base64
def url_to_base64(image_url):
# Fetch the image from the URL
response = requests.get(image_url)
if response.status_code == 200:
# Encode image content to Base64
base64_image = base64.b64encode(response.content).decode('utf-8')
return f"data:image/jpeg;base64,{base64_image}"
else:
return f"data:image/jpeg;base64,"
import json
@spaces.GPU(duration=90)
def run_inference(message, history):
try:
messages= message['text']
print("messages ", messages)
messages = json.loads(messages)
final_msg=[]
for x in messages:
if x['role']=='user':
tmmp=[]
for y in x['content']:
if y['type']=='image':
tmmp+=[ImageURLChunk(image_url= url_to_base64(y['url'])) ]
else:
tmmp+=[TextChunk(text= y['text'] )]
final_msg.append(UserMessage(content =tmmp ) )
else:
final_msg.append(AssistantMessage(content = x['content'][0]['text'] ))
print('final msg ', final_msg)
completion_request = ChatCompletionRequest(messages=final_msg)
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
return result
## may work
except Exception as e:
print('usig deqfualt ', e)
messages = []
images = []
print('\n\nmessage ',message)
print('\n\nhistoery ',history)
for couple in history:
if type(couple[0]) is tuple:
images += couple[0]
elif couple[0][1]:
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(path)) for path in images]+[TextChunk(text=couple[0][1])]))
messages.append(AssistantMessage(content = couple[1]))
images = []
##
messages.append(UserMessage(content = [ImageURLChunk(image_url=image_to_base64(file["path"])) for file in message["files"]]+[TextChunk(text=message["text"])]))
print('\n\nfinal messageds', messages)
completion_request = ChatCompletionRequest(messages=messages)
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=512, temperature=0.45, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
return result
demo = gr.ChatInterface(fn=run_inference, title="Pixtral 12B", multimodal=True, description="A demo chat interface with Pixtral 12B, deployed using Mistral Inference.")
demo.queue().launch() |