vlff李飞飞 commited on
Commit
af1bfb7
1 Parent(s): 56ec8fb

update oai

Browse files
Files changed (2) hide show
  1. qwen_agent/llm/qwen_oai.py +8 -8
  2. run_server.py +1 -0
qwen_agent/llm/qwen_oai.py CHANGED
@@ -416,24 +416,24 @@ def predict(
416
  class QwenChatAsOAI(BaseChatModel):
417
 
418
  def __init__(self, model: str, api_key: str, model_server: str):
419
- checkpoint_path = model
420
  super().__init__()
421
  tokenizer = AutoTokenizer.from_pretrained(
422
- checkpoint_path,
423
  trust_remote_code=True,
424
  resume_download=True,
425
  )
426
  device_map = "cpu"
427
  # device_map = "auto"
428
  model = AutoModelForCausalLM.from_pretrained(
429
- checkpoint_path,
430
  device_map=device_map,
431
  trust_remote_code=True,
432
  resume_download=True,
433
  ).eval()
434
 
435
  model.generation_config = GenerationConfig.from_pretrained(
436
- checkpoint_path,
437
  trust_remote_code=True,
438
  resume_download=True,
439
  )
@@ -444,7 +444,7 @@ class QwenChatAsOAI(BaseChatModel):
444
  messages: List[Dict],
445
  stop: Optional[List[str]] = None,
446
  ) -> Iterator[str]:
447
- _request = ChatCompletionRequest(model=self.model,
448
  messages=messages,
449
  stop=stop,
450
  stream=True)
@@ -459,7 +459,7 @@ class QwenChatAsOAI(BaseChatModel):
459
  messages: List[Dict],
460
  stop: Optional[List[str]] = None,
461
  ) -> str:
462
- _request = ChatCompletionRequest(model=self.model,
463
  messages=messages,
464
  stop=stop,
465
  stream=False)
@@ -471,12 +471,12 @@ class QwenChatAsOAI(BaseChatModel):
471
  messages: List[Dict],
472
  functions: Optional[List[Dict]] = None) -> Dict:
473
  if functions:
474
- _request = ChatCompletionRequest(model=self.model,
475
  messages=messages,
476
  functions=functions)
477
  response = create_chat_completion(_request)
478
  else:
479
- _request = ChatCompletionRequest(model=self.model,
480
  messages=messages)
481
  response = create_chat_completion(_request)
482
  # TODO: error handling
 
416
  class QwenChatAsOAI(BaseChatModel):
417
 
418
  def __init__(self, model: str, api_key: str, model_server: str):
419
+ self.checkpoint_path = copy.copy(model)
420
  super().__init__()
421
  tokenizer = AutoTokenizer.from_pretrained(
422
+ self.checkpoint_path,
423
  trust_remote_code=True,
424
  resume_download=True,
425
  )
426
  device_map = "cpu"
427
  # device_map = "auto"
428
  model = AutoModelForCausalLM.from_pretrained(
429
+ self.checkpoint_path,
430
  device_map=device_map,
431
  trust_remote_code=True,
432
  resume_download=True,
433
  ).eval()
434
 
435
  model.generation_config = GenerationConfig.from_pretrained(
436
+ self.checkpoint_path,
437
  trust_remote_code=True,
438
  resume_download=True,
439
  )
 
444
  messages: List[Dict],
445
  stop: Optional[List[str]] = None,
446
  ) -> Iterator[str]:
447
+ _request = ChatCompletionRequest(model=self.checkpoint_path,
448
  messages=messages,
449
  stop=stop,
450
  stream=True)
 
459
  messages: List[Dict],
460
  stop: Optional[List[str]] = None,
461
  ) -> str:
462
+ _request = ChatCompletionRequest(model=self.checkpoint_path,
463
  messages=messages,
464
  stop=stop,
465
  stream=False)
 
471
  messages: List[Dict],
472
  functions: Optional[List[Dict]] = None) -> Dict:
473
  if functions:
474
+ _request = ChatCompletionRequest(model=self.checkpoint_path,
475
  messages=messages,
476
  functions=functions)
477
  response = create_chat_completion(_request)
478
  else:
479
+ _request = ChatCompletionRequest(model=self.checkpoint_path,
480
  messages=messages)
481
  response = create_chat_completion(_request)
482
  # TODO: error handling
run_server.py CHANGED
@@ -12,6 +12,7 @@ from qwen_agent.utils.utils import get_local_ip
12
  from qwen_server.schema import GlobalConfig
13
  os.environ["TRANSFORMERS_CACHE"] = ".cache/huggingface/"
14
  os.environ["HF_HOME"] = ".cache/huggingface/"
 
15
 
16
  def parse_args():
17
  parser = argparse.ArgumentParser()
 
12
  from qwen_server.schema import GlobalConfig
13
  os.environ["TRANSFORMERS_CACHE"] = ".cache/huggingface/"
14
  os.environ["HF_HOME"] = ".cache/huggingface/"
15
+ os.environ["MPLCONFIGDIR"] = ".cache/matplotlib/"
16
 
17
  def parse_args():
18
  parser = argparse.ArgumentParser()