アニメイラストのセリフや擬音を検出するモデルです ベースモデルにInternViT-6B-448px-V1-5を使用しています。 https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-5を使用しています。

ベースモデルのpooler_output層にこんな感じに繋げば使えると思います。

import torch
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor

class CustomModel(nn.Module):
    def __init__(self, base_model, num_classes=2):
        super(CustomModel, self).__init__()
        self.base_model = base_model
        self.classifier = nn.Linear(base_model.config.hidden_size, num_classes).to(torch.bfloat16)
    
    def forward(self, x):
        outputs = self.base_model(x)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

base_model = AutoModel.from_pretrained(
    'OpenGVLab/InternViT-6B-448px-V1-5',
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True).cuda().eval()

model = CustomModel(base_model, num_classes=2).to(device).eval()
model.classifier.load_state_dict(torch.load("checkpoints/classifier_weights.pth"))

image = Image.open('./examples/image1.jpg').convert('RGB')

image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-448px-V1-5')

pixel_values = image_processor(images=image, return_tensors='pt').pixel_values.to(torch.bfloat16).cuda()

with torch.no_grad():
  outputs = model(pixel_values)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.