import tushare as ts import matplotlib.pyplot as plt import pandas as pd import os import json from matplotlib.ticker import MaxNLocator import matplotlib.font_manager as fm from lab_gpt4_call import send_chat_request,send_chat_request_Azure,send_official_call #import ast import re from tool import * import tiktoken import concurrent.futures import datetime from PIL import Image from io import BytesIO import queue import datetime from threading import Thread # plt.rcParams['font.family'] = 'Times New Roman' # plt.rcParams['axes.unicode_minus'] = False import openai # 设置plt的字体 font_path = './fonts/SimHei.ttf' font_prop = fm.FontProperties(fname= font_path) # To override the Thread method class MyThread(Thread): def __init__(self, target, args): super(MyThread, self).__init__() self.func = target self.args = args def run(self): self.result = self.func(*self.args) def get_result(self): return self.result def parse_and_exe(call_dict, result_buffer, parallel_step: str='1'): """ Parse the input and call the corresponding function to obtain the result. :param call_dict: dict, including arg, func, and output :param result_buffer: dict, storing the corresponding intermediate results :param parallel_step: int, parallel step :return: Returns func(arg) and stores the corresponding result in result_buffer. """ arg_list = call_dict['arg' + parallel_step] replace_arg_list = [result_buffer[item][0] if isinstance(item, str) and ('result' in item or 'input' in item) else item for item in arg_list] # 参数 func_name = call_dict['function' + parallel_step] # output = call_dict['output' + parallel_step] # desc = call_dict['description' + parallel_step] # if func_name == 'loop_rank': replace_arg_list[1] = eval(replace_arg_list[1]) result = eval(func_name)(*replace_arg_list) result_buffer[output] = (result, desc) # 'result1': (df1, desc) return result_buffer def load_tool_and_prompt(tool_lib, tool_prompt ): ''' Read two JSON files. :param tool_lib: Tool description :param tool_prompt: Tool prompt :return: Flattened prompt ''' # with open(tool_lib, 'r') as f: tool_lib = json.load(f) with open(tool_prompt, 'r') as f: # tool_prompt = json.load(f) for key, value in tool_lib.items(): tool_prompt["Function Library:"] = tool_prompt["Function Library:"] + key + " " + value+ '\n\n' prompt_flat = '' for key, value in tool_prompt.items(): prompt_flat = prompt_flat + key +' '+ value + '\n\n' return prompt_flat # callback function intermediate_results = queue.Queue() # Create a queue to store intermediate results. def add_to_queue(intermediate_result): intermediate_results.put(f"After planing, the intermediate result is {intermediate_result}") def check_RPM(run_time_list, new_time, max_RPM=1): # Check if there are already 3 timestamps in the run_time_list, with a maximum of 3 accesses per minute. # False means no rest is needed, True means rest is needed. if len(run_time_list) < 3: run_time_list.append(new_time) return 0 else: if (new_time - run_time_list[0]).seconds < max_RPM: # Calculate the required rest time. sleep_time = 60 - (new_time - run_time_list[0]).seconds print('sleep_time:', sleep_time) run_time_list.pop(0) run_time_list.append(new_time) return sleep_time else: run_time_list.pop(0) run_time_list.append(new_time) return 0 def run(instruction, add_to_queue=None, send_chat_request_Azure = send_official_call, openai_key = '', api_base='', engine=''): output_text = '' ################################# Step-1:Task select ########################################### current_time = datetime.datetime.now() formatted_time = current_time.strftime("%Y-%m-%d") # If the time has not exceeded 3 PM, use yesterday's data. if current_time.hour < 15: formatted_time = (current_time - datetime.timedelta(days=1)).strftime("%Y-%m-%d") print('===============================Intent Detecting===========================================') with open('./prompt_lib/prompt_intent_detection.json', 'r') as f: prompt_task_dict = json.load(f) prompt_intent_detection = '' for key, value in prompt_task_dict.items(): prompt_intent_detection = prompt_intent_detection + key + ": " + value+ '\n\n' prompt_intent_detection = prompt_intent_detection + '\n\n' + 'Instruction:' + '今天的日期是'+ formatted_time +', '+ instruction + ' ###New Instruction: ' # Record the running time. # current_time = datetime.datetime.now() # sleep_time = check_RPM(run_time, current_time) # if sleep_time > 0: # time.sleep(sleep_time) try: response = send_chat_request_Azure(prompt_intent_detection, openai_key=openai_key, api_base=api_base, engine=engine) # 返回错误 except Exception as e: return e new_instruction = response print('new_instruction:', new_instruction) output_text = output_text + '\n======Intent Detecting Stage=====\n\n' output_text = output_text + new_instruction +'\n\n' if add_to_queue is not None: add_to_queue(output_text) event_happen = True print('===============================Task Planing===========================================') output_text= output_text + '=====Task Planing Stage=====\n\n' with open('./prompt_lib/prompt_task.json', 'r') as f: prompt_task_dict = json.load(f) prompt_task = '' for key, value in prompt_task_dict.items(): prompt_task = prompt_task + key + ": " + value+ '\n\n' prompt_task = prompt_task + '\n\n' + 'Instruction:' + new_instruction + ' ###Plan:' # current_time = datetime.datetime.now() # sleep_time = check_RPM(run_time, current_time) # if sleep_time > 0: # time.sleep(sleep_time) try: response = send_chat_request_Azure(prompt_task, openai_key=openai_key,api_base=api_base,engine=engine) except Exception as e: return e task_select = response pattern = r"(task\d+=)(\{[^}]*\})" matches = re.findall(pattern, task_select) task_plan = {} for task in matches: task_step, task_select = task task_select = task_select.replace("'", "\"") # Replace single quotes with double quotes. task_select = json.loads(task_select) task_name = list(task_select.keys())[0] task_instruction = list(task_select.values())[0] task_plan[task_name] = task_instruction # task_plan for key, value in task_plan.items(): print(key, ':', value) output_text = output_text + key + ': ' + str(value) + '\n' output_text = output_text +'\n' if add_to_queue is not None: add_to_queue(output_text) ################################# Step-2:Tool select and use ########################################### print('===============================Tool select and using Stage===========================================') output_text = output_text + '======Tool select and using Stage======\n\n' # Read the task_select JSON file name. task_name = list(task_plan.keys())[0].split('_task')[0] task_instruction = list(task_plan.values())[0] tool_lib = './tool_lib/' + 'tool_' + task_name + '.json' tool_prompt = './prompt_lib/' + 'prompt_' + task_name + '.json' prompt_flat = load_tool_and_prompt(tool_lib, tool_prompt) prompt_flat = prompt_flat + '\n\n' +'Instruction :'+ task_instruction+ ' ###Function Call' #response = "step1={\n \"arg1\": [\"贵州茅台\"],\n \"function1\": \"get_stock_code\",\n \"output1\": \"result1\"\n},step2={\n \"arg1\": [\"result1\",\"20180123\",\"20190313\",\"daily\"],\n \"function1\": \"get_stock_prices_data\",\n \"output1\": \"result2\"\n},step3={\n \"arg1\": [\"result2\",\"close\"],\n \"function1\": \"calculate_stock_index\",\n \"output1\": \"result3\"\n}, ###Output:{\n \"贵州茅台在2018年1月23日到2019年3月13的每日收盘价格的时序表格\": \"result3\",\n}" # current_time = datetime.datetime.now() # sleep_time = check_RPM(run_time, current_time) # if sleep_time > 0: # time.sleep(sleep_time) try: response = send_chat_request_Azure(prompt_flat, openai_key=openai_key,api_base=api_base, engine=engine) except Exception as e: return e #response = "Function Call:step1={\n \"arg1\": [\"五粮液\"],\n \"function1\": \"get_stock_code\",\n \"output1\": \"result1\",\n \"arg2\": [\"泸州老窖\"],\n \"function2\": \"get_stock_code\",\n \"output2\": \"result2\"\n},step2={\n \"arg1\": [\"result1\",\"20190101\",\"20220630\",\"daily\"],\n \"function1\": \"get_stock_prices_data\",\n \"output1\": \"result3\",\n \"arg2\": [\"result2\",\"20190101\",\"20220630\",\"daily\"],\n \"function2\": \"get_stock_prices_data\",\n \"output2\": \"result4\"\n},step3={\n \"arg1\": [\"result3\",\"Cumulative_Earnings_Rate\"],\n \"function1\": \"calculate_stock_index\",\n \"output1\": \"result5\",\n \"arg2\": [\"result4\",\"Cumulative_Earnings_Rate\"],\n \"function2\": \"calculate_stock_index\",\n \"output2\": \"result6\"\n}, ###Output:{\n \"五粮液在2019年1月1日到2022年06月30的每日收盘价格时序表格\": \"result5\",\n \"泸州老窖在2019年1月1日到2022年06月30的每日收盘价格时序表格\": \"result6\"\n}" call_steps, _ = response.split('###') pattern = r"(step\d+=)(\{[^}]*\})" matches = re.findall(pattern, call_steps) result_buffer = {} # The stored format is as follows: {'result1': (000001.SH, 'Stock code of China Ping An'), 'result2': (df2, 'Stock data of China Ping An from January to June 2021')}. output_buffer = [] # Store the variable names [result5, result6] that will be passed as the final output to the next task. # print(task_output) # for match in matches: step, content = match content = content.replace("'", "\"") # Replace single quotes with double quotes. print('==================') print("\n\nstep:", step) print('content:',content) call_dict = json.loads(content) print('It has parallel steps:', len(call_dict) / 4) output_text = output_text + step + ': ' + str(call_dict) + '\n\n' # Execute the following code in parallel using multiple processes. with concurrent.futures.ThreadPoolExecutor() as executor: # Submit tasks to thread pool futures = {executor.submit(parse_and_exe, call_dict, result_buffer, str(parallel_step)) for parallel_step in range(1, int(len(call_dict) / 4) + 1)} # Collect results as they become available for idx, future in enumerate(concurrent.futures.as_completed(futures)): # Handle possible exceptions try: result = future.result() # Print the current parallel step number. print('parallel step:', idx+1) # print(list(result[1].keys())[0]) # print(list(result[1].values())[0]) except Exception as exc: print(f'Generated an exception: {exc}') if step == matches[-1][0]: # Current task's final step. Save the output of the final step. for parallel_step in range(1, int(len(call_dict) / 4) + 1): output_buffer.append(call_dict['output' + str(parallel_step)]) output_text = output_text + '\n' if add_to_queue is not None: add_to_queue(output_text) ################################# Step-3:visualization ########################################### print('===============================Visualization Stage===========================================') output_text = output_text + '======Visualization Stage====\n\n' task_name = list(task_plan.keys())[1].split('_task')[0] #visualization_task #task_name = 'visualization' task_instruction = list(task_plan.values())[1] #'' tool_lib = './tool_lib/' + 'tool_' + task_name + '.json' tool_prompt = './prompt_lib/' + 'prompt_' + task_name + '.json' result_buffer_viz={} Previous_result = {} for output_name in output_buffer: rename = 'input'+ str(output_buffer.index(output_name)+1) Previous_result[rename] = result_buffer[output_name][1] result_buffer_viz[rename] = result_buffer[output_name] prompt_flat = load_tool_and_prompt(tool_lib, tool_prompt) prompt_flat = prompt_flat + '\n\n' +'Instruction: '+ task_instruction + ', Previous_result: '+ str(Previous_result) + ' ###Function Call' # current_time = datetime.datetime.now() # sleep_time = check_RPM(run_time, current_time) # if sleep_time > 0: # time.sleep(sleep_time) try: response = send_chat_request_Azure(prompt_flat, openai_key=openai_key, api_base=api_base, engine=engine) except Exception as e: return e call_steps, _ = response.split('###') pattern = r"(step\d+=)(\{[^}]*\})" matches = re.findall(pattern, call_steps) for match in matches: step, content = match content = content.replace("'", "\"") # Replace single quotes with double quotes. print('==================') print("\n\nstep:", step) print('content:',content) call_dict = json.loads(content) print('It has parallel steps:', len(call_dict) / 4) result_buffer_viz = parse_and_exe(call_dict, result_buffer_viz, parallel_step = '' ) output_text = output_text + step + ': ' + str(call_dict) + '\n\n' if add_to_queue is not None: add_to_queue(output_text) finally_output = list(result_buffer_viz.values()) # plt.Axes # df = pd.DataFrame() str_out = output_text + 'Finally result: ' for ax in finally_output: if isinstance(ax[0], plt.Axes): # If the output is plt.Axes, display it. plt.grid() #plt.show() str_out = str_out + ax[1]+ ':' + 'plt.Axes' + '\n\n' # elif isinstance(ax[0], pd.DataFrame): df = ax[0] str_out = str_out + ax[1]+ ':' + 'pd.DataFrame' + '\n\n' else: str_out = str_out + str(ax[1])+ ':' + str(ax[0]) + '\n\n' # print('===============================Summary Stage===========================================') output_prompt = "请用第一人称总结一下整个任务规划和解决过程,并且输出结果,用[Task]表示每个规划任务,用\{function\}表示每个任务里调用的函数." + \ "示例1:###我用将您的问题拆分成两个任务,首先第一个任务[stock_task],我依次获取五粮液和贵州茅台从2013年5月20日到2023年5月20日的净资产回报率roe的时序数据. \n然后第二个任务[visualization_task],我用折线图绘制五粮液和贵州茅台从2013年5月20日到2023年5月20日的净资产回报率,并计算它们的平均值和中位数. \n\n在第一个任务中我分别使用了2个工具函数\{get_stock_code\},\{get_Financial_data_from_time_range\}获取到两只股票的roe数据,在第二个任务里我们使用折线图\{plot_stock_data\}工具函数来绘制他们的roe十年走势,最后并计算了两只股票十年ROE的中位数\{output_median_col\}和均值\{output_mean_col\}.\n\n最后贵州茅台的ROE的均值和中位数是\{\},{},五粮液的ROE的均值和中位数是\{\},\{\}###" + \ "示例2:###我用将您的问题拆分成两个任务,首先第一个任务[stock_task],我依次获取20230101到20230520这段时间北向资金每日净流入和每日累计流入时序数据,第二个任务是[visualization_task],因此我在同一张图里同时绘制北向资金20230101到20230520的每日净流入柱状图和每日累计流入的折线图 \n\n为了完成第一个任务中我分别使用了2个工具函数\{get_north_south_money\},\{calculate_stock_index\}分别获取到北上资金的每日净流入量和每日的累计净流入量,第二个任务里我们使用折线图\{plot_stock_data\}绘制来两个指标的变化走势.\n\n最后我们给您提供了包含两个指标的折线图和数据表格." + \ "示例3:###我用将您的问题拆分成两个任务,首先第一个任务[economic_task],我爬取了上市公司贵州茅台和其主营业务介绍信息. \n然后第二个任务[visualization_task],我用表格打印贵州茅台及其相关信息. \n\n在第一个任务中我分别使用了1个工具函数\{get_company_info\} 获取到贵州茅台的公司信息,在第二个任务里我们使用折线图\{print_save_table\}工具函数来输出表格.\n" try: output_result = send_chat_request_Azure(output_prompt + str_out + '###', openai_key=openai_key, api_base=api_base,engine=engine) except Exception as e: return e print(output_result) buf = BytesIO() plt.savefig(buf, format='png', fontproperties=font_prop, dpi=300) buf.seek(0) # # image = Image.open(buf) return output_text, image, output_result, df def gradio_interface(query, openai_key, openai_key_azure, api_base, engine): # Create a new thread to run the function. placeholder_dataframe = pd.DataFrame() placeholder_image = np.zeros((100, 100, 3), dtype=np.uint8) # Create a placeholder image. try: if openai_key.startswith('sk') and openai_key_azure == '': print('send_official_call') thread = MyThread(target=run, args=(query, add_to_queue, send_official_call, openai_key, api_base)) elif openai_key =='' and len(openai_key_azure)>0: print('send_chat_request_Azure') thread = MyThread(target=run, args=(query, add_to_queue, send_chat_request_Azure, openai_key_azure, api_base, engine)) thread.start() # # Wait for the result of the calculate function and display the intermediate results simultaneously. while thread.is_alive(): while not intermediate_results.empty(): yield intermediate_results.get(), placeholder_image, 'Running' , placeholder_dataframe # Use the yield keyword to return intermediate results in real-time time.sleep(0.1) # Avoid excessive resource consumption. finally_text, img, output, df = thread.get_result() yield finally_text, img, output, df except Exception as e: yield str(e), placeholder_image, str(e), placeholder_dataframe # Return the final result. instruction = '画一下五粮液和泸州老窖从2019年年初到2022年年中的收益率走势' if __name__ == '__main__': # 初始化pro接口 #openai_call = send_chat_request_Azure # openai_call = send_official_call # openai_key = os.getenv("OPENAI_KEY") output, image, df , output_result = run(instruction, send_chat_request_Azure = openai_call, openai_key=openai_key, api_base='', engine='') print(output_result) plt.show()