-
Notifications
You must be signed in to change notification settings - Fork 5
/
diffusion.py
196 lines (166 loc) · 6.42 KB
/
diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import reduce
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from policy.diffusion_modules.conditional_unet1d import ConditionalUnet1D
from policy.diffusion_modules.mask_generator import LowdimMaskGenerator
class DiffusionUNetPolicy(nn.Module):
def __init__(self,
action_dim,
horizon,
n_obs_steps,
obs_feature_dim,
num_inference_steps=20,
diffusion_step_embed_dim=256,
down_dims=(256,512),
kernel_size=5,
n_groups=8,
cond_predict_scale=True,
# parameters passed to step
**kwargs):
super().__init__()
# create diffusion model
input_dim = action_dim
global_cond_dim = obs_feature_dim * n_obs_steps
self.model = ConditionalUnet1D(
input_dim=input_dim,
local_cond_dim=None,
global_cond_dim=global_cond_dim,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale
)
# create noise scheduler
self.noise_scheduler = DDIMScheduler(
num_train_timesteps=100,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
set_alpha_to_one=True,
steps_offset=0,
prediction_type="epsilon"
)
self.mask_generator = LowdimMaskGenerator(
action_dim=action_dim,
obs_dim=0,
max_n_obs_steps=n_obs_steps,
fix_obs_steps=True,
action_visible=False
)
self.horizon = horizon
self.obs_feature_dim = obs_feature_dim
self.action_dim = action_dim
self.n_obs_steps = n_obs_steps
self.kwargs = kwargs
if num_inference_steps is None:
num_inference_steps = self.noise_scheduler.config.num_train_timesteps
self.num_inference_steps = num_inference_steps
# ========= inference ============
def conditional_sample(self,
condition_data, condition_mask,
local_cond=None, global_cond=None,
generator=None,
# keyword arguments to scheduler.step
**kwargs
):
model = self.model
scheduler = self.noise_scheduler
trajectory = torch.randn(
size=condition_data.shape,
dtype=condition_data.dtype,
device=condition_data.device,
generator=generator)
# set step values
scheduler.set_timesteps(self.num_inference_steps)
for t in scheduler.timesteps:
# 1. apply conditioning
trajectory[condition_mask] = condition_data[condition_mask]
# 2. predict model output
model_output = model(trajectory, t,
local_cond=local_cond, global_cond=global_cond)
# 3. compute previous image: x_t -> x_t-1
trajectory = scheduler.step(
model_output, t, trajectory,
generator=generator,
**kwargs
).prev_sample
# finally make sure conditioning is enforced
trajectory[condition_mask] = condition_data[condition_mask]
return trajectory
def predict_action(self, readout) -> Dict[str, torch.Tensor]:
B = readout.shape[0]
T = self.horizon
Da = self.action_dim
Do = self.obs_feature_dim
To = self.n_obs_steps
# build input
device = readout.device
dtype = readout.dtype
obs_features = readout
assert obs_features.shape[0] == B * To
# condition through global feature
local_cond = None
global_cond = None
# reshape back to B, Do
global_cond = obs_features.reshape(B, -1)
# empty data for action
cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
# run sampling
sample = self.conditional_sample(
cond_data,
cond_mask,
local_cond=local_cond,
global_cond=global_cond,
**self.kwargs)
action_pred = sample[...,:Da]
return action_pred
# ========= training ============
def compute_loss(self, readout, actions):
batch_size = readout.shape[0]
# handle different ways of passing observation
local_cond = None
global_cond = None
trajectory = actions
cond_data = trajectory
assert readout.shape[0] == batch_size * self.n_obs_steps
# reshape back to B, Do
global_cond = readout.reshape(batch_size, -1) # (B, T*C)
# generate impainting mask
condition_mask = self.mask_generator(trajectory.shape)
# Sample noise that we'll add to the images
noise = torch.randn(trajectory.shape, device=trajectory.device)
bsz = trajectory.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps,
(bsz,), device=trajectory.device
).long()
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_trajectory = self.noise_scheduler.add_noise(
trajectory, noise, timesteps)
# compute loss mask
loss_mask = ~condition_mask
# apply conditioning
noisy_trajectory[condition_mask] = cond_data[condition_mask]
# Predict the noise residual
pred = self.model(noisy_trajectory, timesteps,
local_cond=local_cond, global_cond=global_cond)
pred_type = self.noise_scheduler.config.prediction_type
if pred_type == 'epsilon':
target = noise
elif pred_type == 'sample':
target = trajectory
else:
raise ValueError(f"Unsupported prediction type {pred_type}")
loss = F.mse_loss(pred, target, reduction='none')
loss = loss * loss_mask.type(loss.dtype)
loss = reduce(loss, 'b ... -> b (...)', 'mean')
loss = loss.mean()
return loss