FredZhang7
commited on
Commit
•
c776d18
1
Parent(s):
8f8b4b7
Update README.md
Browse files
README.md
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
---
|
2 |
license: creativeml-openrail-m
|
|
|
|
|
|
|
3 |
widget:
|
4 |
- text: "amazing"
|
5 |
- text: "a photo of"
|
@@ -7,14 +10,92 @@ widget:
|
|
7 |
- text: "a portrait of"
|
8 |
- text: "a person standing"
|
9 |
- text: "a boy watching"
|
|
|
|
|
|
|
|
|
|
|
10 |
---
|
|
|
|
|
|
|
11 |
|
12 |
-
|
13 |
|
14 |
-
Major improvements from v1
|
15 |
- 25% more variations
|
16 |
-
- more capable of generating story-like prompts
|
17 |
- cleaned training data
|
18 |
-
* removed prompts that generate images with nsfw scores
|
19 |
-
* removed duplicates
|
20 |
-
* removed punctuations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: creativeml-openrail-m
|
3 |
+
tags:
|
4 |
+
- stable-diffusion
|
5 |
+
- prompt-generator
|
6 |
widget:
|
7 |
- text: "amazing"
|
8 |
- text: "a photo of"
|
|
|
10 |
- text: "a portrait of"
|
11 |
- text: "a person standing"
|
12 |
- text: "a boy watching"
|
13 |
+
datasets:
|
14 |
+
- poloclub/diffusiondb
|
15 |
+
- Gustavosta/Stable-Diffusion-Prompts
|
16 |
+
- bartman081523/stable-diffusion-discord-prompts
|
17 |
+
- FredZhang7/krea-ai-prompts
|
18 |
---
|
19 |
+
# DistilGPT2 Stable Diffusion V2 Model Card
|
20 |
+
DistilGPT2 Stable Diffusion V2 is a text generation model used to generate creative and coherent prompts for text-to-image models, given any text.
|
21 |
+
This model was trained on 2.47 million descriptive stable diffusion prompts on the [FredZhang7/distilgpt2-stable-diffusion](https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion) checkpoint for 4.27 million steps.
|
22 |
|
23 |
+
Compared to other prompt generation models using GPT2, this one runs with 50% faster forwardpropagation and 40% less disk space & RAM.
|
24 |
|
25 |
+
Major improvements from v1 are:
|
26 |
- 25% more variations
|
27 |
+
- more capable of generating story-like prompts
|
28 |
- cleaned training data
|
29 |
+
* removed prompts that generate images with nsfw scores > 0.5
|
30 |
+
* removed duplicates, including prompts that differ by capitalization and punctuations
|
31 |
+
* removed punctuations at random places
|
32 |
+
* removed prompts shorter than 15 characters
|
33 |
+
|
34 |
+
|
35 |
+
### PyTorch
|
36 |
+
|
37 |
+
```bash
|
38 |
+
pip install --upgrade transformers
|
39 |
+
```
|
40 |
+
|
41 |
+
Faster but less fluent generation:
|
42 |
+
|
43 |
+
```python
|
44 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
45 |
+
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
46 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
47 |
+
model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2', pad_token_id=tokenizer.eos_token_id)
|
48 |
+
|
49 |
+
prompt = r'a cat sitting'
|
50 |
+
|
51 |
+
# generate text using fine-tuned model
|
52 |
+
from transformers import pipeline
|
53 |
+
nlp = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
54 |
+
|
55 |
+
# generate 5 samples
|
56 |
+
outs = nlp(prompt, max_length=80, num_return_sequences=5)
|
57 |
+
|
58 |
+
print('\nInput:\n' + 100 * '-')
|
59 |
+
print('\033[96m' + prompt + '\033[0m')
|
60 |
+
print('\nOutput:\n' + 100 * '-')
|
61 |
+
for i in range(len(outs)):
|
62 |
+
outs[i] = str(outs[i]['generated_text']).replace(' ', '')
|
63 |
+
print('\033[92m' + '\n\n'.join(outs) + '\033[0m\n')
|
64 |
+
```
|
65 |
+
|
66 |
+
Example output:
|
67 |
+
![greedy search](./greedy_search.png)
|
68 |
+
|
69 |
+
<br>
|
70 |
+
|
71 |
+
Slower but more fluent generation:
|
72 |
+
|
73 |
+
```python
|
74 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
75 |
+
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
76 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
77 |
+
model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2', pad_token_id=tokenizer.eos_token_id)
|
78 |
+
model.eval()
|
79 |
+
|
80 |
+
prompt = r'a cat sitting' # the beginning of the prompt
|
81 |
+
temperature = 0.9 # a higher temperature will produce more diverse results, but with a higher risk of less coherent text.
|
82 |
+
top_k = 8 # the number of tokens to sample from at each step
|
83 |
+
max_length = 80 # the maximum number of tokens for the output of the model
|
84 |
+
repitition_penalty = 1.2 # the penalty value for each repetition of a token
|
85 |
+
num_beams=10
|
86 |
+
num_return_sequences=5 # the number of results with the highest probabilities out of num_beams
|
87 |
+
|
88 |
+
# generate the result with contrastive search.
|
89 |
+
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
|
90 |
+
output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length, num_return_sequences=num_return_sequences, num_beams=num_beams, repetition_penalty=repitition_penalty, penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)
|
91 |
+
|
92 |
+
# print results
|
93 |
+
print('\nInput:\n' + 100 * '-')
|
94 |
+
print('\033[96m' + prompt + '\033[0m')
|
95 |
+
print('\nOutput:\n' + 100 * '-')
|
96 |
+
for i in range(len(output)):
|
97 |
+
print('\033[92m' + tokenizer.decode(output[i], skip_special_tokens=True) + '\033[0m\n')
|
98 |
+
```
|
99 |
+
|
100 |
+
Example output:
|
101 |
+
![constrastive search](./constrastive_search.png)
|