[CVPR'21] LIIF文章及源碼理解

文章名為:Learning Continuous Image Representation with Local Implicit Image Function,簡(jiǎn)稱LIIF,收錄在CVPR21年,源碼為:https://yinboc.github.io/liif/。由于拜讀文章之后,對(duì)實(shí)現(xiàn)比較感興趣,所以學(xué)習(xí)并分享源碼實(shí)現(xiàn)部分(有所簡(jiǎn)化,主要為了將代碼和原文對(duì)應(yīng)上。):


核心思想:首先,需要得到離散圖片的特征z,然后通過一個(gè)網(wǎng)絡(luò)f_{\theta},將特征z和連續(xù)域的坐標(biāo)x映射成目標(biāo)的預(yù)測(cè)值。

以EDSR數(shù)據(jù)集為例,作者構(gòu)建了一個(gè)Encoder(EDSR類),來完成三通道的離散圖片到特征z的映射。

def conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias)

class EDSR(nn.Module):
    def __init__(self):
        super(EDSR, self).__init__()
        # define head module
        m_head = [conv(args.n_colors, n_feats, kernel_size)]
        # define body module
        m_body = [
            ResBlock(
                conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.out_dim = n_feats
        
    def forward(self, x):
        x = self.head(x)
        res = self.body(x)
        res += x
        x = res
        return x

代碼有所簡(jiǎn)化,主要結(jié)構(gòu)是卷積(self.head)+16層ResBlock+卷積(self.body)。

然后是這個(gè)網(wǎng)絡(luò)f_{\theta},代碼中體現(xiàn)為一個(gè)MLP:

class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_list):
        super().__init__()
        layers = []
        lastv = in_dim
        for hidden in hidden_list:
            layers.append(nn.Linear(lastv, hidden))
            layers.append(nn.ReLU())
            lastv = hidden
        layers.append(nn.Linear(lastv, out_dim))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        shape = x.shape[:-1]
        x = self.layers(x.view(-1, x.shape[-1]))
        return x.view(*shape, -1)

主要結(jié)構(gòu)是(Linear+ReLU)*4 + Linear共9層。

最簡(jiǎn)單的情況是:實(shí)際在預(yù)測(cè)的時(shí)候,以待預(yù)測(cè)下標(biāo)x_{q}相對(duì)最近的離散查詢點(diǎn)的坐標(biāo)來進(jìn)行的:

另外,結(jié)合了三個(gè)特性來優(yōu)化模型表現(xiàn):

  • Feature unfolding:個(gè)人理解:應(yīng)該是多個(gè)離散的像素合成為一個(gè)block,增加信息量。


feat = self.encoder(x) # 輸入x到前面的EDSR實(shí)例,得到特征$z$
feat = nn.functional.unfold(feat, 3, padding=1)
                    .view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) # (n, c, w, h) -> (n, c*9, block_index1, block_index2),其中,第二維是按照先塊后通道來排列的。而且block_index1和block_index2的數(shù)值對(duì)應(yīng)為w和h。
  • Local ensemble:個(gè)人理解:為了緩解圖像邊界的突變問題,把各個(gè)部分按照距離進(jìn)行加權(quán)平均和,平滑過渡邊界區(qū)域。


其中,需要先解決中心坐標(biāo)的問題。以下是根據(jù)圖像大小來生成對(duì)應(yīng)的中心坐標(biāo)。

def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret

數(shù)據(jù)集生成,主要是dataset的getitem方法:

# 用torchvision來縮放,縮放的方法是雙立方插值
def resize_fn(img, size):
    return transforms.ToTensor()(
        transforms.Resize(size, Image.BICUBIC)(
            transforms.ToPILImage()(img)))
