akhaliq's picture
akhaliq HF staff
add files
89dc200
raw
history blame
1.89 kB
# -*- encoding: utf-8 -*-
'''
@File : sr_group.py
@Time : 2022/04/02 01:17:21
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import numpy as np
import torch
import torch.nn.functional as F
from SwissArmyTransformer.resources import auto_create
from .direct_sr import DirectSuperResolution
from .iterative_sr import IterativeSuperResolution
class SRGroup:
def __init__(self, args, home_path=None,):
dsr_path = auto_create('cogview2-dsr', path=home_path)
itersr_path = auto_create('cogview2-itersr', path=home_path)
dsr = DirectSuperResolution(args, dsr_path)
itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer)
self.dsr = dsr
self.itersr = itersr
def sr_base(self, img_tokens, txt_tokens):
assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2
batch_size = img_tokens.shape[0]
txt_len = txt_tokens.shape[-1]
if len(txt_tokens.shape) == 1:
txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
sred_tokens = self.dsr(txt_tokens, img_tokens)
iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone())
return iter_tokens[-batch_size:]
# def sr_patch(self, img_tokens, txt_tokens):
# assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2
# batch_size = img_tokens.shape[0] * 9
# txt_len = txt_tokens.shape[-1]
# if len(txt_tokens.shape) == 1:
# txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
# img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400)
# iter_tokens = self.sr_base(img_tokens, txt_tokens)
# return iter_tokens