import gradio as gr from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage import os import pandas as pd import numpy as np def chat_with_mistral(source_cols, dest_col, prompt, tdoc_name, excel_file, url): df = pd.read_excel(excel_file) api_key = os.environ["MISTRAL_API_KEY"] model = "mistral-small" # Use "Mistral-7B-v0.2" for "mistral-tiny" client = MistralClient(api_key=api_key) source_columns = source_cols#.split(", ") # Split input into multiple variables df[dest_col] = "" try: file_name = url.split("/")[-2] + ".xlsx" except: file_name = excel_file if tdoc_name != '': filtered_df = df[df['File'] == tdoc_name] if not filtered_df.empty: concatenated_content = "\n\n".join(f"{column_name}: {filtered_df[column_name].iloc[0]}" for column_name in source_columns) messages = [ChatMessage(role="user", content=f"Using the following content: {concatenated_content}"), ChatMessage(role="user", content=prompt)] chat_response = client.chat(model=model, messages=messages) filtered_df.loc[filtered_df.index[0], dest_col] = chat_response.choices[0].message.content # Update the DataFrame with the modified row df.update(filtered_df) # Write the updated DataFrame to the Excel file df.to_excel(file_name, index=False) return file_name, df.head(5) else: return file_name, df.head(5) else: for index, row in df.iterrows(): concatenated_content = "\n\n".join(f"{column_name}: {row[column_name]}" for column_name in source_columns) # Check if the concatenated content is not empty print('test') if not concatenated_content == "\n\n".join(f"{column_name}: nan" for column_name in source_columns): print('c bon') messages = [ChatMessage(role="user", content=f"Using the following content: {concatenated_content}"), ChatMessage(role="user", content=prompt)] chat_response = client.chat(model=model, messages=messages) df.at[index, dest_col] = chat_response.choices[0].message.content df.to_excel(file_name, index=False) return file_name, df.head(5) def get_columns(file): if file is not None: df = pd.read_excel(file) columns = list(df.columns) return gr.update(choices=columns), gr.update(choices=columns), gr.update(choices=columns), gr.update(choices=columns + [""]), df.head(5) else: return gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[]), gr.update(choices=[]), pd.DataFrame()