# 把圖像轉(zhuǎn)成對(duì)應(yīng)的坐標(biāo)和數(shù)值
def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (3, H, W)
    """
    coord = make_coord(img.shape[-2:])
    rgb = img.view(3, -1).permute(1, 0)
    return coord, rgb
# 
def __getitem__(self, idx):
    img_lr, img_hr = self.dataset[idx]
    p = idx / (len(self.dataset) - 1)
    w_hr = round(self.size_min + (self.size_max - self.size_min) * p) # 隨著idx的變大,目標(biāo)圖像越來越大。
    img_hr = resize_fn(img_hr, w_hr) # 縮放目標(biāo)圖像

    if self.augment: # 圖像增強(qiáng)
        if random.random() < 0.5: # 概率性反轉(zhuǎn)
            img_lr = img_lr.flip(-1)
            img_hr = img_hr.flip(-1)

    if self.gt_resize is not None: # 如果指定了目標(biāo)圖像大小,則縮放目標(biāo)圖像
        img_hr = resize_fn(img_hr, self.gt_resize)

    # 得到hr的坐標(biāo)和rgb數(shù)值
    hr_coord, hr_rgb = to_pixel_samples(img_hr)

    # hr部分只隨機(jī)取出一部分像素?(不放回取樣)
    if self.sample_q is not None:
        sample_lst = np.random.choice(
            len(hr_coord), self.sample_q, replace=False)
        hr_coord = hr_coord[sample_lst]
        hr_rgb = hr_rgb[sample_lst]
    # 得到目標(biāo)圖像對(duì)應(yīng)cell的大小。
    cell = torch.ones_like(hr_coord)
    cell[:, 0] *= 2 / img_hr.shape[-2]
    cell[:, 1] *= 2 / img_hr.shape[-1]
    # 每次是輸入圖像,目標(biāo)圖像的坐標(biāo),目標(biāo)圖像的cell大小,目標(biāo)圖像。
    return {
        'inp': img_lr,
        'coord': hr_coord,
        'cell': cell,
        'gt': hr_rgb
    }

然后再對(duì)應(yīng)到具體的操作:整體的操作是:構(gòu)造目標(biāo)圖像的坐標(biāo)coord_,然后通過這個(gè)坐標(biāo)去采樣特征和坐標(biāo),包括:從當(dāng)前輸入低分圖像的特征feat采樣到q_feat;從當(dāng)前輸入低分圖像的坐標(biāo)feat_coord采樣得到q_coord。然后構(gòu)造出相對(duì)的坐標(biāo)rel_coord,再乘上特征對(duì)應(yīng)的shape大小,得到相對(duì)的特征空間偏移,作為公式(4)中的x_{q}-v_{t}^{*}

# 先通過函數(shù)算出$z$
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() \
    .permute(2, 0, 1) \
    .unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])

# field radius (global: [-1, 1])
rx = 2 / feat.shape[-2] / 2
ry = 2 / feat.shape[-1] / 2
preds = []
areas = []
for vx in vx_lst:
    for vy in vy_lst:
        coord_ = coord.clone()
        coord_[:, :, 0] += vx * rx + eps_shift
        coord_[:, :, 1] += vy * ry + eps_shift
        coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
        q_feat = F.grid_sample(
            feat, coord_.flip(-1).unsqueeze(1),
            mode='nearest', align_corners=False)[:, :, 0, :] \
            .permute(0, 2, 1)
        q_coord = F.grid_sample(
            feat_coord, coord_.flip(-1).unsqueeze(1),
            mode='nearest', align_corners=False)[:, :, 0, :] \
            .permute(0, 2, 1)
        rel_coord = coord - q_coord
        rel_coord[:, :, 0] *= feat.shape[-2]
        rel_coord[:, :, 1] *= feat.shape[-1]
        inp = torch.cat([q_feat, rel_coord], dim=-1)

        # cell decoding放在下面解說。
        if self.cell_decode:
            rel_cell = cell.clone()
            rel_cell[:, :, 0] *= feat.shape[-2]
            rel_cell[:, :, 1] *= feat.shape[-1]
            inp = torch.cat([inp, rel_cell], dim=-1)

        bs, q = coord.shape[:2]
        pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
        preds.append(pred)

        area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
        areas.append(area + 1e-9)

tot_area = torch.stack(areas).sum(dim=0)
if self.local_ensemble: # 為什么要做面積計(jì)算的結(jié)果翻轉(zhuǎn)?
    t = areas[0]; areas[0] = areas[3]; areas[3] = t
    t = areas[1]; areas[1] = areas[2]; areas[2] = t
ret = 0
for pred, area in zip(preds, areas):
    ret = ret + pred * (area / tot_area).unsqueeze(-1)
return ret

根據(jù)相對(duì)的坐標(biāo),rel_coord可以算出對(duì)應(yīng)部分的面積,進(jìn)而在最后以公式(4)的比例累加的方式進(jìn)行計(jì)算。(體現(xiàn)在最后8行代碼)

  • Cell decoding:個(gè)人理解:把坐標(biāo)以及對(duì)應(yīng)的小塊大小作為額外信息輸入,主要是c的加入,實(shí)驗(yàn)表明是有效果的。

先把目標(biāo)圖像的cell大小,乘上特征對(duì)應(yīng)的shape大小,變?yōu)橄鄬?duì)大小rel_cell。這部分就是公式(5)中的c

最終,把這些特征傳入MLP(即self.imnet)進(jìn)行計(jì)算,獲得預(yù)測(cè)的值。


最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

推薦閱讀更多精彩內(nèi)容