概要
前回ナイブベイズをやってみました
今回は SVM を試してみます
SVM も分類器ですがナイブベイズとは異なりマージン最大化を取り入れた境界線を引くことで分類を行います
環境
- macOS 10.13.2
- Ruby 2.4.1p111
- rb-libsvm 1.4.0
- libsvm
インストール
- brew install libsvm
- bundle init
- vim Gemfile
gem "rb-libsvm"
gem "libsvmloader"
bundle install --path vendor
学習データの取得
wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits
wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits.t
今回のデータは分類するラベル数が 0 から 9 までの 10 個、学習させるデータの次元数が 16 になっていました
学習させるサンプルコード
- vim learn.rb
require 'libsvmloader'
require 'libsvm'
class Model
def self.gen
train_data, train_data_labels = LibSVMLoader.load_libsvm_file('./pendigits')
train_data_ary = train_data.each_row.map { |v| Libsvm::Node.features(v.to_a) }
train_data_labels_ary = train_data_labels.to_flat_a
params = Libsvm::SvmParameter.new.tap do |p|
p.cache_size = 1000
p.svm_type = Libsvm::SvmType::C_SVC
p.kernel_type = Libsvm::KernelType::RBF
p.gamma = 0.0001
p.c = 1.0
p.eps = 0.001
end
problem = Libsvm::Problem.new
problem.set_examples(train_data_labels_ary, train_data_ary)
model = Libsvm::Model.train(problem, params)
end
end
ダウンロードしたデータはすでに SVM で学習できる形式になっています
その場合には LibSVMLoader.load_libsvm_file
という便利なメソッドがあります
データ (train_data) とラベル (train_data_labels) は NMatrix クラスのオブジェクトです
データは Libsvm::Node.features
クラスのオブジェクトに変換します
ラベルは to_flat_a という関数をコールして NMatrix から Array に変換します
次に学習時のパラメータの設定です
とりあえずこのまま学習させましょう
一応ここがチューニングポイントにはなっているので精度を上げる場合にはこの辺の値を調整します
あとは Libsvm::Problem
にテストデータと正解のラベルデータを食わせてパラメータと一緒に学習させればモデルができあがります
テストコード
では先ほど作成したモデルを使ってテストデータを食わせてみましょう
vim test.rb
require './learn.rb'
model = Model.gen
test_data, test_data_labels = LibSVMLoader.load_libsvm_file('./pendigits.t')
# これでテストの実施をしてる
preds = test_data.each_row.map { |v| model.predict(Libsvm::Node.features(v.to_a)) }
ok = 0
preds.each.with_index { |label, index|
if label == test_data_labels[index]
ok = ok + 1
end
}
# puts preds.count
# puts ok
puts 100.0 * (ok / preds.count.to_f)
テストデータを読み込んだら model.predict
でテストさせます
テストデータは学習時同様に Libsvm::Node.features
で変換する必要があります
結果は preds という配列のデータで返ってきます
単純にデータに対してモデルが判断したラベルが入っているだけです
なのであとは each で回してそれぞれのラベルが正しかったかどうかを判断します
今回は正しければ ok のカウントを増やしています
あとは全体のデータのうちどれだけ ok があるかを割り算すれば正答率が出せます
今回はサンプルデータを使ったので 98.28473413379074
という値が出ればコードがうまく動作していることになります
最後に
Ruby で SVM を試してみました
SVM はデータさえ揃えばサクっと試せてかつ精度の良い分類器が作れる手法かなと思います
パラメータ部分の説明は今回省略しましたが SVM の肝でもあるのでちゃんと理解しておいたほうが良いかと思います
(RBF カーネル、C、gamma あたりでググるといわいろと出てきます)
また評価の尺度は他にも Precision と Recall, F値 などがあります
SVM には CrossValidation という手法があり全体のデータの一部を学習データ、一部をテストデータとして学習させることができます
これを使えば上記の尺度も出すことができます
CrossValidation も今回のライブラリを使ってできるので興味があれば以下を参考にしてください
https://github.com/febeling/rb-libsvm/blob/master/examples/iris.rb
0 件のコメント:
コメントを投稿