機械学習(caret package)

今回はcaretパッケージの調査です。

機械学習、予測全般のモデル作成とかモデルの評価が入っているパッケージのようです。

多くの関数があるので、調査したものから並べていきます。


  • varImp

予測モデルを作ったときの、変数の重要度を計算する。

次のプログラムでは、花びらの長さなどの4変数を用いて、あやめの種類をk-近傍法で予測した場合に、どの変数が重要なのかを種類別に計算したもの。

#------irisデータで変数の重要度を計算
data(iris)
TrainData    <- iris[,1:4]
TrainClasses <- iris[,5]
knnFit       <- train(TrainData, TrainClasses, "knn")
knnImp       <- varImp(knnFit)
dotPlot(knnImp)

最後のdotplotで図を描いてくれるのですが、見づらいので自作プロットを描くとこのようになります。


f:id:isseing333:20101022173354j:image


Virginica種の予測をするときにはSepal.Widthは重要ではない。

またVercicolor種の予測をするときにはSepal.Lengthは重要ではないようです。

視覚的に散布図で確認してみます。


    • 散布図

f:id:isseing333:20101022173245j:image


確かに、Sepal.Width(縦軸)を見るとVirginica(緑)があまり分かれてないし、Sepal.Length(横軸)を見るとVercicolor(青)があまり分かれていません。

でも上の重要度のプロットで、青と緑が逆でもいいのでは?と思ってしまいますが、それは計算方法次第なんでしょうね。

重要度のプロット変数選択のような結果が、一目で分かるので便利だと思います。


  • train

この関数がcaretパッケージの肝のようで、かなり多くの予測を行うことができます。

主に使うパラメータは、

    • method:予測方法の指定
    • trControl:trainingデータをどうやって選ぶか
    • tuneLength:methodで指定した方法毎に設定できるパラメータ

くらいでしょうか。もちろん他のパラメータも、手法によっては指定します。

かなり多数の手法を適用できるみたいです(予測と聞いて連想する手法は全て入っています)ので、methodで指定できる手法を紹介しておきます。

(重要なので全部書き出します。モデルのグループ、手法名、入っているパッケージ、調整パラメータを表しています。)


Generalized linear model
	glm 		stats 			none
	glmStepAIC 	MASS 			none
Generalized additive model
	gam 		mgcv 			select, method
	gamLoess 	gam 			span, degree
	gamSpline 	gam 			df
Recursive partitioning 
	rpart 		rpart 			maxdepth
	ctree 		party 			mincriterion
	ctree2 		party 			maxdepth
Boosted trees 
	gbm 		gbm 			interaction depth, n.trees, shrinkage
	blackboost 	mboost 			maxdepth, mstop
	ada 		ada 			maxdepth, iter, nu
Boosted regression models 
	glmboost 	mboost 			mstop
	gamboost 	mboost 			mstop
	logitBoost 	caTools 		nIter
Random forests 
	rf 		randomForest 		mtry
	parRF 		randomForest, foreach 	mtry
	cforest 	party 			mtry
Bagging 
	treebag 	ipred 			None
	bag 		caret 			vars
	logicBag 	logicFS 		ntrees, nleaves
Other Trees 
	nodeHarvest 	nodeHarvest 		maxinter, node
	partDSA 	partDSA 		cut.off.growth, MPD
Logic Regression 
	logreg 		LogicReg 		ntrees, codetreesize
Elastic net (glm) 
	glmnet 		glmnet 			alpha, lambda
Neural networks 
	nnet 		nnet 			decay, size
	neuralnet 	neuralnet 		layer1, layer2, layer3
	pcaNNet 	caret 			decay, size
Projection pursuit regression 
	ppr 		stats 			nterms
Principal component regression 
	pcr 		pls 			ncomp
Independent component regression 
	icr 		caret 			n.comp
Partial least squares 
	pls 		pls, caret 		ncomp
Sparse partial least squares 
	spls 		spls, caret 		K, eta, kappa
Support vector machines 
	svmLinear 	kernlab 		C
	svmRadial 	kernlab 		sigma, C
	svmPoly 	kernlab 		scale, degree, C
Relevance vector machines 
	rvmLinear 	kernlab 		none
	rvmRadial 	kernlab 		sigma
	rvmPoly 	kernlab 		scale, degree
Least squares support vector machines 
	lssvmRadial 	kernlab 		sigma
Gaussian processes 
	guassprLinearl 	kernlab 		none
	guassprRadial 	kernlab 		sigma
	guassprPoly 	kernlab 		scale, degree
Linear least squares 
	lm 		stats 			None
	lmStepAIC 	MASS 			None
Robust linear regression 
	rlm 		MASS 			None
Multivariate adaptive regression splines 
	earth 		earth 			degree, nprune
Bagged MARS 
	bagEarth 	caret, earth 		degree, nprune
Rule Based Regression 
	M5Rules 	RWeka 			pruned
Penalized linear models 
	penalized 	penalized 		lambda1, lambda2
	enet 		elasticnet 		lambda, fraction
	lars 		lars 			fraction
	lars2 		lars 			steps
	enet 		elasticnet 		fraction
	foba 		foba 			lambda, k
Supervised principal components 
	superpc 	superpc 		n.components, threshold
