Fluxi-IA / app.py
J-LAB's picture
Update app.py
144ba4b verified
raw
history blame
4 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import spaces
import io
from PIL import Image
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model_id = 'J-LAB/Florence_2_B_FluxiAI_Product_Caption'
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
DESCRIPTION = "# Product Describe by Fluxi IA\n### Base Model [Florence-2] (https://huggingface.co/microsoft/Florence-2-large)"
@spaces.GPU
def run_example(task_prompt, image):
inputs = processor(text=task_prompt, images=image, return_tensors="pt").to("cuda")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
return parsed_answer
def process_image(image):
image = Image.fromarray(image) # Convert NumPy array to PIL Image
if task_prompt == 'Product Caption':
task_prompt = '<PC>'
results = run_example(task_prompt, image, model_id=model_id)
elif task_prompt == 'OCR':
task_prompt = '<OCR>'
results = run_example(task_prompt, image, model_id=model_id)
results = run_example(task_prompt, image)
# Remove the key and get the text value
if results and task_prompt in results:
output_text = results[task_prompt]
else:
output_text = ""
# Convert newline characters to HTML line breaks
output_text = output_text.replace("\n\n", "<br><br>").replace("\n", "<br>")
return output_text
css = """
#output {
overflow: auto;
border: 1px solid #ccc;
padding: 10px;
background-color: rgb(31 41 55);
color: #fff;
}
"""
js = """
function adjustHeight() {
var outputElement = document.getElementById('output');
outputElement.style.height = 'auto'; // Reset height to auto to get the actual content height
var height = outputElement.scrollHeight + 'px'; // Get the scrollHeight
outputElement.style.height = height; // Set the height
}
// Attach the adjustHeight function to the click event of the submit button
document.querySelector('button').addEventListener('click', function() {
setTimeout(adjustHeight, 500); // Adjust the height after a small delay to ensure content is loaded
});
"""
single_task_list =[
'Product Caption', 'OCR'
]
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Product Image Select"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Product Caption")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.HTML(label="Output Text", elem_id="output")
gr.Markdown("""
## How to use via API
To use this model via API, you can follow the example code below:
```python
!pip install gradio_client
from gradio_client import Client, handle_file
client = Client("J-LAB/Fluxi-IA")
result = client.predict(
image=handle_file('https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png'),
api_name="/process_image"
)
print(result)
```
""")
submit_btn.click(process_image, [input_img], [output_text])
demo.load(lambda: None, inputs=None, outputs=None, js=js)
demo.launch(debug=True)