PyTorch
PyTorchの名前の由来
colabで使う
!pip install -U torch torchvision
特徴
高速に動作
独自テンソル処理ライブラリ
NumPyより高速
Resize
CenterCrop
画像の切り抜き
Normalize
色情報の標準化
Tensor
import torch
x = torch.zeros(5, 4)のようにtorchが提供するもので同様のことが行える
お互いに変換したり、演算したり
.numpy()でTensorをnumpy型に変換
vectorやmatrixのことは全てtensorと呼ぶ
x = torch.tensor([5.5, 3])
DataSet
データとそのラベルをペアにして保持したクラス
作られたインスタンスは前処理が終わったデータとラベルのペアになる
コンストラクタの引数に前処理クラスを渡すことで、自動的に前処理を行える
学習データと訓練データのそれぞれ用のインスタンスを作る
もちろん同じクラスから。
import torch.utils.data as dataのdataを継承してDataSetクラスを定義する
DataLoader
Datasetからどのようにデータを取り出すのかを設定するクラス
tourch.utils.data.DataLoaderを使えばいい
訓練用と、検証用のDataLoaderを作成する
モデルを作る時の流れ
DataSetクラスを作成
事前に前処理クラスを作成する
DataLoaderクラスの作成
モデルを定義
順電波関数の定義
損失関数の定義
最適化手法の設定
学習
テストデータで推論
torch.Tensorクラスでテンソルを扱う
.requires_grad=True
.backward()で勾配を取得
.gradで勾配データを保持
.detach()で勾配をトラックしなくなる
.requires_grad_で.requires_grad=Trueフラグを反転
Functionクラス
.grad_fn
v1.3
学習教材
tfからの移植