File size: 4,116 Bytes
faa2a07
 
 
 
 
 
 
 
 
 
 
 
 
 
6d3d75e
 
d67b468
6d3d75e
faa2a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7c820e
faa2a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bac9ca
a2d2c17
d67b468
583707e
8d29462
d67b468
29c70aa
 
 
 
 
 
 
 
 
de08eb3
d67b468
de08eb3
 
421058f
29c70aa
de08eb3
 
29c70aa
aedb0d3
85afbda
 
 
 
 
 
 
29c70aa
85afbda
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Reference
- https://docs.streamlit.io/library/api-reference/layout
- https://github.com/CodingMantras/yolov8-streamlit-detection-tracking/blob/master/app.py
- https://huggingface.co/keremberke/yolov8m-valorant-detection/tree/main
- https://docs.ultralytics.com/usage/python/
"""
import time
import PIL

import streamlit as st
import torch
from ultralyticsplus import YOLO, render_result

from gtts import gTTS
import os
import pygame

from convert import convert_to_braille_unicode, parse_xywh_and_class


def load_model(model_path):
    """load model from path"""
    model = YOLO(model_path)
    return model


def load_image(image_path):
    """load image from path"""
    image = PIL.Image.open(image_path)
    return image

# title
st.title("Braille Pattern Detection")

# sidebar
st.sidebar.header("Detection Config")

conf = float(st.sidebar.slider("Class Confidence", 10, 75, 15)) / 100
iou = float(st.sidebar.slider("IoU Threshold", 10, 75, 15)) / 100

model_path = "snoop2head/yolov8m-braille"

try:
    model = load_model(model_path)
    model.overrides["conf"] = conf  # NMS confidence threshold
    model.overrides["iou"] = iou  # NMS IoU threshold
    model.overrides["agnostic_nms"] = False  # NMS class-agnostic
    model.overrides["max_det"] = 1000  # maximum number of detections per image

except Exception as ex:
    print(ex)
    st.write(f"Unable to load model. Check the specified path: {model_path}")

source_img = None

source_img = st.sidebar.file_uploader(
    "Choose an image...", type=("jpg", "jpeg", "png", "bmp", "webp")
)
col1, col2 = st.columns(2)

# left column of the page body
with col1:
    if source_img is None:
        default_image_path = "./image/test_1.jpg"
        image = load_image(default_image_path)
        st.image(
            default_image_path, caption="Example Input Image", use_column_width=True
        )
    else:
        image = load_image(source_img)
        st.image(source_img, caption="Uploaded Image", use_column_width=True)

# right column of the page body
with col2:
    with st.spinner("Wait for it..."):
        start_time = time.time()
    try:
        with torch.no_grad():
            res = model.predict(
                image, save=True, save_txt=True, exist_ok=True, conf=conf
            )
            boxes = res[0].boxes  # first image
            res_plotted = res[0].plot()[:, :, ::-1]

            list_boxes = parse_xywh_and_class(boxes)

            st.image(res_plotted, caption="Detected Image", use_column_width=True)
            IMAGE_DOWNLOAD_PATH = f"runs/detect/predict/image0.jpg"

    except Exception as ex:
        st.write("Please upload image with types of JPG, JPEG, PNG ...")


try:
    st.success(f"Done! Inference time: {time.time() - start_time:.2f} seconds")
    st.subheader("Detected Braille Patterns")
    for box_line in list_boxes:
        str_left_to_right = ""
        box_classes = box_line[:, -1]
        for each_class in box_classes:
            str_left_to_right += convert_to_braille_unicode(model.names[int(each_class)])
        result += str_left_to_right + "\n"
        st.write(str_left_to_right)
except Exception as ex:
    st.write("Please try again with images with types of JPG, JPEG, PNG ...")

import tempfile

def text_to_speech_gtts(text, lang='en'):
    """将文本转换为语音并保存为临时音频文件"""
    # 创建一个临时文件
    with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio_file:
        audio_path = temp_audio_file.name
        tts = gTTS(text=text, lang=lang)
        tts.save(audio_path)
    return audio_path

try:
    # 生成语音文件
    audio_file_path = text_to_speech_gtts(result)

    # 在 Streamlit 中播放音频
    st.audio(audio_file_path, format="audio/mp3")

    # 提供下载按钮
    with open(audio_file_path, "rb") as audio_file:
        st.download_button(
            label="Download Braille Audio",
            data=audio_file,
            file_name="detected_braille.mp3",
            mime="audio/mp3",
        )

except Exception as ex:
    st.write("An error occurred while processing the audio.")