eliphatfs commited on
Commit
24cb86c
1 Parent(s): 654bd81

Avoid redundant model loading.

Browse files
Files changed (1) hide show
  1. openshape/demo/caption.py +2 -2
openshape/demo/caption.py CHANGED
@@ -2,7 +2,7 @@ from torch import nn
2
  import numpy as np
3
  import torch
4
  from typing import Tuple, List, Union, Optional
5
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
6
  from huggingface_hub import hf_hub_download
7
 
8
 
@@ -60,7 +60,7 @@ class ClipCaptionModel(nn.Module):
60
  def __init__(self, prefix_length: int, prefix_size: int = 512):
61
  super(ClipCaptionModel, self).__init__()
62
  self.prefix_length = prefix_length
63
- self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
64
  self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
65
  if prefix_length > 10: # not enough memory
66
  self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
 
2
  import numpy as np
3
  import torch
4
  from typing import Tuple, List, Union, Optional
5
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
6
  from huggingface_hub import hf_hub_download
7
 
8
 
 
60
  def __init__(self, prefix_length: int, prefix_size: int = 512):
61
  super(ClipCaptionModel, self).__init__()
62
  self.prefix_length = prefix_length
63
+ self.gpt = GPT2LMHeadModel(GPT2Config.from_pretrained('gpt2'))
64
  self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
65
  if prefix_length > 10: # not enough memory
66
  self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)