Yossh commited on
Commit
1c7d0a2
·
verified ·
1 Parent(s): a48f85c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +45 -3
README.md CHANGED
@@ -1,3 +1,45 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ アニメイラストのセリフや擬音を検出するモデルです
5
+ ベースモデルにInternViT-6B-448px-V1-5を使用しています。
6
+ https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-5を使用しています。
7
+
8
+
9
+ ベースモデルのpooler_output層にこんな感じに繋げば使えると思います。
10
+ ```python
11
+ import torch
12
+ from PIL import Image
13
+ from transformers import AutoModel, CLIPImageProcessor
14
+
15
+ class CustomModel(nn.Module):
16
+ def __init__(self, base_model, num_classes=2):
17
+ super(CustomModel, self).__init__()
18
+ self.base_model = base_model
19
+ self.classifier = nn.Linear(base_model.config.hidden_size, num_classes).to(torch.bfloat16)
20
+
21
+ def forward(self, x):
22
+ outputs = self.base_model(x)
23
+ pooled_output = outputs.pooler_output
24
+ logits = self.classifier(pooled_output)
25
+ return logits
26
+
27
+ base_model = AutoModel.from_pretrained(
28
+ 'OpenGVLab/InternViT-6B-448px-V1-5',
29
+ torch_dtype=torch.bfloat16,
30
+ low_cpu_mem_usage=True,
31
+ trust_remote_code=True).cuda().eval()
32
+
33
+ model = CustomModel(base_model, num_classes=2).to(device).eval()
34
+ model.classifier.load_state_dict(torch.load("checkpoints/classifier_weights.pth"))
35
+
36
+ image = Image.open('./examples/image1.jpg').convert('RGB')
37
+
38
+ image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-448px-V1-5')
39
+
40
+ pixel_values = image_processor(images=image, return_tensors='pt').pixel_values.to(torch.bfloat16).cuda()
41
+
42
+ with torch.no_grad():
43
+ outputs = model(pixel_values)
44
+
45
+ ```