ttengwang
clean up code, add langchain for chatbox
9a84ec8
raw
history blame
4.07 kB
from langchain.llms.openai import OpenAI
import torch
from PIL import Image, ImageDraw, ImageOps
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
import pdb
class TextRefiner:
def __init__(self, device, api_key=""):
print(f"Initializing TextRefiner to {device}")
self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
self.prompt_tag = {
"imagination": {"True": "could",
"False": "could not"}
}
self.short_prompts = {
"length": "around {length} words",
"sentiment": "of {sentiment} sentiment",
"language": "in {language}",
}
self.long_prompts = {
"imagination": "The new sentence could extend the original description by using your imagination to create additional details, or think about what might have happened before or after the scene in the image, but should not conflict with the original sentence",
}
self.wiki_prompts = "I want you to act as a Wikipedia page. I will give you a sentence and you will parse the single main object in the sentence and provide a summary of that object in the format of a Wikipedia page. Your summary should be informative and factual, covering the most important aspects of the object. Start your summary with an introductory paragraph that gives an overview of the object. The overall length of the response should be around 100 words. You should not describe the parsing process and only provide the final summary. The sentence is \"{query}\"."
self.control_prompts = "As a text reviser, you will convert an image description into a new sentence or long paragraph. The new text is {prompts}. {long_prompts} The sentence is \"{query}\" (give me the revised sentence only)"
def parse(self, response):
out = response.strip()
return out
def parse2(self, response):
out = response.strip()
return out
def prepare_input(self, query, short_prompts, long_prompts):
input = self.control_prompts.format(**{'prompts': ', '.join(short_prompts), 'long_prompts': '. '.join(long_prompts), 'query': query})
print('prompt: ', input)
return input
def inference(self, query: str, controls: dict, context: list=[], enable_wiki=False):
"""
query: the caption of the region of interest, generated by captioner
controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"}
"""
prompts = []
long_prompts = []
for control, value in controls.items():
# if control in self.prompt_tag:
# value = self.prompt_tag[control][value]
if control in self.short_prompts:
prompts.append(self.short_prompts[control].format(**{control: value}))
else:
if value in [True, "True", "true"]:
long_prompts.append(self.long_prompts[control])
input = self.prepare_input(query, prompts, long_prompts)
response = self.llm(input)
response = self.parse(response)
response_wiki = ""
if enable_wiki:
tmp_configs = {"query": query}
prompt_wiki = self.wiki_prompts.format(**tmp_configs)
response_wiki = self.llm(prompt_wiki)
response_wiki = self.parse2(response_wiki)
out = {
'raw_caption': query,
'caption': response,
'wiki': response_wiki
}
print(out)
return out
if __name__ == "__main__":
model = TextRefiner(device='cpu')
controls = {
"length": "30",
"sentiment": "negative",
# "imagination": "True",
"imagination": "False",
"language": "English",
}
# model.inference(query='a dog is sitting on a brown bench', controls=controls)
model.inference(query='a cat is sleeping', controls=controls)