#!/usr/bin/env python import ast import os import datasets import gradio as gr import PIL.Image DESCRIPTION = """\ # [JMMMU](https://huggingface.co/datasets/JMMMU/JMMMU) dataset viewer """ SHOW_ANSWER = os.getenv("SHOW_ANSWER", "false").lower() == "true" SHOW_QUESTION_DETAILS = os.getenv("SHOW_QUESTION_DETAILS", "false").lower() == "true" SUBJECTS = [ "Accounting", "Agriculture", "Architecture_and_Engineering", "Basic_Medical_Science", "Biology", "Chemistry", "Clinical_Medicine", "Computer_Science", "Design", "Diagnostics_and_Laboratory_Medicine", "Economics", "Electronics", "Energy_and_Power", "Finance", "Japanese_Art", "Japanese_Heritage", "Japanese_History", "Manage", "Marketing", "Materials", "Math", "Mechanical_Engineering", "Music", "Pharmacy", "Physics", "Psychology", "Public_Health", "World_History", ] ds = {subject: datasets.load_dataset("JMMMU/JMMMU", name=subject, split="test") for subject in SUBJECTS} def set_default_subject() -> str: return "Accounting" def get_images(subject: str, question_index: int) -> list[PIL.Image.Image]: images = [] for image_id in range(1, 8): image = ds[subject][question_index][f"image_{image_id}"] if image is None: break images.append(image) return images def update_subject( subject: str, ) -> tuple[ gr.Textbox, # Number of Questions gr.Slider, # Question Index gr.Gallery, # Images gr.Textbox, # Question gr.Textbox, # Options gr.Textbox, # Answer gr.Textbox, # Explanation gr.Textbox, # Topic Difficulty gr.Textbox, # Question Type gr.Textbox, # Subfield ]: return ( gr.Textbox(value=len(ds[subject])), # Number of Questions gr.Slider(label="Question Index", minimum=0, maximum=len(ds[subject]) - 1, step=1, value=0), # Question Index *update_question(subject, 0), ) def update_question( subject: str, question_index: int ) -> tuple[ gr.Gallery, # Images gr.Textbox, # Question gr.Textbox, # Options gr.Textbox, # Answer gr.Textbox, # Explanation gr.Textbox, # Topic Difficulty gr.Textbox, # Question Type gr.Textbox, # Subfield ]: question = ds[subject][question_index] options = ast.literal_eval(question["options"]) options_str = "\n".join([f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]) images = get_images(subject, question_index) return ( gr.Gallery(value=images, columns=min(len(images), 2)), # Images gr.Textbox(value=question["question"]), # Question gr.Textbox(value=options_str), # Options gr.Textbox(value=question["answer"]), # Answer gr.Textbox(value=question["explanation"]), # Explanation gr.Textbox(value=question["topic_difficulty"]), # Topic Difficulty gr.Textbox(value=question["question_type"]), # Question Type gr.Textbox(value=question["subfield"]), # Subfield ) with gr.Blocks(css_paths="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): subject = gr.Dropdown(label="Subject", choices=SUBJECTS, value=SUBJECTS[-1]) question_count = gr.Textbox(label="Number of Questions") with gr.Group(): question_index = gr.Slider(label="Question Index") with gr.Row(): with gr.Column(): question = gr.Textbox(label="Question") options = gr.Textbox(label="Options") with gr.Column(): images = gr.Gallery(label="Images", object_fit="scale-down") with gr.Accordion("Answer and Explanation", open=SHOW_ANSWER): with gr.Row(): answer = gr.Textbox(label="Answer") explanation = gr.Textbox(label="Explanation") with gr.Accordion("Question Details", open=SHOW_QUESTION_DETAILS): with gr.Row(): topic_difficulty = gr.Textbox(label="Topic Difficulty") question_type = gr.Textbox(label="Question Type") subfield = gr.Textbox(label="Subfield") subject.change( fn=update_subject, inputs=subject, outputs=[ question_count, question_index, images, question, options, answer, explanation, topic_difficulty, question_type, subfield, ], queue=False, api_name=False, ) question_index.input( fn=update_question, inputs=[subject, question_index], outputs=[images, question, options, answer, explanation, topic_difficulty, question_type, subfield], queue=False, api_name=False, ) demo.load(fn=set_default_subject, outputs=subject, queue=False, api_name=False) if __name__ == "__main__": demo.queue(api_open=False).launch(show_api=False)