from fastapi import FastAPI, Query from transformers import ( AutoProcessor, AutoModelForCausalLM, AutoTokenizer, ) from transformers import Qwen2_5_VLForConditionalGeneration from qwen_vl_utils import process_vision_info import torch import logging logging.basicConfig(level=logging.INFO) app = FastAPI() # Qwen2.5-VL Model Setup # qwen_checkpoint = "Qwen/Qwen2.5-VL-7B-Instruct" # min_pixels = 256 * 28 * 28 # max_pixels = 1280 * 28 * 28 # processor = AutoProcessor.from_pretrained( # qwen_checkpoint, # min_pixels=min_pixels, # max_pixels=max_pixels, # ) # qwen_model = AutoModelForCausalLM.from_pretrained( # qwen_checkpoint, # torch_dtype=torch.bfloat16, # device_map="auto", # ) checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct" min_pixels = 256*28*28 max_pixels = 1280*28*28 processor = AutoProcessor.from_pretrained( checkpoint, min_pixels=min_pixels, max_pixels=max_pixels ) qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( checkpoint, torch_dtype=torch.bfloat16, device_map="auto", # attn_implementation="flash_attention_2", ) # LLaMA Model Setup llama_model_name = "Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2" llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_name) llama_model = AutoModelForCausalLM.from_pretrained( llama_model_name, torch_dtype=torch.float16, device_map="auto" ) @app.get("/") def read_root(): return {"message": "API is live. Use the /predict, /chat, or /llama_chat endpoints."} @app.get("/predict") def predict(image_url: str = Query(...), prompt: str = Query(...)): messages = [ {"role": "system", "content": "You are a helpful assistant with vision abilities."}, { "role": "user", "content": [ {"type": "image", "image": image_url}, {"type": "text", "text": prompt}, ], }, ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(qwen_model.device) with torch.no_grad(): generated_ids = qwen_model.generate(**inputs, max_new_tokens=128) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) ] output_texts = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return {"response": output_texts[0]} @app.get("/chat") def chat(prompt: str = Query(...)): messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [{"type": "text", "text": prompt}]}, ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[text], padding=True, return_tensors="pt", ).to(qwen_model.device) with torch.no_grad(): generated_ids = qwen_model.generate(**inputs, max_new_tokens=128) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) ] output_texts = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return {"response": output_texts[0]} @app.get("/llama_chat") def llama_chat(prompt: str = Query(...)): inputs = llama_tokenizer(prompt, return_tensors="pt").to(llama_model.device) with torch.no_grad(): outputs = llama_model.generate(**inputs, max_new_tokens=128) response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True) return {"response": response}