本文主要关注潜在有效的,值得炼丹的Loss函数:
TV loss
Total Variation loss
在图像复原过程中,图像上的一点点噪声可能就会对复原的结果产生非常大的影响,因为很多复原算法都会放大噪声。这时候我们就需要在最优化问题的模型中添加一些正则项来保持图像的光滑性,TV loss是常用的一种正则项(注意是正则项,配合其他loss一起使用,约束噪声)。图片中相邻像素值的差异可以通过降低TV loss来一定程度上解决。比如降噪,对抗checkerboard等等。
1. 初始定义
Rudin等人(Rudin1990)观察到,受噪声污染的图像的总变分比无噪图像的总变分明显的大。 那么最小化TV理论上就可以最小化噪声。图片中相邻像素值的差异可以通过降低TV loss来一定程度上解决。比如降噪,对抗checkerboard等等。总变分定义为梯度幅值的积分:
其中,
,
,
是图像的支持域。限制总变分就会限制噪声。
2. 扩展定义
带阶数的TV loss 定义如下:
但是在图像中,连续域的积分就变成了像素离散域中求和,所以可以这么算:
即:求每一个像素和横向下一个像素的差的平方,加上纵向下一个像素的差的平方。然后开β/2次根。
3. 效果
The total variation (TV) loss encourages spatial smoothness in the generated image.(总变差(TV)损失促进了生成的图像中的空间平滑性。)根据论文Nonlinear total variation based noise removal algorithms的描述,当β < 1时,会出现下图左侧的小点点的artifact。当β > 1时,图像中小点点会被消除,但是代价就是图像的清晰度。效果图如下:
4. 代码实现
这两种实现都默认
,不支持
的调整。
4.1 pytorch
import torch
import torch.nn as nn
from torch.autograd import Variable
class TVLoss(nn.Module):
def __init__(self,TVLoss_weight=1):
super(TVLoss,self).__init__()
self.TVLoss_weight = TVLoss_weight
def forward(self,x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self._tensor_size(x[:,:,1:,:])
count_w = self._tensor_size(x[:,:,:,1:])
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
def _tensor_size(self,t):
return t.size()[1]*t.size()[2]*t.size()[3]
def main():
# x = Variable(torch.FloatTensor([[[1,2],[2,3]],[[1,2],[2,3]]]).view(1,2,2,2), requires_grad=True)
# x = Variable(torch.FloatTensor([[[3,1],[4,3]],[[3,1],[4,3]]]).view(1,2,2,2), requires_grad=True)
# x = Variable(torch.FloatTensor([[[1,1,1], [2,2,2],[3,3,3]],[[1,1,1], [2,2,2],[3,3,3]]]).view(1, 2, 3, 3), requires_grad=True)
x = Variable(torch.FloatTensor([[[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]]]).view(1, 2, 3, 3),requires_grad=True)
addition = TVLoss()
z = addition(x)
print x
print z.data
z.backward()
print x.grad
if __name__ == '__main__':
4.2 tensorflow
def tv_loss(X, weight):
with tf.variable_scope('tv_loss'):
return weight * tf.reduce_sum(tf.image.total_variation(X))
4. 参考资料
本章节参考以下资料,作一定的整理,方便他人阅读与研究:
- wiki上关于TVLoss的描述:https://en.wikipedia.org/wiki/Total_variation_denoising,https://en.wikipedia.org/wiki/Total_variation_distance_of_probability_measures
- CSDN博客《Total Variation》
- 视频教程Denoising, deconvolution and computed tomography using total variation penalty
- 实验——基于pytorch的噪声估计网络
- pytorch的TV loss实现