eliphatfs
commited on
Commit
•
24cb86c
1
Parent(s):
654bd81
Avoid redundant model loading.
Browse files
openshape/demo/caption.py
CHANGED
@@ -2,7 +2,7 @@ from torch import nn
|
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
from typing import Tuple, List, Union, Optional
|
5 |
-
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
|
@@ -60,7 +60,7 @@ class ClipCaptionModel(nn.Module):
|
|
60 |
def __init__(self, prefix_length: int, prefix_size: int = 512):
|
61 |
super(ClipCaptionModel, self).__init__()
|
62 |
self.prefix_length = prefix_length
|
63 |
-
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
|
64 |
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
65 |
if prefix_length > 10: # not enough memory
|
66 |
self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
|
|
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
from typing import Tuple, List, Union, Optional
|
5 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
|
8 |
|
|
|
60 |
def __init__(self, prefix_length: int, prefix_size: int = 512):
|
61 |
super(ClipCaptionModel, self).__init__()
|
62 |
self.prefix_length = prefix_length
|
63 |
+
self.gpt = GPT2LMHeadModel(GPT2Config.from_pretrained('gpt2'))
|
64 |
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
65 |
if prefix_length > 10: # not enough memory
|
66 |
self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
|