jmourad commited on
Commit
c8c7ed1
β€’
1 Parent(s): 160e4d7

Create your Gradio app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+ from clip_interrogator import Config, Interrogator
5
+ import random
6
+ import re
7
+ import requests
8
+ import shutil
9
+ from PIL import Image
10
+
11
+
12
+ # Definir la funciΓ³n para generar prompt desde imagen
13
+ config = Config()
14
+ config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ config.blip_offload = False if torch.cuda.is_available() else True
16
+ config.chunk_size = 2048
17
+ config.flavor_intermediate_count = 512
18
+ config.blip_num_beams = 64
19
+ config.clip_model_name = "ViT-H-14/laion2b_s32b_b79k"
20
+ ci = Interrogator(config)
21
+
22
+ def get_prompt_from_image(image, mode):
23
+ image = image.convert('RGB')
24
+ if mode == 'best':
25
+ prompt = ci.interrogate(image)
26
+ elif mode == 'classic':
27
+ prompt = ci.interrogate_classic(image)
28
+ elif mode == 'fast':
29
+ prompt = ci.interrogate_fast(image)
30
+ elif mode == 'negative':
31
+ prompt = ci.interrogate_negative(image)
32
+ return prompt
33
+
34
+
35
+ # Definir la funciΓ³n para generar prompt desde texto
36
+ model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
37
+ tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
38
+
39
+ def translate(text):
40
+ with torch.no_grad():
41
+ encoded = tokenizer([text], return_tensors='pt')
42
+ sequences = model.generate(**encoded)
43
+ return tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
44
+
45
+ text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator')
46
+
47
+ def text_generate(input):
48
+ seed = random.randint(100, 1000000)
49
+ set_seed(seed)
50
+ text_in_english = translate(input)
51
+ for count in range(6):
52
+ sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8)
53
+ list = []
54
+ for sequence in sequences:
55
+ line = sequence['generated_text'].strip()
56
+ if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith((':', '-', 'β€”')) is False:
57
+ list.append(line)
58
+
59
+ result = "\n".join(list)
60
+ result = re.sub('[^ ]+\.[^ ]+','', result)
61
+ result = result.replace('<', '').replace('>', '')
62
+ if result != '':
63
+ return result
64
+ if count == 5:
65
+ return result
66
+
67
+
68
+ # Definir la funciΓ³n que permite al usuario cargar una imagen desde una URL
69
+ def load_image_from_url(url):
70
+ response = requests.get(url, stream=True)
71
+ if response.status_code == 200:
72
+ with open('./image.jpg', 'wb') as f:
73
+ response.raw.decode_content = True
74
+ shutil.copyfileobj(response.raw, f)
75
+ return Image.open('./image.jpg')
76
+ else:
77
+ raise ValueError("No se pudo cargar la imagen")
78
+
79
+
80
+ # Crear la interfaz de usuario de Gradio
81
+ with gr.Interface(
82
+ [get_prompt_from_image, text_generate],
83
+ [
84
+ gr.inputs.Image(type='pil', label='Imagen'),
85
+ gr.inputs.Radio(['best', 'fast', 'classic', 'negative'], value='best', label='Modo'),
86
+ gr.inputs.Textbox(lines=6, label='Texto de entrada'),
87
+ ],
88
+ [