DCWIR-Offcial-Demo / utils.py
PFEemp2024's picture
solving GPU error for previous version
4a1df2e
raw
history blame
6.36 kB
import random
from difflib import Differ
from textattack.attack_recipes import BAEGarg2019
from textattack.datasets import Dataset
from textattack.models.wrappers import HuggingFaceModelWrapper
from findfile import find_files
from flask import Flask
from textattack import Attacker
class ModelWrapper(HuggingFaceModelWrapper):
def __init__(self, model):
self.model = model # pipeline = pipeline
def __call__(self, text_inputs, **kwargs):
outputs = []
for text_input in text_inputs:
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
outputs.append(raw_outputs["probs"])
return outputs
class SentAttacker:
def __init__(self, model, recipe_class=BAEGarg2019):
model = model
model_wrapper = ModelWrapper(model)
recipe = recipe_class.build(model_wrapper)
# WordNet defaults to english. Set the default language to French ('fra')
# recipe.transformation.language = "en"
_dataset = [("", 0)]
_dataset = Dataset(_dataset)
self.attacker = Attacker(recipe, _dataset)
def diff_texts(text1, text2):
d = Differ()
text1_words = text1.split()
text2_words = text2.split()
return [
(token[2:], token[0] if token[0] != " " else None)
for token in d.compare(text1_words, text2_words)
]
def get_ensembled_tad_results(results):
target_dict = {}
for r in results:
target_dict[r["label"]] = (
target_dict.get(r["label"]) + 1 if r["label"] in target_dict else 1
)
return dict(zip(target_dict.values(), target_dict.keys()))[
max(target_dict.values())
]
def get_sst2_example():
filter_key_words = [
".py",
".md",
"readme",
"log",
"result",
"zip",
".state_dict",
".model",
".png",
"acc_",
"f1_",
".origin",
".adv",
".csv",
]
dataset_file = {"train": [], "test": [], "valid": []}
dataset = "sst2"
search_path = "./"
task = "text_defense"
dataset_file["test"] += find_files(
search_path,
[dataset, "test", task],
exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
+ filter_key_words,
)
for dat_type in ["test"]:
data = []
label_set = set()
for data_file in dataset_file[dat_type]:
with open(data_file, mode="r", encoding="utf8") as fin:
lines = fin.readlines()
for line in lines:
text, label = line.split("$LABEL$")
text = text.strip()
label = int(label.strip())
data.append((text, label))
label_set.add(label)
return random.choice(data)
def get_agnews_example():
filter_key_words = [
".py",
".md",
"readme",
"log",
"result",
"zip",
".state_dict",
".model",
".png",
"acc_",
"f1_",
".origin",
".adv",
".csv",
]
dataset_file = {"train": [], "test": [], "valid": []}
dataset = "agnews"
search_path = "./"
task = "text_defense"
dataset_file["test"] += find_files(
search_path,
[dataset, "test", task],
exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
+ filter_key_words,
)
for dat_type in ["test"]:
data = []
label_set = set()
for data_file in dataset_file[dat_type]:
with open(data_file, mode="r", encoding="utf8") as fin:
lines = fin.readlines()
for line in lines:
text, label = line.split("$LABEL$")
text = text.strip()
label = int(label.strip())
data.append((text, label))
label_set.add(label)
return random.choice(data)
def get_amazon_example():
filter_key_words = [
".py",
".md",
"readme",
"log",
"result",
"zip",
".state_dict",
".model",
".png",
"acc_",
"f1_",
".origin",
".adv",
".csv",
]
dataset_file = {"train": [], "test": [], "valid": []}
dataset = "amazon"
search_path = "./"
task = "text_defense"
dataset_file["test"] += find_files(
search_path,
[dataset, "test", task],
exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
+ filter_key_words,
)
for dat_type in ["test"]:
data = []
label_set = set()
for data_file in dataset_file[dat_type]:
with open(data_file, mode="r", encoding="utf8") as fin:
lines = fin.readlines()
for line in lines:
text, label = line.split("$LABEL$")
text = text.strip()
label = int(label.strip())
data.append((text, label))
label_set.add(label)
return random.choice(data)
def get_imdb_example():
filter_key_words = [
".py",
".md",
"readme",
"log",
"result",
"zip",
".state_dict",
".model",
".png",
"acc_",
"f1_",
".origin",
".adv",
".csv",
]
dataset_file = {"train": [], "test": [], "valid": []}
dataset = "imdb"
search_path = "./"
task = "text_defense"
dataset_file["test"] += find_files(
search_path,
[dataset, "test", task],
exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
+ filter_key_words,
)
for dat_type in ["test"]:
data = []
label_set = set()
for data_file in dataset_file[dat_type]:
with open(data_file, mode="r", encoding="utf8") as fin:
lines = fin.readlines()
for line in lines:
text, label = line.split("$LABEL$")
text = text.strip()
label = int(label.strip())
data.append((text, label))
label_set.add(label)
return random.choice(data)