TensorDatasetクラスによるデータセットの作成
ここではTensorDatasetクラス(略してデータセットクラス)を利用したデータセットの生成方法を学ぶ。
code:batch1.py
import torch as pt
from torch.utils.data import TensorDataset, DataLoader
# データの用意
X = 0, 1], 2, 3, 4, 5, 6, 7, [8, 9
y = 0], 1, 2, 3, [4
X_train = pt.tensor(X, dtype=pt.float)
y_train = pt.tensor(y)
dataset_train = TensorDataset(X_train, y_train) # (1)
for i, (X, y) in enumerate(dataset_train):
print(i, ') X:', X, 'y:', y)
関連:enumerate関数【builtin】、TensorDataset関数【torch】
TensorDataset型インスタンスはジェネレータであり、FOR文の制御に用いることで訓練データを順に生成する。
code:結果.py
0 ) X: tensor(0., 1.) y: tensor(0)
1 ) X: tensor(2., 3.) y: tensor(1)
2 ) X: tensor(4., 5.) y: tensor(2)
3 ) X: tensor(6., 7.) y: tensor(3)
4 ) X: tensor(8., 9.) y: tensor(4)
ループの度にデータセットから学習データXと教師データyの組を取り出すことができるので、これを利用するとニューラルネットの学習を行うことができる。ただし、
この場合、入力したテンソルの通りにデータが出力される。よりランダムに取り出すことができれば、探索が局所解に陥ることを抑制できる。
データは1件づつ得られるため処理速度の面で不利である。
/icons/hr.icon
※ ブラウザのバックボタンで戻る