File size: 2,956 Bytes
189d2e5
 
 
 
 
 
c4dec28
189d2e5
 
 
e277252
189d2e5
 
 
 
 
 
 
 
 
 
 
 
 
e277252
 
 
 
 
189d2e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e277252
 
 
 
 
189d2e5
 
 
 
 
 
 
 
 
e277252
189d2e5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import base64
import requests
import io
from PIL import Image
import numpy as np
import os

URL = os.environ['URL']

def sketch_to_text(image, api_key):
    if image is None or not isinstance(image, dict) or 'composite' not in image:
        return "Please write something first."
    
    # Extract the image data from the dictionary
    image_data = image['composite']
    # Convert the image data to a PIL Image
    pil_image = Image.fromarray(image_data.astype(np.uint8))
    
    # Convert the image to base64
    buffered = io.BytesIO()
    pil_image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    if api_key:
        API_KEY = api_key
    else:
        API_KEY = os.environ['API_KEY']
        
    # Prepare the API request
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {API_KEY}"
    }
    payload = {
        "model": "Llama-3.2-11B-Vision-Instruct",
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "Please read the forumla in the image and transcribe it into latex code, respond with code only, make sure the code is enclosed within a pair of $$"
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{img_str}"
                        }
                    }
                ]
            }
        ],
        "max_tokens": 300
    }

    # Make the API request
    response = requests.post(URL, headers=headers, json=payload)
    
    if response.status_code == 200:
        return response.json()["choices"][0]["message"]["content"], response.json()["choices"][0]["message"]["content"]
    else:
        return f"Error: {response.status_code}, {response.text}", f"Error: {response.status_code}, {response.text}"

# Create the Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# Pix2Latex")
    gr.Markdown("Transcribing handwritten forumla into latex with Llama3.2 instruct. [Powered by SambaNova Cloud, Get Your API Key Here](https://cloud.sambanova.ai/apis)")

    with gr.Row():
        api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability. ")
        
    with gr.Column(scale=1):
        input_image = gr.ImageEditor()

    with gr.Row():
        with gr.Column(scale=1):
            output1 = gr.Textbox(label="Raw")
        with gr.Column(scale=1):
            output2 = gr.Markdown(label="Rendered")
    
    input_image.change(fn=sketch_to_text, inputs=[input_image, api_key], outputs=[output1, output2])
    
    gr.Markdown("How to use: 1. write your formula in the box above. 2. See it in real time, have fun doing math?")

# Launch the app
iface.launch()