Pytorch实现warping操作
warping思路
如果现在有2张图片,记作为A和B,同时,我们已经得到了这两帧之间的光流,我们可以将图片B和光流进行一个warping操作,从而得回到图片A,那么这个功能在Pytorch里面怎么实现呢?
实现的思路到不是很难,如下:
- 首先,我们需要分别定义水平方向x和竖直方向y的meshgrid(size为
H,W,2
),代表第二帧每个像素的位置; - 然后,我们将meshgrid加上光流,记作
vgrid
,代表第二帧每个像素对应第一帧上的那个位置(这个位置可能是小数、甚至超出图像边界); - 最后,按照第一帧的像素值进行插值即可。
最后插值的步骤Pytorch已经为我们实现好了,函数名为grid_sample()
,所以,其实我们要做的就是定义出vgrid
来。grid_sample()
的参数官方文档解释的很清楚,可能有点迷惑的就是grid
这个参数了。
这个参数的范围是[-1,1],-1代表图片的左边和上边,1则代表图片的右边和下边,所以,我们一定要对vgrid
进行归一化,方法为:
1 | # scale grid to [-1,1] |
此外,还特别需要注意,输入的tensor(也就是第一帧)的范围必须是[0,1],不然输出的结果会很糟糕,类似>>这种情况<<。至于padding_mode
参数则是出界时候怎么处理,官方文档已经说的很明白了。
warping的实现
1 | def optical_flow_warping(x, flo, pad_mode="zeros"): |
affine_grid的使用
经常和grid_sample()
搭配使用的还有affine_grid()
函数,这个函数的作用就是根据几何变换(2D or 3D仿射变换)生成一个grid,这个grid可以输入给grid_sample()
使用,在spatial transfomer network中很常用。
1 | def affine_grid_test(): |