emozilla commited on
Commit
ffbd412
1 Parent(s): bb31505

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +61 -0
README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - emozilla/booksum-summary-analysis_gptneox-8192
5
+ - kmfoda/booksum
6
+ ---
7
+
8
+ # mpt-7b-storysummarizer
9
+
10
+ This is a fine-tuned version of [mosaicml/mpt-7b-storywriter](https://huggingface.co/mosaicml/mpt-7b-storywriter) on [emozilla/booksum-summary-analysis_gptneox-8192](emozilla/booksum-summary-analysis_gptneox-8192), which is adapted from [kmfoda/booksum](https://huggingface.co/datasets/kmfoda/booksum).
11
+ The training run was performed using [llm-foundry](https://github.com/mosaicml/llm-foundry) on an 8xA100 80 GB node at 8192 context length. The run can be viewed on [wandb](https://wandb.ai/emozilla/booksum/runs/457ym4r9).
12
+
13
+ ## How to Use
14
+
15
+ This model is intended for summarization and literary analysis of fiction stories. It can be prompted in one of two ways:
16
+
17
+ ```
18
+ SOME_FICTION
19
+
20
+ ### SUMMARY:
21
+ ```
22
+
23
+ or
24
+
25
+ ```
26
+ SOME_FICTION
27
+
28
+ ### ANALYSIS:
29
+ ```
30
+
31
+ A `repetition_penalty` of ~1.04 seems to be best. For summary prompts, simple greedy search suffices while a temperature of 0.8 works well for analysis.
32
+ The model often prints `'#'` to delinate the end of a a summary or analyis. You can use `transformers.StopOnTokens` to end a generation.
33
+
34
+ ```python
35
+ class StopOnTokens(StoppingCriteria):
36
+ def __init__(self, stop_ids):
37
+ self.stop_ids = stop_ids
38
+
39
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
40
+ for stop_id in self.stop_ids:
41
+ if input_ids[0][-1] == stop_id:
42
+ return True
43
+ return False
44
+
45
+ stop_ids = tokenizer("#").input_ids
46
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)]),
47
+ ```
48
+
49
+ Pass `stopping_criteria` as an argument to the model's `generate` function to stop on `#`.
50
+
51
+ The code for this model includes adaptions from [Birchlabs/mosaicml-mpt-7b-chat-qlora](https://huggingface.co/Birchlabs/mosaicml-mpt-7b-chat-qlora) which allow MPT models to be loaded with `device_map="auto"` and `load_in_8bit=True`.
52
+ For longer contexts, the following is recommended:
53
+
54
+ ```python
55
+ tokenizer = AutoTokenizer.from_pretrained("emozilla/mpt-7b-storysummarizer")
56
+ model = AutoModelForCausalLM.from_pretrained(
57
+ "emozilla/mpt-7b-storysummarizer",
58
+ load_in_8bit=True,
59
+ trust_remote_code=True,
60
+ device_map="auto")
61
+ ```