import re import tenacity import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel class LLM: def __init__(self, model_id="Qwen/Qwen2.5-7B-Instruct",): self.model_id = model_id self.device = "cuda" if torch.cuda.is_available() else "cpu" # Load the model and tokenizer based on the model_id if "meta-llama" in self.model_id: self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto" ) elif "InternVL" in self.model_id: self.model = AutoModel.from_pretrained( model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map="auto" ).eval() self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False) else: self.model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype="auto", device_map="auto" ) self.tokenizer = AutoTokenizer.from_pretrained(model_id) @torch.no_grad() def generate(self, query): if "meta-llama" in self.model_id: messages = [ {"role": "user", "content": [ {"type": "text", "text": f"{query}"} ]} ] text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device) generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=512) generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)] response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] elif "InternVL" in self.model_id: generation_config = dict(max_new_tokens=1024, do_sample=True) response = self.model.chat(self.tokenizer, None, query, generation_config, history=None, return_history=False) else: messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": query}] text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device) generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=512) generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)] response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return response @tenacity.retry(stop=tenacity.stop_after_delay(10)) def answer(self, query, objects): query = f""" Extract the object that satisfies the intent of the query or determine the tool that aligns with the purpose of {query}. pick the best option from the following: {', '.join(objects)}, Please return a list of all suitable options as long as they make sense in the format of a Python list in the following format: ```python\n['option1', 'option2', ...]```""" res = self.generate(query) match = re.search(r"`{3}python\\n(.*)`{3}", res, re.DOTALL) if match: res = match.group(1) res = [r.translate(str.maketrans("", "", "_-")) for r in eval(res)] return res else: # Try to extract content directly from brackets [] match_brackets = re.search(r"\[(.*?)\]", res, re.DOTALL) if match_brackets: res = match_brackets.group(0) # Include brackets for eval res = [r.translate(str.maketrans("", "", "_-")) for r in eval(res)] return res else: raise ValueError(f"Failed to parse response: {res}")