Rastrigin 関数
code:p.py
import torch as pt
import matplotlib.pyplot as plt
def rastrigin_function(n, x, y):
A = 10
ret = A * n
ret += x**2 - A * pt.cos(2 * pt.pi * x)
ret += y**2 - A * pt.cos(2 * pt.pi * y)
return ret
x_range = -5.12, 5.12
y_range = -5.12, 5.12
xx = pt.linspace(x_range0, x_range1, 100)
yy = pt.linspace(y_range0, y_range1, 100)
xxx, yyy = pt.meshgrid(xx, yy, indexing='xy')
n = 2 # 2-dim
zzz = rastrigin_function(n, xxx, yyy)
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.plot_surface(xxx, yyy, zzz, cmap='jet')
plt.show()
https://scrapbox.io/files/6889a50e66aa52fb92eb01ae.png
高次元の場合でも対応可能にした。n = 1, 2 の場合のみ可視化を行う。
code:p.py
import torch as pt
import matplotlib.pyplot as plt
def rastrigin_function(x):
n = len(x)
A = 10
ret = A * n
for i in range(n):
ret += xi**2 - A * pt.cos(2 * pt.pi * xi)
return ret
n = 2 #
x_range = -5.12, 5.12
x_steps = 100
xx = []
for _ in range(n):
xx.append(pt.linspace(x_range0, x_range1, x_steps))
print()
xxx = pt.meshgrid(xx, indexing='xy')
zzz = rastrigin_function(xxx)
if n == 1:
plt.plot(xxx0, zzz)
elif n == 2:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.plot_surface(xxx0, xxx1, zzz, cmap='jet')
plt.show()