from PIL import Image import requests import matplotlib.pyplot as plt import gradio as gr from gradio.mix import Parallel import torch from transformers import ( ViTConfig, ViTForImageClassification, ViTFeatureExtractor, AutoModelForCausalLM, LogitsProcessorList, MinLengthLogitsProcessor, StoppingCriteriaList, MaxLengthCriteria, ImageClassificationPipeline, PerceiverForImageClassificationConvProcessing, PerceiverFeatureExtractor, VisionEncoderDecoderModel, AutoTokenizer, ) import json import os #get from local file spaces_info.py from spaces_info import description, examples, initial_prompt_value #some constants API_URL = os.getenv("API_URL") HF_API_TOKEN = os.getenv("HF_API_TOKEN") ##Bloom Inference API API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom" #HF_API_TOKEN = os.environ["HF_API_TOKEN"] headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} print(API_URL) print(HF_API_TOKEN) def query(payload): print(payload) response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"}) print(response) return json.loads(response.content.decode("utf-8")) def inference(input_sentence, max_length, sample_or_greedy, seed=42): if sample_or_greedy == "Sample": parameters = { "max_new_tokens": max_length, "top_p": 0.9, "do_sample": True, "seed": seed, "early_stopping": False, "length_penalty": 0.0, "eos_token_id": None, } else: parameters = { "max_new_tokens": max_length, "do_sample": False, "seed": seed, "early_stopping": False, "length_penalty": 0.0, "eos_token_id": None, } payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} } data = query(payload) if "error" in data: return (None, None, f"ERROR: {data['error']} ") generation = data[0]["generated_text"].split(input_sentence, 1)[1] print(generation) ''' return ( input_sentence + prompt_to_generation + generation + after_generation, data[0]["generated_text"], "", ) ''' return generation def create_story(text_seed): #tokenizer = AutoTokenizer.from_pretrained("gpt2") #model = AutoModelForCausalLM.from_pretrained("gpt2") #eleutherAI gpt-3 based tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M") model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M") # set pad_token_id to eos_token_id because GPT2 does not have a EOS token model.config.pad_token_id = model.config.eos_token_id #input_prompt = "It might be possible to" input_prompt = text_seed input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids # instantiate logits processors logits_processor = LogitsProcessorList( [ MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id), ] ) stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=100)]) outputs = model.greedy_search( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria ) result_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) return result_text def self_caption(image): repo_name = "ydshieh/vit-gpt2-coco-en" #test_image = "cats.jpg" test_image = image #url = 'http://images.cocodataset.org/val2017/000000039769.jpg' #test_image = Image.open(requests.get(url, stream=True).raw) #test_image.save("cats.png") feature_extractor2 = ViTFeatureExtractor.from_pretrained(repo_name) tokenizer = AutoTokenizer.from_pretrained(repo_name) model2 = VisionEncoderDecoderModel.from_pretrained(repo_name) pixel_values = feature_extractor2(test_image, return_tensors="pt").pixel_values print("Pixel Values") print(pixel_values) # autoregressively generate text (using beam search or other decoding strategy) generated_ids = model2.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True) # decode into text preds = tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True) preds = [pred.strip() for pred in preds] print("Predictions") print(preds) print("The preds type is : ",type(preds)) pred_keys = ["Prediction"] pred_value = preds pred_dictionary = dict(zip(pred_keys, pred_value)) print("Pred dictionary") print(pred_dictionary) #return(pred_dictionary) preds = ' '.join(preds) #inference(input_sentence, max_length, sample_or_greedy, seed=42) story = inference(preds, 32, "Sample", 42) #story = create_story(preds) #story = ' '.join(story) return story def classify_image(image): config = ViTConfig(num_hidden_layers=12, hidden_size=768) model = ViTForImageClassification(config) #print(config) feature_extractor = ViTFeatureExtractor() # or, to load one that corresponds to a checkpoint on the hub: #feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") #the following gets called by classify_image() feature_extractor = PerceiverFeatureExtractor.from_pretrained("deepmind/vision-perceiver-conv") model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv") #google/vit-base-patch16-224, deepmind/vision-perceiver-conv image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor) results = image_pipe(image) print("RESULTS") print(results) # convert to format Gradio expects output = {} for prediction in results: predicted_label = prediction['label'] score = prediction['score'] output[predicted_label] = score print("OUTPUT") print(output) return output image = gr.inputs.Image(type="pil") label = gr.outputs.Label(num_top_classes=5) #examples = [ ["cats.jpg"], ["batter.jpg"],["drinkers.jpg"] ] examples = [ ["batter.jpg"] ] title = "Generate a Story from an Image" description = "Demo for classifying images with Perceiver IO. To use it, simply upload an image and click 'submit', a story is autogenerated as well" article = "

" img_info1 = gr.Interface( fn=classify_image, inputs=image, outputs=label, ) img_info2 = gr.Interface( fn=self_caption, inputs=image, #outputs=label, outputs = [ gr.outputs.Textbox(label = 'Story') ], ) Parallel(img_info1,img_info2, inputs=image, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True) #Parallel(img_info1,img_info2, inputs=image, outputs=label, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True)