import re import gradio as gr from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM import spacy from spacy.matcher import Matcher device='cpu' processor = AutoProcessor.from_pretrained("microsoft/git-base") model = AutoModelForCausalLM.from_pretrained("nkasmanoff/git-planet").to(device) nlp = spacy.load('en_core_web_sm') def predict(image,max_length=64,device='cpu'): pixel_values = processor(images=image, return_tensors="pt").to(device).pixel_values generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length) generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] relation = get_relation(generated_caption) entity_pair = get_entities(generated_caption) knowlege_triplet = f"'{entity_pair[0]}'---{relation}--->'{entity_pair[1]}'" return knowlege_triplet def get_entities(sent): ## chunk 1 ent1 = "" ent2 = "" prv_tok_dep = "" # dependency tag of previous token in the sentence prv_tok_text = "" # previous token in the sentence prefix = "" modifier = "" ############################################################# for tok in nlp(sent): ## chunk 2 # if token is a punctuation mark then move on to the next token if tok.dep_ != "punct": # check: token is a compound word or not if tok.dep_ == "compound": prefix = tok.text # if the previous word was also a 'compound' then add the current word to it if prv_tok_dep == "compound": prefix = prv_tok_text + " " + tok.text # check: token is a modifier or not if tok.dep_.endswith("mod") == True: modifier = tok.text # if the previous word was also a 'compound' then add the current word to it if prv_tok_dep == "compound": modifier = prv_tok_text + " " + tok.text ## chunk 3 if tok.dep_.find("subj") == True: ent1 = modifier + " " + prefix + " " + tok.text prefix = "" modifier = "" prv_tok_dep = "" prv_tok_text = "" ## chunk 4 if tok.dep_.find("obj") == True: ent2 = modifier + " " + prefix + " " + tok.text ## chunk 5 # update variables prv_tok_dep = tok.dep_ prv_tok_text = tok.text ############################################################# return [ent1.strip(), ent2.strip()] def get_relation(sent): doc = nlp(sent) # Matcher class object matcher = Matcher(nlp.vocab) #define the pattern pattern = [{'DEP':'ROOT'}, {'DEP':'prep','OP':"?"}, {'DEP':'agent','OP':"?"}, {'POS':'ADJ','OP':"?"}] matcher.add('matching_pattern', patterns=[pattern]) matches = matcher(doc) k = len(matches) - 1 span = doc[matches[k][1]:matches[k][2]] return(span.text) input = gr.inputs.Image(label="Please upload an image", type = 'pil', optional=True) output = gr.outputs.Textbox(type="text",label="Captions") title = "Satellite Image Knowledge Extraction" description = "Provide an image taken from above, and receive back a corresponding head-relation-tail triplet that can be used to form a knowledge graph." interface = gr.Interface( fn=predict, inputs = input, theme="grass", outputs=output, title=title, ) interface.launch(debug=True)