Evan73 commited on
Commit
6f96319
1 Parent(s): 479384b

update the app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -137
app.py CHANGED
@@ -1,143 +1,184 @@
1
- import streamlit as st
2
- import os
3
- import json
4
- import re
5
- import datasets
6
- import tiktoken
7
- import zipfile
8
- from pathlib import Path
9
-
10
- # 定义 tiktoken 编码器
11
- encoding = tiktoken.get_encoding("cl100k_base")
12
-
13
- _CITATION = """\
14
- @InProceedings{huggingface:dataset,
15
- title = {MGT detection},
16
- author={Trustworthy AI Lab},
17
- year={2024}
18
- }
19
- """
20
-
21
- _DESCRIPTION = """\
22
- For detecting machine generated text.
23
- """
24
-
25
- _HOMEPAGE = ""
26
- _LICENSE = ""
27
-
28
- # MGTHuman 类
29
- class MGTHuman(datasets.GeneratorBasedBuilder):
30
- VERSION = datasets.Version("1.0.0")
31
- BUILDER_CONFIGS = [
32
- datasets.BuilderConfig(name="human", version=VERSION, description="This part of human data"),
33
- datasets.BuilderConfig(name="Moonshot", version=VERSION, description="Data from the Moonshot model"),
34
- datasets.BuilderConfig(name="gpt35", version=VERSION, description="Data from the gpt-3.5-turbo model"),
35
- datasets.BuilderConfig(name="Llama3", version=VERSION, description="Data from the Llama3 model"),
36
- datasets.BuilderConfig(name="Mixtral", version=VERSION, description="Data from the Mixtral model"),
37
- datasets.BuilderConfig(name="Qwen", version=VERSION, description="Data from the Qwen model"),
38
- ]
39
- DEFAULT_CONFIG_NAME = "human"
40
-
41
- def _info(self):
42
- features = datasets.Features(
43
- {
44
- "id": datasets.Value("int32"),
45
- "text": datasets.Value("string"),
46
- "file": datasets.Value("string"),
47
- }
48
- )
49
- return datasets.DatasetInfo(
50
- description=_DESCRIPTION,
51
- features=features,
52
- homepage=_HOMEPAGE,
53
- license=_LICENSE,
54
- citation=_CITATION,
55
- )
56
-
57
- def truncate_text(self, text, max_tokens=2048):
58
- tokens = encoding.encode(text, allowed_special={'<|endoftext|>'})
59
- if len(tokens) > max_tokens:
60
- tokens = tokens[:max_tokens]
61
- truncated_text = encoding.decode(tokens)
62
- last_period_idx = truncated_text.rfind('。')
63
- if last_period_idx == -1:
64
- last_period_idx = truncated_text.rfind('.')
65
- if last_period_idx != -1:
66
- truncated_text = truncated_text[:last_period_idx + 1]
67
- return truncated_text
68
- else:
69
- return text
70
-
71
- def get_text_by_index(self, filepath, index, cut_tokens=False, max_tokens=2048):
72
- count = 0
73
- with open(filepath, 'r') as f:
74
- data = json.load(f)
75
- for row in data:
76
- if not row["text"].strip():
77
- continue
78
- if count == index:
79
- text = row["text"]
80
- if cut_tokens:
81
- text = self.truncate_text(text, max_tokens)
82
- return text
83
- count += 1
84
- return "Index 超出范围,请输入有效的数字。"
85
 
86
- def count_entries(self, filepath):
87
- """返回文件中的总条数,用于动态生成索引范围"""
88
- count = 0
89
- with open(filepath, 'r') as f:
90
- data = json.load(f)
91
- for row in data:
92
- if row["text"].strip():
93
- count += 1
94
- return count
95
-
96
- # Streamlit UI
97
- st.title("MGTHuman Dataset Viewer")
98
-
99
- # 上传包含 JSON 文件的 ZIP 文件
100
- uploaded_folder = st.file_uploader("上传包含 JSON 文件的 ZIP 文件夹", type=["zip"])
101
- if uploaded_folder:
102
- folder_path = Path("temp")
103
- folder_path.mkdir(exist_ok=True)
104
- zip_path = folder_path / uploaded_folder.name
105
- with open(zip_path, "wb") as f:
106
- f.write(uploaded_folder.getbuffer())
107
-
108
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
109
- zip_ref.extractall(folder_path)
110
-
111
- # 递归获取所有 JSON 文件并分类到不同的 domain
112
- category = {}
113
- for json_file in folder_path.rglob("*.json"): # 使用 rglob 递归查找所有 JSON 文件
114
- domain = json_file.stem.split('_task3')[0]
115
- category.setdefault(domain, []).append(str(json_file))
116
-
117
- # 显示可用的 domain 下拉框
118
- if category:
119
- selected_domain = st.selectbox("选择数据种类", options=list(category.keys()))
120
 
