GptVision / app.py
Rooni's picture
Create app.py
c92287b verified
raw
history blame contribute delete
No virus
2.89 kB
import gradio as gr
import torch
from PIL import Image
from io import BytesIO
from huggingface_hub import hf_hub_download
from processing_llava import LlavaProcessor, OpenCLIPImageProcessor
from modeling_llava import LlavaForConditionalGeneration
from transformers import AutoTokenizer, TextStreamer
# Скачиваем необходимые файлы модели
hf_hub_download(repo_id="OEvortex/HelpingAI-Vision", filename="configuration_llava.py", local_dir="./", force_download=True)
hf_hub_download(repo_id="OEvortex/HelpingAI-Vision", filename="configuration_phi.py", local_dir="./", force_download=True)
hf_hub_download(repo_id="OEvortex/HelpingAI-Vision", filename="modeling_llava.py", local_dir="./", force_download=True)
hf_hub_download(repo_id="OEvortex/HelpingAI-Vision", filename="modeling_phi.py", local_dir="./", force_download=True)
hf_hub_download(repo_id="OEvortex/HelpingAI-Vision", filename="processing_llava.py", local_dir="./", force_download=True)
# Создаем модель
model = LlavaForConditionalGeneration.from_pretrained("OEvortex/HelpingAI-Vision", torch_dtype=torch.float16)
model = model.to("cuda")
# Создаем процессоры
tokenizer = AutoTokenizer.from_pretrained("OEvortex/HelpingAI-Vision")
image_processor = OpenCLIPImageProcessor(model.config.preprocess_config)
processor = LlavaProcessor(image_processor, tokenizer)
# Функция для генерации текста
def generate_text(image, initial_text):
# Обрабатываем входные данные
with torch.inference_mode():
inputs = processor(initial_text, image, model, return_tensors='pt')
inputs['input_ids'] = inputs['input_ids'].to(model.device)
inputs['attention_mask'] = inputs['attention_mask'].to(model.device)
streamer = TextStreamer(tokenizer)
# Генерируем данные
output = model.generate(**inputs, max_new_tokens=200, do_sample=True, top_p=0.9, temperature=1.2, eos_token_id=tokenizer.eos_token_id, streamer=streamer)
# Возвращаем сгенерированный текст, убирая начальный и конечный токены
return tokenizer.decode(output[0], skip_special_tokens=True)
# Создаем интерфейс Gradio
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Загрузите изображение")
text_input = gr.Textbox(label="Введите текст запроса")
with gr.Column():
output_text = gr.Textbox(label="Сгенерированный текст")
generate_button = gr.Button("Генерировать текст")
generate_button.click(generate_text, inputs=[image_input, text_input], outputs=output_text)
# Запускаем интерфейс
demo.launch()