Learning Continuous Image Representation with Local Implicit Image Function
- abstract
- Local Implicit Image Function
- Feature unfolding
- Local ensemble
- Cell decoding
- LIIF class 完全代码
abstract
物理世界以连续的方式呈现视觉图像,但计算机以离散2D像素数组的方式存储和显示图像。此文学习图像的连续表示,使用局部隐式图像函数(Local Implicit Image Function,LIIF)将图像坐标和坐标周围的2D深度特征作为输入,预测输出给定坐标下的RGB值。通过自监督超分辨率任务来训练一个编码器和LIIF表示来生成像素图像的连续表示,可以做到任意倍数的分辨率,甚至可以推算不在训练任务中的30倍以上超分。通过将图像模型化为一个在连续域中的函数,可以恢复和生成任意分辨率的图像。隐式函数的思想是将一个对象表示为一个函数,将坐标映射到相应的信号(如3D对象表面的符号距离,图像中的RGB值)。神经隐式函数采用深度神经网络参数化。为了跨实例共享知识,而不是为每个对象拟合单独的隐式函数,提出了基于编码器的方法来预测每个对象的潜在编码。然后隐式函数由所有对象共享,同时它将潜在代码作为额外的输入。
Local Implicit Image Function
在LIIF表示中,每个连续图像由二维特征映射
表示。 一个神经隐式函数
(以
为其参数)被所有图像共享,它被参数化为
并采取
(简便省略
)形式,其中
是一个向量,
是连续图像域中的二维坐标,
是预测信号(即RGB值)。
对于定义的,每个向量
都可以看作是表示函数
。
可以看作是一个连续的图像,即映射坐标到RGB值的函数。假设
的
特征向量(称为隐码latent codes)均匀分布在
的连续图像域的2D空间中,并为它们中的每一个分配一个2D坐标。
对于图像,坐标
处的RGB值定义为
,其中
是
中与
最近的(欧几里德距离)隐码,
是图像域中潜码
的坐标。 例如
是当前定义中
的
,而
被定义为
的坐标。
在所有图像共享的隐式函数下,连续图像由二维特征映射
表示,该特征映射被看作是在2D域中均匀分布的
隐码。 在
中的每个潜在码
表示连续图像的局部部分,负责预测与它最近的坐标集的信号。
从图像得到归一化坐标值和RGB值
def make_coord(shape, ranges=None, flatten=True):
""" Make coordinates at grid centers.
"""
coord_seqs = []
for i, n in enumerate(shape):
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
coord = make_coord((h, w)) #h,w为SR目标的高宽
def to_pixel_samples(img):
""" Convert the image to coord-RGB pairs.
img: Tensor, (3, H, W)
"""
coord = make_coord(img.shape[-2:]) #(h*w,2)--(h*w,[x,y])
rgb = img.view(3, -1).permute(1, 0) #(h*w,3)--(h*w,[R,G,B])
return coord, rgb
Feature unfolding
为了丰富隐码包含的信息,对特征展开得到
。
是在
中
相邻隐码的合并。
当指的是一组向量的连接时,
在其边界外被零向量填充。
的
被以下
后变为
。
feat = F.unfold(feat, 3, padding=1).view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
Local ensemble
是一个不连续预测,由于
的信号预测是通过查询
中最近的隐码
完成的,所以当
在图像域中移动时,
会突然从一个隐码切换到另一个隐码。在
选择切换的那些坐标周围,两个无限接近坐标的信号将从不同的隐码中预测出来,只要学习的隐式函数
不是完美的,在
选择切换的边界处没出现不连续的图形。为了解决这个问题,使用局部集成技术,扩大每个隐码的表示
指左上、右上,左下,右下子空间中最近的隐码,
指
的坐标,
是
和
(
是
的对角,如00对11,10对01)之间的矩形面积。权重由
归一化。特征图
在边界外是镜像填充的,因此这也适用于边界附近的坐标。
这是为了让由隐码表示的局部图像块与其相邻块重叠,使得在每个坐标处有四个隐码用于独立预测信号。然后,这四个预测通过用归一化置信度投票来合并,归一化置信度与查询点和其最近的隐码对角对应点之间的矩形面积成比例,因此当查询坐标更近时,置信度变得更高。通过这种投票,它在转换坐标(即图中的虚线)处实现了连续过渡。
vx_lst = [-1, 1]
vy_lst = [-1, 1]
eps_shift = 1e-6
rx = 2 / feat.shape[-2] / 2 #2/H/2
ry = 2 / feat.shape[-1] / 2 #2/W/2
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() #[LR_H,LR_W,2]
feat_coord = feat_coord.permute(2, 0, 1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])#[N,2,LR_H,LR_W]
preds = []
areas = []
for vx in vx_lst:
for vy in vy_lst:
coord_ = coord.clone()#[N,SR_H*SR_W,2]
coord_[:, :, 0] += vx * rx + eps_shift
coord_[:, :, 1] += vy * ry + eps_shift
coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
q_feat = F.grid_sample(feat, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,C*9,1,SR_H*SR_W]
q_feat = q_feat[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,C*9]
q_coord = F.grid_sample(feat_coord, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,2,1,SR_H*SR_W]
q_coord = q_coord[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,2]
rel_coord = coord - q_coord #[N,SR_H*SR_W,2]
rel_coord[:, :, 0] *= feat.shape[-2]
rel_coord[:, :, 1] *= feat.shape[-1]
inp = torch.cat([q_feat, rel_coord], dim=-1) #[N,SR_H*SR_W,C*9+2]
if self.cell_decode:
rel_cell = cell.clone()
rel_cell[:, :, 0] *= feat.shape[-2]
rel_cell[:, :, 1] *= feat.shape[-1]
inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]
bs, q = coord.shape[:2] #bs=N q=SR_H*SR_W
#[N*SR_H*SR_W,C*9+2+2] --> [N*SR_H*SR_W,3]
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) #[N,SR_H*SR_W,3]
preds.append(pred) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]
area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
areas.append(area + 1e-9) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]
tot_area = torch.stack(areas).sum(dim=0) #[N,SR_H*SR_W]
if self.local_ensemble:
t = areas[0]; areas[0] = areas[3]; areas[3] = t #swap(areas[0],areas[3])
t = areas[1]; areas[1] = areas[2]; areas[2] = t #swap(areas[1],areas[2])
Cell decoding
为了LIIF能够表示基于像素形式的任意分辨率呈现,假设给定了所需分辨率,一种简单方法是查询连续表示中像素中心坐标处的RGB值,但因为查询像素的预测RGB值与其大小无关,其像素区域中的信息除了中心值都被丢弃,可能不是最佳的。
包含指定查询像素的高度和宽度两个值,
是值
和
的连接(concatenation),
是附加输入。
能理解为使用形状
渲染以坐标
为中心的像素的RGB值。对于
的分辨率,
是图像宽度的
。逻辑上,当
时,
,即连续图像可以看作像素无限小的图像。
cell = torch.ones_like(coord) #[SR_H*SR_W,2] [1*2/SR_H,1*2/SR_W]
cell[:, 0] *= 2 / h
cell[:, 1] *= 2 / w
if self.cell_decode:
rel_cell = cell.clone()
rel_cell[:, :, 0] *= feat.shape[-2]
rel_cell[:, :, 1] *= feat.shape[-1]
inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]
LIIF class 完全代码
class LIIF(nn.Module):
def __init__(self, encoder_spec, imnet_spec=None,
local_ensemble=True, feat_unfold=True, cell_decode=True):
super().__init__()
self.local_ensemble = local_ensemble
self.feat_unfold = feat_unfold
self.cell_decode = cell_decode
self.encoder = models.make(encoder_spec)
#print("self.encoder.out_dim",self.encoder.out_dim)
if imnet_spec is not None:
imnet_in_dim = self.encoder.out_dim #64
if self.feat_unfold:
imnet_in_dim *= 9
imnet_in_dim += 2 # attach coord 指定查询像素的坐标 [x,y]
if self.cell_decode:
imnet_in_dim += 2 #[Cell_h, Cell_w]指定查询像素的高度和宽度的两个值
self.imnet = models.make(imnet_spec, args={'in_dim': imnet_in_dim})
else:
self.imnet = None
def gen_feat(self, inp):
self.feat = self.encoder(inp)
return self.feat
def query_rgb(self, coord, cell=None):
#coord [N,SR_H*SR_*W,2]
#cell [N,SR_H*SR_*W,2]
feat = self.feat #[N,C,LR_H,LR_W]
if self.imnet is None:
ret = F.grid_sample(feat, coord.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)
ret = ret[:, :, 0, :].permute(0, 2, 1)
return ret
if self.feat_unfold:
# [N,C*3*3,H,W]
feat = F.unfold(feat, 3, padding=1).view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
if self.local_ensemble:
vx_lst = [-1, 1]
vy_lst = [-1, 1]
eps_shift = 1e-6
else:
vx_lst, vy_lst, eps_shift = [0], [0], 0
# field radius (global: [-1, 1])
rx = 2 / feat.shape[-2] / 2 #2/H/2
ry = 2 / feat.shape[-1] / 2 #2/W/2
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() #[LR_H,LR_W,2]
feat_coord = feat_coord.permute(2, 0, 1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])#[N,2,LR_H,LR_W]
preds = []
areas = []
for vx in vx_lst:
for vy in vy_lst:
coord_ = coord.clone()#[N,SR_H*SR_W,2]
coord_[:, :, 0] += vx * rx + eps_shift
coord_[:, :, 1] += vy * ry + eps_shift
coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
q_feat = F.grid_sample(feat, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,C*9,1,SR_H*SR_W]
q_feat = q_feat[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,C*9]
q_coord = F.grid_sample(feat_coord, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,2,1,SR_H*SR_W]
q_coord = q_coord[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,2]
rel_coord = coord - q_coord #[N,SR_H*SR_W,2]
rel_coord[:, :, 0] *= feat.shape[-2]
rel_coord[:, :, 1] *= feat.shape[-1]
inp = torch.cat([q_feat, rel_coord], dim=-1) #[N,SR_H*SR_W,C*9+2]
if self.cell_decode:
rel_cell = cell.clone()
rel_cell[:, :, 0] *= feat.shape[-2]
rel_cell[:, :, 1] *= feat.shape[-1]
inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]
bs, q = coord.shape[:2] #bs=N q=SR_H*SR_W
#[N*SR_H*SR_W,C*9+2+2] --> [N*SR_H*SR_W,3]
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) #[N,SR_H*SR_W,3]
preds.append(pred) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]
area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
areas.append(area + 1e-9) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]
tot_area = torch.stack(areas).sum(dim=0) #[N,SR_H*SR_W]
if self.local_ensemble:
t = areas[0]; areas[0] = areas[3]; areas[3] = t #swap(areas[0],areas[3])
t = areas[1]; areas[1] = areas[2]; areas[2] = t #swap(areas[1],areas[2])
ret = 0
for pred, area in zip(preds, areas):
ret = ret + pred * (area / tot_area).unsqueeze(-1)
return ret
def forward(self, inp, coord, cell):
self.gen_feat(inp)
return self.query_rgb(coord, cell)