-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix sb3_contrib/ars/policies.py
type hint
#122
Conversation
and to be released too |
One thing that need to be checked before merging: does it break pre-trained models? (we might have to do some manual renaming in case it does) |
All |
Probably because
to
|
from sb3_contrib import ARS
from torch import nn
FOLDER = "rl-trained-agents"
env_id = "CartPole-v1"
for model_name in ["best_model", env_id]:
model = ARS.load(f"{FOLDER}/ars/{env_id}_1/{model_name}.zip")
model.policy.action_net = nn.Sequential(model.policy.action_net)
model.save(f"{FOLDER}/ars/{env_id}_1/{model_name}.zip") It seems to have worked. Does that look right to you? |
For HF, I think this should also work, but I'm not familiar enough with HF Hub to be sure: from huggingface_sb3 import load_from_hub, push_to_hub
from torch import nn
from sb3_contrib import ARS
checkpoint = load_from_hub(
repo_id="sb3/ars-CartPole-v1",
filename="ars-CartPole-v1.zip",
)
model = ARS.load(checkpoint)
model.policy.action_net = nn.Sequential(model.policy.action_net)
push_to_hub(
repo_id="sb3/ars-CartPole-v1",
filename="ars-CartPole-v1.zip",
commit_message="Update action_net structure",
) |
I fixed that in my latest commit ;) (the name of the commit is wrong though ^^ read "load" instead of "save") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, please check what I did before merging ;)
Looks good, I just turned the docstring of |
Description
Requires DLR-RM/stable-baselines3#1188 to be merged
Context
Types of changes
Checklist:
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)Note: we are using a maximum length of 127 characters per line