From 9a74b614370efcf2a39cf400e5a06c22b56e4afe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 29 Nov 2022 09:07:43 +0100 Subject: [PATCH 1/4] Add with_bias arg --- stable_baselines3/common/torch_layers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index f87337c62..c8620fb72 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -98,6 +98,7 @@ def create_mlp( output_dim: int, net_arch: List[int], activation_fn: Type[nn.Module] = nn.ReLU, + with_bias: bool = True, squash_output: bool = False, ) -> List[nn.Module]: """ @@ -111,23 +112,24 @@ def create_mlp( The length of this list is the number of layers. :param activation_fn: The activation function to use after each layer. + :param with_bias: If set to False, the layers will not learn an additive bias :param squash_output: Whether to squash the output using a Tanh activation function :return: """ if len(net_arch) > 0: - modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()] + modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()] else: modules = [] for idx in range(len(net_arch) - 1): - modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1])) + modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias)) modules.append(activation_fn()) if output_dim > 0: last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim - modules.append(nn.Linear(last_layer_dim, output_dim)) + modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias)) if squash_output: modules.append(nn.Tanh()) return modules From 8e6c2afba66bda0fce2fba0e88e51e05108766e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 29 Nov 2022 09:07:50 +0100 Subject: [PATCH 2/4] Update changelog --- docs/misc/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b4355c71d..27dc82cfb 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,6 +17,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Introduced mypy type checking +- Added ``with_bias`` argument to ``create_mlp`` SB3-Contrib ^^^^^^^^^^^ From a60cd1d75093b899f8d6bf6e0a3948fbb9e7eadf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 29 Nov 2022 12:13:59 +0100 Subject: [PATCH 3/4] move torch_layers to the last position --- stable_baselines3/common/torch_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index c8620fb72..2ce0cc101 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -98,8 +98,8 @@ def create_mlp( output_dim: int, net_arch: List[int], activation_fn: Type[nn.Module] = nn.ReLU, - with_bias: bool = True, squash_output: bool = False, + with_bias: bool = True, ) -> List[nn.Module]: """ Create a multi layer perceptron (MLP), which is @@ -112,9 +112,9 @@ def create_mlp( The length of this list is the number of layers. :param activation_fn: The activation function to use after each layer. - :param with_bias: If set to False, the layers will not learn an additive bias :param squash_output: Whether to squash the output using a Tanh activation function + :param with_bias: If set to False, the layers will not learn an additive bias :return: """ From f0a305d57caf8c9b317b520ba66f5817ec8d6dc7 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 29 Nov 2022 12:19:46 +0100 Subject: [PATCH 4/4] Update version --- docs/misc/changelog.rst | 2 +- stable_baselines3/version.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 71de963cb..99a35226a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.7.0a3 (WIP) +Release 1.7.0a4 (WIP) -------------------------- Breaking Changes: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 08b6b37ca..0952a4bef 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.7.0a3 +1.7.0a4