Motivation
最近有項(xiàng)目用到Scikit-learn上的高斯樸素貝葉斯模型(簡稱GNB),隨著數(shù)據(jù)量增大,單機(jī)上跑GNB肯定會很慢,所以打算轉(zhuǎn)Spark上。然后發(fā)現(xiàn)MLlib并沒有實(shí)現(xiàn)GNB,自己動手,豐衣足食~
原理
GNB的原理是基于樸素貝葉斯,所以先交代樸素貝葉斯的原理。
樸素貝葉斯
貝葉斯公式
 = \frac{P(X \mid Y)*P(Y)}{P(X)})
利用貝葉斯公式我們就可以在已知P(X|Y)和P(Y)的情況下計(jì)算得出P(Y|X)。現(xiàn)在把Y看成類別,把X看成特征,那么利用貝葉斯公式,我們在已知“特征X出現(xiàn)的時候類別為Y的概率P(X|Y)” 和 “類別為Y的概率P(Y)”的情況下,我們就可以計(jì)算在特征X出現(xiàn)的情況下其類別為Y的概率P(Y|X)。
??上面只考慮了只有一種特征的情況,現(xiàn)在考慮模型有N種特征和C種類別的情況。在給定特征X的情況下,求類別為k的概率,公式可以表示成
 \= \frac{P(X_{1},...,X_{N} \mid Y=k)P(Y=k)}{P(X_{1},...,X_{N})} \= \frac{P(Y=k)\prod_{i}^{N}P(X_{i}\mid Y=k)}{\sum_{j}{C}P(Y=j)*\prod_{i}{N}P(X_{i}\mid Y=j)} )
根據(jù)上式,我們可以計(jì)算在特征X出現(xiàn)的情況下其類別為Y=k的概率,對于所有的k,我們?nèi)「怕首畲蟮模ㄗ畲蠛篁?yàn))作為我們的Predict,這就是樸素貝葉斯的思路。
??等等,好像有點(diǎn)問題,憑什么說
 = P(X_{1},...,P_{N}|Y=k) )
對的,這就是樸素貝葉斯Naive的地方,它基于一個很強(qiáng)的假設(shè)——所有特征的出現(xiàn)是相互獨(dú)立的,這也是樸素貝葉斯的局限性。
??在實(shí)際應(yīng)用中,還需要考慮極端情況——某個類別沒有出現(xiàn)在樣本集中 or 某個特征沒有出現(xiàn)在某類樣本集中。這個時候就需要加入平滑因子lambda去調(diào)整。
=\frac{Number\ of\ Labeled\ k\ Samples\ +\ lambda}{Number\ of\ Samples\ +\ Number\ of\ Labels\ * \ lambda} )
多項(xiàng)式模型下:
 = \frac{Count\ of\ Feature\ i\ in Labeled\ k\ Samples\ +\ lambda}{Count\ of\ All\ Features\ in\ Labeled\ k\ Samples\ +\ Count\ of Feature's kind\ * \ lambda} )
伯努力模型下:
 = \frac{Count\ of\ Feature\ i\ in Labeled\ k\ Samples\ +\ lambda}{Count\ of\ All\ Features\ in\ Labeled\ k\ Samples\ +2 \ * \ lambda} )
樸素貝葉斯有兩種常用的模型,一種叫伯努利模型,另一種叫多項(xiàng)式模型。兩者的區(qū)別就在于伯努利模型只考慮在一個樣本中,特征是否出現(xiàn)了(例如某個詞語是否出現(xiàn)了,0 or 1),而多項(xiàng)式模型則會考慮一個樣本中特征出現(xiàn)的次數(shù)(例如某個詞語出現(xiàn)的次數(shù),一個具體的數(shù)字)。兩種模型都是面向離散型的特征,如果被建模對象的特征是連續(xù)變量時,一般有兩個解決方案,一是量化連續(xù)型的特征成離散型的,另一種則使用高斯樸素貝葉斯。
高斯樸素貝葉斯
高斯模型下的樸素貝葉斯與上面介紹的兩種模型不同的地方是在計(jì)算P(X|Y)時,假設(shè)其服從高斯分布,這是對于連續(xù)型的特征有很友好的表現(xiàn)。
 \backsim N(\mu,\sigma^{2}) \P(X=a \mid Y = k)=\frac{1}{\sqrt{2\pi}\sigma}e{-\frac{(a-\mu){2}}{2\sigma^{2}}})
