Spaces:
Sleeping
Sleeping
import os, yaml | |
import gradio as gr | |
import requests | |
import argparse | |
from PIL import Image | |
import numpy as np | |
import torch | |
from transformers import AutoModelForCausalLM | |
from huggingface_hub import hf_hub_download | |
## InstructIR Plugin ## | |
from insir_models import instructir | |
from insir_text.models import LanguageModel, LMHead | |
hf_hub_download(repo_id="marcosv/InstructIR", filename="im_instructir-7d.pt", local_dir="./") | |
hf_hub_download(repo_id="marcosv/InstructIR", filename="lm_instructir-7d.pt", local_dir="./") | |
CONFIG = "eval5d.yml" | |
LM_MODEL = "lm_instructir-7d.pt" | |
MODEL_NAME = "im_instructir-7d.pt" | |
def dict2namespace(config): | |
namespace = argparse.Namespace() | |
for key, value in config.items(): | |
if isinstance(value, dict): | |
new_value = dict2namespace(value) | |
else: | |
new_value = value | |
setattr(namespace, key, new_value) | |
return namespace | |
# parse config file | |
with open(os.path.join(CONFIG), "r") as f: | |
config = yaml.safe_load(f) | |
cfg = dict2namespace(config) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
ir_model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks, | |
middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim) | |
ir_model = ir_model.to(device) | |
print ("IMAGE MODEL CKPT:", MODEL_NAME) | |
ir_model.load_state_dict(torch.load(MODEL_NAME, map_location="cpu"), strict=True) | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
LMODEL = cfg.llm.model | |
language_model = LanguageModel(model=LMODEL) | |
lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses) | |
lm_head = lm_head.to(device) | |
print("LMHEAD MODEL CKPT:", LM_MODEL) | |
lm_head.load_state_dict(torch.load(LM_MODEL, map_location="cpu"), strict=True) | |
def process_img(image, prompt=None): | |
if prompt is None: | |
prompt = chat("How to improve the quality of the image?", [], image, None, None, None) | |
prompt += "Please help me improve its quality!" | |
print(prompt) | |
img = np.array(image) | |
img = img / 255. | |
img = img.astype(np.float32) | |
y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device) | |
lm_embd = language_model(prompt) | |
lm_embd = lm_embd.to(device) | |
with torch.no_grad(): | |
text_embd, deg_pred = lm_head(lm_embd) | |
x_hat = ir_model(y, text_embd) | |
restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy() | |
restored_img = np.clip(restored_img, 0. , 1.) | |
restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8 | |
return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img)) | |
## InstructIR Plugin ## | |
model = AutoModelForCausalLM.from_pretrained("q-future/co-instruct-preview", | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
attn_implementation="eager", | |
device_map={"":"cuda:0"}) | |
def chat(message, history, image_1, image_2, image_3, image_4): | |
print(history) | |
if history: | |
if image_1 is not None and image_2 is None: | |
past_message = "USER: The input image: <|image|>" + history[0][0] + " ASSISTANT:" + history[0][1] | |
for i in range((len(history) - 1)): | |
past_message += "USER:" +history[i][0] + " ASSISTANT:" + history[i][1] + "</s>" | |
message = past_message + "USER:" + message + " ASSISTANT:" | |
images = [image_1] | |
if image_1 is not None and image_2 is not None: | |
if image_3 is None: | |
past_message = "USER: The first image: <|image|>\nThe second image: <|image|>" + history[0][0] + " ASSISTANT:" + history[0][1] + "</s>" | |
for i in range((len(history) - 1)): | |
past_message += "USER:" + history[i][0] + " ASSISTANT:" + history[i][1] + "</s>" | |
message = past_message + "USER:" + message + " ASSISTANT:" | |
images = [image_1, image_2] | |
else: | |
if image_4 is None: | |
past_message = "USER: The first image: <|image|>\nThe second image: <|image|>\nThe third image:<|image|>" + history[0][0] + " ASSISTANT:" + history[0][1] + "</s>" | |
for i in range((len(history) - 1)): | |
past_message += "USER:" + history[i][0] + " ASSISTANT:" + history[i][1] + "</s>" | |
message = past_message + "USER:" + message + " ASSISTANT:" | |
images = [image_1, image_2, image_3] | |
else: | |
past_message = "USER: The first image: <|image|>\nThe second image: <|image|>\nThe third image:<|image|>\nThe fourth image:<|image|>" + history[0][0] + " ASSISTANT:" + history[0][1] + "</s>" | |
for i in range((len(history) - 1)): | |
past_message += "USER:" + history[i][0] + " ASSISTANT:" + history[i][1] + "</s>" | |
message = past_message + "USER:" + message + " ASSISTANT:" | |
images = [image_1, image_2, image_3, image_4] | |
else: | |
if image_1 is not None and image_2 is None: | |
message = "USER: The input image: <|image|>" + message + " ASSISTANT:" | |
images = [image_1] | |
if image_1 is not None and image_2 is not None: | |
if image_3 is None: | |
message = "USER: The first image: <|image|>\nThe second image: <|image|>" + message + " ASSISTANT:" | |
images = [image_1, image_2] | |
else: | |
if image_4 is None: | |
message = "USER: The first image: <|image|>\nThe second image: <|image|>\nThe third image:<|image|>" + message + " ASSISTANT:" | |
images = [image_1, image_2, image_3] | |
else: | |
message = "USER: The first image: <|image|>\nThe second image: <|image|>\nThe third image:<|image|>\nThe fourth image:<|image|>" + message + " ASSISTANT:" | |
images = [image_1, image_2, image_3, image_4] | |
print(message) | |
return model.tokenizer.batch_decode(model.chat(message, images, max_new_tokens=600).clamp(0, 100000))[0].split("ASSISTANT:")[-1] | |
#### Image,Prompts examples | |
examples = [ | |
["Which part of the image is relatively clearer, the upper part or the lower part? Please analyze in details.", "examples/sausage.jpg", None], | |
["Which image is noisy, and which one is with motion blur? Please analyze in details.", "examples/211.jpg", "examples/frog.png"], | |
["What is the problem in this image, and how to fix it? Please answer my questions one by one.", "examples/lol_748.png", None], | |
] | |
#<h1 align="center"><a href="https://github.com/Q-Future/Q-Instruct"><img src="https://github.com/Q-Future/Q-Instruct/blob/main/q_instruct_logo.png?raw=true", alt="Q-Instruct (mPLUG-Owl-2)" border="0" style="margin: 0 auto; height: 85px;" /></a> </h1> | |
title = "Co-Instruct-Plus🧑🏫🖌️" | |
with gr.Blocks(title="Co-Instruct-Plus🧑🏫🖌️") as demo: | |
title_markdown = (""" | |
<h1 align="center"><a href="https://github.com/Q-Future/Co-Instruct"><img src="https://raw.githubusercontent.com/Q-Future/Co-Instruct/main/co-instruct.png", alt="Co-Instruct" border="0" style="margin: 0 auto; height: 85px;" /></a> </h1> | |
<div align="center">Built upon <strong>Q-Instruct: Improving Low-level Visual Abilities for Multi-modality Foundation Models (CVPR 2024)</strong></div> | |
<div align="center">Built upon the Upgraded Version, Co-Instruct, supporting up to 4 images: <strong>Towards Open-ended Visual Quality Comparison (Arxiv 2024)</strong></div> | |
<div align="center">We also support <a href='https://huggingface.co/marcosv/InstructIR'>InstructIR</a> as PLUGIN to restore image!</div> | |
<h5 align="center"> Please find our more accurate visual scoring demo on <a href='https://huggingface.co/spaces/teowu/OneScorer'>[OneScorer]</a> (Q-Align)!</h2> | |
<div align="center"> | |
<div style="display:flex; gap: 0.25rem;" align="center"> | |
<strong>Q-Instruct Resources:</strong> | |
<a href='https://github.com/Q-Future/Q-Instruct'><img src='https://img.shields.io/badge/Github-Code-blue'></a> | |
<a href="https://Q-Instruct.github.io/Q-Instruct/fig/Q_Instruct_v0_1_preview.pdf"><img src="https://img.shields.io/badge/Technical-Report-red"></a> | |
<a href='https://github.com/Q-Future/Q-Instruct/stargazers'><img src='https://img.shields.io/github/stars/Q-Future/Q-Instruct.svg?style=social'></a> | |
</div> | |
</div> | |
<div align="center"> | |
<div style="display:flex; gap: 0.25rem;" align="center"> | |
<strong>Co-Instruct Resources:</strong> | |
<a href='https://github.com/Q-Future/Co-Instruct'><img src='https://img.shields.io/badge/Github-Code-blue'></a> | |
<a href="https://arxiv.org/pdf/2402.16641.pdf"><img src="https://img.shields.io/badge/Technical-Report-red"></a> | |
<a href='https://github.com/Q-Future/Co-Instruct/stargazers'><img src='https://img.shields.io/github/stars/Q-Future/Co-Instruct.svg?style=social'></a> | |
</div> | |
</div> | |
""") | |
gr.Markdown(title_markdown) | |
with gr.Row(): | |
input_img_1 = gr.Image(type='pil', label="Image 1 (First image)") | |
input_img_2 = gr.Image(type='pil', label="Image 2 (Second image)") | |
input_img_3 = gr.Image(type='pil', label="Image 3 (Third image)") | |
input_img_4 = gr.Image(type='pil', label="Image 4 (Third image)") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.ChatInterface(fn = chat, additional_inputs=[input_img_1, input_img_2, input_img_3, input_img_4], theme="Soft", examples=examples) | |
with gr.Column(scale=1): | |
input_image_ir = gr.Image(type="pil", label="Image for Auto Restoration") | |
output_image_ir = gr.Image(type="pil", label="Output of Auto Restoration") | |
gr.Interface( | |
fn=process_img, | |
inputs=[input_image_ir], | |
outputs=[output_image_ir], | |
examples=["examples/gopro.png", "examples/noise50.png", "examples/lol_748.png"], | |
) | |
demo.launch(share=True) |