We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
学習用の設定で初期化され、Dreamerクラスを受け取って学習させるTrainerクラスを実装します。
Note 2023/1/27: 途中まで作りかけているtrainerに書き足します。
コンストラクタ__init__では、Trainerの設定に関わる項目を受け取る。
__init__
学習フローの実装メソッドfit(env,model) -> log_metricsの実装
fit(env,model) -> log_metrics
学習対象のDreamerの内部処理に必要な属性(device, dtype, tensorboard logger)を付与するメソッドの実装
Dreamer
checkpointを保存するメソッド及びロードするメソッド。
class Trainer: def __init__(self, setting1: int = 1, setting2:str = "abc", ...): self.__dict__.update(locals()) # これによって引数を全て属性に付与できる def save_checkpoint(model) -> None: """モデルを保存するメソッド""" def load_checkpoint(model) -> None: """パラメータをロードするメソッド""" def set_attributes_to_model(model): """モデルに属性を付与するメソッド""" def fit(env, replay_buffer, model) -> log_metrics: self.set_attributes_to_model(model) model.to(self.device, self.dtype) self.load_checkpoint(model) # checkpointがあればロード world_optimizer, controller_optimizer = model.configure_optimizer() for episode in range(self.num_episodes): model.collect_experiences(model, world_optimizer, controller_optimizer) for interval in range(self.collect_interval): experiences = replay_buffer.sample() loss_dict, experiences = model.world_training_step(experiences) # Update world model by returned loss # log loss_dict to console. loss_dict, experiences = model.controller_training_step(experiences) # Update controller model by returned loss # log loss_dict to console. if current_step % self.evaluation_interval == 0: loss_dict = model.evaluation_step(env) # log loss_dict to console if current_step % self.model_saving_interval == 0: self.save_checkpoint(model) metric_dict = model.evaluation_step(env) self.save_checkpoint(model) return metric_dict def evaluation(self, env, model) -> metric_dict: """このメソッドは評価のみを実行したいときに使用します。""" self.load_checkpoint(model) return model.evaluation_step(env)
The text was updated successfully, but these errors were encountered:
プログレスバーとしてtqdmを使用する
tqdm
Sorry, something went wrong.
Geson-anko
No branches or pull requests
タスク内容
学習用の設定で初期化され、Dreamerクラスを受け取って学習させるTrainerクラスを実装します。
提案内容
Note 2023/1/27: 途中まで作りかけているtrainerに書き足します。
コンストラクタ
__init__
では、Trainerの設定に関わる項目を受け取る。学習フローの実装メソッド
fit(env,model) -> log_metrics
の実装学習対象の
Dreamer
の内部処理に必要な属性(device, dtype, tensorboard logger)を付与するメソッドの実装checkpointを保存するメソッド及びロードするメソッド。
達成条件
参考
疑似コード
The text was updated successfully, but these errors were encountered: