GMM算发步骤:
1. 初始化参数,包括Gauss分布个数、均值、协方差;
2. 计算每个节点属于每个分布的概率;
3. 计算每个分布产生每个节点的概率;
4. 更新每个分布的权值,均值和它们的协方差。
基本参数类:
public class Parameter {
private ArrayList<ArrayList<Double>> pMiu; // 均值参数k个分布的中心点,每个中心点d维
private ArrayList<Double> pPi; // k个GMM的权值
private ArrayList<ArrayList<ArrayList<Double>>> pSigma; // k类GMM的协方差矩阵,d*d*k
public ArrayList<ArrayList<Double>> getpMiu() {
return pMiu;
}
public void setpMiu(ArrayList<ArrayList<Double>> pMiu) {
this.pMiu = pMiu;
}
public ArrayList<Double> getpPi() {
return pPi;
}
public void setpPi(ArrayList<Double> pPi) {
this.pPi = pPi;
}
public ArrayList<ArrayList<ArrayList<Double>>> getpSigma() {
return pSigma;
}
public void setpSigma(ArrayList<ArrayList<ArrayList<Double>>> pSigma) {
this.pSigma = pSigma;
}
}
?
核心代码如下:
public class GMMAlgorithm {
/**
*
* @Title: GMMCluster
* @Description: GMM聚类算法的实现类,返回每条数据的类别(0~k-1)
* @return int[]
* @throws
*/
public int[] GMMCluster(ArrayList<ArrayList<Double>>dataSet, ArrayList<ArrayList<Double>> pMiu, int dataNum, int k, int dataDimen) {
Parameter parameter = iniParameters(dataSet, dataNum, k, dataDimen);
double Lpre = -1000000; // 上一次聚类的误差
double threshold = 0.0001;
while(true) {
ArrayList<ArrayList<Double>> px = computeProbablity(dataSet, pMiu, dataNum, k, dataDimen);
double[][] pGama = new double[dataNum][k];
for(int i = 0; i < dataNum; i++) {
for(int j = 0; j < k; j++) {
pGama[i][j] = px.get(i).get(j) * parameter.getpPi().get(j);
}
}
double[] sumPGama = GMMUtil.matrixSum(pGama, 2);
for(int i = 0; i < dataNum; i++) {
for(int j = 0; j < k; j++) {
pGama[i][j] = pGama[i][j] / sumPGama[i];
}
}
double[] NK = GMMUtil.matrixSum(pGama, 1); // 第k个高斯生成每个样本的概率的和,所有Nk的总和为N
// 更新pMiu
double[] NKReciprocal = new double[NK.length];
for(int i = 0; i < NK.length; i++) {
NKReciprocal[i] = 1 / NK[i];
}
double[][] pMiuTmp = GMMUtil.matrixMultiply(GMMUtil.matrixMultiply(GMMUtil.diag(NKReciprocal), GMMUtil.matrixReverse(pGama)), GMMUtil.toArray(dataSet));
// 更新pPie
double[][] pPie = new double[k][1];
for(int i = 0; i < NK.length; i++) {
pPie[i][1] = NK[i] / dataNum;
}
// 更新k个pSigma
double[][][] pSigmaTmp = new double[dataDimen][dataDimen][k];
for(int i = 0; i < k; i++) {
double[][] xShift = new double[dataNum][dataDimen];
for(int j = 0; j < dataNum; j++) {
for(int l = 0; l < dataDimen; l++) {
xShift[j][l] = pMiuTmp[i][l];
}
}
double[] pGamaK = new double[dataNum]; // 第k条pGama值
for(int j = 0; j < dataNum; j++) {
pGamaK[j] = pGama[j][i];
}
double[][] diagPGamaK = GMMUtil.diag(pGamaK);
double[][] pSigmaK = GMMUtil.matrixMultiply(GMMUtil.matrixReverse(xShift), (GMMUtil.matrixMultiply(diagPGamaK, xShift)));
for(int j = 0; j < dataDimen; j++) {
for(int l = 0; l < dataDimen; l++) {
pSigmaTmp[j][l][k] = pSigmaK[j][l] / NK[i];
}
}
}
// 判断是否迭代结束
double[][] a = GMMUtil.matrixMultiply(GMMUtil.toArray(px), pPie);
for(int i = 0; i < dataNum; i++) {
a[i][0] = Math.log(a[i][0]);
}
double L = GMMUtil.matrixSum(a, 1)[0];
if(L - Lpre < threshold) {
break;
}
Lpre = L;
}
return null;
}
/**
*
* @Title: computeProbablity
* @Description: 计算每个节点(共n个)属于每个分布(k个)的概率
* @return ArrayList<ArrayList<Double>>
* @throws
*/
public ArrayList<ArrayList<Double>> computeProbablity(ArrayList<ArrayList<Double>>dataSet, ArrayList<ArrayList<Double>> pMiu, int dataNum, int k, in