Pytorch实现warping操作

Posted by 刘知安 on 2020-08-16
文章目录
  1. Pytorch实现warping操作
    1. warping思路
    2. warping的实现
    3. affine_grid的使用

Pytorch实现warping操作

warping思路

如果现在有2张图片,记作为A和B,同时,我们已经得到了这两帧之间的光流,我们可以将图片B和光流进行一个warping操作,从而得回到图片A,那么这个功能在Pytorch里面怎么实现呢?

实现的思路到不是很难,如下:

  • 首先,我们需要分别定义水平方向x和竖直方向y的meshgrid(size为H,W,2),代表第二帧每个像素的位置;
  • 然后,我们将meshgrid加上光流,记作vgrid,代表第二帧每个像素对应第一帧上的那个位置(这个位置可能是小数、甚至超出图像边界);
  • 最后,按照第一帧的像素值进行插值即可。

最后插值的步骤Pytorch已经为我们实现好了,函数名为grid_sample(),所以,其实我们要做的就是定义出vgrid来。grid_sample()的参数官方文档解释的很清楚,可能有点迷惑的就是grid这个参数了。

这个参数的范围是[-1,1],-1代表图片的左边和上边,1则代表图片的右边和下边,所以,我们一定要对vgrid进行归一化,方法为:

1
2
3
# scale grid to [-1,1]
vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0

此外,还特别需要注意,输入的tensor(也就是第一帧)的范围必须是[0,1],不然输出的结果会很糟糕,类似>>这种情况<<。至于padding_mode参数则是出界时候怎么处理,官方文档已经说的很明白了。

warping的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def optical_flow_warping(x, flo, pad_mode="zeros"):
"""
warp an image/tensor (im2) back to im1, according to the optical flow

x: [B, C, H, W] (im2)
flo: [B, 2, H, W] flow
pad_mode (optional): ref to https://pytorch.org/docs/stable/nn.functional.html#grid-sample
"zeros": use 0 for out-of-bound grid locations,
"border": use border values for out-of-bound grid locations
"""
B, C, H, W = x.size()
# mesh grid
xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
grid = torch.cat((xx, yy), 1).float()

vgrid = grid + flo # warp后,新图每个像素对应原图的位置

# scale grid to [-1,1]
vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0

vgrid = vgrid.permute(0, 2, 3, 1)
output = F.grid_sample(x, vgrid, padding_mode=pad_mode)

mask = torch.ones(x.size())
mask = F.grid_sample(mask, vgrid)

mask[mask < 0.9999] = 0
mask[mask > 0] = 1

return output * mask

>>测试结果<<

affine_grid的使用

经常和grid_sample()搭配使用的还有affine_grid()函数,这个函数的作用就是根据几何变换(2D or 3D仿射变换)生成一个grid,这个grid可以输入给grid_sample()使用,在spatial transfomer network中很常用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def affine_grid_test():
img_path = "/Users/liuzhian/PycharmProjects/VUNet/lena.jpg"
trans = transforms.Compose([
transforms.Resize(size=200),
transforms.ToTensor(), ])
img = Image.open(img_path)
img_torch = trans(img)
print(img_torch.size())
# 向右平移(0.2/2)*width,向下平移动(0.4/2)*height
theta = torch.tensor([
[1, 0, -0.2],
[0, 1, -0.4]
], dtype=torch.float)
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size())

print(grid.size())
print(grid[0, :, :, 0] * 200)

output = F.grid_sample(img_torch.unsqueeze(0), grid)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1, 2, 0))
plt.show()