jaxtyping
jax だけでなく numpy 等にもつけられる
jaxtyping
Array annotations - jaxtyping
numpyやPyTorchの配列にdtypeとshapeをアノテーションするjaxtypingのススメ - Speaker Deck
pytorchのdtype, shapeを型安全にするjaxtypingのすすめ
こういう感じで使う
code:check.py
def matmul(
a: Floatndarray, "m n",
b: Floatndarray, "n k", # n は a の2次元目と同じサイズ
) -> Floatndarray, "m k":
...
書き方
Array annotations - jaxtyping
dtype[array, shape], such as Float[Array, "batch channels"]
dtype: jaxtyping の型で指定する値の型
素朴には from jaxtyping import Float, Int, Bool 程度
https://docs.kidger.site/jaxtyping/api/array/#dtype
Float32, Int64 や Complex など、ツリーがある
array: 具体的な実装クラス
https://docs.kidger.site/jaxtyping/api/array/#array
numpy.ndarray
torch.Tensor
jax.Array
など
shape: 文字列で次元を表現
int は固定サイズ, str は可変サイズの軸に名前をつける
code:example.py
"batch seq" # 名前付き次元 (同名は同サイズを保証)
"3 224 224" # 固定サイズ (整数リテラル)
"batch 3 224" # 名前付きと固定サイズの混在
"vocab" # 1次元
"" # スカラー (0次元)
"..." # 任意の shape (dtype のみチェック)
"*batch dim" # *name: 0個以上の次元にマッチ (可変長)
"... height width" # ...: *_ と同等、匿名の可変長次元
"batch _" # _: その次元のチェックを無効化
"#batch height" # #name: ブロードキャスト可能 (サイズ1も許容)
"dim-1" # 数式: 他の次元を参照した計算
"rows=3 cols=4" # name=value: ドキュメント用
チェック
書いただけではチェックされない
mypy は shape をチェックしない(dtype のみ)
pytest 実行時にチェックを有効にする
beartype に馴染みがないからこれで済ませたい
code:pyproject.toml
tool.pytest.ini_options
addopts = "--jaxtyping-packages=mypackage,beartype.beartype"
個別のチェック
実行時に TypeCheckError が投げられる
code:runtime_check.py
from jaxtyping import jaxtyped
from beartype import beartype
@jaxtyped(typechecker=beartype)
def func(x: Floatndarray, "batch dim") -> Floatndarray, "batch":
...
import hook
code:import_hook.py
from jaxtyping import install_import_hook
# 以降の mypackage の import で自動的にチェックが有効に
install_import_hook("mypackage", "beartype.beartype")
ruff (flake8) のチェックとの兼ね合い
F722 - Syntax error in forward annotation
型アノテーションに文字列式を使う際に、その文字列が Python の式か検証される
code:pyproject.toml
tool.ruff.lint
ignore = "F722"
F821 undefined name
https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error
Float[torch.Tensor, " vocab"] のようにスペース入れて回避する
#Python #ML