Spaces:
Running
Running
import streamlit as st | |
import os | |
import json | |
import re | |
import datasets | |
import tiktoken | |
import zipfile | |
from pathlib import Path | |
# 定义 tiktoken 编码器 | |
encoding = tiktoken.get_encoding("cl100k_base") | |
# MGTHuman 类 | |
class MGTHuman(datasets.GeneratorBasedBuilder): | |
VERSION = datasets.Version("1.0.0") | |
BUILDER_CONFIGS = [ | |
datasets.BuilderConfig(name="human", version=VERSION, description="This part of human data"), | |
datasets.BuilderConfig(name="Moonshot", version=VERSION, description="Data from the Moonshot model"), | |
datasets.BuilderConfig(name="gpt35", version=VERSION, description="Data from the gpt-3.5-turbo model"), | |
datasets.BuilderConfig(name="Llama3", version=VERSION, description="Data from the Llama3 model"), | |
datasets.BuilderConfig(name="Mixtral", version=VERSION, description="Data from the Mixtral model"), | |
datasets.BuilderConfig(name="Qwen", version=VERSION, description="Data from the Qwen model"), | |
] | |
DEFAULT_CONFIG_NAME = "human" | |
def truncate_text(self, text, max_tokens=2048): | |
tokens = encoding.encode(text, allowed_special={'<|endoftext|>'}) | |
if len(tokens) > max_tokens: | |
tokens = tokens[:max_tokens] | |
truncated_text = encoding.decode(tokens) | |
last_period_idx = truncated_text.rfind('。') | |
if last_period_idx == -1: | |
last_period_idx = truncated_text.rfind('.') | |
if last_period_idx != -1: | |
truncated_text = truncated_text[:last_period_idx + 1] | |
return truncated_text | |
else: | |
return text | |
def get_text_by_index(self, filepath, index): | |
count = 0 | |
with open(filepath, 'r') as f: | |
data = json.load(f) | |
for row in data: | |
if not row["text"].strip(): | |
continue | |
if count == index: | |
text = self.truncate_text(row["text"], max_tokens=2048) | |
return text | |
count += 1 | |
return "Index 超出范围,请输入有效的数字。" | |
# Streamlit UI | |
st.title("MGTHuman Dataset Viewer") | |
# 上传包含 JSON 文件的 ZIP 文件 | |
uploaded_folder = st.file_uploader("上传包含 JSON 文件的 ZIP 文件夹", type=["zip"]) | |
if uploaded_folder: | |
folder_path = Path("temp") | |
folder_path.mkdir(exist_ok=True) | |
zip_path = folder_path / uploaded_folder.name | |
with open(zip_path, "wb") as f: | |
f.write(uploaded_folder.getbuffer()) | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(folder_path) | |
# 获取所有 JSON 文件并分类到不同的 domain | |
category = {} | |
for json_file in folder_path.glob("*.json"): | |
domain = json_file.stem.split('_task3')[0] | |
category.setdefault(domain, []).append(str(json_file)) | |
# 显示可用的 domain | |
st.write("可用的数据种类:", list(category.keys())) | |
# 用户选择 domain | |
selected_domain = st.selectbox("选择数据种类", options=list(category.keys())) | |
# 输入序号查看文本 | |
index_to_view = st.number_input("输入要查看的文本序号", min_value=0, step=1) | |
if st.button("显示文本"): | |
# 选择第一个文件进行展示 | |
file_to_display = category[selected_domain][0] | |
mgt_human = MGTHuman(name=selected_domain) | |
text = mgt_human.get_text_by_index(file_to_display, index=index_to_view) | |
st.write("对应的文本内容为:", text) | |
# 清理上传文件的临时目录 | |
if st.button("清除文件"): | |
import shutil | |
shutil.rmtree("temp") | |
st.write("临时文件已清除。") | |