??對于上式的均值和方差都是可以從樣本集中統(tǒng)計(jì)得出。
??上述利用高斯分布,我們把連續(xù)變量轉(zhuǎn)變成一個概率,上一小節(jié)提到的特征是連續(xù)變量的問題解決了,其它一切照搬Naive Bayes即可。
實(shí)現(xiàn)
Talk is cheap,show me the code. 接下來講講具體實(shí)現(xiàn),由于Spark MLlib中實(shí)現(xiàn)的向量對外API甚少,所以自己動手寫了個LabeledPoint
class LabeledPoint(val label: Double, val denseVector: DenseVector[Double])
extends Serializable {
}
object LabeledPoint extends Serializable {
def apply(label: Double, denseVector: DenseVector[Double]) = {
new LabeledPoint(label, denseVector)
}
}
高斯分布函數(shù),給入均值和方差,生成分布函數(shù),使用柯里化
def distributiveFunc(mean: Double, variance: Double)(x: Double) : Double = {
if (variance == 0.0) {
if (x == mean) 1.0
else 0.0
} else {
1.0 / sqrt(2 * Pi * variance) * exp(- pow(x - mean, 2.0) / (2 * variance))
}
}
核心代碼全覽
import breeze.linalg.DenseVector
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import breeze.numerics._
import scala.math.Pi
import xyz.qspring.spark.ml.base.LabeledPoint//注意:就是上面的LabeledPoint
/**
* Created by qero on 16/8/7.
*/
class GuassianNaiveBayes private (private val input: RDD[LabeledPoint], private val lambda: Double = 1.0) extends Serializable with Logging{
def distributiveFunc(mean: Double, variance: Double)(x: Double) : Double = { //柯里化分布函數(shù)
if (variance == 0.0) {
if (x == mean) 1.0
else 0.0
} else {
1.0 / sqrt(2 * Pi * variance) * exp(- pow(x - mean, 2.0) / (2 * variance))
}
}
def run() = {
val sampleN = input.count
val grouped = input.map(point => (point.label, point.denseVector)).groupByKey().cache
val classN = grouped.count
//計(jì)算各類的出現(xiàn)概率(注意平滑因子lambda)
val pi = grouped.map{case (c, a) => {
val p = (a.toList.length * 1.0 + lambda) / (sampleN + lambda * classN)
(c, log2(p)) //取對數(shù),防止后期出現(xiàn)連乘(小數(shù)連乘容易精度丟失)
}}
//計(jì)算在各類情況下的各特征的均值和方差
val pji = grouped.mapValues(a => {
val aSum = a.reduce((v1 ,v2) => v1 + v2) //求總數(shù)
val aSampleN = a.toArray.length //求總數(shù)
val mean = aSum / (aSampleN * 1.0) //求均值
val variance = a.map(i => { //求方差(去中心化->求和->求均值)
(i - mean) :* (i - mean)
}).reduce((v1 ,v2) => v1 + v2) / (aSampleN * 1.0)
val paras = mean.toArray.zip(variance.toArray)
paras.map(p => distributiveFunc(p._1, p._2)_) //返回(類別,[特征1的分布函數(shù), ..., 特征n的分布函數(shù)])
})
new GuassianNBModel(pi.collectAsMap(), pji.collectAsMap())
}
}
class GuassianNBModel(val pi:collection.Map[Double, Double], val pji:collection.Map[Double, Array[Double => Double]]) extends Serializable {
def predict(features: DenseVector[Double]) = {
pji.map{case (label, models) => {
val score = models.zip(features.toArray).map{case (m, v) => {
log2(m(v)) //取對數(shù),防止后期出現(xiàn)連乘(小數(shù)連乘容易精度丟失)
}}.sum + pi(label)
(score, label) //返回(log(P(F1...Fn|Label)*P(Label)), Label)
}}.max //選概率最大的,其對應(yīng)的Label就是模型的預(yù)測
}
}
object GuassianNaiveBayes extends Serializable {
def fit(input: RDD[LabeledPoint]) = {
new GuassianNaiveBayes(input).run()
}
}
測試文件,訓(xùn)練集train.dat
-0.017612 14.053064 0
-1.395634 4.662541 1
-0.752157 6.538620 0
-1.322371 7.152853 0
0.423363 11.054677 0
0.406704 7.067335 1
0.667394 12.741452 0
-2.460150 6.866805 1
0.569411 9.548755 0
-0.026632 10.427743 0
0.850433 6.920334 1
1.347183 13.175500 0
1.176813 3.167020 1
-1.781871 9.097953 0
-0.566606 5.749003 1
0.931635 1.589505 1
-0.024205 6.151823 1
-0.036453 2.690988 1
-0.196949 0.444165 1
1.014459 5.754399 1
1.985298 3.230619 1
-1.693453 -0.557540 1
-0.576525 11.778922 0
-0.346811 -1.678730 1
-2.124484 2.672471 1
1.217916 9.597015 0
-0.733928 9.098687 0
-3.642001 -1.618087 1
0.315985 3.523953 1
1.416614 9.619232 0
-0.386323 3.989286 1
0.556921 8.294984 1
1.224863 11.587360 0
-1.347803 -2.406051 1
-0.445678 3.297303 1
1.042222 6.105155 1
-0.618787 10.320986 0
1.152083 0.548467 1
0.828534 2.676045 1
-1.237728 10.549033 0
-0.683565 -2.166125 1
0.229456 5.921938 1
-0.959885 11.555336 0
0.492911 10.993324 0
0.184992 8.721488 0
-0.355715 10.325976 0
-0.397822 8.058397 0
0.824839 13.730343 0
1.507278 5.027866 1
0.099671 6.835839 1
-0.344008 10.717485 0
1.785928 7.718645 1
-0.918801 11.560217 0
-0.364009 4.747300 1
-0.841722 4.119083 1
0.490426 1.960539 1
-0.007194 9.075792 0
0.356107 12.447863 0
0.342578 12.281162 0
-0.810823 -1.466018 1
2.530777 6.476801 1
1.296683 11.607559 0
0.475487 12.040035 0
-0.783277 11.009725 0
0.074798 11.023650 0
-1.337472 0.468339 1
-0.102781 13.763651 0
-0.147324 2.874846 1
0.518389 9.887035 0
1.015399 7.571882 0
-1.658086 -0.027255 1
1.319944 2.171228 1
2.056216 5.019981 1
-0.851633 4.375691 1
-1.510047 6.061992 0
-1.076637 -3.181888 1
1.821096 10.283990 0
3.010150 8.401766 1
-1.099458 1.688274 1
-0.834872 -1.733869 1
-0.846637 3.849075 1
測試文件,測試集test.dat
1.400102 12.628781 0
1.752842 5.468166 1
0.078557 0.059736 1
0.089392 -0.715300 1
1.825662 12.693808 0
0.197445 9.744638 0
0.126117 0.922311 1
-0.679797 1.220530 1
0.677983 2.556666 1
0.761349 10.693862 0
-2.168791 0.143632 1
1.388610 9.341997 0
0.275221 9.543647 0
0.470575 9.332488 0
-1.889567 9.542662 0
-1.527893 12.150579 0
-1.185247 11.309318 0
測試程序
object Main extends App {
override def main(args: Array[String]) {
val conf = new SparkConf().setAppName("naive_bayes")
val sc = new SparkContext(conf)
val data = sc.textFile("data/train.dat")
Logger.getRootLogger.setLevel(Level.WARN)
val trainData = data.map(line => {
val items = line.split("\\s+")
LabeledPoint(items(items.length-1).toDouble, DenseVector(items.slice(0, items.length-1).map(_.toDouble)))
})
val model = GuassianNaiveBayes.fit(trainData)
val testData = sc.textFile("data/test.dat").foreach(line => {
val items = line.split("\\s+")
val res = model.predict(DenseVector(items.slice(0, items.length-1).map(_.toDouble)))
println("true is " + items(items.length - 1) + ", predict is " + res._2 + ", score = " + pow(2, res._1))
})
}
}
結(jié)果
true is 0, predict is 0.0, score = 0.007287035226911837
true is 1, predict is 1.0, score = 0.006537938765007012
true is 1, predict is 1.0, score = 0.012801368971056088
true is 1, predict is 1.0, score = 0.00970655657450153
true is 0, predict is 0.0, score = 0.00305462018270487
true is 0, predict is 0.0, score = 0.03716655013066987
true is 1, predict is 1.0, score = 0.01613160178250759
true is 1, predict is 1.0, score = 0.01548224987302873
true is 1, predict is 1.0, score = 0.01784234527209572
true is 0, predict is 0.0, score = 0.029683595996118462
true is 1, predict is 1.0, score = 0.0037636068269885714
true is 0, predict is 0.0, score = 0.011051732411404247
true is 0, predict is 0.0, score = 0.034819190499309864
true is 0, predict is 0.0, score = 0.03027279470621322
true is 0, predict is 0.0, score = 0.003400879969005375
true is 0, predict is 0.0, score = 0.0060605923826227105
true is 0, predict is 0.0, score = 0.014488715477020412