NN設計8(検証)
学習が完了したパラメータがどの程度の推定精度を持つかを検証するために、訓練に利用しなかった検証用のデータ(X_test, y_test)に対する推定精度を求めている。
code:(追加).py
model.eval() # ネットを推論モードに設定 # (8-1)
y_predict = model(X_test) # (8-2)
y_predict_max_index = y_predict.max(dim=1)1 # (8-3)
y_predict_compare = (y_predict_max_index == y_test) # (8-4)
y_predict_accuracy = y_predict_compare.sum() / len(y_predict_compare) # (8-5)
print('精度: ', y_predict_accuracy.item()) # (8-6)
本授業の構成では細かいところまで理解するための時間が足りないので、この頁は参考程度に留めて貰って構わない。大まかにいうと以下のような処理を行っている。
(8-3)maxメソッドによる最大要素のインデックス(添字)取得
(8-4)配列の論理演算により、推定と正解の一致状況を抽出
(8-5) bool値の扱いについてにより、Trueの数をカウントし、データの総数で割る
(8-6)itemメソッド【torch】により、精度の値をビルトイン型として取り出す
/icons/hr.icon
※ ブラウザのバックボタンで戻る