Quantile Regression Forests 
	qrf 		quantregForest 		mtry
Linear discriminant analysis 
	lda 		MASS 			None
	Linda 		rrcov 			None
Quadratic discriminant analysis 
	qda 		MASS 			None
	QdaCov 		rrcov 			None
Stabilized linear discriminant analysis 
	slda 		ipred 			None
Heteroscedastic discriminant analysis 
	hda 		hda 			newdim, lambda, gamma
Stepwise discriminant analysis 
	stepLDA 	klaR 			maxvar, direction
	stepQDA 	klaR 			maxvar, direction
Stepwise diagonal discriminant analysis 
	sddaLDA 	SDDA 			None
	sddaQDA 	SDDA 			None
Shrinkage discriminant analysis 
	sda 		sda 			diagonal
Sparse linear discriminant analysis 
	sparseLDA 	sparseLDA 		NumVars, lambda
Regularized discriminant analysis 
	rda 		klaR 			lambda, gamma
Mixture discriminant analysis 
	mda 		mda 			subclasses
Sparse mixture discriminant analysis 
	smda 		sparseLDA 		NumVars, R, lambda
Penalized discriminant analysis 
	pda 		mda 			lambda
	pda2 		mda 			df
Stabilised linear discriminant analysis 
	slda 		ipred 			None
High dimensional discriminant analysis 
	hdda 		HDclassif 		model, threshold
Flexible discriminant analysis (MARS) 
	fda 		mda, earth 		degree, nprune
Bagged FDA 
	bagFDA 		caret, earth 		degree, nprune
Logistic/multinomial regression 
	multinom 	nnet 			decay
Penalized logistic regression 
	plr 		stepPlr 		lambda, cp
Rule-based classification 
	J48 		RWeka 			C
	OneR 		RWeka 			None
	PART 		RWeka 			threshold, pruned
	JRip 		RWeka 			NumOpt
Logic Forests 
	logforest 	LogicForest 		None
Bayesian multinomial probit model 
	vbmpRadial 	vbmp 			estimateTheta
k nearest neighbors 
	knn3 		caret 			k
Nearest shrunken centroids 
	pam 		pamr 			threshold
	scrda 		rda 			alpha, delta
Naive Bayes 
	nb 		klaR 			usekernel
Generalized partial least squares 
	gpls 		gpls 			K.prov
Learned vector quantization 
	lvq 		class 			size, k
ROC Curves 
	rocc 		rocc 			xgenes

すげーーーーー!!!

1つのパッケージでこれだけできるとは。。。数にして93個。

しかもクロスバリデーションやブートストラップをしながら。

大体の手法は数式まで分からなくてもどういう計算方法かわかるけど、博士のうちにこれ全部勉強したかったな~(これから予測モデルについて勉強する人にオススメします)。

HastieとかBishopの教科書も、これの一部しか載ってない感じだもんなぁ。

こういう風にまとまっているだけでも価値があります。

Efron先生のLARSもちゃんと入っているのに驚き。

好きなんだよな~、LARS。実データで使ったことないけどw


ちなみに、最近ブログで書いていたシミュレーションで使っている手法はlda、svm、knn、nnet、pca、plsですね。

まぁ満遍なく使っている感じはするかなぁ(自画自賛ww)。

このパッケージを使った方がプログラムが綺麗にはなっただろうな~(時間はあんまり変わらないだろうけど)。

こんな感じで検定についてまとまっているパッケージはないのだろうか??

検定は手法によって対象データの形が違うから難しいかな?


ちょっと記事が長くなってきたので今日はこの辺で。

プロット図のプログラムはこちらです↓

#---------caret
library(caret)

#------irisデータで変数の重要度を計算
data(iris)
TrainData    <- iris[,1:4]
TrainClasses <- iris[,5]
knnFit       <- train(TrainData, TrainClasses, "knn")
knnImp       <- varImp(knnFit)
dotPlot(knnImp)

#------dotPlotは色の制御ができないので自作のプロットを描く
#---knnImpデータを加工する
x1 <- data.frame(Imp   = knnImp$importance[, 1], 
		 var   = rownames(knnImp$importance), 
		 class = colnames(knnImp$importance)[1])
x2 <- data.frame(Imp   = knnImp$importance[, 2], 
		 var   = rownames(knnImp$importance), 
		 class = colnames(knnImp$importance)[2])
x3 <- data.frame(Imp   = knnImp$importance[, 3], 
		 var   = rownames(knnImp$importance), 
		 class = colnames(knnImp$importance)[3])

x  <- rbind(x1, x2, x3)

#---ggplot2で描く
library(ggplot2)
#---複数のウインドウがあった場合に最初のウインドウに移動する
dev.set(dev.prev())
ggplot(x, aes(Imp, var, col = class, shape = class)) + 
	geom_point(size = 4) + 
	geom_point(col="grey90", size=1.5) +
	xlab("重要度") + ylab("変数")

#------iris種と変数の関連をチェック
#---新規ウインドウを作成する
x11()
dev.set(dev.next())
ggplot(iris, aes(Sepal.Length, Sepal.Width, col = Species, shape = Species)) + 
	geom_point(size = 3) + 
	geom_point(col="grey90", size=1)

ページTOPへ