Trainer
公式サイト説明
↑ 引数も多く、重要なことは大体書いてあるので読もう!!
devices はどのデバイスを使用するかを指す、など
fitメソッドの引数を埋めるだけで学習を実行できる
kerasに似てハイレベルのラッパーとなっている
code:python
from pytorch_lightning import Trainer
trainer = Trainer()
# 学習開始
trainer.fit(lightning_module, datamodule=data_module)
trainer.fit(model, train_dataloader, val_dataloader)
のようにdatamoduleではなくdataloaderを引数にとることもできる。
trainerのメソッドには大きく4種類
trainer.fit()
学習データセットと検証データセットを使ってモデルを最適化するメソッド
trainer.validate()
検証データセットに対して1epochのみ実行を行うメソッド
trainer.test()
テストデータセットに対して1epochのみ実行を行うメソッド。必要になるまでテストデータセットで実行しないためにfit()から分離されている
trainer.predict()
データにチアして推論を実行するメソッド。model.forward()が実行されて予測が計算される。
分散(学習)やバッチ予測をするのに用いる
Trainerオブジェクトを作成するときに定義するロギングは無効化される
train, val, testと同様にpredict_step()predict_epoch_end()などの実装が可能に
predict(ckpt_path="")を入れると勝手にこれを使ってくれるようになるため、trainで使っていた新規モデルを作成するといったコードはそのままに、指定されたモデルで推論を行うことができるようになっている
predict_step()などを実装していない場合は勝手にforwordメソッドを呼んでくれる
ここが指定されていない場合もCheckPointCallbackが設定されている場合は、最適なfit()後のも出るチェックポイントが読み込まれる!すごい!
trainer.tune()
トレーニングの前にハイパラを調整する
学習済モデルを利用する方法
trainerのメソッドの引数にfit(ckpt_path="")というものを定義してあげるとそれがモデルとして使われる。
これはpredict()にも当てはまる。