LightningDataModule
Dataloaderの代わりに設定するクラス
参考サイト
DataLoader とデータの前処理コードは散財して終わる傾向があります。
データコードを LightningDataModule に体系化することにより再利用可能にします。
生のDataloaderをTrainerに渡せば学習は一応可能だが、できればDataModuleを使いたい。
データと前処理を密結合にするため
CPUとGPU間のワークロードの自動分散など追加の最適化を行うため
DataLoaderの3種類の利用方法
1.DataloaderをTrainer.fit()に渡す
一番単純、使っているDataloaderをそのまま利用できる。ただ前述した通りおすすめしない。
2.Lightning Module内に書いちゃう
後述する3の方が疎結合になって見通しが良くなるためあまりおすすめしない。
これはLightningModule内にtrain_dataloader, val_dataloader, test_dataloaderなどのメソッドを含めてしまう方法
モデルごとにDataloaderが異なる可能性が高い場合に採用する
fit()でfit(datamodule=data_module)を指定する必要がなくなる。
代わりにmodel等と密結合なので再利用しにくくなる。
疎結合を目指したいから個人的にはうーん。
完全に場合によりけりではある。だから公式は複数の方法を用意してくれている。
3.DataModuleを用いる
公式推奨の方法、それぞれの手順をグループ分けすることで分かりやすくする
ダウンロード手順
(前) 処理手順
分割手順
訓練dataloader
検証dataloader
テストdataloader
今までImageTransformクラス等を作っていたものを1つのクラスにしていてとても好感が持てる
オブジェクトを作成したら、それをfit()に渡せばよい
データセットをダウンロードして処理するために prepare_data() を使用する。
これは別の場所に用意されているデータ郡をダウンロードできる場所
この関数ではまだ状態の割当を行わないこと!
例えばMNISTのデータセットをインストールする際はここで行う。
これを用意して、次のsetup()でそれぞれのGPUにオンメモリで用意してあげる流れになる。
つまり、self.data_dir = data_dirみたいなことしかできない。
分割を行ない、モデル内部をビルドするために setup() を使用する。
DataModuleはDatasetとDataLoaderの呼び出しをする。それによってデータにまつわるすべてのコードを集約することができる。
code:method3.py
class MyDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.train_dims = None
self.vocab_size = 0
def prepare_data(self):
# called only on 1 GPU
# 1つのGPUのみで呼び出される
# データセットのダウンロードが主な役割
# 一度だけ行なわれる必要のあるデータ処理(DLやトークン化)
download_dataset()
tokenize()
build_vocab()
def setup(self):
# called on every GPU
# 全てのGPU上で自動的に呼び出される
# 分割を行ない、モデル内部で呼び出される形に成形する
vocab = load_vocab()
self.vocab_size = len(vocab)
self.train, self.val, self.test = load_datasets()
self.train_dims = self.train.next_batch.size()
def train_dataloader(self):
transforms = ...
return DataLoader(self.train, batch_size=64)
def val_dataloader(self):
transforms = ...
return DataLoader(self.val, batch_size=64)
def test_dataloader(self):
transforms = ...
return DataLoader(self.test, batch_size=64)