TransformerLens
利用可能なオープンソース モデルについては、こちらで文書化されています 公式に対応しているモデル名はloading_from_pretrained.pyのOFFICIAL_MODEL_NAMESに記載あり。
code:py
OFFICIAL_MODEL_NAMES = [
"gpt2",
"gpt2-medium",
"gpt2-large",
"gpt2-xl",
"distilgpt2",
"facebook/opt-125m",
"facebook/opt-1.3b",
"facebook/opt-2.7b",
"facebook/opt-6.7b",
"facebook/opt-13b",
"facebook/opt-30b",
"facebook/opt-66b",
"EleutherAI/gpt-neo-125M",
"EleutherAI/gpt-neo-1.3B",
"EleutherAI/gpt-neo-2.7B",
"EleutherAI/gpt-j-6B",
"EleutherAI/gpt-neox-20b",
"stanford-crfm/alias-gpt2-small-x21",
"stanford-crfm/battlestar-gpt2-small-x49",
"stanford-crfm/caprica-gpt2-small-x81",
"stanford-crfm/darkmatter-gpt2-small-x343",
"stanford-crfm/expanse-gpt2-small-x777",
"stanford-crfm/arwen-gpt2-medium-x21",
"stanford-crfm/beren-gpt2-medium-x49",
"stanford-crfm/celebrimbor-gpt2-medium-x81",
"stanford-crfm/durin-gpt2-medium-x343",
"stanford-crfm/eowyn-gpt2-medium-x777",
"EleutherAI/pythia-14m",
"EleutherAI/pythia-31m",
"EleutherAI/pythia-70m",
"EleutherAI/pythia-160m",
"EleutherAI/pythia-410m",
"EleutherAI/pythia-1b",
"EleutherAI/pythia-1.4b",
"EleutherAI/pythia-2.8b",
"EleutherAI/pythia-6.9b",
"EleutherAI/pythia-12b",
"EleutherAI/pythia-70m-deduped",
"EleutherAI/pythia-160m-deduped",
"EleutherAI/pythia-410m-deduped",
"EleutherAI/pythia-1b-deduped",
"EleutherAI/pythia-1.4b-deduped",
"EleutherAI/pythia-2.8b-deduped",
"EleutherAI/pythia-6.9b-deduped",
"EleutherAI/pythia-12b-deduped",
"EleutherAI/pythia-70m-v0",
"EleutherAI/pythia-160m-v0",
"EleutherAI/pythia-410m-v0",
"EleutherAI/pythia-1b-v0",
"EleutherAI/pythia-1.4b-v0",
"EleutherAI/pythia-2.8b-v0",
"EleutherAI/pythia-6.9b-v0",
"EleutherAI/pythia-12b-v0",
"EleutherAI/pythia-70m-deduped-v0",
"EleutherAI/pythia-160m-deduped-v0",
"EleutherAI/pythia-410m-deduped-v0",
"EleutherAI/pythia-1b-deduped-v0",
"EleutherAI/pythia-1.4b-deduped-v0",
"EleutherAI/pythia-2.8b-deduped-v0",
"EleutherAI/pythia-6.9b-deduped-v0",
"EleutherAI/pythia-12b-deduped-v0",
"EleutherAI/pythia-160m-seed1",
"EleutherAI/pythia-160m-seed2",
"EleutherAI/pythia-160m-seed3",
"NeelNanda/SoLU_1L_v9_old",
"NeelNanda/SoLU_2L_v10_old",
"NeelNanda/SoLU_4L_v11_old",
"NeelNanda/SoLU_6L_v13_old",
"NeelNanda/SoLU_8L_v21_old",
"NeelNanda/SoLU_10L_v22_old",
"NeelNanda/SoLU_12L_v23_old",
"NeelNanda/SoLU_1L512W_C4_Code",
"NeelNanda/SoLU_2L512W_C4_Code",
"NeelNanda/SoLU_3L512W_C4_Code",
"NeelNanda/SoLU_4L512W_C4_Code",
"NeelNanda/SoLU_6L768W_C4_Code",
"NeelNanda/SoLU_8L1024W_C4_Code",
"NeelNanda/SoLU_10L1280W_C4_Code",
"NeelNanda/SoLU_12L1536W_C4_Code",
"NeelNanda/GELU_1L512W_C4_Code",
"NeelNanda/GELU_2L512W_C4_Code",
"NeelNanda/GELU_3L512W_C4_Code",
"NeelNanda/GELU_4L512W_C4_Code",
"NeelNanda/Attn_Only_1L512W_C4_Code",
"NeelNanda/Attn_Only_2L512W_C4_Code",
"NeelNanda/Attn_Only_3L512W_C4_Code",
"NeelNanda/Attn_Only_4L512W_C4_Code",
"NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr",
"NeelNanda/SoLU_1L512W_Wiki_Finetune",
"NeelNanda/SoLU_4L512W_Wiki_Finetune",
"ArthurConmy/redwood_attn_2l",
"llama-7b-hf",
"llama-13b-hf",
"llama-30b-hf",
"llama-65b-hf",
"meta-llama/Llama-2-7b-hf",
"meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-2-13b-hf",
"meta-llama/Llama-2-13b-chat-hf",
"meta-llama/Llama-2-70b-chat-hf",
"CodeLlama-7b-hf",
"CodeLlama-7b-Python-hf",
"CodeLlama-7b-Instruct-hf",
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-70B",
"meta-llama/Meta-Llama-3-70B-Instruct",
"Baidicoot/Othello-GPT-Transformer-Lens",
"bert-base-cased",
"roneneldan/TinyStories-1M",
"roneneldan/TinyStories-3M",
"roneneldan/TinyStories-8M",
"roneneldan/TinyStories-28M",
"roneneldan/TinyStories-33M",
"roneneldan/TinyStories-Instruct-1M",
"roneneldan/TinyStories-Instruct-3M",
"roneneldan/TinyStories-Instruct-8M",
"roneneldan/TinyStories-Instruct-28M",
"roneneldan/TinyStories-Instruct-33M",
"roneneldan/TinyStories-1Layer-21M",
"roneneldan/TinyStories-2Layers-33M",
"roneneldan/TinyStories-Instuct-1Layer-21M",
"roneneldan/TinyStories-Instruct-2Layers-33M",
"stabilityai/stablelm-base-alpha-3b",
"stabilityai/stablelm-base-alpha-7b",
"stabilityai/stablelm-tuned-alpha-3b",
"stabilityai/stablelm-tuned-alpha-7b",
"mistralai/Mistral-7B-v0.1",
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"bigscience/bloom-560m",
"bigscience/bloom-1b1",
"bigscience/bloom-1b7",
"bigscience/bloom-3b",
"bigscience/bloom-7b1",
"bigcode/santacoder",
"Qwen/Qwen-1_8B",
"Qwen/Qwen-7B",
"Qwen/Qwen-14B",
"Qwen/Qwen-1_8B-Chat",
"Qwen/Qwen-7B-Chat",
"Qwen/Qwen-14B-Chat",
"Qwen/Qwen1.5-0.5B",
"Qwen/Qwen1.5-0.5B-Chat",
"Qwen/Qwen1.5-1.8B",
"Qwen/Qwen1.5-1.8B-Chat",
"Qwen/Qwen1.5-4B",
"Qwen/Qwen1.5-4B-Chat",
"Qwen/Qwen1.5-7B",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen1.5-14B",
"Qwen/Qwen1.5-14B-Chat",
"Qwen/Qwen2-0.5B",
"Qwen/Qwen2-0.5B-Instruct",
"Qwen/Qwen2-1.5B",
"Qwen/Qwen2-1.5B-Instruct",
"Qwen/Qwen2-7B",
"Qwen/Qwen2-7B-Instruct",
"microsoft/phi-1",
"microsoft/phi-1_5",
"microsoft/phi-2",
"microsoft/Phi-3-mini-4k-instruct",
"google/gemma-2b",
"google/gemma-7b",
"google/gemma-2b-it",
"google/gemma-7b-it",
"google/gemma-2-2b",
"google/gemma-2-2b-it",
"google/gemma-2-9b",
"google/gemma-2-9b-it",
"google/gemma-2-27b",
"google/gemma-2-27b-it",
"01-ai/Yi-6B",
"01-ai/Yi-34B",
"01-ai/Yi-6B-Chat",
"01-ai/Yi-34B-Chat",
"google-t5/t5-small",
"google-t5/t5-base",
"google-t5/t5-large",
"ai-forever/mGPT",
]
utils.py内のget_device的にmpsいけそう。 code:py
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
# Parse the PyTorch version to check if it's below version 2.0
major_version = int(torch.__version__.split(".")0) if major_version >= 2:
return torch.device("mps")
return torch.device("cpu")
run_with_cahce() 関数を使用して 2 層の注意のみのモデルを実行します。この関数には、モデルの出力ロジット (正規化されていない予測) と各層のアクティベーション パターン (キャッシュ) の 2 つの出力があります。
キャッシュを使用してモデルを実行した後、TransformerLens 関数 to_str_tokens() を使用して生の入力テキストをトークンに変換します。
たとえば、トークン「インテリジェンス」に焦点を合わせてヘッド L0H7 を調べると、この特定のヘッドが前のトークン「マシン」に注意を払っていることがわかります。注意パターンのインタラクティブな視覚化を視覚的に調べると、少なくとも 3 つの異なるタイプを区別できます。
前のトークンヘッドは主に前のトークン、つまりL0H7に関係しています。
現在のトークンヘッドは主に現在のトークン、つまりL1H6に注目しています。
最初のトークンヘッドは主に最初のトークン、つまりL0H3、L1H4、L1H10に対応します。
つまり、例えば「食べ」をクリックしたらそのトークンにフォーカスが当たる。
フォーカスを当てた状態で各Headにホバーすると「あなたの」とかどのトークンに着目しているか分かる。
https://scrapbox.io/files/66de0a7bc7ef8e001c2d09fc.png
各トークンの関連度合いを示す相関図的なもの
である。つまり、一番左上は第一トークンと第一トークン同士の注視の度合い。
一番左端で上から3つ目は第三トークンの第一トークンの注視の度合い。
そのように各トークンの関連度合いを示している。
例えば「ご飯」に関するワードに着目するHeadがあればそれはご飯に反応するHeadだと分かる。
下三角行列っぽくなっているのは第一トークンにとって例えば第三トークンはまだ未入力状態なので。
実態は、各トークンのQueryベクトルとKeyベクトルの内積
各トークン毎に生まれるQueryとKeyの内積を取る(トークンと同じ個数生まれる)。
Hooks: Intervening on Activations
current_activation_valueとhook_pointを受け取り、それを新しいアクティベーション値に変換します。モデルが実行されると、通常通りそのアクティベーションが計算され、その後フック関数が適用され、アクティベーションを置き換えます。フック関数は任意のPython関数であり、正しい形状のテンソルを返す限り利用可能です
code:py
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("tiny-stories-1M")
Loaded pretrained model tiny-stories-1M into HookedTransformer
_logits, cache = model.run_with_cache("Why did the chicken cross the")
residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn")
answer = " road" # Note the proceeding space to match the model's tokenization
logit_attrs = cache.logit_attrs(residual_stream, answer)
print(logit_attrs.shape) # Attention layers
most_important_component_idx = torch.argmax(logit_attrs)
3_attn_out
ニワトリはなぜまでを入れたあとに「道路」を横断するのかのプロンプトを入れることでどのAttentionが「道路」に反応したかを見る例。
forwardメソッドなぜlossが計算できるかと言えばNext Token Predictionだから。
実際に入力したトークンを正解として次のトークンを予測している。
一応generateメソッドもある。これだとlogitsは得られない?
load_in_8bitやってみたらエラー出た。
AssertionError: Quantization not supported
コード読むにload_inはダメにしてるっぽい。
ただload_inは手元に重み落としたときに動的に量子化するオプションなので、そもそも量子化済みのもの使えば良いのでは?
この図が良さそう。
https://raw.githubusercontent.com/chloeli-15/ARENA_img/main/img/transformer-full-updated.png