Skip to content

Commit

Permalink
Add leaky relu and flatten to map_torch_types_to_onnx
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Hsieh <[email protected]>
  • Loading branch information
quic-klhsieh authored Apr 10, 2021
1 parent 108b8c5 commit e87c56b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 2 additions & 0 deletions TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
nn.Sigmoid: ['Sigmoid'],
nn.Upsample: ['Upsample'],
nn.PReLU: ['PRelu'],
nn.LeakyReLU: ['LeakyRelu'],
nn.Flatten: ['Flatten'],
elementwise_ops.Add: ['Add'],
elementwise_ops.Subtract: ['Sub'],
elementwise_ops.Multiply: ['Mul'],
Expand Down
9 changes: 5 additions & 4 deletions TrainingExtensions/torch/test/python/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def __init__(self):

self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.maxpool2 = nn.MaxPool2d(2)
self.relu2 = nn.ReLU()
self.relu2 = nn.LeakyReLU()
self.flatten = nn.Flatten()

self.fc1 = nn.Linear(320, 50)
self.relu3 = nn.ReLU()
Expand All @@ -201,7 +202,7 @@ def forward(self, x1, x2):
x2 = self.relu1_b(self.maxpool1_b(self.conv1_b(x2)))
x = x1 + x2
x = self.relu2(self.maxpool2(self.conv2(x)))
x = x.view(-1, 320)
x = self.flatten(x)
x = self.relu3(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
Expand Down Expand Up @@ -711,7 +712,7 @@ def forward_pass(model, args):

activation_encodings = encodings['activation_encodings']
param_encodings = encodings['param_encodings']
self.assertEqual(15, len(activation_encodings))
self.assertEqual(16, len(activation_encodings))
self.assertIn('conv1_a.bias', param_encodings)
self.assertEqual(param_encodings['conv1_a.bias'][0]['bitwidth'], 32)
self.assertEqual(6, len(param_encodings['conv1_a.weight'][0]))
Expand All @@ -722,7 +723,7 @@ def forward_pass(model, args):

activation_encodings = encodings['activation_encodings']
param_encodings = encodings['param_encodings']
self.assertEqual(15, len(activation_encodings))
self.assertEqual(16, len(activation_encodings))
self.assertIn('conv1_a.bias', param_encodings)
self.assertEqual(param_encodings['conv1_a.bias'][0]['bitwidth'], 32)
self.assertEqual(6, len(param_encodings['conv1_a.weight'][0]))
Expand Down

0 comments on commit e87c56b

Please sign in to comment.