import gradio as gr import spaces from scorer import DSGPromptProcessor import matplotlib.pyplot as plt import networkx as nx from PIL import Image import io def draw_colored_graph(dependencies, questions, answers): # Create a directed graph G = nx.DiGraph() # Add nodes with labels and colors based on answers for node, question in questions.items(): color = 'green' if answers[node] else 'red' G.add_node(int(node), label=question, color=color) # Add edges based on dependencies for node, deps in dependencies.items(): for dep in deps: G.add_edge(dep, int(node)) # Set node positions using a layout pos = nx.spring_layout(G) # You can use other layouts like 'shell_layout' or 'circular_layout' # Draw nodes with custom colors and labels node_colors = [G.nodes[node]['color'] for node in G.nodes()] nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000, edgecolors='white') # Draw edges with arrows nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.1') # Draw labels labels = nx.get_node_attributes(G, 'label') nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color='black') # Save the graph as a Pillow image buf = io.BytesIO() plt.axis('off') plt.savefig(buf, format='png') buf.seek(0) img = Image.open(buf) return img processor = DSGPromptProcessor("mistralai/Mixtral-8x7B-Instruct-v0.1") @spaces.GPU() def process_image(image, prompt): tuples, _ = processor.generate_tuples(prompt) dependencies, _ = processor.generate_dependencies(tuples) questions, _ = processor.generate_questions( prompt, tuples.tuples, dependencies ) reward, sorted_questions = processor.get_reward(questions, dependencies, [image]) reward = reward[0] print(reward) answers = {str(i): v > 0.5 for i, v in enumerate(reward)} sorted_questions = {str(i): v for i, v in enumerate(sorted_questions)} print(answers, sorted_questions) graph_img = draw_colored_graph(dependencies, sorted_questions, answers) return graph_img, f""" Question: {questions}. Reward per question: {reward}""" description = """

[Original Paper] [My Github] [Binary VQA Model - Query Answering] [Mixtral 7x8 - Query Generating]

""" css = ''' #gen_btn{height: 100%} #title{text-align: center} #title h1{font-size: 3em; display:inline-flex; align-items:center} #title img{width: 100px; margin-right: 0.5em} #gallery .grid-wrap{height: 10vh} ''' # Define the Gradio interface interface = gr.Interface( fn=process_image, inputs=[gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Enter your prompt")], outputs=[gr.Image(type="pil", label="Graph Score Image", format="png"), gr.Textbox(label="Analyzed Result")], theme=gr.themes.Soft(), description=description, examples = [ ["examples/input_image.png", "A cat with red eyes in the jungle. All tree in the jungle has blue color."], ], css=css, title="T2I Adherence Scorer based on Davidsonian Scene Graph", cache_examples=True ) # Launch the Gradio app interface.launch()