J-LAB's picture
Update app.py
96c662b verified
raw
history blame
4.96 kB
import spaces
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
from huggingface_hub import InferenceClient
import io
from PIL import Image
import torch
import numpy as np
import subprocess
import os
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = 'J-LAB/Florence-vl3'
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
@spaces.GPU
def run_example(task_prompt, image):
inputs = processor(text=task_prompt, images=image, return_tensors="pt", padding=True).to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
return parsed_answer
def process_image(image, task_prompt):
if isinstance(image, str): # Check if the image path is provided
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image) # Convert NumPy array to PIL Image
if task_prompt == 'Product Caption':
task_prompt = '<MORE_DETAILED_CAPTION>'
elif task_prompt == 'OCR':
task_prompt = '<OCR>'
results = run_example(task_prompt, image)
# Remove the key and get the text value
if results and task_prompt in results:
output_text = results[task_prompt]
else:
output_text = ""
return output_text
# Inicializando o cliente
client = InferenceClient(api_key=os.getenv('YOUR_HF_TOKEN'))
# Função de resposta para o chatbot
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, image):
image_result = ""
if image is not None:
try:
image_result_caption = process_image(image, 'Product Caption')
image_result_ocr = process_image(image, 'OCR')
image_result = image_result_caption + " " + image_result_ocr # Concatenar os resultados
except Exception as e:
image_result = f"An error occurred with image processing: {str(e)}"
# Construindo a mensagem completa com o resultado da imagem
full_message = message
if image_result:
full_message = f"\n<image>{image_result}</image>\n\n{message}"
# Adicionando mensagens ao histórico
messages = [{"role": "system", "content": f'{system_message} a descrição das imagens enviadas pelo usuário ficam dentro da tag <image> </image>'}]
for user, assistant in history:
if user:
messages.append({"role": "user", "content": user})
if assistant:
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": full_message})
# Gerando a resposta
response = ""
try:
stream = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True
)
response = ""
for chunk in stream:
if chunk.choices[0].delta.content is not None:
token = chunk.choices[0].delta.content
response += token
yield response
except Exception as e:
yield f"An error occurred: {str(e)}"
# Atualizando o histórico, mas sem mostrar image_result no chat
history.append((message, response))
return history, gr.update(value=None), gr.update(value="")
# Configurando a interface do Gradio
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
chat_input = gr.Textbox(placeholder="Enter message...", show_label=False)
image_input = gr.Image(type="filepath", label="Upload an image")
submit_btn = gr.Button("Send Message")
system_message = gr.Textbox(value="Você é um chatbot útil que sempre responde em português", label="System message")
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
submit_btn.click(respond, inputs=[chat_input, chatbot, system_message, max_tokens, temperature, top_p, image_input], outputs=[chatbot, image_input, chat_input])
if __name__ == "__main__":
demo.launch(debug=True, quiet=True)