Julia で『ゼロから作る Deep Learning』
1 章
2 章
サクッと省略
3 章
numpy のキモさに耐えれば 3.5 までは実装できるはず 3.6 が問題
Julia 用の MNIST データセットを用意する必要がある code:prepare.jl
using MLDatasets.MNIST
function prepare_testdata(;batch=1)
data, labels = MNIST.testdata()
batch_num = div(length(labels), batch)
data0 = eachslice(reshape(data, (28 * 28 * batch, batch_num)), dims=2)
labels0 = eachslice(transpose(reshape(labels, (batch, batch_num))), dims=1)
collect(zip(data0, labels0))
end
バッチ処理を見越した実装(本当は後で戻ってきたのだが)
3.6.2 で使うサンプルネットワークが pkl ファイルとなっており Julia から読めない
自分は一旦飛ばして 4 章へ行った
pkl ファイルを読み込んでの方法は試したら書く
試したが、numpy のオブジェクトが直接入っておりそのままでは読み出せない模様
非常に癪だが PyCall を使うことにする。おのれ……
code:make_predict.jl
using PyCall
@pyimport pickle
function make_predict(fname="./sample_weight.pkl")
f = pybuiltin("open")(fname, "rb")
network = pickle.load(f)
f.close()
function (x::AbstractArray)
a₁ = x * W₁ + transpose(b₁)
z₁ = h(a₁)
a₂ = z₁ * W₂ + transpose(b₂)
z₂ = h(a₂)
a₃ = z₂ * W₃ + transpose(b₃)
σ(a₃)
end
end
numpy の行と列が逆なためしっちゃかめっちゃかである
ベンチ関数を書いて精度 93.52% が無事出力された
code:bench.jl
function bench01()
predict = make_predict()
data = prepare_testdata()
count = 0
for (x, l) in data
if argmax(predict(x))2 - 1 == l count += 1
end
end
count / length(data)
end
バッチバチにバッチ処理をやっていく
code:batch.jl
function bench01_batch(;batch=100)
predict = make_predict()
data = prepare_testdata(batch=batch)
count = 0
for (xs, ls) in data
for (p, l) in zip(eachslice(predict(transpose(reshape(xs, (28 * 28, batch)))), dims=1), ls)
if argmax(p) - 1 == l
count += 1
end
end
end
count / (length(data) * batch)
end
4 章
前半は定義どおり素直に書いていけば動く
code:grad.jl
function numerical_gradient(
f::Function,
xs::AbstractArray{T, N};
h=1e-4::T,
)::AbstractArray{T, N} where{T <: Number, N}
grad = zeros(size(xs))
for (i, x) in enumerate(xs)
fxh1 = f(xs)
fxh2 = f(xs)
gradi = (fxh1 - fxh2) / (2h) end
grad
end
function gradient_descent(
f::Function,
init_x::AbstractArray{T, N};
learning_rate=0.01::T,
steps=100::UInt128
)::AbstractArray{T, N} where{T <: Number, N}
x = init_x
for _ = 1:steps
x -= learning_rate * numerical_gradient(f, x)
end
x
end
続く