TXT2IMG-MJ-Desc / app.py
jmourad's picture
Create your Gradio app.py
c8c7ed1
raw
history blame
3.05 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from clip_interrogator import Config, Interrogator
import random
import re
import requests
import shutil
from PIL import Image
# Definir la funciΓ³n para generar prompt desde imagen
config = Config()
config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
config.blip_offload = False if torch.cuda.is_available() else True
config.chunk_size = 2048
config.flavor_intermediate_count = 512
config.blip_num_beams = 64
config.clip_model_name = "ViT-H-14/laion2b_s32b_b79k"
ci = Interrogator(config)
def get_prompt_from_image(image, mode):
image = image.convert('RGB')
if mode == 'best':
prompt = ci.interrogate(image)
elif mode == 'classic':
prompt = ci.interrogate_classic(image)
elif mode == 'fast':
prompt = ci.interrogate_fast(image)
elif mode == 'negative':
prompt = ci.interrogate_negative(image)
return prompt
# Definir la funciΓ³n para generar prompt desde texto
model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
def translate(text):
with torch.no_grad():
encoded = tokenizer([text], return_tensors='pt')
sequences = model.generate(**encoded)
return tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator')
def text_generate(input):
seed = random.randint(100, 1000000)
set_seed(seed)
text_in_english = translate(input)
for count in range(6):
sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8)
list = []
for sequence in sequences:
line = sequence['generated_text'].strip()
if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith((':', '-', 'β€”')) is False:
list.append(line)
result = "\n".join(list)
result = re.sub('[^ ]+\.[^ ]+','', result)
result = result.replace('<', '').replace('>', '')
if result != '':
return result
if count == 5:
return result
# Definir la funciΓ³n que permite al usuario cargar una imagen desde una URL
def load_image_from_url(url):
response = requests.get(url, stream=True)
if response.status_code == 200:
with open('./image.jpg', 'wb') as f:
response.raw.decode_content = True
shutil.copyfileobj(response.raw, f)
return Image.open('./image.jpg')
else:
raise ValueError("No se pudo cargar la imagen")
# Crear la interfaz de usuario de Gradio
with gr.Interface(
[get_prompt_from_image, text_generate],
[
gr.inputs.Image(type='pil', label='Imagen'),
gr.inputs.Radio(['best', 'fast', 'classic', 'negative'], value='best', label='Modo'),
gr.inputs.Textbox(lines=6, label='Texto de entrada'),
],
[