MGTbenchmark / app.py
Evan73's picture
modify app.py
732150f
raw
history blame
3.61 kB
import streamlit as st
import os
import json
import re
import datasets
import tiktoken
import zipfile
from pathlib import Path
# 定义 tiktoken 编码器
encoding = tiktoken.get_encoding("cl100k_base")
# MGTHuman 类
class MGTHuman(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("1.0.0")
BUILDER_CONFIGS = [
datasets.BuilderConfig(name="human", version=VERSION, description="This part of human data"),
datasets.BuilderConfig(name="Moonshot", version=VERSION, description="Data from the Moonshot model"),
datasets.BuilderConfig(name="gpt35", version=VERSION, description="Data from the gpt-3.5-turbo model"),
datasets.BuilderConfig(name="Llama3", version=VERSION, description="Data from the Llama3 model"),
datasets.BuilderConfig(name="Mixtral", version=VERSION, description="Data from the Mixtral model"),
datasets.BuilderConfig(name="Qwen", version=VERSION, description="Data from the Qwen model"),
]
DEFAULT_CONFIG_NAME = "human"
def truncate_text(self, text, max_tokens=2048):
tokens = encoding.encode(text, allowed_special={'<|endoftext|>'})
if len(tokens) > max_tokens:
tokens = tokens[:max_tokens]
truncated_text = encoding.decode(tokens)
last_period_idx = truncated_text.rfind('。')
if last_period_idx == -1:
last_period_idx = truncated_text.rfind('.')
if last_period_idx != -1:
truncated_text = truncated_text[:last_period_idx + 1]
return truncated_text
else:
return text
def get_text_by_index(self, filepath, index):
count = 0
with open(filepath, 'r') as f:
data = json.load(f)
for row in data:
if not row["text"].strip():
continue
if count == index:
text = self.truncate_text(row["text"], max_tokens=2048)
return text
count += 1
return "Index 超出范围,请输入有效的数字。"
# Streamlit UI
st.title("MGTHuman Dataset Viewer")
# 上传包含 JSON 文件的 ZIP 文件
uploaded_folder = st.file_uploader("上传包含 JSON 文件的 ZIP 文件夹", type=["zip"])
if uploaded_folder:
folder_path = Path("temp")
folder_path.mkdir(exist_ok=True)
zip_path = folder_path / uploaded_folder.name
with open(zip_path, "wb") as f:
f.write(uploaded_folder.getbuffer())
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(folder_path)
# 获取所有 JSON 文件并分类到不同的 domain
category = {}
for json_file in folder_path.glob("*.json"):
domain = json_file.stem.split('_task3')[0]
category.setdefault(domain, []).append(str(json_file))
# 显示可用的 domain
st.write("可用的数据种类:", list(category.keys()))
# 用户选择 domain
selected_domain = st.selectbox("选择数据种类", options=list(category.keys()))
# 输入序号查看文本
index_to_view = st.number_input("输入要查看的文本序号", min_value=0, step=1)
if st.button("显示文本"):
# 选择第一个文件进行展示
file_to_display = category[selected_domain][0]
mgt_human = MGTHuman(name=selected_domain)
text = mgt_human.get_text_by_index(file_to_display, index=index_to_view)
st.write("对应的文本内容为:", text)
# 清理上传文件的临时目录
if st.button("清除文件"):
import shutil
shutil.rmtree("temp")
st.write("临时文件已清除。")