121
- # 确定该 domain 的第一个文件路径并获取条目数量
122
- file_to_display = category[selected_domain][0]
123
- mgt_human = MGTHuman(name=selected_domain)
124
- total_entries = mgt_human.count_entries(file_to_display)
125
- st.write(f"可用的索引范围: 0 到 {total_entries - 1}")
126
 
127
- # 输入序号查看文本
128
- index_to_view = st.number_input("输入要查看的文本序号", min_value=0, max_value=total_entries - 1, step=1)
129
 
130
- # 添加复选框以选择是否切割文本
131
- cut_tokens = st.checkbox("是否对文本进行token切割", value=False)
132
 
133
- if st.button("显示文本"):
134
- text = mgt_human.get_text_by_index(file_to_display, index=index_to_view, cut_tokens=cut_tokens)
135
- st.write("对应的文本内容为:", text)
136
- else:
137
- st.write("未找到任何 JSON 文件,请检查 ZIP 文件结构。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- # 清理上传文件的临时目录
140
- if st.button("清除文件"):
141
- import shutil
142
- shutil.rmtree("temp")
143
- st.write("临时文件已清除。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
+ # import os
3
+ # import json
4
+ # import re
5
+ # import datasets
6
+ # import tiktoken
7
+ # import zipfile
8
+ # from pathlib import Path
9
+
10
+ # # 定义 tiktoken 编码器
11
+ # encoding = tiktoken.get_encoding("cl100k_base")
12
+
13
+ # _CITATION = """\
14
+ # @InProceedings{huggingface:dataset,
15
+ # title = {MGT detection},
16
+ # author={Trustworthy AI Lab},
17
+ # year={2024}
18
+ # }
19
+ # """
20
+
21
+ # _DESCRIPTION = """\
22
+ # For detecting machine generated text.
23
+ # """
24
+
25
+ # _HOMEPAGE = ""
26
+ # _LICENSE = ""
27
+
28
+ # # MGTHuman 类
29
+ # class MGTHuman(datasets.GeneratorBasedBuilder):
30
+ # VERSION = datasets.Version("1.0.0")
31
+ # BUILDER_CONFIGS = [
32
+ # datasets.BuilderConfig(name="human", version=VERSION, description="This part of human data"),
33
+ # datasets.BuilderConfig(name="Moonshot", version=VERSION, description="Data from the Moonshot model"),
34
+ # datasets.BuilderConfig(name="gpt35", version=VERSION, description="Data from the gpt-3.5-turbo model"),
35
+ # datasets.BuilderConfig(name="Llama3", version=VERSION, description="Data from the Llama3 model"),
36
+ # datasets.BuilderConfig(name="Mixtral", version=VERSION, description="Data from the Mixtral model"),
37
+ # datasets.BuilderConfig(name="Qwen", version=VERSION, description="Data from the Qwen model"),
38
+ # ]
39
+ # DEFAULT_CONFIG_NAME = "human"
40
+
41
+ # def _info(self):
42
+ # features = datasets.Features(
43
+ # {
44
+ # "id": datasets.Value("int32"),
45
+ # "text": datasets.Value("string"),
46
+ # "file": datasets.Value("string"),
47
+ # }
48
+ # )
49
+ # return datasets.DatasetInfo(
50
+ # description=_DESCRIPTION,
51
+ # features=features,
52
+ # homepage=_HOMEPAGE,
53
+ # license=_LICENSE,
54
+ # citation=_CITATION,
55
+ # )
56
+
57
+ # def truncate_text(self, text, max_tokens=2048):
58
+ # tokens = encoding.encode(text, allowed_special={'<|endoftext|>'})
59
+ # if len(tokens) > max_tokens:
60
+ # tokens = tokens[:max_tokens]
61
+ # truncated_text = encoding.decode(tokens)
62
+ # last_period_idx = truncated_text.rfind('。')
63
+ # if last_period_idx == -1:
64
+ # last_period_idx = truncated_text.rfind('.')
65
+ # if last_period_idx != -1:
66
+ # truncated_text = truncated_text[:last_period_idx + 1]
67
+ # return truncated_text
68
+ # else:
69
+ # return text
70
+
71
+ # def get_text_by_index(self, filepath, index, cut_tokens=False, max_tokens=2048):
72
+ # count = 0
73
+ # with open(filepath, 'r') as f:
74
+ # data = json.load(f)
75
+ # for row in data:
76
+ # if not row["text"].strip():
77
+ # continue
78
+ # if count == index:
79
+ # text = row["text"]
80
+ # if cut_tokens:
81
+ # text = self.truncate_text(text, max_tokens)
82
+ # return text
83
+ # count += 1
84
+ # return "Index 超出范围,请输入有效的数字。"
85
 
86
+ # def count_entries(self, filepath):
87
+ # """返回文件中的总条数,用于动态生成索引范围"""
88
+ # count = 0
89
+ # with open(filepath, 'r') as f:
90
+ # data = json.load(f)
91
+ # for row in data:
92
+ # if row["text"].strip():
93
+ # count += 1
94
+ # return count
95
+
96
+ # # Streamlit UI
97
+ # st.title("MGTHuman Dataset Viewer")
98
+
99
+ # # 上传包含 JSON 文件的 ZIP 文件
100
+ # uploaded_folder = st.file_uploader("上传包含 JSON 文件的 ZIP 文件夹", type=["zip"])
101
+ # if uploaded_folder:
102
+ # folder_path = Path("temp")
103
+ # folder_path.mkdir(exist_ok=True)
104
+ # zip_path = folder_path / uploaded_folder.name
105
+ # with open(zip_path, "wb") as f:
106
+ # f.write(uploaded_folder.getbuffer())
107
+
108
+ # with zipfile.ZipFile(zip_path, 'r') as zip_ref:
109
+ # zip_ref.extractall(folder_path)
110
+
111
+ # # 递归获取所有 JSON 文件并分类到不同的 domain
112
+ # category = {}
113
+ # for json_file in folder_path.rglob("*.json"): # 使用 rglob 递归查找所有 JSON 文件
114
+ # domain = json_file.stem.split('_task3')[0]
115
+ # category.setdefault(domain, []).append(str(json_file))
116
+
117
+ # # 显示可用的 domain 下拉框
118
+ # if category:
119
+ # selected_domain = st.selectbox("选择数据种类", options=list(category.keys()))
120
 
121
+ # # 确定该 domain 的第一个文件路径并获取条目数量
122
+ # file_to_display = category[selected_domain][0]
123
+ # mgt_human = MGTHuman(name=selected_domain)
124
+ # total_entries = mgt_human.count_entries(file_to_display)
125
+ # st.write(f"可用的索引范围: 0 到 {total_entries - 1}")
126
 
127
+ # # 输入序号查看文本
128
+ # index_to_view = st.number_input("输入要查看的文本序号", min_value=0, max_value=total_entries - 1, step=1)
129
 
130
+ # # 添加复选框以选择是否切割文本
131
+ # cut_tokens = st.checkbox("是否对文本进行token切割", value=False)
132
 
133
+ # if st.button("显示文本"):
134
+ # text = mgt_human.get_text_by_index(file_to_display, index=index_to_view, cut_tokens=cut_tokens)
135
+ # st.write("对应的文本内容为:", text)
136
+ # else:
137
+ # st.write("未找到任何 JSON 文件,请检查 ZIP 文件结构。")
138
+
139
+ # # 清理上传文件的临时目录
140
+ # if st.button("清除文件"):
141
+ # import shutil
142
+ # shutil.rmtree("temp")
143
+ # st.write("临时文件已清除。")
144
+
145
+ import streamlit as st
146
+ from transformers import pipeline
147
+
148
+ # Initialize Hugging Face text classifier
149
+ @st.cache_resource # Cache the model to avoid reloading
150
+ def load_model():
151
+ # Use a Hugging Face pre-trained text classification model
152
+ # Replace with a suitable model if necessary
153
+ classifier = pipeline("text-classification", model="roberta-base-openai-detector")
154
+ return classifier
155
+
156
+ st.title("Machine-Generated Text Detector")
157
+ st.write("Enter a text snippet, and I will analyze it to determine if it is likely written by a human or generated by a machine.")
158
 
159
+ # Load the model
160
+ classifier = load_model()
161
+
162
+ # Input text
163
+ input_text = st.text_area("Enter text here:", height=150)
164
+
165
+ # Button to trigger detection
166
+ if st.button("Analyze"):
167
+ if input_text:
168
+ # Make prediction
169
+ result = classifier(input_text)
170
+
171
+ # Extract label and confidence score
172
+ label = result[0]['label']
173
+ score = result[0]['score'] * 100 # Convert to percentage for readability
174
+
175
+ # Display result
176
+ if label == "LABEL_1":
177
+ st.write(f"**Result:** This text is likely **Machine-Generated**.")
178
+ else:
179
+ st.write(f"**Result:** This text is likely **Human-Written**.")
180
+
181
+ # Display confidence score
182
+ st.write(f"**Confidence Score:** {score:.2f}%")
183
+ else:
184
+ st.write("Please enter some text for analysis.")