BAAI
/

Anhforth commited on
Commit
bfc26a0
1 Parent(s): 181b080

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +76 -22
README.md CHANGED
@@ -121,31 +121,85 @@ image.save("./alt.png")
121
 
122
  ![alt](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/hub/alt.png)
123
 
124
- ## FlagAI Example
125
-
126
- 以下示例将为文本输入`Anime portrait of natalie portman as an anime girl by stanley artgerm lau, wlop, rossdraws, james jean, andrei riabovitchev, marc simonetti, and sakimichan, trending on artstation` 在目录`./AltDiffusionOutputs`下生成图片结果。
127
-
128
- The following example will generate image results for text input `Anime portrait of natalie portman as an anime girl by stanley artgerm lau, wlop, rossdraws, james jean, andrei riabovitchev, marc simonetti, and sakimichan, trending on artstation` under the default output directory `./AltDiffusionOutputs`
129
 
130
  ```python
 
131
  import torch
132
- from flagai.auto_model.auto_loader import AutoLoader
133
- from flagai.model.predictor.predictor import Predictor
134
-
135
- # Initialize
136
- prompt = "Anime portrait of natalie portman as an anime girl by stanley artgerm lau, wlop, rossdraws, james jean, andrei riabovitchev, marc simonetti, and sakimichan, trending on artstation"
137
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
138
-
139
-
140
- loader = AutoLoader(task_name="text2img", #contrastive learning
141
- model_name="AltDiffusion",
142
- model_dir="./checkpoints")
143
-
144
- model = loader.get_model()
145
- model.eval()
146
- model.to(device)
147
- predictor = Predictor(model)
148
- predictor.predict_generate_images(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  ```
150
 
151
 
 
121
 
122
  ![alt](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/hub/alt.png)
123
 
124
+ ## Transformers Example
 
 
 
 
125
 
126
  ```python
127
+ import os
128
  import torch
129
+ import transformers
130
+ from transformers import BertPreTrainedModel
131
+ from transformers.models.clip.modeling_clip import CLIPPreTrainedModel
132
+ from transformers.models.xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
133
+ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
134
+ from diffusers import StableDiffusionPipeline
135
+ from transformers import BertPreTrainedModel,BertModel,BertConfig
136
+ import torch.nn as nn
137
+ import torch
138
+ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
139
+ from transformers import XLMRobertaModel
140
+ from transformers.activations import ACT2FN
141
+ from typing import Optional
142
+
143
+
144
+ class RobertaSeriesConfig(XLMRobertaConfig):
145
+ def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=768,pooler_fn='cls',learn_encoder=False, **kwargs):
146
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
147
+ self.project_dim = project_dim
148
+ self.pooler_fn = pooler_fn
149
+ # self.learn_encoder = learn_encoder
150
+
151
+ class RobertaSeriesModelWithTransformation(BertPreTrainedModel):
152
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
153
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
154
+ base_model_prefix = 'roberta'
155
+ config_class= XLMRobertaConfig
156
+ def __init__(self, config):
157
+ super().__init__(config)
158
+ self.roberta = XLMRobertaModel(config)
159
+ self.transformation = nn.Linear(config.hidden_size, config.project_dim)
160
+ self.post_init()
161
+
162
+ def get_text_embeds(self,bert_embeds,clip_embeds):
163
+ return self.merge_head(torch.cat((bert_embeds,clip_embeds)))
164
+
165
+ def set_tokenizer(self, tokenizer):
166
+ self.tokenizer = tokenizer
167
+
168
+ def forward(self, input_ids: Optional[torch.Tensor] = None) :
169
+ attention_mask = (input_ids != self.tokenizer.pad_token_id).to(torch.int64)
170
+ outputs = self.base_model(
171
+ input_ids=input_ids,
172
+ attention_mask=attention_mask,
173
+ )
174
+
175
+ projection_state = self.transformation(outputs.last_hidden_state)
176
+
177
+ return (projection_state,)
178
+
179
+ model_path_encoder = "BAAI/RobertaSeriesModelWithTransformation"
180
+ model_path_diffusion = "BAAI/AltDiffusion"
181
+ device = "cuda"
182
+
183
+ seed = 12345
184
+ tokenizer = XLMRobertaTokenizer.from_pretrained(model_path_encoder, use_auth_token=True)
185
+ tokenizer.model_max_length = 77
186
+
187
+ text_encoder = RobertaSeriesModelWithTransformation.from_pretrained(model_path_encoder, use_auth_token=True)
188
+ text_encoder.set_tokenizer(tokenizer)
189
+ print("text encode loaded")
190
+ pipe = StableDiffusionPipeline.from_pretrained(model_path_diffusion,
191
+ tokenizer=tokenizer,
192
+ text_encoder=text_encoder,
193
+ use_auth_token=True,
194
+ )
195
+ print("diffusion pipeline loaded")
196
+ pipe = pipe.to(device)
197
+
198
+ prompt = "Thirty years old lee evans as a sad 19th century postman. detailed, soft focus, candle light, interesting lights, realistic, oil canvas, character concept art by munkácsy mihály, csók istván, john everett millais, henry meynell rheam, and da vinci"
199
+ with torch.no_grad():
200
+ image = pipe(prompt, guidance_scale=7.5).images[0]
201
+
202
+ image.save("3.png")
203
  ```
204
 
205