Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainerクラスの実装 #33

Open
8 of 9 tasks
Geson-anko opened this issue Jan 22, 2023 · 1 comment
Open
8 of 9 tasks

Trainerクラスの実装 #33

Geson-anko opened this issue Jan 22, 2023 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@Geson-anko
Copy link
Collaborator

Geson-anko commented Jan 22, 2023

タスク内容

学習用の設定で初期化され、Dreamerクラスを受け取って学習させるTrainerクラスを実装します。

提案内容

Note 2023/1/27: 途中まで作りかけているtrainerに書き足します。

  • コンストラクタ__init__では、Trainerの設定に関わる項目を受け取る。

    • 学習ループに関わる変数(episode数や経験を集めるインターバル数など)
    • 任意のモデルの学習に関わる変数(gradient_clip_valなど)
    • 演算デバイスや精度に関わる変数(device, dtype)
    • tensorboardへのロガー
    • モデルの保存に関わる項目 (checkpoint, checkpointを記録する変数など)
    • hydraのDictConfigオブジェクト
  • 学習フローの実装メソッドfit(env,model) -> log_metricsの実装

  • 学習対象のDreamerの内部処理に必要な属性(device, dtype, tensorboard logger)を付与するメソッドの実装

  • checkpointを保存するメソッド及びロードするメソッド。

達成条件

  • コンストラクタが実装された
  • fitメソッドが実装された。
  • 学習対象のモデルに属性を付与するメソッドが実装された
  • checkpointを保存するメソッドが実装された
  • checkpointを読み込むメソッドが実装された。
  • Dreamerの学習が実行できる。
  • Tensorboardでその学習結果を見ることができる
  • 学習の途中経過をコンソールに出力できる
  • 学習済のモデルを読み込んで評価関数を呼び出すメソッドが存在する。

参考

疑似コード

class Trainer:
    def __init__(self, setting1: int = 1, setting2:str = "abc", ...):
        self.__dict__.update(locals()) # これによって引数を全て属性に付与できる
        
    def save_checkpoint(model) -> None:
        """モデルを保存するメソッド"""
    def load_checkpoint(model) -> None:
        """パラメータをロードするメソッド"""
    def set_attributes_to_model(model):
        """モデルに属性を付与するメソッド"""
    
    def fit(env, replay_buffer, model) -> log_metrics:
        self.set_attributes_to_model(model)
        
        model.to(self.device, self.dtype)
        
        self.load_checkpoint(model) # checkpointがあればロード
        
        world_optimizer, controller_optimizer = model.configure_optimizer()
        
        for episode in range(self.num_episodes):
            model.collect_experiences(model, world_optimizer, controller_optimizer)
            
            for interval in range(self.collect_interval):
                experiences = replay_buffer.sample()
                
                loss_dict, experiences = model.world_training_step(experiences)
                # Update world model by returned loss
                # log loss_dict to console.
                
                loss_dict, experiences = model.controller_training_step(experiences)
                # Update controller model by returned loss
                # log loss_dict to console.
                
                if current_step % self.evaluation_interval == 0:
                    loss_dict = model.evaluation_step(env)
                    # log loss_dict to console
                    
                if current_step % self.model_saving_interval == 0:
                    self.save_checkpoint(model)
        metric_dict = model.evaluation_step(env)
        self.save_checkpoint(model)
        return metric_dict

    def evaluation(self, env, model) -> metric_dict:
        """このメソッドは評価のみを実行したいときに使用します。"""
        self.load_checkpoint(model)
        return model.evaluation_step(env)
@Geson-anko
Copy link
Collaborator Author

プログレスバーとしてtqdmを使用する

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant