-
Notifications
You must be signed in to change notification settings - Fork 1
/
submodule.py
107 lines (80 loc) · 3.21 KB
/
submodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import torch.nn as nn
import torch.nn.functional as F
def convbn(c1, c2, k, s, p, d):
"""conv + bn"""
return nn.Sequential(nn.Conv2d(c1, c2, k, s, padding=d if d > 1 else p, dilation=d, bias=False),
nn.BatchNorm2d(c2))
def convbn_3d(c1, c2, k, s, p):
"""conv3d + bn"""
return nn.Sequential(nn.Conv3d(c1, c2, k, s, padding=p, bias=False),
nn.BatchNorm3d(c2))
class BasicBlock(nn.Module):
"""ResNet BasicBlock"""
expansion = 1
def __init__(self, c1, c2, s, downsample, p, d):
super(BasicBlock, self).__init__()
self.conv1 = nn.Sequential(convbn(c1, c2, 3, s, p, d), nn.ReLU(inplace=True))
self.conv2 = convbn(c2, c2, 3, 1, p, d)
self.downsample = downsample
self.stride = s
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
if self.downsample is not None:
x = self.downsample(x)
out += x
return out
def disparity_regression(x, maxdisp):
"""disparity regression"""
assert len(x.shape) == 4
disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device)
disp_values = disp_values.view(1, maxdisp, 1, 1)
return torch.sum(x * disp_values, 1, keepdim=False)
def build_image_pyramid(img, k):
image_pyramid = [F.interpolate(img,scale_factor=(1./(2**i)),mode='bilinear')
for i in range(k)]
return list(reversed(image_pyramid))
def build_concat_volume(left_ft, right_ft, maxdisp):
B, C, H, W = left_ft.shape
volume = left_ft.new_zeros([B, 2 * C, maxdisp, H, W])
for i in range(maxdisp):
if i > 0:
volume[:, :C, i, :, i:] = left_ft[:, :, :, i:]
volume[:, C:, i, :, i:] = right_ft[:, :, :, :-i]
else:
volume[:, :C, i, :, :] = left_ft
volume[:, C:, i, :, :] = right_ft
volume = volume.contiguous()
return volume
def build_minus_volume(left_ft, right_ft, maxdisp):
B, C, H, W = left_ft.shape
volume = left_ft.new_zeros([B, C, maxdisp, H, W])
for i in range(maxdisp):
if i > 0:
volume[:, :, i, :, i:] = left_ft[:, :, :, i:] - right_ft[:, :, :, :-i]
else:
volume[:, :, i, :, :] = left_ft - right_ft
volume = volume.contiguous()
return volume
def groupwise_correlation(fea1, fea2, num_groups):
B, C, H, W = fea1.shape
assert C % num_groups == 0
channels_per_group = C // num_groups
cost = (fea1 * fea2).view([B, num_groups, channels_per_group, H, W]).mean(dim=2)
assert cost.shape == (B, num_groups, H, W)
return cost
def build_gwc_volume(left_ft, right_ft, maxdisp, num_groups):
B, C, H, W = left_ft.shape
volume = left_ft.new_zeros([B, num_groups, maxdisp, H, W])
for i in range(maxdisp):
if i > 0:
volume[:, :, i, :, i:] = groupwise_correlation(left_ft[:, :, :, i:],
right_ft[:, :, :, :-i],
num_groups)
else:
volume[:, :, i, :, :] = groupwise_correlation(left_ft, right_ft, num_groups)
volume = volume.contiguous()
return volume
if __name__ == '__main__':
print(__file__)