1. 背景
1.1 問題概述
有10萬+條短文本,均是用戶反饋的問題(每條文本長度大概在200字左右),需要對這些文本進行主題聚類,看下用戶反饋的問題都集中在哪些方面。
1.2 工作
先采用Spark MLlib自帶的Kmeans聚類算法對文本進行聚類,因為其訓練速度很快。這里我采用TF-IDF作為特征提取方法,Spark ML Kmeans中的距離計算方法為歐式距離。
運行之后發現效果不太好:某個類的數據量達到了70%,也就是大量的文章都被劃分到了同一組,其他類的數量都較少。
如下圖所示,對歐式距離的Kmeans和其他距離方法進行了對比,實驗中表明歐式距離的結果中,本應屬于不同組的大量文章被劃分成了一組,與我這邊的效果一致,因此接下來可以試著采用余弦相似度作為距離算法。
2. 第一步優化:歐式距離改為余弦相似度
2.1 歐式距離和余弦相似度對比
定義兩個n維向量:X(x1,x2,...,xn)和Y(y1,y2,...,yn)
歐式距離計算公式:
余弦相似度計算公式:
歐式距離主要是衡量空間中兩個點的絕對距離,而余弦相似度注重兩個樣本之間在方向上的差異而非距離上的差異,主要是衡量兩個個體之間的相似性,值越大,說明差異越少,與歐式距離相反(距離越小,差異越小)。
從上圖可以看出,歐氏距離衡量的是空間各點的絕對距離;而余弦距離衡量的是空間向量的夾角,更加體現在方向上的差異,而不是位置。如果保持 A 點位置不變,B 點朝原方向遠離坐標軸原點,那么這個時候余弦距離是保持不變的(因為夾角沒有發生變化),而 A 與 B 兩點的距離顯然在發生變化,這就是歐式距離與余弦相似度的不同之處。
2.2 開發基于Spark和余弦相似度的Kmeans聚類
由于Spark ML中的Kmeans不提供對距離函數進行更新和選擇的接口,因此只能按照Kmeans的原理開發,和在GitHub上借鑒已有的代碼修改開發。
這里有個小技巧,由于余弦相似度越大,兩個體之間的差異越少,所以為了保證在計算每個樣本所屬的最近的中心點的時候與歐式距離一致,這里在計算兩個體之間距離的時候采用如下方法計算:
(代碼框架參考基于歐式距離的Scala實現的Kmeans,基于該代碼修改為Spark和余弦相似度距離。https://blog.csdn.net/u014135021/article/details/53668634)
/**求兩個向量的余弦,1-相似度,結果越大 差異越大,越小差異越小 */
def cos_distance(that: Point) = {
val cos = 1- innerProduct(this.px, that.px) / (module(this.px) * module(that.px))
cos
}
/** 求兩個向量的內積*/
def innerProduct(v1: Vector[Double], v2: Vector[Double]) = {
val listBuffer = ListBuffer[Double]()
for (i <- 0 until v1.length; j <- 0 to v2.length; if i == j) {
if (i == j) listBuffer.append(v1(i) * v2(j))
}
listBuffer.sum
}
3. 第二步優化:Scala代碼改成Spark
如下為Spark的Kmeans主題的聚類迭代部分
其中初始化隨機中心點的方法為:
takeSample(withReplacement: Boolean,num: Int,seed: Long = Utils.random.nextLong): Array[T]
其中參數:
- withReplacement:是否是有放回的抽樣
- num:返回的樣本的大小
- seed:隨機數生成器的種子
//kmeans函數運行主體
def run(sc:SparkContext)
{
var k=0 //當前迭代次數
var f=true //是否還需要接著迭代
val st=System.nanoTime()
//設置隨機種子
val seed = 10000l
val random = new java.util.Random()
random.setSeed(seed)
InitCenterRandom(random)//隨機初始化中心點
while(k<MaxIterations && f)
{
val st1 = System.currentTimeMillis()
k+=1
//計算每個點屬于哪個中心點所在的類,并且記錄每個類中點的數量,與該類中所有向量的和
val data_with_center = data.map(x => {
var cid = FastSearch2(x._2).center_id
(x._1,x._2,cid)
})
data_with_center.cache()
//按照類別ID分組
val result_groupby:RDD[(Int, Iterable[(X,Point,Int)])] = data_with_center.groupBy(_._3)
result_groupby.cache()
result = result_groupby.map(x => {
val center_datas = x._2.map(_._1).toList
(x._1,center_datas)
})
//按照中心點相同groupby
val newPoints = result_groupby.map(x => {
val cid = x._1
val center_datas = x._2
val center_data_size:Int = center_datas.seq.size.toInt
//計算該中心點下所有數據的Point向量和
val totalPoint:Point = center_datas.map(_._2).reduce((x,y) => (x+y))
//新的中心點為該類別下樣本向量和的平均值
val newPoint:Point = totalPoint./(center_data_size)
(cid, newPoint)
})
newPoints.cache()
val newPoints2 = newPoints.collect()
result_groupby.unpersist()
data_with_center.unpersist()
var i = -1
//如果當前中心點中,存在比上一次迭代的中心點的距離大于閾值的情況,還需要接著迭代。
f = CenterPoint.map(x => {
i = i+1
(i, x)
}).zip(newPoints2).map(f=>f._1._2.cos_distance(f._2._2)).exists {_>threshold}
//更新中心點
if(f)
{
newPoints2.map(x=> {
if(x._2 != null) {
CenterPoint(x._1) = x._2
}
})
}
val et1 = System.currentTimeMillis()
println("第"+k+"次聚類,sse=" + getSSE(sc,data) + ",time=" + (et1-st1)/1000+"s")
newPoints.unpersist()
System.gc()
}
val ed=System.nanoTime()
//data.unpersist()
println("Kmeans聚類時間為:"+(ed-st))
}
/*根據隨機種子對象,初始化中心*/
def InitCenterRandom(random:java.util.Random) {
val st=System.nanoTime()
val random_seed = random.nextLong()
CenterPoint = data.takeSample(false, numClusters, random_seed).map(_._2)
val ed=System.nanoTime()
println("隨機中心點生成時間為:"+(ed-st))
}
4. 第三步優化:大數據量的Mini Batch Kmeans
上述的Kmeans算法在大數據量的情況下,運算依然很慢,因此采用KMeans的變種:Mini Batch Kmeans算法,當數據量超過1萬的時候就可以使用該方法。該方法不僅處理速度快,準確度也很高。其實現原理是每次迭代的時候,選取部分樣本來更新當前迭代的中心點。這種分批處理的思路同樣也被應用在了梯度下降等算法中。
如下圖為摘自鏈接:https://blog.csdn.net/cht5600/article/details/76014573
將Kmeans算法與Mini Batch Kmeans算法的聚類結果對比,第三幅圖代表兩種方式分類差異的樣本:
從圖中可以知道,針對同樣數量的文本分別采用Kmeans和Mini Batch Kmeans訓練,其時間差別較大,且inertia相差較少。
inertia:樣本離最近聚類中心的總和,其是K均值模型對象的屬性,表示樣本距離最近的聚類中心的總和,它是作為在沒有真實分類標簽下的非監督式評估指標,該值越小越好,值越小證明樣本在類間的分布越集中,即類內的距離越小。
因此,該方法在盡量保持準確度的情況下,大大減少了聚類時間。在本次實驗中采用10萬樣本,選取1000維特征,每次選取1000個樣本迭代,每次迭代僅需要1分鐘。
如下為Mini Batch Kmeans的迭代部分,大部分邏輯與上面的Kmeans一致,只有在每次迭代的選取的樣本不同:
//kmeans函數運行主體
def runBatch(sc:SparkContext)
{
var k=0
var f=true
val st=System.nanoTime()
//設置隨機種子
val seed = 10000l
val random = new java.util.Random()
random.setSeed(seed)
InitCenterRandom(random)//隨機初始化中心點
while(k<MaxIterations && f)
{
val st1 = System.currentTimeMillis()
k+=1
//堆積選取MiniBatchSize個樣本 轉成RDD
val data_batch = sc.parallelize(data.takeSample(false,ConfigUtil.MiniBatchSize,random.nextLong()))
data_batch.cache()
val data_with_center = data_batch.map(x => {
var cid = FastSearch2(x._2).center_id
(x._1,x._2,cid)
})
data_with_center.cache()
val result_groupby:RDD[(Int, Iterable[(GovComment,Point,Int)])] = data_with_center.groupBy(_._3)
result_groupby.cache()
result = result_groupby.map(x => {
val center_datas = x._2.map(_._1).toList
(x._1,center_datas)
})
//按照中心點相同groupby
val newPoints = result_groupby.map(x => {
val cid = x._1
val center_datas = x._2
val center_data_size:Int = center_datas.seq.size.toInt
//計算該中心點下所有數據的Point向量和
val totalPoint:Point = center_datas.map(_._2).reduce((x,y) => (x+y))
val newPoint:Point = totalPoint./(center_data_size)
(cid, newPoint)
})
newPoints.cache()
val newPoints2 = newPoints.collect()
result_groupby.unpersist()
data_with_center.unpersist()
var i = -1
f = CenterPoint.map(x => {
i = i+1
(i, x)
}).zip(newPoints2).map(f=>f._1._2.cos_distance(f._2._2)).exists {_>threshold}
if(f)
{
newPoints2.map(x=> {
if(x._2 != null) {
CenterPoint(x._1) = x._2
}
})
}
println("第"+k+"次聚類,sse=" + getSSE(sc,data_batch) + ",time=" +(System.currentTimeMillis()-st1)/1000+"s")
newPoints.unpersist()
System.gc()
data_batch.unpersist()
}
val ed=System.nanoTime()
//data.unpersist()
println("Kmeans聚類時間為:"+(ed-st))
}
5. 總結
- 針對短文本聚類,可以在條件允許的情況下提高特征維度;
- Spark Accumulator累加器的使用注意;
- 采用Mini Batch Kmeans可以盡量維持聚類準確度;
- 文本上面相似,余弦相似度效果相對歐式距離好些。
6. 參考
- https://blog.csdn.net/huangfei711/article/details/78469614
- http://xueshu.baidu.com/usercenter/paper/show?paperid=a3195f1409270d32f304145ce00e967e&site=xueshu_se
- https://blog.csdn.net/linvo/article/details/9333019
- https://blog.csdn.net/u014135021/article/details/53668634
- https://blog.csdn.net/cht5600/article/details/76014573
- https://blog.csdn.net/weixin_37536446/article/details/81326932