import time, torch, json from langchain.prompts import PromptTemplate from langchain_openai import ChatOpenAI, OpenAI from langchain.schema import HumanMessage from langchain_core.output_parsers import JsonOutputParser from langchain.output_parsers import RetryWithErrorOutputParser from vouchervision.utils_LLM import SystemLoadMonitor, run_tools, count_tokens, save_individual_prompt, sanitize_prompt from vouchervision.utils_LLM_JSON_validation import validate_and_align_JSON_keys_with_template class OpenAIHandler: RETRY_DELAY = 10 # Wait 10 seconds before retrying MAX_RETRIES = 3 # Maximum number of retries STARTING_TEMP = 0.5 TOKENIZER_NAME = 'gpt-4' VENDOR = 'openai' def __init__(self, cfg, logger, model_name, JSON_dict_structure, is_azure, llm_object): self.cfg = cfg self.tool_WFO = self.cfg['leafmachine']['project']['tool_WFO'] self.tool_GEO = self.cfg['leafmachine']['project']['tool_GEO'] self.tool_wikipedia = self.cfg['leafmachine']['project']['tool_wikipedia'] self.logger = logger self.model_name = model_name self.JSON_dict_structure = JSON_dict_structure self.is_azure = is_azure self.llm_object = llm_object self.name_parts = self.model_name.split('-') self.monitor = SystemLoadMonitor(logger) self.has_GPU = torch.cuda.is_available() self.starting_temp = float(self.STARTING_TEMP) self.temp_increment = float(0.2) self.adjust_temp = self.starting_temp # Set up a parser self.parser = JsonOutputParser() self.prompt = PromptTemplate( template="Answer the user query.\n{format_instructions}\n{query}\n", input_variables=["query"], partial_variables={"format_instructions": self.parser.get_format_instructions()}, ) self._set_config() def _set_config(self): self.config = {'max_new_tokens': 1024, 'temperature': self.starting_temp, 'random_seed': 2023, 'top_p': 1, } # Adjusting the LLM settings based on whether Azure is used if self.is_azure: self.llm_object.deployment_name = self.model_name self.llm_object.model_name = self.model_name else: self.llm_object = None self._build_model_chain_parser() # Define a function to format the input for azure_call def format_input_for_azure(self, prompt_text): msg = HumanMessage(content=prompt_text.text) # self.llm_object.temperature = self.config.get('temperature') return self.llm_object(messages=[msg]) def _adjust_config(self): new_temp = self.adjust_temp + self.temp_increment self.json_report.set_text(text_main=f'Incrementing temperature from {self.adjust_temp} to {new_temp}') self.logger.info(f'Incrementing temperature from {self.adjust_temp} to {new_temp}') self.adjust_temp += self.temp_increment self.config['temperature'] = self.adjust_temp def _reset_config(self): self.json_report.set_text(text_main=f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}') self.logger.info(f'Resetting temperature from {self.adjust_temp} to {self.starting_temp}') self.adjust_temp = self.starting_temp self.config['temperature'] = self.starting_temp def _build_model_chain_parser(self): if not self.is_azure and ('instruct' in self.name_parts): # Set up the retry parser with 3 retries self.retry_parser = RetryWithErrorOutputParser.from_llm( # parser=self.parser, llm=self.llm_object if self.is_azure else OpenAI(temperature=self.config.get('temperature'), model=self.model_name), max_retries=self.MAX_RETRIES parser=self.parser, llm=self.llm_object if self.is_azure else OpenAI(model=self.model_name), max_retries=self.MAX_RETRIES ) else: # Set up the retry parser with 3 retries self.retry_parser = RetryWithErrorOutputParser.from_llm( # parser=self.parser, llm=self.llm_object if self.is_azure else ChatOpenAI(temperature=self.config.get('temperature'), model=self.model_name), max_retries=self.MAX_RETRIES parser=self.parser, llm=self.llm_object if self.is_azure else ChatOpenAI(model=self.model_name), max_retries=self.MAX_RETRIES ) # Prepare the chain if not self.is_azure and ('instruct' in self.name_parts): # self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else OpenAI(temperature=self.config.get('temperature'), model=self.model_name)) self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else OpenAI(model=self.model_name)) else: # self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else ChatOpenAI(temperature=self.config.get('temperature'), model=self.model_name)) self.chain = self.prompt | (self.format_input_for_azure if self.is_azure else ChatOpenAI(model=self.model_name)) def call_llm_api_OpenAI(self, prompt_template, json_report, paths): _____, ____, _, __, ___, json_file_path_wiki, txt_file_path_ind_prompt = paths self.json_report = json_report self.json_report.set_text(text_main=f'Sending request to {self.model_name}') self.monitor.start_monitoring_usage() nt_in = 0 nt_out = 0 ind = 0 while ind < self.MAX_RETRIES: ind += 1 try: model_kwargs = {"temperature": self.adjust_temp} # Invoke the chain to generate prompt text response = self.chain.invoke({"query": prompt_template, "model_kwargs": model_kwargs}) response_text = response.content if not isinstance(response, str) else response # Use retry_parser to parse the response with retry logic output = self.retry_parser.parse_with_prompt(response_text, prompt_value=prompt_template) if output is None: self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}') self._adjust_config() else: nt_in = count_tokens(prompt_template, self.VENDOR, self.TOKENIZER_NAME) nt_out = count_tokens(response_text, self.VENDOR, self.TOKENIZER_NAME) output = validate_and_align_JSON_keys_with_template(output, self.JSON_dict_structure) if output is None: self.logger.error(f'[Attempt {ind}] Failed to extract JSON from:\n{response_text}') self._adjust_config() else: self.monitor.stop_inference_timer() # Starts tool timer too json_report.set_text(text_main=f'Working on WFO, Geolocation, Links') output_WFO, WFO_record, output_GEO, GEO_record = run_tools(output, self.tool_WFO, self.tool_GEO, self.tool_wikipedia, json_file_path_wiki) # output1, WFO_record = validate_taxonomy_WFO(self.tool_WFO, output, replace_if_success_wfo=False) # output2, GEO_record = validate_coordinates_here(self.tool_GEO, output, replace_if_success_geo=False) # validate_wikipedia(self.tool_wikipedia, json_file_path_wiki, output) save_individual_prompt(sanitize_prompt(prompt_template), txt_file_path_ind_prompt) self.logger.info(f"Formatted JSON:\n{json.dumps(output,indent=4)}") usage_report = self.monitor.stop_monitoring_report_usage() if self.adjust_temp != self.starting_temp: self._reset_config() json_report.set_text(text_main=f'LLM call successful') return output, nt_in, nt_out, WFO_record, GEO_record, usage_report except Exception as e: self.logger.error(f'{e}') self._adjust_config() time.sleep(self.RETRY_DELAY) self.logger.info(f"Failed to extract valid JSON after [{ind}] attempts") self.json_report.set_text(text_main=f'Failed to extract valid JSON after [{ind}] attempts') self.monitor.stop_inference_timer() # Starts tool timer too usage_report = self.monitor.stop_monitoring_report_usage() self._reset_config() json_report.set_text(text_main=f'LLM call failed') return None, nt_in, nt_out, None, None, usage_report