Spaces:
Sleeping
Sleeping
File size: 4,192 Bytes
563ac92 658af68 b9e87be 91633ba 2bb32bf b9e87be 658af68 91633ba e3d517e 91633ba c35e301 91633ba e3d517e 91633ba c35e301 91633ba c35e301 91633ba ee7c5db 91633ba ee7c5db 832c4a9 ee7c5db 91633ba ee7c5db 91633ba ee7c5db 91633ba ee7c5db 91633ba 658af68 ee7c5db 91633ba ee7c5db 658af68 ee7c5db 91633ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import spaces
import gradio as gr
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
# 获取 Hugging Face 访问令牌
hf_token = os.getenv("HF_API_TOKEN")
# 定义模型名称
vqa_model_name = "llava-hf/llava-v1.6-mistral-7b-hf"
language_model_name = "larry1129/WooWoof_AI_Vision_merged_16bit_3b"
# 全局变量用于缓存模型和分词器
vqa_processor = None
vqa_model = None
language_tokenizer = None
language_model = None
# 初始化看图说话模型
def load_vqa_model():
global vqa_processor, vqa_model
if vqa_processor is None or vqa_model is None:
vqa_processor = LlavaNextProcessor.from_pretrained(vqa_model_name, use_auth_token=hf_token)
vqa_model = LlavaNextForConditionalGeneration.from_pretrained(
vqa_model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to("cuda:0")
return vqa_processor, vqa_model
# 初始化纯语言模型
def load_language_model():
global language_tokenizer, language_model
if language_tokenizer is None or language_model is None:
language_tokenizer = AutoTokenizer.from_pretrained(language_model_name, use_auth_token=hf_token)
language_model = AutoModelForCausalLM.from_pretrained(
language_model_name,
device_map="auto",
torch_dtype=torch.float16
)
language_tokenizer.pad_token = language_tokenizer.eos_token
language_model.config.pad_token_id = language_tokenizer.pad_token_id
language_model.eval()
return language_tokenizer, language_model
# 从图片生成描述
# 定义生成响应的函数,并使用 @spaces.GPU 装饰
@spaces.GPU(duration=40) # 建议将 duration 增加到 120
def generate_image_description(image):
vqa_processor, vqa_model = load_vqa_model()
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
]
prompt = vqa_processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = vqa_processor(images=image, text=prompt, return_tensors="pt").to("cuda:0")
with torch.no_grad():
output = vqa_model.generate(**inputs, max_new_tokens=100)
image_description = vqa_processor.decode(output[0], skip_special_tokens=True)
return image_description
# 使用纯语言模型生成最终回答
# 定义生成响应的函数,并使用 @spaces.GPU 装饰
@spaces.GPU(duration=40) # 建议将 duration 增加到 120
def generate_language_response(instruction, image_description):
language_tokenizer, language_model = load_language_model()
prompt = f"""### Instruction:
{instruction}
### Input:
{image_description}
### Response:
"""
inputs = language_tokenizer(prompt, return_tensors="pt").to(language_model.device)
with torch.no_grad():
outputs = language_model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask"),
max_new_tokens=128,
temperature=0.7,
top_p=0.95,
do_sample=True,
)
response = language_tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("### Response:")[-1].strip()
return response
# 整合的 Gradio 接口函数
def process_image_and_text(image, instruction):
image_description = generate_image_description(image)
final_response = generate_language_response(instruction, image_description)
return f"图片描述: {image_description}\n\n最终回答: {final_response}"
# 创建 Gradio 界面
iface = gr.Interface(
fn=process_image_and_text,
inputs=[
gr.Image(type="pil", label="上传图片"),
gr.Textbox(lines=2, placeholder="Instruction", label="Instruction")
],
outputs="text",
title="WooWoof AI - 图片和文本交互",
description="输入图片并添加指令,生成基于图片描述的回答。",
allow_flagging="never"
)
# 启动 Gradio 接口
iface.launch()
|