FredZhang7
commited on
Commit
•
ab114ee
1
Parent(s):
0aabb74
Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,55 @@
|
|
1 |
---
|
2 |
-
license:
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
license: creativeml-openrail-m
|
3 |
+
tags:
|
4 |
+
- stable-diffusion
|
5 |
+
- prompt-generator
|
6 |
+
- distilgpt2
|
7 |
---
|
8 |
+
# Distilgpt2 Stable Diffusion Model Card
|
9 |
+
Distilgpt2 Stable Diffusion is a text-to-text model used to generate creative and coherent prompts given any text.
|
10 |
+
This model was finetuned on 2.03M stable diffusion prompts from [Stable Diffusion discord](https://huggingface.co/datasets/bartman081523/stable-diffusion-discord-prompts), [Lexica.art](https://huggingface.co/datasets/Gustavosta/Stable-Diffusion-Prompts), and (my hand-picked) [Krea.ai](./krea.ai.txt). I filtered the hand-picked prompts based on the output results from Stable Diffusion v1.4.
|
11 |
+
|
12 |
+
### PyTorch
|
13 |
+
|
14 |
+
```bash
|
15 |
+
pip install --upgrade transformers
|
16 |
+
```
|
17 |
+
|
18 |
+
```python
|
19 |
+
# download DistilGPT2 Stable Diffusion if haven't already
|
20 |
+
import os
|
21 |
+
if not os.path.exists('./distil-sd-gpt2.pt'):
|
22 |
+
import urllib.request
|
23 |
+
print('Downloading model...')
|
24 |
+
urllib.request.urlretrieve('https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion/resolve/main/distil-sd-gpt2.pt', './distil-sd-gpt2.pt')
|
25 |
+
print('Model downloaded.')
|
26 |
+
|
27 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
28 |
+
|
29 |
+
# load the pretrained tokenizer
|
30 |
+
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
|
31 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
32 |
+
tokenizer.max_len = 512
|
33 |
+
|
34 |
+
# load the fine-tuned model
|
35 |
+
import torch
|
36 |
+
model = GPT2LMHeadModel.from_pretrained('distilgpt2')
|
37 |
+
model.load_state_dict(torch.load('model.pt'))
|
38 |
+
|
39 |
+
# generate text using fine-tuned model
|
40 |
+
from transformers import pipeline
|
41 |
+
nlp = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
42 |
+
ins = "a beautiful city"
|
43 |
+
|
44 |
+
# generate 5 samples
|
45 |
+
outs = nlp(ins, max_length=80, num_return_sequences=10)
|
46 |
+
|
47 |
+
# print the 5 samples
|
48 |
+
for i in range(len(outs)):
|
49 |
+
outs[i] = str(outs[i]['generated_text']).replace(' ', '')
|
50 |
+
print('\033[96m' + ins + '\033[0m')
|
51 |
+
print('\033[93m' + '\n\n'.join(outs) + '\033[0m')
|
52 |
+
```
|
53 |
+
|
54 |
+
Example Output:
|
55 |
+
![Example Output](https://media.discordapp.net/attachments/884528247998664744/1049544706163482704/image.png?width=1440&height=479)
|