File size: 3,655 Bytes
acb96f1 c776d18 9ceae57 c776d18 acb96f1 c776d18 65a9a78 b027a2c c776d18 6ab84be c776d18 6ab84be c776d18 7c227ba c776d18 18ab019 c776d18 18ab019 c776d18 18ab019 c776d18 18ab019 c776d18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
---
license: creativeml-openrail-m
tags:
- stable-diffusion
- prompt-generator
widget:
- text: "amazing"
- text: "a photo of"
- text: "a sci-fi"
- text: "a portrait of"
- text: "a person standing"
- text: "a boy watching"
datasets:
- poloclub/diffusiondb
- Gustavosta/Stable-Diffusion-Prompts
- bartman081523/stable-diffusion-discord-prompts
- FredZhang7/krea-ai-prompts
---
# DistilGPT2 Stable Diffusion V2 Model Card
This model was trained on 2.47 million descriptive stable diffusion prompts on the [FredZhang7/distilgpt2-stable-diffusion](https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion) checkpoint for another 4.27 million steps.
Compared to other prompt generation models using GPT2, this one runs with 50% faster forwardpropagation and 40% less disk space & RAM.
Major improvements from v1 are:
- 25% more variations
- more capable of generating story-like prompts
- cleaned training data
* removed prompts that generate images with nsfw scores > 0.5
* removed duplicates, including prompts that differ by capitalization and punctuations
* removed punctuations at random places
* removed prompts shorter than 15 characters
### PyTorch
```bash
pip install --upgrade transformers
```
Faster but less fluent generation:
```python
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2', pad_token_id=tokenizer.eos_token_id)
prompt = r'a cat sitting'
# generate text using fine-tuned model
from transformers import pipeline
nlp = pipeline('text-generation', model=model, tokenizer=tokenizer)
# generate 5 samples
outs = nlp(prompt, max_length=80, num_return_sequences=5)
print('\nInput:\n' + 100 * '-')
print('\033[96m' + prompt + '\033[0m')
print('\nOutput:\n' + 100 * '-')
for i in range(len(outs)):
outs[i] = str(outs[i]['generated_text']).replace(' ', '')
print('\033[92m' + '\n\n'.join(outs) + '\033[0m\n')
```
Example output:
![greedy search](./greedy_search.png)
<br>
Slower but more fluent generation:
```python
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained('FredZhang7/distilgpt2-stable-diffusion-v2')
model.eval()
prompt = r'a cat sitting' # the beginning of the prompt
temperature = 0.9 # a higher temperature will produce more diverse results, but with a higher risk of less coherent text.
top_k = 8 # the number of tokens to sample from at each step
max_length = 80 # the maximum number of tokens for the output of the model
repitition_penalty = 1.2 # the penalty value for each repetition of a token
num_return_sequences=5 # the number of results to generate
# generate the result with contrastive search. generate 5 results with the highest probability out of 10.
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
output = model.generate(input_ids, do_sample=True, temperature=temperature, top_k=top_k, max_length=max_length, num_return_sequences=num_return_sequences, repetition_penalty=repitition_penalty, penalty_alpha=0.6, no_repeat_ngram_size=1, early_stopping=True)
print('\nInput:\n' + 100 * '-')
print('\033[96m' + prompt + '\033[0m')
print('\nOutput:\n' + 100 * '-')
for i in range(len(output)):
print('\033[92m' + tokenizer.decode(output[i], skip_special_tokens=True) + '\033[0m\n')
```
Example output:
![constrastive search](./constrastive_search.png) |