meshgrid【torch】
2025.7.16 3Dグラフ【matplotlib】を使ってプロットするデータを生成することに用いられる。使い方はほぼ meshgrid【numpy】と同じであるので、ここでは相違点に着目した説明を行う。
与えるベクトルの型はtensor型に指定されている。
今後、引数 indexing が必須となるようだ。
code:p1.py
import torch as pt
def func(x, y, z):
return x**2 + y*z
x = pt.tensor(0, 2, 4, 6)
y = pt.tensor(1, 3, 5)
XX, YY = pt.meshgrid(x, y)
print('# XX:\n', XX)
print('# YY:\n', YY)
'''
tensor([0, 2, 4, 6,
0, 2, 4, 6,
0, 2, 4, 6])
# YY:
tensor([1, 1, 1, 1,
3, 3, 3, 3,
5, 5, 5, 5])
'''
引数 indexing は、生成されるテンソルの次元サイズを指定する。
ij の場合、動作はnumpy版と同じ。
xyの場合、与えた引数ベクトルサイズの順になる。
code:p.py
import torch as pt
def func(x, y, z):
return x**2 + y*z
x = pt.tensor(1)
y = pt.tensor(2, 3)
z = pt.tensor(4, 5, 6)
XX1, YY1, ZZ1 = pt.meshgrid(x, y, z, indexing='ij')
f1 = func(XX1, YY1, ZZ1)
print(f1.shape)
XX2, YY2, ZZ2 = pt.meshgrid(x, y, z, indexing='xy')
f2 = func(XX2, YY2, ZZ2)
print(f2.shape)
'''
torch.Size(2, 1, 3) <--- 'ij'だと y, z, x の順、numpy版と同じ
torch.Size(1, 2, 3) <--- 'xy'だと x, y, z の順
'''