thorunna
commited on
Commit
•
5cdaab3
1
Parent(s):
62753ff
Script to run model updated
Browse files- run_model.py +29 -183
run_model.py
CHANGED
@@ -1,188 +1,34 @@
|
|
1 |
-
"""
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
import
|
6 |
-
import
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
-
logging.info(f"Device: {device}")
|
17 |
-
|
18 |
-
# Prompts for the different tasks
|
19 |
-
START_PROMPT_TASK1 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
|
20 |
-
END_PROMPT_TASK1 = "Sérðu eitthvað sem mætti betur fara í textanum? Búðu til lista af öllum slíkum tilvikum þar sem hver lína tilgreinir hver villan er, hvar hún er, og hvað væri gert í staðinn fyrir villuna.\n\n"
|
21 |
-
|
22 |
-
START_PROMPT_TASK2 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.Ég er með tvær útgáfur af textanum, A og B, og önnur þeirra gæti verið betri en hin á einhvern hátt, t.d. hvað varðar stafsetningu, málfræði o.s.frv.\nHér er texti A:\n\n"
|
23 |
-
MIDDLE_PROMPT_TASK2 = "Hér er texti B:\n\n"
|
24 |
-
END_PROMPT_TASK2 = "Hvorn textann líst þér betur á?\n\n"
|
25 |
-
|
26 |
-
START_PROMPT_TASK3 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
|
27 |
-
END_PROMPT_TASK3 = "Reyndu nú að laga textann þannig að hann líti betur út, eins og þér finnst best við hæfi.\n\n"
|
28 |
-
|
29 |
-
START_PROMPT_TASK = {
|
30 |
-
1: START_PROMPT_TASK1,
|
31 |
-
2: START_PROMPT_TASK2,
|
32 |
-
3: START_PROMPT_TASK3,
|
33 |
}
|
34 |
-
END_PROMPT_TASK = {1: END_PROMPT_TASK1, 2: END_PROMPT_TASK2, 3: END_PROMPT_TASK3}
|
35 |
-
|
36 |
-
SEP = "\n\n"
|
37 |
-
|
38 |
-
|
39 |
-
def set_seed(seed):
|
40 |
-
"""Set the random seed for reproducibility."""
|
41 |
-
torch.manual_seed(seed)
|
42 |
-
if torch.cuda.is_available():
|
43 |
-
torch.cuda.manual_seed_all(seed)
|
44 |
-
torch.backends.cudnn.deterministic = True
|
45 |
-
torch.backends.cudnn.benchmark = False
|
46 |
-
random.seed(seed)
|
47 |
-
|
48 |
-
|
49 |
-
def tokenize_data(tokenizer, data, task, max_length):
|
50 |
-
"""Tokenize the data and return the input_ids and attention_mask."""
|
51 |
-
tokenized_start = tokenizer(START_PROMPT_TASK[task])["input_ids"]
|
52 |
-
tokenized_end = tokenizer(END_PROMPT_TASK[task])["input_ids"]
|
53 |
-
if task == 2:
|
54 |
-
tokenized_middle = tokenizer(MIDDLE_PROMPT_TASK2)["input_ids"]
|
55 |
-
|
56 |
-
# Tokenize the data
|
57 |
-
tokenized_data = []
|
58 |
-
if task == 1 or task == 3:
|
59 |
-
for sentence in data:
|
60 |
-
tokenized_sentence = tokenizer(sentence + SEP)["input_ids"]
|
61 |
-
|
62 |
-
# Concatenate the tokenized data
|
63 |
-
concatted_data = (
|
64 |
-
[tokenizer.bos_token_id]
|
65 |
-
+ tokenized_start
|
66 |
-
+ tokenized_sentence
|
67 |
-
+ tokenized_end
|
68 |
-
)
|
69 |
-
|
70 |
-
# Truncate the data
|
71 |
-
concatted_data = concatted_data[:max_length]
|
72 |
-
|
73 |
-
tokenized_data.append(concatted_data)
|
74 |
-
elif task == 2:
|
75 |
-
for line in data:
|
76 |
-
data_a = line["a"]
|
77 |
-
data_b = line["b"]
|
78 |
-
tokenized_sentence_a = tokenizer(data_a + SEP)["input_ids"]
|
79 |
-
tokenized_sentence_b = tokenizer(data_b + SEP)["input_ids"]
|
80 |
-
|
81 |
-
# Concatenate the tokenized data
|
82 |
-
concatted_data = (
|
83 |
-
[tokenizer.bos_token_id]
|
84 |
-
+ tokenized_start
|
85 |
-
+ tokenized_sentence_a
|
86 |
-
+ tokenized_middle
|
87 |
-
+ tokenized_sentence_b
|
88 |
-
+ tokenized_end
|
89 |
-
)
|
90 |
-
|
91 |
-
# Truncate the data
|
92 |
-
concatted_data = concatted_data[:max_length]
|
93 |
-
|
94 |
-
tokenized_data.append(concatted_data)
|
95 |
-
|
96 |
-
return tokenized_data
|
97 |
-
|
98 |
-
|
99 |
-
def run_model_on_data(model_path, tokenizer_name, arguments):
|
100 |
-
"""Run the model on the data and save the predictions to a file."""
|
101 |
-
# Load the model and tokenizer
|
102 |
-
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
|
103 |
-
model.to(device)
|
104 |
-
model.eval()
|
105 |
-
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
106 |
-
|
107 |
-
# Load the data
|
108 |
-
if arguments.task == 1 or arguments.task == 3:
|
109 |
-
with open(arguments.input_file, "r") as file:
|
110 |
-
data = file.read().splitlines()
|
111 |
-
elif arguments.task == 2:
|
112 |
-
with open(arguments.input_file, "r") as file:
|
113 |
-
data = file.read().splitlines()
|
114 |
-
data = [json.loads(line) for line in data]
|
115 |
-
|
116 |
-
# Tokenize the data
|
117 |
-
data_tokenized = tokenize_data(
|
118 |
-
tokenizer, data, arguments.task, tokenizer.model_max_length
|
119 |
-
)
|
120 |
-
logging.info(f"Number of examples: {len(data_tokenized)}")
|
121 |
-
|
122 |
-
# Run the model on the data
|
123 |
-
predictions = []
|
124 |
-
progress_bar = tqdm.tqdm(total=len(data_tokenized), desc="Running model on data")
|
125 |
-
|
126 |
-
for input_ids in data_tokenized:
|
127 |
-
progress_bar.update(1)
|
128 |
-
|
129 |
-
# Generate the predictions
|
130 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
131 |
-
input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)
|
132 |
-
output = model.generate(
|
133 |
-
input_ids=input_ids_tensor, max_new_tokens=500, num_return_sequences=1
|
134 |
-
)
|
135 |
-
|
136 |
-
# Only get the part of the prediction that was generated
|
137 |
-
prediction = tokenizer.decode(
|
138 |
-
output[0][len(input_ids) :], skip_special_tokens=True
|
139 |
-
)
|
140 |
-
predictions.append(prediction)
|
141 |
-
|
142 |
-
progress_bar.close()
|
143 |
-
|
144 |
-
# Save the predictions to a file
|
145 |
-
with open(arguments.output_file, "w") as file:
|
146 |
-
if arguments.task == 1:
|
147 |
-
# We want to include the original text in the output file
|
148 |
-
counter = 0
|
149 |
-
for prediction in predictions:
|
150 |
-
file.write(data[counter] + "\n")
|
151 |
-
file.write(prediction.split("\n\n")[0] + "\n\n")
|
152 |
-
counter += 1
|
153 |
-
else:
|
154 |
-
for prediction in predictions:
|
155 |
-
file.write(prediction.split("\n\n")[0] + "\n")
|
156 |
-
|
157 |
-
logging.info(f"Predictions written to file: {arguments.output_file}")
|
158 |
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
parser.add_argument(
|
164 |
-
"--task",
|
165 |
-
type=int,
|
166 |
-
choices=(1,2,3),
|
167 |
-
required=True,
|
168 |
-
help="The task type (1, 2, or 3)",
|
169 |
-
)
|
170 |
-
parser.add_argument(
|
171 |
-
"--input-file",
|
172 |
-
type=str,
|
173 |
-
required=True,
|
174 |
-
help="The path to the input file with data to be corrected",
|
175 |
-
)
|
176 |
-
parser.add_argument(
|
177 |
-
"--output-file",
|
178 |
-
type=str,
|
179 |
-
required=True,
|
180 |
-
help="The path to the output file where the corrected data will be saved",
|
181 |
-
)
|
182 |
-
args = parser.parse_args()
|
183 |
|
184 |
-
model_path = "."
|
185 |
-
tokenizer_name = "AI-Sweden-Models/gpt-sw3-6.7b"
|
186 |
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Script for running the model using the Hugging Face endpoint. An authorized Hugging Face API key is required.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import requests
|
6 |
+
import os
|
7 |
+
|
8 |
+
API_URL = "https://otaf5w2ge8huxngl.eu-west-1.aws.endpoints.huggingface.cloud"
|
9 |
+
# Set your Hugging Face API key as an environment variable
|
10 |
+
api_key = os.environ.get("HF_API_KEY")
|
11 |
+
headers = {
|
12 |
+
"Accept": "application/json",
|
13 |
+
"Authorization": f"Bearer {api_key}",
|
14 |
+
"Content-Type": "application/json",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
+
def query(payload):
|
19 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
20 |
+
return response.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
|
|
|
|
22 |
|
23 |
+
output = query(
|
24 |
+
{
|
25 |
+
"inputs": "", # Can be left empty.
|
26 |
+
"input_a": "<text A>", # Required for all tasks.
|
27 |
+
"input_b": "<text B>", # Required for task 2 but not for task 1 or 3.
|
28 |
+
"task": 1 | 2 | 3, # Choose the task number.
|
29 |
+
"parameters": {
|
30 |
+
# Can be left empty
|
31 |
+
},
|
32 |
+
}
|
33 |
+
)
|
34 |
+
print(output)
|