jaxtyping
こういう感じで使う
code:check.py
def matmul(
...
書き方
dtype[array, shape], such as Float[Array, "batch channels"]
dtype: jaxtyping の型で指定する値の型
素朴には from jaxtyping import Float, Int, Bool 程度
Float32, Int64 や Complex など、ツリーがある
array: 具体的な実装クラス
numpy.ndarray
torch.Tensor
jax.Array
など
shape: 文字列で次元を表現
int は固定サイズ, str は可変サイズの軸に名前をつける
code:example.py
"batch seq" # 名前付き次元 (同名は同サイズを保証)
"3 224 224" # 固定サイズ (整数リテラル)
"batch 3 224" # 名前付きと固定サイズの混在
"vocab" # 1次元
"" # スカラー (0次元)
"..." # 任意の shape (dtype のみチェック)
"*batch dim" # *name: 0個以上の次元にマッチ (可変長)
"... height width" # ...: *_ と同等、匿名の可変長次元
"batch _" # _: その次元のチェックを無効化
"#batch height" # #name: ブロードキャスト可能 (サイズ1も許容) "dim-1" # 数式: 他の次元を参照した計算
"rows=3 cols=4" # name=value: ドキュメント用
チェック
書いただけではチェックされない
mypy は shape をチェックしない(dtype のみ)
pytest 実行時にチェックを有効にする
beartype に馴染みがないからこれで済ませたい
code:pyproject.toml
addopts = "--jaxtyping-packages=mypackage,beartype.beartype"
個別のチェック
実行時に TypeCheckError が投げられる
code:runtime_check.py
from jaxtyping import jaxtyped
from beartype import beartype
@jaxtyped(typechecker=beartype)
...
import hook
code:import_hook.py
from jaxtyping import install_import_hook
# 以降の mypackage の import で自動的にチェックが有効に
install_import_hook("mypackage", "beartype.beartype")
ruff (flake8) のチェックとの兼ね合い
型アノテーションに文字列式を使う際に、その文字列が Python の式か検証される
code:pyproject.toml
Float[torch.Tensor, " vocab"] のようにスペース入れて回避する