Spaces:
Runtime error
Runtime error
improved stop token for supercot
Browse files
app.py
CHANGED
@@ -38,10 +38,11 @@ def prompt_chat(system_msg, history):
|
|
38 |
class Pipeline:
|
39 |
prefer_async = True
|
40 |
|
41 |
-
def __init__(self, endpoint_id, name, prompt_fn):
|
42 |
self.endpoint_id = endpoint_id
|
43 |
self.name = name
|
44 |
self.prompt_fn = prompt_fn
|
|
|
45 |
self.generation_config = {
|
46 |
"max_new_tokens": 1024,
|
47 |
"top_k": 40,
|
@@ -52,7 +53,7 @@ class Pipeline:
|
|
52 |
"seed": -1,
|
53 |
"batch_size": 8,
|
54 |
"threads": -1,
|
55 |
-
"stop": ["</s>", "USER:", "### Instruction:"],
|
56 |
}
|
57 |
|
58 |
def __call__(self, prompt):
|
@@ -102,7 +103,7 @@ AVAILABLE_MODELS = {
|
|
102 |
"hermes-13b": ("p0zqb2gkcwp0ww", prompt_instruct),
|
103 |
"manticore-13b-chat": ("u6tv84bpomhfei", prompt_chat),
|
104 |
"airoboros-13b": ("rglzxnk80660ja", prompt_chat),
|
105 |
-
"supercot-13b": ("0be7865dwxpwqk", prompt_instruct),
|
106 |
"mpt-7b-instruct": ("jpqbvnyluj18b0", prompt_instruct),
|
107 |
}
|
108 |
|
@@ -111,7 +112,10 @@ _memoized_models = defaultdict()
|
|
111 |
|
112 |
def get_model_pipeline(model_name):
|
113 |
if not _memoized_models.get(model_name):
|
114 |
-
|
|
|
|
|
|
|
115 |
return _memoized_models.get(model_name)
|
116 |
|
117 |
start_message = """- The Assistant is helpful and transparent.
|
|
|
38 |
class Pipeline:
|
39 |
prefer_async = True
|
40 |
|
41 |
+
def __init__(self, endpoint_id, name, prompt_fn, stop_tokens=None):
|
42 |
self.endpoint_id = endpoint_id
|
43 |
self.name = name
|
44 |
self.prompt_fn = prompt_fn
|
45 |
+
stop_tokens = stop_tokens or []
|
46 |
self.generation_config = {
|
47 |
"max_new_tokens": 1024,
|
48 |
"top_k": 40,
|
|
|
53 |
"seed": -1,
|
54 |
"batch_size": 8,
|
55 |
"threads": -1,
|
56 |
+
"stop": ["</s>", "USER:", "### Instruction:"] + stop_tokens,
|
57 |
}
|
58 |
|
59 |
def __call__(self, prompt):
|
|
|
103 |
"hermes-13b": ("p0zqb2gkcwp0ww", prompt_instruct),
|
104 |
"manticore-13b-chat": ("u6tv84bpomhfei", prompt_chat),
|
105 |
"airoboros-13b": ("rglzxnk80660ja", prompt_chat),
|
106 |
+
"supercot-13b": ("0be7865dwxpwqk", prompt_instruct, ["Instruction:"]),
|
107 |
"mpt-7b-instruct": ("jpqbvnyluj18b0", prompt_instruct),
|
108 |
}
|
109 |
|
|
|
112 |
|
113 |
def get_model_pipeline(model_name):
|
114 |
if not _memoized_models.get(model_name):
|
115 |
+
kwargs = {}
|
116 |
+
if len(AVAILABLE_MODELS[model_name]) >= 3:
|
117 |
+
kwargs["stop_tokens"] = AVAILABLE_MODELS[model_name][2]
|
118 |
+
_memoized_models[model_name] = Pipeline(AVAILABLE_MODELS[model_name][0], model_name, AVAILABLE_MODELS[model_name][1], **kwargs)
|
119 |
return _memoized_models.get(model_name)
|
120 |
|
121 |
start_message = """- The Assistant is helpful and transparent.
|