文章名為:Learning Continuous Image Representation with Local Implicit Image Function,簡(jiǎn)稱LIIF,收錄在CVPR21年,源碼為:https://yinboc.github.io/liif/。由于拜讀文章之后,對(duì)實(shí)現(xiàn)比較感興趣,所以學(xué)習(xí)并分享源碼實(shí)現(xiàn)部分(有所簡(jiǎn)化,主要為了將代碼和原文對(duì)應(yīng)上。):
核心思想:首先,需要得到離散圖片的特征,然后通過一個(gè)網(wǎng)絡(luò)
,將特征
和連續(xù)域的坐標(biāo)
映射成目標(biāo)的預(yù)測(cè)值。
以EDSR數(shù)據(jù)集為例,作者構(gòu)建了一個(gè)Encoder(EDSR類),來完成三通道的離散圖片到特征的映射。
def conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias)
class EDSR(nn.Module):
def __init__(self):
super(EDSR, self).__init__()
# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [
ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
) for _ in range(n_resblocks)
]
m_body.append(conv(n_feats, n_feats, kernel_size))
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.out_dim = n_feats
def forward(self, x):
x = self.head(x)
res = self.body(x)
res += x
x = res
return x
代碼有所簡(jiǎn)化,主要結(jié)構(gòu)是卷積(self.head)+16層ResBlock+卷積(self.body)。
然后是這個(gè)網(wǎng)絡(luò),代碼中體現(xiàn)為一個(gè)MLP:
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_list):
super().__init__()
layers = []
lastv = in_dim
for hidden in hidden_list:
layers.append(nn.Linear(lastv, hidden))
layers.append(nn.ReLU())
lastv = hidden
layers.append(nn.Linear(lastv, out_dim))
self.layers = nn.Sequential(*layers)
def forward(self, x):
shape = x.shape[:-1]
x = self.layers(x.view(-1, x.shape[-1]))
return x.view(*shape, -1)
主要結(jié)構(gòu)是(Linear+ReLU)*4 + Linear共9層。
最簡(jiǎn)單的情況是:實(shí)際在預(yù)測(cè)的時(shí)候,以待預(yù)測(cè)下標(biāo)相對(duì)最近的離散查詢點(diǎn)的坐標(biāo)來進(jìn)行的:
另外,結(jié)合了三個(gè)特性來優(yōu)化模型表現(xiàn):
-
Feature unfolding:個(gè)人理解:應(yīng)該是多個(gè)離散的像素合成為一個(gè)block,增加信息量。
feat = self.encoder(x) # 輸入x到前面的EDSR實(shí)例,得到特征$z$
feat = nn.functional.unfold(feat, 3, padding=1)
.view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) # (n, c, w, h) -> (n, c*9, block_index1, block_index2),其中,第二維是按照先塊后通道來排列的。而且block_index1和block_index2的數(shù)值對(duì)應(yīng)為w和h。
-
Local ensemble:個(gè)人理解:為了緩解圖像邊界的突變問題,把各個(gè)部分按照距離進(jìn)行加權(quán)平均和,平滑過渡邊界區(qū)域。
其中,需要先解決中心坐標(biāo)的問題。以下是根據(jù)圖像大小來生成對(duì)應(yīng)的中心坐標(biāo)。
def make_coord(shape, ranges=None, flatten=True):
""" Make coordinates at grid centers.
"""
coord_seqs = []
for i, n in enumerate(shape):
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
數(shù)據(jù)集生成,主要是dataset的getitem方法:
# 用torchvision來縮放,縮放的方法是雙立方插值
def resize_fn(img, size):
return transforms.ToTensor()(
transforms.Resize(size, Image.BICUBIC)(
transforms.ToPILImage()(img)))
# 把圖像轉(zhuǎn)成對(duì)應(yīng)的坐標(biāo)和數(shù)值
def to_pixel_samples(img):
""" Convert the image to coord-RGB pairs.
img: Tensor, (3, H, W)
"""
coord = make_coord(img.shape[-2:])
rgb = img.view(3, -1).permute(1, 0)
return coord, rgb
#
def __getitem__(self, idx):
img_lr, img_hr = self.dataset[idx]
p = idx / (len(self.dataset) - 1)
w_hr = round(self.size_min + (self.size_max - self.size_min) * p) # 隨著idx的變大,目標(biāo)圖像越來越大。
img_hr = resize_fn(img_hr, w_hr) # 縮放目標(biāo)圖像
if self.augment: # 圖像增強(qiáng)
if random.random() < 0.5: # 概率性反轉(zhuǎn)
img_lr = img_lr.flip(-1)
img_hr = img_hr.flip(-1)
if self.gt_resize is not None: # 如果指定了目標(biāo)圖像大小,則縮放目標(biāo)圖像
img_hr = resize_fn(img_hr, self.gt_resize)
# 得到hr的坐標(biāo)和rgb數(shù)值
hr_coord, hr_rgb = to_pixel_samples(img_hr)
# hr部分只隨機(jī)取出一部分像素?(不放回取樣)
if self.sample_q is not None:
sample_lst = np.random.choice(
len(hr_coord), self.sample_q, replace=False)
hr_coord = hr_coord[sample_lst]
hr_rgb = hr_rgb[sample_lst]
# 得到目標(biāo)圖像對(duì)應(yīng)cell的大小。
cell = torch.ones_like(hr_coord)
cell[:, 0] *= 2 / img_hr.shape[-2]
cell[:, 1] *= 2 / img_hr.shape[-1]
# 每次是輸入圖像,目標(biāo)圖像的坐標(biāo),目標(biāo)圖像的cell大小,目標(biāo)圖像。
return {
'inp': img_lr,
'coord': hr_coord,
'cell': cell,
'gt': hr_rgb
}
然后再對(duì)應(yīng)到具體的操作:整體的操作是:構(gòu)造目標(biāo)圖像的坐標(biāo)coord_,然后通過這個(gè)坐標(biāo)去采樣特征和坐標(biāo),包括:從當(dāng)前輸入低分圖像的特征feat采樣到q_feat;從當(dāng)前輸入低分圖像的坐標(biāo)feat_coord采樣得到q_coord。然后構(gòu)造出相對(duì)的坐標(biāo)rel_coord,再乘上特征對(duì)應(yīng)的shape大小,得到相對(duì)的特征空間偏移,作為公式(4)中的
# 先通過函數(shù)算出$z$
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() \
.permute(2, 0, 1) \
.unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])
# field radius (global: [-1, 1])
rx = 2 / feat.shape[-2] / 2
ry = 2 / feat.shape[-1] / 2
preds = []
areas = []
for vx in vx_lst:
for vy in vy_lst:
coord_ = coord.clone()
coord_[:, :, 0] += vx * rx + eps_shift
coord_[:, :, 1] += vy * ry + eps_shift
coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
q_feat = F.grid_sample(
feat, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
q_coord = F.grid_sample(
feat_coord, coord_.flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
rel_coord = coord - q_coord
rel_coord[:, :, 0] *= feat.shape[-2]
rel_coord[:, :, 1] *= feat.shape[-1]
inp = torch.cat([q_feat, rel_coord], dim=-1)
# cell decoding放在下面解說。
if self.cell_decode:
rel_cell = cell.clone()
rel_cell[:, :, 0] *= feat.shape[-2]
rel_cell[:, :, 1] *= feat.shape[-1]
inp = torch.cat([inp, rel_cell], dim=-1)
bs, q = coord.shape[:2]
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
preds.append(pred)
area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
areas.append(area + 1e-9)
tot_area = torch.stack(areas).sum(dim=0)
if self.local_ensemble: # 為什么要做面積計(jì)算的結(jié)果翻轉(zhuǎn)?
t = areas[0]; areas[0] = areas[3]; areas[3] = t
t = areas[1]; areas[1] = areas[2]; areas[2] = t
ret = 0
for pred, area in zip(preds, areas):
ret = ret + pred * (area / tot_area).unsqueeze(-1)
return ret
根據(jù)相對(duì)的坐標(biāo),rel_coord可以算出對(duì)應(yīng)部分的面積,進(jìn)而在最后以公式(4)的比例累加的方式進(jìn)行計(jì)算。(體現(xiàn)在最后8行代碼)
- Cell decoding:個(gè)人理解:把坐標(biāo)以及對(duì)應(yīng)的小塊大小作為額外信息輸入,主要是
的加入,實(shí)驗(yàn)表明是有效果的。
先把目標(biāo)圖像的cell大小,乘上特征對(duì)應(yīng)的shape大小,變?yōu)橄鄬?duì)大小rel_cell。這部分就是公式(5)中的。
最終,把這些特征傳入MLP(即self.imnet)進(jìn)行計(jì)算,獲得預(yù)測(cè)的值。