足し算の学習
数をPositional Encoding的な指数的に周期の伸びる回転する単位ベクトルの組み合わせで表現することで、その足し算が意外と小さなネットワークで学習できる。N桁の整数を2N次元に埋め込めば中間層はたかだか32ユニットで3桁の足し算を全問正解することができる。 code::
2桁の足し算の正解率、埋め込みが2次元の場合
16: 2291/5050 = 45.4%
32: 3586/5050 = 71.0%
64: 4276/5050 = 84.7%
128: 4719/5050 = 93.4%
256: 4909/5050 = 97.2%
4次元の場合
16: 4899/5050 = 97.0%
32: 5050/5050 = 100.0%
64: 5050/5050 = 100.0%
128: 5050/5050 = 100.0%
256: 5050/5050 = 100.0%
3桁の足し算(6次元へ埋め込み)
16: 423463/500500 = 84.6%
32: 500500/500500 = 100.0%
8→32→4という小ささで、5050件の2桁の数値同士の結果が2桁に収まる足し算を全問正解するニューラルネットを巡回セールスマン問題でニューラルネットの可視化した結果、中間層が半分ずつに分担して、入力の半分だけに注目してることがわかる。 https://gyazo.com/34f4742106426e7ad3fdaf0684bfb5d7
で、この中間層の並び替えに従って出力層への投射も可視化するとこう。
https://gyazo.com/25c805015f61b5df1a5cc5dff19d5861
2桁の整数を2 * 2次元に埋め込んでいるというところから予想がつくと思うが、要するに周期100と周期10の回転子がそれぞれ10の位と1の位を表現しているのである。ただしあくまでPositional Encodingなので10の位は階段状に上がるのではなく、10倍ゆっくり振動するだけである。 3桁の足し算ができるNN
https://gyazo.com/75de58e055aa536519bc93541e817722
なお2問間違えている: 32: 500498/500500 = 100.0%
12→32→6
中間層が意外と少なくても大丈夫な実験。2桁の足し算の正解率を中間層を減らしながら出す。12ぐらいから性能が顕著に悪化する。8→12→4で9割以上成功できる理由について数学的に説明がつくかどうか少し考えてみたけどわからない。単純な掛け算のだと入力の4倍必要になるのだが、今回は入力が単位円周上に制限されていることによることがおそらく有用活用されているのだろう。
code::
20: 5043/5050 = 99.9%
19: 5041/5050 = 99.8%
18: 4975/5050 = 98.5%
17: 5002/5050 = 99.0%
16: 4957/5050 = 98.2%
15: 4765/5050 = 94.4%
14: 5003/5050 = 99.1%
13: 4921/5050 = 97.4%
12: 4592/5050 = 90.9%
11: 4489/5050 = 88.9%
10: 3632/5050 = 71.9%
9: 2866/5050 = 56.8%
8: 2477/5050 = 49.0%
7: 1919/5050 = 38.0%
6: 1887/5050 = 37.4%
5: 815/5050 = 16.1%
4: 368/5050 = 7.3%
3: 250/5050 = 5.0%