Callback
訓練に直接関係のない追加のコード。Loggerやearlystopping等。Trainerに渡すことで利用できる。 EarlyStopping(モデルが改善しなさそうなら学習を止める)
WandbLogger(ログを取る)
ModelCheckpoint(モデルを保存する)
自作Callbackを作成
LightningModuleまたはTrainerに渡すことができる
モデル構造によって挙動が変わるならLightningModuleに渡す
どんな訓練でも使えそうならTrainerに渡す
自作Callbackのベストプラクティス
コールバックは機能的に分離されていなければなりません。
コールバックは、他のコールバックの動作に依存してはいけません。
コールバックから手動でメソッドを呼び出さないでください。
メソッドを直接呼び出すこと(例:on_validation_end)は強く推奨しません。
可能な限り、コールバックの実行順序に依存しないようにしてください。
例えば例として
訓練に直接関係のない可視化のためのログ集め+可視化
訓練のある時点で電子メールを送る
モデルを増大する
学習率を更新する
勾配を可視化する
例えば学習率減衰をするものは以下のようになる。
Callbackを自作する場合はpytorch_lightning.callbacks.Callbackを継承して、学習ループ内のhookに該当する抽象メソッドをオーバーライドする。色々できるので本当に一例。
code:python
class DecayLearningRate(pl.callbacks.Callback):
def __init__(self):
self.old_lrs = []
def on_train_start(self, trainer, pl_module):
# track the initial learning rates
for opt_idx, optimizer in enumerate(trainer.optimizers):
group = [param_group'lr' for param_group in optimizer.param_groups] self.old_lrs.append(group)
def on_train_epoch_end(self, trainer, pl_module, outputs):
for opt_idx, optimizer in enumerate(trainer.optimizers):
new_lr_group = []
for p_idx, param_group in enumerate(optimizer.param_groups):
old_lr = old_lr_groupp_idx new_lr = old_lr * 0.98
new_lr_group.append(new_lr)