損失関数(2)
出力される値について、より詳細を確認しましょう。
code:mseloss1.py
import torch.nn as nn
import torch as pt
loss_func = nn.MSELoss()
data1 = pt.tensor(1, 2, 1, dtype=pt.float) data2 = pt.tensor(1, 2, 0, dtype=pt.float) loss = loss_func(data1, data2)
print(loss)
print(type(loss))
print(loss.shape)
code:実行結果.txt
tensor(0.3333)
<class 'torch.Tensor'>
torch.Size([])
結果はテンソル型、形状は0次元のスカラ値である。
itemメソッドによる値の取り出し
いっぱんに、損失関数の値はリストなどに保存しておき、後の処理でグラフ化などに利用される。データを扱いやすくするために、itemメソッド【torch】を用いてテンソルからビルトイン型の値を抽出することがよく行われる。 code:(続き).py
print(loss) # これだとテンソル型
l = loss.item() # ビルトインオブジェクトのfloat型
print(l)
print(type(l))
/icons/hr.icon
※ ブラウザのバックボタンで戻る