Spaces:
Runtime error
Runtime error
π¨ π
Browse filesSigned-off-by: peter szemraj <peterszemraj@gmail.com>
- aggregate.py +31 -18
aggregate.py
CHANGED
@@ -1,3 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pprint as pp
|
2 |
import logging
|
3 |
import time
|
@@ -14,10 +23,15 @@ logging.basicConfig(
|
|
14 |
|
15 |
|
16 |
class BatchAggregator:
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
21 |
GENERIC_CONFIG = GenerationConfig(
|
22 |
num_beams=8,
|
23 |
early_stopping=True,
|
@@ -29,10 +43,23 @@ class BatchAggregator:
|
|
29 |
no_repeat_ngram_size=4,
|
30 |
encoder_no_repeat_ngram_size=5,
|
31 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def __init__(
|
34 |
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
|
35 |
):
|
|
|
|
|
|
|
|
|
|
|
36 |
self.device = None
|
37 |
self.is_compiled = False
|
38 |
self.logger = logging.getLogger(__name__)
|
@@ -125,20 +152,6 @@ class BatchAggregator:
|
|
125 |
"""
|
126 |
self.aggregator.model.generation_config = self.GENERIC_CONFIG
|
127 |
|
128 |
-
if "bart" in self.model_name.lower():
|
129 |
-
self.logger.info("Using BART model, updating generation config")
|
130 |
-
upd = {
|
131 |
-
"num_beams": 8,
|
132 |
-
"repetition_penalty": 1.3,
|
133 |
-
"length_penalty": 1.0,
|
134 |
-
"_from_model_config": False,
|
135 |
-
"max_new_tokens": 256,
|
136 |
-
"min_new_tokens": 32,
|
137 |
-
"no_repeat_ngram_size": 3,
|
138 |
-
"encoder_no_repeat_ngram_size": 6,
|
139 |
-
} # TODO: clean up
|
140 |
-
self.aggregator.model.generation_config.update(**upd)
|
141 |
-
|
142 |
if (
|
143 |
"large"
|
144 |
or "xl" in self.model_name.lower()
|
|
|
1 |
+
"""
|
2 |
+
aggregate.py is a module for aggregating text from multiple sources, or multiple parts of a single source.
|
3 |
+
Primary usage is through the BatchAggregator class.
|
4 |
+
|
5 |
+
How it works:
|
6 |
+
1. We tell the language model to do it.
|
7 |
+
2. The language model does it.
|
8 |
+
3. Yaay!
|
9 |
+
"""
|
10 |
import pprint as pp
|
11 |
import logging
|
12 |
import time
|
|
|
23 |
|
24 |
|
25 |
class BatchAggregator:
|
26 |
+
"""
|
27 |
+
BatchAggregator is a class for aggregating text from multiple sources.
|
28 |
+
|
29 |
+
Usage:
|
30 |
+
>>> from aggregate import BatchAggregator
|
31 |
+
>>> aggregator = BatchAggregator()
|
32 |
+
>>> aggregator.aggregate(["This is a test", "This is another test"])
|
33 |
+
"""
|
34 |
+
|
35 |
GENERIC_CONFIG = GenerationConfig(
|
36 |
num_beams=8,
|
37 |
early_stopping=True,
|
|
|
43 |
no_repeat_ngram_size=4,
|
44 |
encoder_no_repeat_ngram_size=5,
|
45 |
)
|
46 |
+
CONFIGURED_MODELS = [
|
47 |
+
"pszemraj/bart-large-mnli-dolly_hhrlhf-v1",
|
48 |
+
"pszemraj/bart-base-instruct-dolly_hhrlhf",
|
49 |
+
"pszemraj/flan-t5-large-instruct-dolly_hhrlhf",
|
50 |
+
"pszemraj/flan-t5-base-instruct-dolly_hhrlhf",
|
51 |
+
] # these have generation configs defined for this task in their model repos
|
52 |
+
|
53 |
+
DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
|
54 |
|
55 |
def __init__(
|
56 |
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
|
57 |
):
|
58 |
+
"""
|
59 |
+
__init__ initializes the BatchAggregator class.
|
60 |
+
|
61 |
+
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
62 |
+
"""
|
63 |
self.device = None
|
64 |
self.is_compiled = False
|
65 |
self.logger = logging.getLogger(__name__)
|
|
|
152 |
"""
|
153 |
self.aggregator.model.generation_config = self.GENERIC_CONFIG
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
if (
|
156 |
"large"
|
157 |
or "xl" in self.model_name.lower()
|