|
import re |
|
|
|
import openai |
|
import pandas as pd |
|
from tqdm import tqdm |
|
|
|
|
|
class AutoTab: |
|
def __init__( |
|
self, |
|
in_file_path: str, |
|
out_file_path: str, |
|
max_examples: int, |
|
model_name: str, |
|
api_key: str, |
|
base_url: str, |
|
generation_config: dict, |
|
save_every: int, |
|
instruction: str, |
|
): |
|
self.in_file_path = in_file_path |
|
self.out_file_path = out_file_path |
|
self.max_examples = max_examples |
|
self.model_name = model_name |
|
self.api_key = api_key |
|
self.base_url = base_url |
|
self.generation_config = generation_config |
|
self.save_every = save_every |
|
self.instruction = instruction |
|
|
|
|
|
|
|
def load_excel(self) -> tuple[pd.DataFrame, list, list]: |
|
"""Load the Excel file and identify input and output fields.""" |
|
df = pd.read_excel(self.in_file_path) |
|
input_fields = [col for col in df.columns if col.startswith("[Input] ")] |
|
output_fields = [col for col in df.columns if col.startswith("[Output] ")] |
|
return df, input_fields, output_fields |
|
|
|
|
|
|
|
def openai_request(self, query: str) -> str: |
|
"""Make a request to an OpenAI-format API.""" |
|
client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) |
|
response = client.chat.completions.create( |
|
model=self.model_name, |
|
messages=[{"role": "user", "content": query}], |
|
**self.generation_config, |
|
) |
|
str_response = response.choices[0].message.content.strip() |
|
return str_response |
|
|
|
|
|
|
|
def derive_incontext( |
|
self, data: pd.DataFrame, input_columns: list[str], output_columns: list[str] |
|
) -> str: |
|
"""Derive the in-context prompt with angle brackets.""" |
|
n = min(self.max_examples, len(data.dropna(subset=output_columns))) |
|
in_context = "" |
|
for i in range(n): |
|
in_context += "".join( |
|
f"<{col.replace('[Input] ', '')}>{data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n" |
|
for col in input_columns |
|
) |
|
in_context += "".join( |
|
f"<{col.replace('[Output] ', '')}>{data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n" |
|
for col in output_columns |
|
) |
|
in_context += "\n" |
|
self.in_context = in_context |
|
return in_context |
|
|
|
def predict_output( |
|
self, in_context: str, input_data: pd.DataFrame, input_fields: str |
|
): |
|
"""Predict the output values for the given input data using the API.""" |
|
query = ( |
|
self.instruction |
|
+ "\n\n" |
|
+ in_context |
|
+ "".join( |
|
f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n" |
|
for col in input_fields |
|
) |
|
) |
|
self.query_example = query |
|
output = self.openai_request(query) |
|
return output |
|
|
|
def extract_fields( |
|
self, response: str, output_columns: list[str] |
|
) -> dict[str, str]: |
|
"""Extract fields from the response text based on output columns.""" |
|
extracted = {} |
|
for col in output_columns: |
|
field = col.replace("[Output] ", "") |
|
match = re.search(f"<{field}>(.*?)</{field}>", response) |
|
extracted[col] = match.group(1) if match else "" |
|
return extracted |
|
|
|
|
|
|
|
def run(self): |
|
data, input_fields, output_fields = self.load_excel() |
|
in_context = self.derive_incontext(data, input_fields, output_fields) |
|
|
|
num_existed_examples = len(data.dropna(subset=output_fields)) |
|
|
|
for i in tqdm(range(num_existed_examples, len(data))): |
|
prediction = self.predict_output(in_context, data.iloc[i], input_fields) |
|
extracted_fields = self.extract_fields(prediction, output_fields) |
|
for field_name in output_fields: |
|
data.at[i, field_name] = extracted_fields.get(field_name, "") |
|
if i % self.save_every == 0: |
|
data.to_excel(self.out_file_path, index=False) |
|
self.data = data |
|
data.to_excel(self.out_file_path, index=False) |
|
print(f"Results saved to {self.out_file_path}") |
|
|