TensorDatasetクラスによるデータセットの作成
ここではTensorDatasetクラス(略してデータセットクラス)を利用したデータセットの生成方法を学ぶ。
code:batch1.py
import torch as pt
from torch.utils.data import TensorDataset, DataLoader
# データの用意
X_train = pt.tensor(X, dtype=pt.float)
y_train = pt.tensor(y)
dataset_train = TensorDataset(X_train, y_train) # (1)
for i, (X, y) in enumerate(dataset_train):
print(i, ') X:', X, 'y:', y)
TensorDataset型インスタンスはジェネレータであり、FOR文の制御に用いることで訓練データを順に生成する。 code:結果.py
ループの度にデータセットから学習データXと教師データyの組を取り出すことができるので、これを利用するとニューラルネットの学習を行うことができる。ただし、
この場合、入力したテンソルの通りにデータが出力される。よりランダムに取り出すことができれば、探索が局所解に陥ることを抑制できる。
データは1件づつ得られるため処理速度の面で不利である。
/icons/hr.icon
※ ブラウザのバックボタンで戻る