import spaces import gradio as gr import torch import random import os from typing import List, Tuple from config_generator import generate_complete_game from dataset import get_processor, joint_speaker_input, joint_listener_input, get_index_to_token import torch import transformers from transformers import Idefics2ForConditionalGeneration from peft import LoraConfig, get_peft_model from joint_inference import IdeficsJointInferenceModel # Initialize the model globally repo = 'lil-lab/cogen' checkpoint = "HuggingFaceM4/idefics2-8b" model = Idefics2ForConditionalGeneration.from_pretrained(checkpoint, torch_dtype=torch.bfloat16) target_modules=r'(.*(vision_model|modality_projection|perceiver_resampler).*(out_proj|fc1|fc2|down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)|(.*(k_proj|q_proj|v_proj).*$)' lora_config = LoraConfig( r=16, lora_alpha=8, lora_dropout=0.1, target_modules=target_modules, init_lora_weights="gaussian" ) model = get_peft_model(model, lora_config, adapter_name="initial") model.load_adapter(repo, "initial", revision="r0_full") # Add other adapter new_targets = set() for n, p in model.named_parameters(): if 'lora' in n: new_targets.add(n[17:n.find('lora')-1]) new_targets = list(new_targets) lora_config = LoraConfig( r=16, lora_alpha=8, lora_dropout=0.1, target_modules=new_targets, init_lora_weights="gaussian" ) model.add_adapter('final', lora_config) model.load_adapter(repo, "final", revision="r3_full") model = IdeficsJointInferenceModel(0.5, 0, model=model).cuda() model.eval() css=""" .radio-group .wrap { display: grid; grid-template-columns: repeat(5, 1fr); grid-template-rows: repeat(5, 1fr); width: 100%; height: 100% } """ def initialize_game() -> List[List[str]]: context_dicts = [generate_complete_game() for _ in range(4)] roles = ["listener"] * 3 + ["speaker"] * 3 + ["listener"] * 3 + ["speaker"] * 3 speaker_images = [] listener_images = [] targets = [] for context_dict in context_dicts: for i in range(3): speaker_images.append(context_dict["speaker_context"]) listener_images.append(context_dict["listener_context"]) targets.append(context_dict["targets"][i]) return list(zip(speaker_images, listener_images, targets, roles)) def get_model_response( model, adapter_name, processor, index_to_token, role: str, image_paths: List[str], user_message: str = "", target_image: str = "" ) -> str: if role == "speaker": img_dir = "tangram_pngs" print("Starting processing") input_tokens, attn_mask, images, image_attn_mask, label = joint_speaker_input( processor, image_paths, target_image, model.get_listener().device ) image_paths = [image_paths] print("Starting inference") captions = get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token, adapter_name) print("Done") response = captions[0] else: # listener print("Starting processing") images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens, s_attn_mask, \ s_image_attn_mask, s_target_mask, s_target_label = joint_listener_input( processor, image_paths, user_message, model.get_listener().device ) print("Starting inference") response = get_listener_response( model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name ) print("Done") return response @spaces.GPU(duration=20) def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token, adapter_name): if model.model.active_adapter != adapter_name: model.model.set_adapter(adapter_name) with torch.no_grad(): captions, _, _, _, _ = model.generate( images.cuda(), input_tokens.cuda(), attn_mask.cuda(), image_attn_mask.cuda(), label.cuda(), image_paths, processor, img_dir, index_to_token, max_steps=30, sampling_type="nucleus", temperature=0.7, top_k=50, top_p=1, repetition_penalty=1, num_samples=5 ) return captions @spaces.GPU(duration=20) def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token, s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name): if model.model.active_adapter != adapter_name: model.model.set_adapter(adapter_name) with torch.no_grad(): _, _, joint_log_probs = model.comprehension_side([ images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token, s_input_tokens.cuda(), s_attn_mask.cuda(), s_image_attn_mask.cuda(), s_target_mask.cuda(), s_target_label.cuda(), ]) target_idx = joint_log_probs[0].argmax().item() response = image_paths[target_idx] return response def initialize_interaction(model_iteration): # initialize the overall history new_history = { 'adapter_name' : 'initial' if model_iteration == "Initial System" else "final", 'image_role_pairs' : initialize_game(), 'conversation' : [], 'turn' : 0, 'num_correct' : 0, } # Initialize the first turn (always a listener) turn = new_history['turn'] image_role_pairs = new_history['image_role_pairs'] speaker_image, listener_image, target_image, _ = image_role_pairs[turn] target_idx = speaker_image.index(target_image) new_history['conversation'].extend([ f"TURN: {turn + 1}/12", f"Generate a description for the target image. Your target is Image {target_idx + 1}" ]) return new_history def progress_game(user_message, processor, index_to_token, current_state): # First get the game state turn = current_state['turn'] image_role_pairs = current_state['image_role_pairs'] speaker_image, listener_image, target_image, model_role = image_role_pairs[turn] human_role = "Speaker" if model_role == "listener" else "Listener" # Next, move on with current turn if model_role == "listener": human_context = speaker_image model_context = listener_image # If model is a listener, the human must have sent a message current_state['conversation'].append(f"You: {user_message}") model_message = get_model_response( model, current_state['adapter_name'], processor, index_to_token, model_role, model_context, user_message=user_message ) model_idx = human_context.index(model_message) target_idx = human_context.index(target_image) if int(model_idx) == int(target_idx): current_state['conversation'].append("The model guessed correctly!\n") current_state['num_correct'] += 1 else: current_state['conversation'].append(f"The model guessed incorrectly.\n") else: human_context = listener_image model_context = speaker_image # If model is a speaker, the human must have made a guess target_idx = human_context.index(target_image) current_state['conversation'][-1] += f"{user_message}" if int(user_message) == target_idx + 1: current_state['conversation'].append("Correct!\n") current_state['num_correct'] += 1 else: current_state['conversation'].append(f"Incorrect!\n") # We move on to the next turn current_state['turn'] += 1 acc_message = f"{current_state['num_correct']}/{current_state['turn']}" turn_message = f"{current_state['turn'] + 1}/12" if current_state['turn'] == len(image_role_pairs): current_state['conversation'].append('The game is over!') return human_context, current_state['conversation'], human_role, turn_message, acc_message, {} speaker_image, listener_image, target_image, model_role = image_role_pairs[current_state['turn']] human_role = "Listener" if model_role == "speaker" else "Speaker" if model_role == "speaker": human_context = listener_image model_context = speaker_image current_state['conversation'].extend([ f"TURN: {current_state['turn'] + 1}/12", f"Guess the target image given the speaker's description. ", ]) model_message = get_model_response(model, current_state['adapter_name'], processor, index_to_token, model_role, model_context, target_image=target_image) current_state['conversation'].append(f"Model: {model_message}") current_state['conversation'].append("You: The target is Image ") else: human_context = speaker_image model_context = listener_image target_idx = human_context.index(target_image) current_state['conversation'].extend([ f"TURN: {current_state['turn'] + 1}/12", f"Generate a description for the target image. Your target is Image {target_idx + 1}", ]) return human_context, current_state['conversation'], human_role, turn_message, acc_message, current_state def get_current_images(current_history): turn = current_history['turn'] image_role_pairs = current_history['image_role_pairs'] speaker_image, listener_image, target_image, model_role = image_role_pairs[turn] human_context = listener_image if model_role == "speaker" else speaker_image return human_context def get_human_role(current_history): turn = current_history['turn'] image_role_pairs = current_history['image_role_pairs'] speaker_image, listener_image, target_image, model_role = image_role_pairs[turn] return "Listener" if model_role == "speaker" else "Speaker" def create_app(): with gr.Blocks(css=css) as app: game_history = gr.State(value={}) gr.Markdown("# Tangram Reference Game") gr.Markdown( '### You will be playing a sequence of reference games against a model. To start a game, first select whether ' +\ 'you wish to play against our initial trained model ("Initial System") or our model at the end of deployment ("Final System") ' +\ 'and press the "Start Game" button. There will be 12 rounds of reference games. You will take on a "listener" or a "speaker" role at each round.' ) gr.Markdown( '### In the speaker role, you will be assigned a target image. Your goal will be to describe this image (via a message in the textbox) ' +\ 'so that your partner can guess what it is.' ) gr.Markdown( '### In the listener role, you will be given a description. Your goal will be ' +\ 'to select the image that the description best describes (by clicking on the relevant button).' ) gr.Markdown( '### Press "Send" to submit your action in either role and make the game proceed.' ) with gr.Row(): model_iteration = gr.Radio(["Initial System", "Final System"], label="Model Iteration") start_btn = gr.Button("Start Game") with gr.Row(): current_role = gr.Textbox(label="YOUR ROLE") current_turn = gr.Textbox(label="TURN") accuracy = gr.Textbox(label="FINAL ACCURACY") with gr.Row(): image_output = gr.Gallery( label="CONTEXT", show_label=False, elem_id="gallery", columns=5, rows=2, object_fit="contain", height="250px", allow_preview=False, container=True ) with gr.Row(): conversation_output = gr.Textbox(label="Interaction History") with gr.Column(): user_input = gr.Textbox(label="Your Message as Speaker", interactive=False) radio_buttons = gr.Radio( label="Your Guess as Listener", elem_classes="radio-group", choices=list(range(1, 11)), interactive=False, ) send_btn = gr.Button("Send", interactive=False) processor = get_processor() index_to_token = get_index_to_token() def start_interaction(model_iteration): # Initialize the interaction if model_iteration is None: return [], "Please select a model iteration.", "", "", "", gr.update(interactive=False), \ gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), {} current_history = initialize_interaction(model_iteration) # Unpack the relevant items images = get_current_images(current_history) conversation = current_history["conversation"] role = get_human_role(current_history) human_listener = role == "Listener" current_turn = current_history['turn'] + 1 turn_msg = f"{current_turn}/12" acc_msg = "0/0" return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn_msg, acc_msg, \ gr.update(interactive=not human_listener), gr.update(interactive=human_listener), gr.update(interactive=True), gr.update(interactive=False), current_history def send_message(message, radio_choice, current_state): nonlocal processor nonlocal index_to_token # Game ended if current_state['turn'] == len(current_state['image_role_pairs']): return [], conversation_output.value, current_role.value, current_turn.value, accuracy.value, gr.update(interactive=False), \ gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, value=None), {} # Regular game progress user_output = message if radio_choice is None else radio_choice images, conversation, role, turn, acc_message, current_state = progress_game(user_output, processor, index_to_token, current_state) human_listener = role == "Listener" return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, \ acc_message, gr.update(interactive=not human_listener, value=""), gr.update(interactive=human_listener, value=None), \ gr.update(interactive=True), gr.update(interactive=False), current_state start_btn.click( start_interaction, inputs=[model_iteration], outputs=[ image_output, conversation_output, current_role, current_turn, accuracy, user_input, radio_buttons, send_btn, model_iteration, game_history], queue=False ) send_btn.click( send_message, inputs=[user_input, radio_buttons, game_history], outputs=[image_output, conversation_output, current_role, current_turn, accuracy, user_input, radio_buttons, send_btn, model_iteration, game_history], queue=True ) return app app = create_app() app.queue() app.launch()