File size: 6,804 Bytes
6f96319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296f63c
6f96319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296f63c
6f96319
 
 
 
 
296f63c
6f96319
 
479384b
6f96319
 
296f63c
6f96319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a80e8e
6f96319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# import streamlit as st
# import os
# import json
# import re
# import datasets
# import tiktoken
# import zipfile
# from pathlib import Path

# # 定义 tiktoken 编码器
# encoding = tiktoken.get_encoding("cl100k_base")

# _CITATION = """\
# @InProceedings{huggingface:dataset,
# title = {MGT detection},
# author={Trustworthy AI Lab},
# year={2024}
# }
# """

# _DESCRIPTION = """\
# For detecting machine generated text.
# """

# _HOMEPAGE = ""
# _LICENSE = ""

# # MGTHuman 类
# class MGTHuman(datasets.GeneratorBasedBuilder):
#     VERSION = datasets.Version("1.0.0")
#     BUILDER_CONFIGS = [
#         datasets.BuilderConfig(name="human", version=VERSION, description="This part of human data"),
#         datasets.BuilderConfig(name="Moonshot", version=VERSION, description="Data from the Moonshot model"),
#         datasets.BuilderConfig(name="gpt35", version=VERSION, description="Data from the gpt-3.5-turbo model"),
#         datasets.BuilderConfig(name="Llama3", version=VERSION, description="Data from the Llama3 model"),
#         datasets.BuilderConfig(name="Mixtral", version=VERSION, description="Data from the Mixtral model"),
#         datasets.BuilderConfig(name="Qwen", version=VERSION, description="Data from the Qwen model"),
#     ]
#     DEFAULT_CONFIG_NAME = "human"

#     def _info(self):
#         features = datasets.Features(
#             {
#                 "id": datasets.Value("int32"),
#                 "text": datasets.Value("string"),
#                 "file": datasets.Value("string"),
#             }
#         )
#         return datasets.DatasetInfo(
#             description=_DESCRIPTION,
#             features=features,
#             homepage=_HOMEPAGE,
#             license=_LICENSE,
#             citation=_CITATION,
#         )

#     def truncate_text(self, text, max_tokens=2048):
#         tokens = encoding.encode(text, allowed_special={'<|endoftext|>'})
#         if len(tokens) > max_tokens:
#             tokens = tokens[:max_tokens]
#             truncated_text = encoding.decode(tokens)
#             last_period_idx = truncated_text.rfind('。')
#             if last_period_idx == -1:
#                 last_period_idx = truncated_text.rfind('.')
#             if last_period_idx != -1:
#                 truncated_text = truncated_text[:last_period_idx + 1]
#             return truncated_text
#         else:
#             return text

#     def get_text_by_index(self, filepath, index, cut_tokens=False, max_tokens=2048):
#         count = 0
#         with open(filepath, 'r') as f:
#             data = json.load(f)
#         for row in data:
#             if not row["text"].strip():
#                 continue
#             if count == index:
#                 text = row["text"]
#                 if cut_tokens:
#                     text = self.truncate_text(text, max_tokens)
#                 return text
#             count += 1
#         return "Index 超出范围,请输入有效的数字。"
    
#     def count_entries(self, filepath):
#         """返回文件中的总条数,用于动态生成索引范围"""
#         count = 0
#         with open(filepath, 'r') as f:
#             data = json.load(f)
#             for row in data:
#                 if row["text"].strip():
#                     count += 1
#         return count

# # Streamlit UI
# st.title("MGTHuman Dataset Viewer")

# # 上传包含 JSON 文件的 ZIP 文件
# uploaded_folder = st.file_uploader("上传包含 JSON 文件的 ZIP 文件夹", type=["zip"])
# if uploaded_folder:
#     folder_path = Path("temp")
#     folder_path.mkdir(exist_ok=True)
#     zip_path = folder_path / uploaded_folder.name
#     with open(zip_path, "wb") as f:
#         f.write(uploaded_folder.getbuffer())

#     with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#         zip_ref.extractall(folder_path)

#     # 递归获取所有 JSON 文件并分类到不同的 domain
#     category = {}
#     for json_file in folder_path.rglob("*.json"):  # 使用 rglob 递归查找所有 JSON 文件
#         domain = json_file.stem.split('_task3')[0]
#         category.setdefault(domain, []).append(str(json_file))

#     # 显示可用的 domain 下拉框
#     if category:
#         selected_domain = st.selectbox("选择数据种类", options=list(category.keys()))
        
#         # 确定该 domain 的第一个文件路径并获取条目数量
#         file_to_display = category[selected_domain][0]
#         mgt_human = MGTHuman(name=selected_domain)
#         total_entries = mgt_human.count_entries(file_to_display)
#         st.write(f"可用的索引范围: 0 到 {total_entries - 1}")
        
#         # 输入序号查看文本
#         index_to_view = st.number_input("输入要查看的文本序号", min_value=0, max_value=total_entries - 1, step=1)

#         # 添加复选框以选择是否切割文本
#         cut_tokens = st.checkbox("是否对文本进行token切割", value=False)
        
#         if st.button("显示文本"):
#             text = mgt_human.get_text_by_index(file_to_display, index=index_to_view, cut_tokens=cut_tokens)
#             st.write("对应的文本内容为:", text)
#     else:
#         st.write("未找到任何 JSON 文件,请检查 ZIP 文件结构。")

# # 清理上传文件的临时目录
# if st.button("清除文件"):
#     import shutil
#     shutil.rmtree("temp")
#     st.write("临时文件已清除。")

import streamlit as st
from transformers import pipeline

# Initialize Hugging Face text classifier
@st.cache_resource  # Cache the model to avoid reloading
def load_model():
    # Use a Hugging Face pre-trained text classification model
    # Replace with a suitable model if necessary
    classifier = pipeline("text-classification", model="roberta-base-openai-detector")
    return classifier

st.title("Machine-Generated Text Detector")
st.write("Enter a text snippet, and I will analyze it to determine if it is likely written by a human or generated by a machine.")

# Load the model
classifier = load_model()

# Input text
input_text = st.text_area("Enter text here:", height=150)

# Button to trigger detection
if st.button("Analyze"):
    if input_text:
        # Make prediction
        result = classifier(input_text)

        # Extract label and confidence score
        label = result[0]['label']
        score = result[0]['score'] * 100  # Convert to percentage for readability

        # Display result
        if label == "LABEL_1":
            st.write(f"**Result:** This text is likely **Machine-Generated**.")
        else:
            st.write(f"**Result:** This text is likely **Human-Written**.")

        # Display confidence score
        st.write(f"**Confidence Score:** {score:.2f}%")
    else:
        st.write("Please enter some text for analysis.")