Rubyで機械学習: ベイズ分類器でプログラミング言語分類器を作る(作れなかった)
失敗した話です。
async-httpを使ってみた。
code:download-rosetta-code.rb
require "async/http/internet/instance"
def main
Sync do
response.save("dataset_infos.json")
end
while (300..309).cover?(res.status)
res = Async::HTTP::Internet.get(res.headers"location") end
case res.status
when 200
res.save("train-00000-of-00001-8b4da49264116bbf.parquet")
else
raise res.inspect
end
end
end
end
if $PROGRAM_NAME == __FILE__
main
end
最初全部のデータで学習させてみたけど、誤判定が甚だしく、知らないような言語がよく出てきてしまっていたので、Rosetta Codeの言語のうち、Rougeで扱える物一覧に絞ることにした。 Rougeでのタグを使っていくことにし、以下はRosetta Codeの言語名からのマップ。
code:langs.json
{
"ABAP": "abap",
"APEX": "apex",
"AWK": "awk",
"ActionScript": "awk",
"Ada": "awk",
"Apex": "apex",
"AppleScript": "applescript",
"Applescript": "applescript",
"Bash": "shell",
"Bash Shell": "shell",
"Batch": "batchfile",
"Batch File": "batchfile",
"Batch file": "batchfile",
"C": "c",
"C / C++": "cpp",
"C Shell": "shell",
"C and C++": "cpp",
"C#": "csharp",
"C# and Visual Basic .NET": "csharp",
"C++": "cpp",
"C++/CLI": "cpp",
"C/C++": "cpp",
"CLIPS": "common_lisp",
"CMake": "make",
"CSS": "css",
"C_sharp": "csharp",
"Ceylon": "ceylon",
"Chef": "ruby",
"Clojure": "clojure",
"CoffeeScript": "coffeescript",
"Coffeescript": "coffeescript",
"ColdFusion": "cfscript",
"Common Lisp": "common_lisp",
"Component Pascal": "pascal",
"Coq": "pascal",
"Crystal": "crystal",
"D": "d",
"Dafny": "dafny",
"Dart": "dart",
"ECL": "ecl",
"ECMAScript": "javascript",
"EchoLisp": "common_lisp",
"Echolisp": "common_lisp",
"Eiffel": "eiffel",
"Elixir": "elixir",
"Elm": "elm",
"Emacs Lisp": "common_lisp",
"Erlang": "erlang",
"F#": "fsharp",
"FORTRAN": "fortran",
"Factor": "fortran",
"Free Pascal": "pascal",
"FreePascal": "pascal",
"GDScript": "gdscript",
"GLSL": "glsl",
"GNU make": "make",
"Go": "go",
"Groovy": "groovy",
"HTML": "html",
"HTML5": "html",
"Hack": "hack",
"Haskell": "haskell",
"Haxe": "haxe",
"Idris": "idris",
"Io": "io",
"J": "j",
"JSON": "j",
"Janet": "janet",
"Java": "java",
"JavaScript": "javascript",
"JavaScript + HTML": "javascript",
"JavaScript + SVG": "javascript",
"Javascript/NodeJS": "javascript",
"Jinja": "jinja",
"Kotlin": "kotlin",
"LLVM": "llvm",
"LaTeX": "kotlin",
"Lasso": "lasso",
"Lean": "lean",
"Lisp": "common_lisp",
"LiveScript": "lasso",
"Lua": "lua",
"Lua/Torch": "lua",
"MATLAB": "matlab",
"MATLAB / Octave": "matlab",
"Make": "make",
"MatLab": "matlab",
"Mathematica / Wolfram Language": "mathematica",
"Mathematica/ Wolfram Language": "mathematica",
"Mathematica//Wolfram Language": "mathematica",
"Mathematica/Wolfram Language": "mathematica",
"MySQL": "sql",
"NewLISP": "common_lisp",
"NewLisp": "common_lisp",
"Nial": "nial",
"Nim": "nim",
"OCaml": "ocaml",
"Objective-C": "objective_c",
"OpenLisp": "common_lisp",
"Oracle": "sql",
"PHP": "php",
"PHP+SQLite": "php",
"PL/SQL": "plsql",
"Pascal": "pascal",
"Perl": "perl",
"Perl5i": "perl",
"Plain TeX": "tex",
"Pony": "pony",
"PostScript": "postscript",
"PostgreSQL": "sql",
"PowerShell": "powershell",
"PowerShell+SQLite": "powershell",
"Powershell": "powershell",
"Prolog": "prolog",
"Python": "python",
"Python 3.x Long Form": "python",
"Python+SQLite": "python",
"Q": "q",
"R": "r",
"Racket": "racket",
"ReasonML": "reasonml",
"Ruby": "ruby",
"Ruby with RSpec": "ruby",
"Rust": "rust",
"SAS": "sas",
"SQL": "sql",
"SQL PL": "sql",
"SQL/PostgreSQL": "sql",
"SQLite": "sql",
"SVG": "xml",
"Sass/SCSS": "scss",
"Scala": "scala",
"Scheme": "scheme",
"Sed": "sed",
"Shell": "shell",
"Smalltalk": "smalltalk",
"SuperCollider": "supercollider",
"Swift": "swift",
"Swift Playground": "swift",
"TOML": "toml",
"TypeScript": "typescript",
"Typescript": "typescript",
"UNIX Shell": "shell",
"VBA": "vb",
"VBA (Visual Basic for Application)": "vb",
"VBA Excel": "vb",
"VBA/Visual Basic": "vb",
"VHDL": "vhdl",
"Vala": "vala",
"Verilog": "verilog",
"Visual Basic": "vb",
"Visual Basic .NET": "vb",
"Wollok": "wollok",
"XPath": "xpath",
"XPath 2.0": "xpath",
"XQuery": "xquery",
"XSLT": "xml",
"XSLT 1.0": "xml",
"XSLT 2.0": "xml",
"Xojo": "xojo",
"Zig": "zig",
"Zsh": "shell",
"bash": "shell",
"clojure": "clojure",
"haskell": "clojure",
"make": "clojure",
"php": "clojure",
"rust": "rust",
"sed": "sed",
"zig": "zig"
}
code:train.rb
require "parquet"
require "classifier"
require "pathname"
require "json"
MODEL_PATH = Pathname.new(__dir__)/"model.rubymarshal"
LANGS_PATH = Pathname.new(__dir__)/"langs.json"
def main(file)
build_classifier
classifier = load_classifier
puts classifier.classify(file.read)
end
def build_classifier
return if MODEL_PATH.exist?
langs = JSON.parse(LANGS_PATH.read)
available_lang_names = langs.keys
data = Arrow::Table.load("train-00000-of-00001-8b4da49264116bbf.parquet")
.select_columns("language_name", "code")
shuffled = ShuffledTable.new(data)
n_training_data = (n_records * 0.8).to_i
classifier = Classifier::Bayes.new(*langs.values.uniq)
correct = 0
shuffled.each_record.with_index do |record, i|
$stderr.puts "#{i + 1}/#{n_records}" if i % 100 == 0
if i <= n_training_data
else
end
end
$stderr.puts "Accuracy: #{correct * 100.to_f / (n_records - n_training_data)}%" MODEL_PATH.write Marshal.dump(classifier)
end
def load_classifier
Marshal.load(MODEL_PATH.read)
end
class ShuffledTable
def initialize(table)
@table = table
@indices = (0...n_rows).to_a.shuffle
end
def each_record
return to_enum(__method__) unless block_given?
@indices.each do |i|
yield Arrow::Record.new(@table, i)
end
end
end
if $PROGRAM_NAME == __FILE__
main ARGF
end
一応できたけど全然まともに判定してくれない。
訓練データの扱いが悪い気がする……とも思うが、プログラムのソースコードは似たような物になることも多いから、そもそも判定が難しいのかも知れない。
やること(やるとは言っていない):
言語をもっと絞ってみる
てか実験的な少量データでClassifierライブラリーがまともに動くか確認する