Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dreamerクラスの実装 #47

Merged
merged 31 commits into from
Jan 29, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8c54b1d
Merge branch 'feature/#36array-voc-state-env-wrapper' into feature/#3…
Geson-anko Jan 24, 2023
5764a83
ADD registering names for replay buffer
Geson-anko Jan 24, 2023
2faf49f
ADD Dreamer (wip)
Geson-anko Jan 24, 2023
cae808f
ADD trainer (wip)
Geson-anko Jan 24, 2023
63fd8c9
ADD eval and train
Geson-anko Jan 24, 2023
9894368
とりあえず動作確認はなしで作るだけつくった
Geson-anko Jan 24, 2023
efee645
ADD test_dreamer.py
cehl-kurage Jan 26, 2023
098a6b8
merge main
cehl-kurage Jan 26, 2023
8ff4ed4
ADD dummy layer to dummy classes
cehl-kurage Jan 26, 2023
9de156c
ADD test__init__, test_configure_optimizers, test_collect_experiences
cehl-kurage Jan 27, 2023
82e1e7a
ADD import
cehl-kurage Jan 27, 2023
6391085
Fix as_tensor
cehl-kurage Jan 27, 2023
c8c25cf
ADD test_world_training_step
cehl-kurage Jan 27, 2023
1dd2261
Removed print() for debug
cehl-kurage Jan 27, 2023
961b1e0
Fix as_tensor arguments
cehl-kurage Jan 27, 2023
e2aef99
Fix test_world_training_step
cehl-kurage Jan 27, 2023
bb49717
Fix loss computation
cehl-kurage Jan 27, 2023
9cdca39
Fix the order of wrapping env
cehl-kurage Jan 27, 2023
1440aef
Fix optimizer and initializing instances
cehl-kurage Jan 28, 2023
85b596a
Fix device, dtype
cehl-kurage Jan 28, 2023
b338372
ADD test_controller_training_step
cehl-kurage Jan 28, 2023
35100dd
Fix device, dtype
cehl-kurage Jan 28, 2023
b6814b4
ADD configure_replay_buffer
cehl-kurage Jan 28, 2023
a81d124
ADD test_evaluation_step
cehl-kurage Jan 28, 2023
c85a75c
ADD test_configure_replay_buffer
cehl-kurage Jan 28, 2023
1c97d61
Fix configure_replay_buffer
cehl-kurage Jan 28, 2023
6af48f3
pre-commit
cehl-kurage Jan 28, 2023
83f26f4
ADD docstring for configure_replay_buffer, __init__ and removed some …
cehl-kurage Jan 28, 2023
c4a32f9
ADD docstring for configure_replay_buffer, __init__ and removed some …
cehl-kurage Jan 28, 2023
722cef9
Fix attributes of Dreamer and type annotation in World
cehl-kurage Jan 29, 2023
050a78e
marge remote
cehl-kurage Jan 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/datamodules/buffer_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
ACTION = "action"
VOC_STATE = "voc_state"
GENERATED_SOUND = "generated_sound"
TARGET_SOUND = "target_sound"
DONE = "done"
16 changes: 15 additions & 1 deletion src/models/abc/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(
next_obs: _tensor_or_any,
*args: Any,
**kwds: Any,
) -> tuple[_dist_or_any, _dist_or_any, _tensor_or_any]:
) -> tuple[_dist_or_any, _dist_or_any, _tensor_or_any, _tensor_or_any]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ここの変更は何によって生まれたのでしょうか?

Suggested change
) -> tuple[_dist_or_any, _dist_or_any, _tensor_or_any, _tensor_or_any]:
) -> tuple[_dist_or_any, _dist_or_any, _tensor_or_any]:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

記憶にない変更です...(??)
もとに戻しておきます

"""Make world model transition.

Args:
Expand All @@ -65,3 +65,17 @@ def forward(
next_state_posterior = self.obs_encoder.forward(next_hidden, next_obs)

return next_state_prior, next_state_posterior, next_hidden

def eval(self):
"""Set models to evaluation mode."""
self.transition.eval()
self.prior.eval()
self.obs_encoder.eval()
self.obs_decoder.eval()

def train(self):
"""Set models to training mode."""
self.transition.train()
self.prior.train()
self.obs_encoder.train()
self.obs_decoder.train()
Loading