import gradio as gr from diffusers import StableDiffusionPipeline import torch import io from PIL import Image import os from cryptography.fernet import Fernet from google.cloud import storage import pinecone import json import uuid import pandas as pd # decrypt Storage Cloud credentials fernet = Fernet(os.environ['DECRYPTION_KEY']) with open('cloud-storage.encrypted', 'rb') as fp: encrypted = fp.read() creds = json.loads(fernet.decrypt(encrypted).decode()) # then save creds to file with open('cloud-storage.json', 'w', encoding='utf-8') as fp: fp.write(json.dumps(creds, indent=4)) # connect to Cloud Storage os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json' storage_client = storage.Client() bucket = storage_client.get_bucket('hf-diffusion-images') # get api key for pinecone auth PINECONE_KEY = os.environ['PINECONE_KEY'] index_id = "hf-diffusion" # init connection to pinecone pinecone.init( api_key=PINECONE_KEY, environment="us-west1-gcp" ) if index_id not in pinecone.list_indexes(): raise ValueError(f"Index '{index_id}' not found") index = pinecone.Index(index_id) device = 'cpu' # init all of the models and move them to a given GPU pipe = StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", use_auth_token=os.environ['HF_AUTH'] ) pipe.to(device) missing_im = Image.open('missing.png') threshold = 0.85 def encode_text(text: str): text_inputs = pipe.tokenizer( text, return_tensors='pt' ).to(device) text_embeds = pipe.text_encoder(**text_inputs) text_embeds = text_embeds.pooler_output.cpu().tolist()[0] return text_embeds def prompt_query(text: str): embeds = encode_text(text) try: xc = index.query(embeds, top_k=30, include_metadata=True) except Exception as e: print(f"Error during query: {e}") # reinitialize connection pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp') index2 = pinecone.Index(index_id) try: xc = index2.query(embeds, top_k=30, include_metadata=True) except Exception as e: raise ValueError(e) prompts = [ match['metadata']['prompt'] for match in xc['matches'] ] scores = [round(match['score'], 2) for match in xc['matches']] # deduplicate while preserving order df = pd.DataFrame({'Similarity': scores, 'Prompt': prompts}) df = df.drop_duplicates(subset='Prompt', keep='first') df = df[df['Prompt'].str.len() > 7].head() return df def diffuse(text: str): # diffuse out = pipe(text) if any(out.nsfw_content_detected): return {} else: _id = str(uuid.uuid4()) # add image to Cloud Storage im = out.images[0] im.save(f'{_id}.png', format='png') # push to storage blob = bucket.blob(f'images/{_id}.png') blob.upload_from_filename(f'{_id}.png') # delete local file os.remove(f'{_id}.png') # add embedding and metadata to Pinecone embeds = encode_text(text) meta = { 'prompt': text, 'image_url': f'images/{_id}.png' } index.upsert([(_id, embeds, meta)]) return out.images[0] def get_image(url: str): blob = bucket.blob(url).download_as_string() blob_bytes = io.BytesIO(blob) im = Image.open(blob_bytes) return im def test_image(_id, image): try: image.save('tmp.png') return True except OSError: # delete corrupted file from pinecone and cloud index.delete(ids=[_id]) bucket.blob(f"images/{_id}.png").delete() print(f"DELETED '{_id}'") return False def prompt_image(text: str): embeds = encode_text(text) try: xc = index.query(embeds, top_k=9, include_metadata=True) except Exception as e: print(f"Error during query: {e}") # reinitialize connection pinecone.init(api_key=PINECONE_KEY, environment='us-west1-gcp') index2 = pinecone.Index(index_id) try: xc = index2.query(embeds, top_k=9, include_metadata=True) except Exception as e: raise ValueError(e) image_urls = [ match['metadata']['image_url'] for match in xc['matches'] ] scores = [match['score'] for match in xc['matches']] ids = [match['id'] for match in xc['matches']] images = [] for _id, image_url in zip(ids, image_urls): try: blob = bucket.blob(image_url).download_as_string() blob_bytes = io.BytesIO(blob) im = Image.open(blob_bytes) if test_image(_id, im): images.append(im) else: images.append(missing_im) except ValueError: print(f"ValueError: '{image_url}'") return images, scores # __APP FUNCTIONS__ def set_suggestion(text: str): return gr.TextArea.update(value=text[0]) def set_images(text: str): images, scores = prompt_image(text) match_found = False for score in scores: if score > threshold: match_found = True if match_found: print("MATCH FOUND") return gr.Gallery.update(value=images) else: print("NO MATCH FOUND") diffuse(text) images, scores = prompt_image(text) return gr.Gallery.update(value=images) # __CREATE APP__ demo = gr.Blocks() with demo: gr.Markdown( """ # Dream Cacher """ ) with gr.Row(): with gr.Column(): prompt = gr.TextArea( value="A person surfing", placeholder="Enter a prompt to dream about", interactive=True ) search = gr.Button(value="Search!") suggestions = gr.Dataframe( values=[], headers=['Similarity', 'Prompt'] ) # event listener for change in prompt prompt.change( prompt_query, prompt, suggestions, show_progress=False ) # results column with gr.Column(): pics = gr.Gallery() pics.style(grid=3) # search event listening try: search.click(set_images, prompt, pics) except OSError: print("OSError") demo.launch()