Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sb3_contrib/ars/policies.py type hint #122

Merged
merged 17 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
Changelog
==========

Release 1.7.0a1 (WIP)
Release 1.7.0a4 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
please use an ``EvalCallback`` instead
- Removed deprecated ``sde_net_arch`` parameter
- Upgraded to Stable-Baselines3 >= 1.7.0a4

New Features:
^^^^^^^^^^^^^
- Introduced mypy type checking
- Added ``with_bias`` parameter to ``ARSPolicy``

Bug Fixes:
^^^^^^^^^^
Expand All @@ -29,6 +31,7 @@ Others:
- Fixed flake8 config
- Fixed ``sb3_contrib/common/utils.py`` type hint
- Fixed ``sb3_contrib/common/recurrent/type_aliases.py`` type hint
- Fixed ``sb3_contrib/ars/policies.py`` type hint

Release 1.6.2 (2022-10-10)
--------------------------
Expand Down
20 changes: 7 additions & 13 deletions sb3_contrib/ars/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ARSPolicy(BasePolicy):
:param action_space: The action space of the environment
:param net_arch: Network architecture, defaults to a 2 layers MLP with 64 hidden nodes.
:param activation_fn: Activation function
:param with_bias: If set to False, the layers will not learn an additive bias
:param squash_output: For continuous actions, whether the output is squashed
or not using a ``tanh()`` function. If not squashed with tanh the output will instead be clipped.
"""
Expand All @@ -26,6 +27,7 @@ def __init__(
action_space: gym.spaces.Space,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
with_bias: bool = True,
squash_output: bool = True,
):

Expand All @@ -45,9 +47,11 @@ def __init__(

if isinstance(action_space, gym.spaces.Box):
action_dim = get_action_dim(action_space)
actor_net = create_mlp(self.features_dim, action_dim, net_arch, activation_fn, squash_output=True)
actor_net = create_mlp(
self.features_dim, action_dim, net_arch, activation_fn, with_bias=with_bias, squash_output=squash_output
)
elif isinstance(action_space, gym.spaces.Discrete):
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn)
actor_net = create_mlp(self.features_dim, action_space.n, net_arch, activation_fn, with_bias=with_bias)
else:
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")

Expand Down Expand Up @@ -98,17 +102,7 @@ def __init__(
squash_output: bool = False,
):

super().__init__(observation_space, action_space, squash_output=squash_output)

if isinstance(action_space, gym.spaces.Box):
action_dim = get_action_dim(action_space)
self.action_net = nn.Linear(self.features_dim, action_dim, bias=with_bias)
if squash_output:
self.action_net = nn.Sequential(self.action_net, nn.Tanh())
elif isinstance(action_space, gym.spaces.Discrete):
self.action_net = nn.Linear(self.features_dim, action_space.n, bias=with_bias)
else:
araffin marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(f"Error: ARS policy not implemented for action space of type {type(action_space)}.")
super().__init__(observation_space, action_space, net_arch=[], with_bias=with_bias, squash_output=squash_output)


MlpPolicy = ARSPolicy
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.7.0a1
1.7.0a4
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ show_error_codes = True
exclude = (?x)(
sb3_contrib/ppo_recurrent/ppo_recurrent.py$
| sb3_contrib/ars/ars.py$
| sb3_contrib/ars/policies.py$
| sb3_contrib/qrdqn/qrdqn.py$
| sb3_contrib/qrdqn/policies.py$
| sb3_contrib/common/recurrent/policies.py$
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=1.7.0a0",
"stable_baselines3>=1.7.0a4",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down