|
from bart import BartCaptionModel |
|
from audio_utils import load_audio, STR_CH_FIRST |
|
import torch |
|
|
|
try: |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
except: |
|
print("1") |
|
try: |
|
model = BartCaptionModel(max_length = 128) |
|
except: |
|
print("2") |
|
|
|
try: |
|
pretrained_object = torch.load('transfer.pth', map_location='cpu') |
|
except: |
|
print("3") |
|
|
|
try: |
|
state_dict = pretrained_object['state_dict'] |
|
except: |
|
print("4") |
|
|
|
try: |
|
model.load_state_dict(state_dict) |
|
except: |
|
print("5") |
|
|
|
try: |
|
torch.save(model,"model.pth") |
|
except: |
|
print("6") |