2025.7.22 Himmelblau関数のプロット
勾配をベクトル場として描画
code:p.py
import torch as pt
from torch.autograd.functional import jacobian
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler
def himmelblau(x): # 引数1個
return (x0**2 + x1 - 11)**2 + (x0 + x1**2 - 7)**2 def plot_himmelblau():
x = pt.linspace(-5, 5, 100)
y = pt.linspace(-5, 5, 100)
xx, yy = pt.meshgrid(x, y, indexing='xy')
zz = himmelblau((xx, yy)) # タプルで与えることに注意
levels = 30
plt.contour(xx, yy, zz, levels=levels)
# plt.pcolormesh(xx, yy, zz)
plt.colorbar()
mx, my = 100, 100
xx = pt.linspace(-5, 5, mx)
yy = pt.linspace(-5, 5, my)
x_hist, y_hist = [], []
u_hist, v_hist, c_hist = [], [], []
for x in xx:
for y in yy:
p = pt.tensor((x, y), dtype=pt.float)
u, v = jacobian(himmelblau, p)
c = pt.norm(u + v)
x_hist.append(x)
y_hist.append(y)
u_hist.append(u)
v_hist.append(v)
c_hist.append(x)
plot_himmelblau()
plt.quiver(x_hist, y_hist, u_hist, v_hist, c_hist, cmap='Reds')
plt.show()
鞍点の確認
code:p.py
import torch as pt
from torch.autograd.functional import jacobian, hessian
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler
def is_symmetric(x):
if isinstance(x, np.ndarray):
x = pt.tensor(x)
return pt.equal(x, x.T)
def matrix_definiteness(x):
if isinstance(x, np.ndarray):
x = pt.tensor(x)
if is_symmetric(x) == False:
print('警告:行列は対称(エルミート)ではない')
eigs = pt.linalg.eigvals(x)
# print(eigs)
eigs = eigs.real
if pt.all(eigs > 0):
# print('positive definite')
return 1
elif pt.all(eigs >= 0):
# print('positive semi-definite')
return 2
elif pt.all(eigs < 0):
# print('negative definite')
return 3
elif pt.all(eigs <= 0):
# print('negative semi-definite')
return 4
else:
# print('Not definite')
return 0
def himmelblau(x): # 引数1個
return (x0**2 + x1 - 11)**2 + (x0 + x1**2 - 7)**2 def plot_himmelblau():
x = pt.linspace(-5, 5, 100)
y = pt.linspace(-5, 5, 100)
xx, yy = pt.meshgrid(x, y, indexing='xy')
zz = himmelblau((xx, yy)) # タプルで与えることに注意
levels = 30
plt.contour(xx, yy, zz, levels=levels)
# plt.pcolormesh(xx, yy, zz)
plt.colorbar()
mx, my = 500, 500
xx = pt.linspace(-5, 5, mx)
yy = pt.linspace(-5, 5, my)
x_hist, y_hist = [], []
u_hist, v_hist, c_hist = [], [], []
h_hist = []
df_hist = []
for x in xx:
for y in yy:
p = pt.tensor((x, y), dtype=pt.float)
u, v = jacobian(himmelblau, p)
hess = hessian(himmelblau, p)
chk = matrix_definiteness(hess)
df = jacobian(himmelblau, p)
# print(df, end= ' ')
if pt.norm(df) < 1:
df_hist.append(True)
else:
df_hist.append(False)
h_hist.append(chk)
x_hist.append(x)
y_hist.append(y)
u_hist.append(u)
v_hist.append(v)
c_hist.append(chk)
df_hist = pt.tensor(df_hist)
x_hist = pt.tensor(x_hist)
y_hist = pt.tensor(y_hist)
c_hist = pt.tensor(c_hist)
mask_df = df_hist == True
for c in c_hist:
mask = c_hist == c
plt.scatter(x_histmask, y_histmask, marker=markersc) plot_himmelblau()
plt.show()
https://scrapbox.io/files/687f580b4df7dec80101fdaf.png
極小点4個、極大1箇所、鞍点4箇所を確認した。