損失関数(バッチ学習)
ニューラルネットの学習においては、訓練に利用するデータを一件ずつ入力するのではなく、データの集合に対する損失関数を求めることが一般的である。全てのデータをまとめて一度に入力することで学習を行うことをバッチ学習という。
code:loss2.py
import torch as pt
import torch.nn as nn
loss_func = nn.MSELoss()
data1 = pt.tensor(1.0, 2.0], 2.0, 1.0, [3.1, 1.0) # (3) data2 = pt.tensor(1.0, 2.0], 2.0, 1.0, [3.0, 0.0) # (3') loss = loss_func(data1, data2)
print(loss)
print(type(loss))
print(loss.shape)
print(loss.ndim)
code:(結果).py
tensor(0.1683)
<class 'torch.Tensor'>
torch.Size([])
0
バッチ学習に用いられるデータを (3), (3') のように2次元のテンソルで与えている。
このように多次元のデータが入力された場合でも、損失関数の値はスカラ(0次元)のテンソルとして得られる。
/icons/hr.icon
※ ブラウザのバックボタンで戻る