Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Can a model be used in environments with different observation_space sizes? #2031

Open
4 tasks done
SummerDiver opened this issue Nov 1, 2024 · 4 comments
Open
4 tasks done
Labels
custom gym env Issue related to Custom Gym Env question Further information is requested

Comments

@SummerDiver
Copy link

❓ Question

I am trying to use stablebaselines3 to handle a graph-related problem. Graphs of different sizes will have different numbers of nodes and edges, resulting in different observation space sizes of the environment defined based on this.

My goal is to train an agent using environment A and then use it in environment B. I have customized a feature extractor using a graph neural network and it should be able to handle inputs of graphs of uncertain sizes.

However, when I try to input the observation generated by environment B into the model, an error occurs:

    env_small = GymMISEnv(data_folder_path) # with observation shape (148640, 2)
    env_big = GymMISEnv(test_data_folder_path) # with observation shape (250151, 2)
    policy_kwargs = {
        "features_extractor_class": GNNFeatureExtractor,
        "features_extractor_kwargs": {"features_dim": 64},
        "net_arch": dict(pi=[128, 64], vf=[128, 64]),
    }
    model = PPO("MultiInputPolicy", env_small, policy_kwargs=policy_kwargs, verbose=1) # env_small to train
    obs, _ = env_big.reset() # switch to env_big
    model.predict(obs)  # Error: Unexpected observation shape (250151, 2) for Box environment, please use (148640, 2) or (n_env, 148640, 2) for the observation shape.

Is there a way to handle this problem?

Checklist

@SummerDiver SummerDiver added the question Further information is requested label Nov 1, 2024
@araffin
Copy link
Member

araffin commented Nov 1, 2024

Hello,
there is no "out of the box" solution for your problem, you will probably need to add an adapter layer to have the same size at the end (and also the model might complain, we have some checks at load time).

@araffin araffin added the custom gym env Issue related to Custom Gym Env label Nov 1, 2024
@SummerDiver
Copy link
Author

I see, thanks a lot for replying, I'll try to see if I can hold this. Otherwise I'm afraid I have to implement the whole work without sb3 unfortunately. Still, very appreciate for your excellent work!

@SummerDiver
Copy link
Author

SummerDiver commented Nov 8, 2024

Hi there, it comes to me that since the policy network actually has the same structure for both env_small and env_big, does it make any sense if I create model_A using env_small and model_B using env_big, and then;

model_A = PPO("MultiInputPolicy", env_small, policy_kwargs=policy_kwargs, verbose=1) # env_small to train
model_B = PPO("MultiInputPolicy", env_big, policy_kwargs=policy_kwargs, verbose=1) # env_big to test
model_B.policy.load_state_dict(model_A.policy.state_dict()) # transfer model_A's policy settings to model_B
model_B.predict(obs_B) # then I can deal with env_big using model_B with model_A's policy network 

I've tried this and seems it's worked. However, when it comes to saving and loading, I have to:

model_A = PPO.load("model_A.zip")
model_B = PPO("MultiInputPolicy", env_big, policy_kwargs=policy_kwargs, verbose=1) 
model_B.policy.load_state_dict(model_A.policy.state_dict())

It looks a little hacky and I have to create model_A first to get its state_dict even though I don't really need it, so I wonder if I can only save and load model_A's state_dict instead of the whole model? I've found that the old sb seems has this feature(hill-a/stable-baselines#344), does sb3 also support this?

@araffin
Copy link
Member

araffin commented Nov 8, 2024

I've tried this and seems it's worked. However, when it comes to saving and loading, I have to:

looks fine, model.policy is just a pytorch module and can be saved/loaded on its own (https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#id3)

ound that the old sb seems has this feature(hill-a/stable-baselines#344), does sb3 also support this?

https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#accessing-and-modifying-model-parameters

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
custom gym env Issue related to Custom Gym Env question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants