Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from clip_interrogator import Config, Interrogator | |
import random | |
import re | |
import requests | |
import shutil | |
from PIL import Image | |
# Definir la funciΓ³n para generar prompt desde imagen | |
config = Config() | |
config.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
config.blip_offload = False if torch.cuda.is_available() else True | |
config.chunk_size = 2048 | |
config.flavor_intermediate_count = 512 | |
config.blip_num_beams = 64 | |
config.clip_model_name = "ViT-H-14/laion2b_s32b_b79k" | |
ci = Interrogator(config) | |
def get_prompt_from_image(image, mode): | |
image = image.convert('RGB') | |
if mode == 'best': | |
prompt = ci.interrogate(image) | |
elif mode == 'classic': | |
prompt = ci.interrogate_classic(image) | |
elif mode == 'fast': | |
prompt = ci.interrogate_fast(image) | |
elif mode == 'negative': | |
prompt = ci.interrogate_negative(image) | |
return prompt | |
# Definir la funciΓ³n para generar prompt desde texto | |
model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval() | |
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en') | |
def translate(text): | |
with torch.no_grad(): | |
encoded = tokenizer([text], return_tensors='pt') | |
sequences = model.generate(**encoded) | |
return tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] | |
text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator') | |
def text_generate(input): | |
seed = random.randint(100, 1000000) | |
set_seed(seed) | |
text_in_english = translate(input) | |
for count in range(6): | |
sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8) | |
list = [] | |
for sequence in sequences: | |
line = sequence['generated_text'].strip() | |
if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith((':', '-', 'β')) is False: | |
list.append(line) | |
result = "\n".join(list) | |
result = re.sub('[^ ]+\.[^ ]+','', result) | |
result = result.replace('<', '').replace('>', '') | |
if result != '': | |
return result | |
if count == 5: | |
return result | |
# Definir la funciΓ³n que permite al usuario cargar una imagen desde una URL | |
def load_image_from_url(url): | |
response = requests.get(url, stream=True) | |
if response.status_code == 200: | |
with open('./image.jpg', 'wb') as f: | |
response.raw.decode_content = True | |
shutil.copyfileobj(response.raw, f) | |
return Image.open('./image.jpg') | |
else: | |
raise ValueError("No se pudo cargar la imagen") | |
# Crear la interfaz de usuario de Gradio | |
with gr.Interface( | |
[get_prompt_from_image, text_generate], | |
[ | |
gr.inputs.Image(type='pil', label='Imagen'), | |
gr.inputs.Radio(['best', 'fast', 'classic', 'negative'], value='best', label='Modo'), | |
gr.inputs.Textbox(lines=6, label='Texto de entrada'), | |
], | |
[ |