NN設計8(検証)
学習が完了したパラメータがどの程度の推定精度を持つかを検証するために、訓練に利用しなかった検証用のデータ(X_test, y_test)に対する推定精度を求めている。
code:(追加).py
model.eval() # ネットを推論モードに設定 # (8-1)
y_predict = model(X_test) # (8-2)
y_predict_max_index = y_predict.max(dim=1)1 # (8-3) y_predict_compare = (y_predict_max_index == y_test) # (8-4)
y_predict_accuracy = y_predict_compare.sum() / len(y_predict_compare) # (8-5)
print('精度: ', y_predict_accuracy.item()) # (8-6)
本授業の構成では細かいところまで理解するための時間が足りないので、この頁は参考程度に留めて貰って構わない。大まかにいうと以下のような処理を行っている。
(8-3)maxメソッドによる最大要素のインデックス(添字)取得
/icons/hr.icon
※ ブラウザのバックボタンで戻る