Spaces:
Running
Running
vlff李飞飞
commited on
Commit
•
af1bfb7
1
Parent(s):
56ec8fb
update oai
Browse files- qwen_agent/llm/qwen_oai.py +8 -8
- run_server.py +1 -0
qwen_agent/llm/qwen_oai.py
CHANGED
@@ -416,24 +416,24 @@ def predict(
|
|
416 |
class QwenChatAsOAI(BaseChatModel):
|
417 |
|
418 |
def __init__(self, model: str, api_key: str, model_server: str):
|
419 |
-
checkpoint_path = model
|
420 |
super().__init__()
|
421 |
tokenizer = AutoTokenizer.from_pretrained(
|
422 |
-
checkpoint_path,
|
423 |
trust_remote_code=True,
|
424 |
resume_download=True,
|
425 |
)
|
426 |
device_map = "cpu"
|
427 |
# device_map = "auto"
|
428 |
model = AutoModelForCausalLM.from_pretrained(
|
429 |
-
checkpoint_path,
|
430 |
device_map=device_map,
|
431 |
trust_remote_code=True,
|
432 |
resume_download=True,
|
433 |
).eval()
|
434 |
|
435 |
model.generation_config = GenerationConfig.from_pretrained(
|
436 |
-
checkpoint_path,
|
437 |
trust_remote_code=True,
|
438 |
resume_download=True,
|
439 |
)
|
@@ -444,7 +444,7 @@ class QwenChatAsOAI(BaseChatModel):
|
|
444 |
messages: List[Dict],
|
445 |
stop: Optional[List[str]] = None,
|
446 |
) -> Iterator[str]:
|
447 |
-
_request = ChatCompletionRequest(model=self.
|
448 |
messages=messages,
|
449 |
stop=stop,
|
450 |
stream=True)
|
@@ -459,7 +459,7 @@ class QwenChatAsOAI(BaseChatModel):
|
|
459 |
messages: List[Dict],
|
460 |
stop: Optional[List[str]] = None,
|
461 |
) -> str:
|
462 |
-
_request = ChatCompletionRequest(model=self.
|
463 |
messages=messages,
|
464 |
stop=stop,
|
465 |
stream=False)
|
@@ -471,12 +471,12 @@ class QwenChatAsOAI(BaseChatModel):
|
|
471 |
messages: List[Dict],
|
472 |
functions: Optional[List[Dict]] = None) -> Dict:
|
473 |
if functions:
|
474 |
-
_request = ChatCompletionRequest(model=self.
|
475 |
messages=messages,
|
476 |
functions=functions)
|
477 |
response = create_chat_completion(_request)
|
478 |
else:
|
479 |
-
_request = ChatCompletionRequest(model=self.
|
480 |
messages=messages)
|
481 |
response = create_chat_completion(_request)
|
482 |
# TODO: error handling
|
|
|
416 |
class QwenChatAsOAI(BaseChatModel):
|
417 |
|
418 |
def __init__(self, model: str, api_key: str, model_server: str):
|
419 |
+
self.checkpoint_path = copy.copy(model)
|
420 |
super().__init__()
|
421 |
tokenizer = AutoTokenizer.from_pretrained(
|
422 |
+
self.checkpoint_path,
|
423 |
trust_remote_code=True,
|
424 |
resume_download=True,
|
425 |
)
|
426 |
device_map = "cpu"
|
427 |
# device_map = "auto"
|
428 |
model = AutoModelForCausalLM.from_pretrained(
|
429 |
+
self.checkpoint_path,
|
430 |
device_map=device_map,
|
431 |
trust_remote_code=True,
|
432 |
resume_download=True,
|
433 |
).eval()
|
434 |
|
435 |
model.generation_config = GenerationConfig.from_pretrained(
|
436 |
+
self.checkpoint_path,
|
437 |
trust_remote_code=True,
|
438 |
resume_download=True,
|
439 |
)
|
|
|
444 |
messages: List[Dict],
|
445 |
stop: Optional[List[str]] = None,
|
446 |
) -> Iterator[str]:
|
447 |
+
_request = ChatCompletionRequest(model=self.checkpoint_path,
|
448 |
messages=messages,
|
449 |
stop=stop,
|
450 |
stream=True)
|
|
|
459 |
messages: List[Dict],
|
460 |
stop: Optional[List[str]] = None,
|
461 |
) -> str:
|
462 |
+
_request = ChatCompletionRequest(model=self.checkpoint_path,
|
463 |
messages=messages,
|
464 |
stop=stop,
|
465 |
stream=False)
|
|
|
471 |
messages: List[Dict],
|
472 |
functions: Optional[List[Dict]] = None) -> Dict:
|
473 |
if functions:
|
474 |
+
_request = ChatCompletionRequest(model=self.checkpoint_path,
|
475 |
messages=messages,
|
476 |
functions=functions)
|
477 |
response = create_chat_completion(_request)
|
478 |
else:
|
479 |
+
_request = ChatCompletionRequest(model=self.checkpoint_path,
|
480 |
messages=messages)
|
481 |
response = create_chat_completion(_request)
|
482 |
# TODO: error handling
|
run_server.py
CHANGED
@@ -12,6 +12,7 @@ from qwen_agent.utils.utils import get_local_ip
|
|
12 |
from qwen_server.schema import GlobalConfig
|
13 |
os.environ["TRANSFORMERS_CACHE"] = ".cache/huggingface/"
|
14 |
os.environ["HF_HOME"] = ".cache/huggingface/"
|
|
|
15 |
|
16 |
def parse_args():
|
17 |
parser = argparse.ArgumentParser()
|
|
|
12 |
from qwen_server.schema import GlobalConfig
|
13 |
os.environ["TRANSFORMERS_CACHE"] = ".cache/huggingface/"
|
14 |
os.environ["HF_HOME"] = ".cache/huggingface/"
|
15 |
+
os.environ["MPLCONFIGDIR"] = ".cache/matplotlib/"
|
16 |
|
17 |
def parse_args():
|
18 |
parser = argparse.ArgumentParser()
|