File size: 3,426 Bytes
168da77 06de88f bfd4b05 5ae7f9c d364219 168da77 0963b2f 168da77 d3f5533 168da77 d364219 0963b2f 5659ce7 69eca47 0963b2f 8c734d9 d1df8d3 5659ce7 d1df8d3 1b08433 d1df8d3 5659ce7 1b08433 5659ce7 0963b2f d1df8d3 5ae7f9c 1b08433 0963b2f 1b08433 d1df8d3 1b08433 8c734d9 1b08433 daa8caf 317d81f 0963b2f daa8caf 0963b2f 153de5a 2c4c1d7 daa8caf 0963b2f daa8caf d1df8d3 daa8caf 0963b2f 5ae7f9c 8c734d9 5ae7f9c 8c734d9 1c699c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import threading
import spaces
import accelerate
import time
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton π</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''
model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
model.to('cuda')
processor = AutoProcessor.from_pretrained(model_id)
# Confirming and setting the eos_token_id (if necessary)
model.generation_config.eos_token_id = processor.tokenizer.eos_token_id
@spaces.GPU(duration=120)
def krypton(input, history):
print(f"Input: {input}") # Debug input
print(f"History: {history}") # Debug history
if input["files"]:
print("Found the image\n")
image_path = input["files"][-1]["path"] if isinstance(input["files"][-1], dict) else input["files"][-1]
print(f"Image path: {image_path}")
else:
image_path = None
for hist in history:
if isinstance(hist[0], tuple):
image_path = hist[0][0]
if not image_path:
gr.Error("You need to upload an image for Krypton to work.")
return
try:
image = Image.open(image_path)
print(f"Image open: {image}")
except Exception as e:
print(f"Error opening image: {e}")
gr.Error("Failed to open the image.")
return
# Adding more context to the prompt with a placeholder for the image
prompt = f"user: Here is an image and a question about it.\n<image>{input['text']}\nassistant: "
print("Made the prompt")
try:
inputs = processor(text=prompt, images=image, return_tensors='pt').to('cuda', torch.float16)
print(f"Processed inputs: {inputs}")
except Exception as e:
print(f"Error processing inputs: {e}")
gr.Error("Failed to process the inputs.")
return
# Streamer
print('About to init streamer')
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=False, skip_prompt=True)
# Generation kwargs
generation_kwargs = dict(
inputs=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
streamer=streamer,
max_new_tokens=1024,
do_sample=False
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
print('Thread about to start')
thread.start()
buffer = ""
# time.sleep(0.5)
for new_text in streamer:
buffer += new_text
generated_text_without_prompt = buffer
# time.sleep(0.06)
yield generated_text_without_prompt
chatbot = gr.Chatbot(height=600, label="Krypt AI")
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter your question or upload an image.", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=krypton,
chatbot=chatbot,
fill_height=True,
multimodal=True,
textbox=chat_input,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False) |