GMM算发步骤:
1. 初始化参数,包括Gauss分布个数、均值、协方差;
2. 计算每个节点属于每个分布的概率;
3. 计算每个分布产生每个节点的概率;
4. 更新每个分布的权值,均值和它们的协方差。
基本参数类:
public class Parameter { private ArrayList<ArrayList<Double>> pMiu; // 均值参数k个分布的中心点,每个中心点d维 private ArrayList<Double> pPi; // k个GMM的权值 private ArrayList<ArrayList<ArrayList<Double>>> pSigma; // k类GMM的协方差矩阵,d*d*k public ArrayList<ArrayList<Double>> getpMiu() { return pMiu; } public void setpMiu(ArrayList<ArrayList<Double>> pMiu) { this.pMiu = pMiu; } public ArrayList<Double> getpPi() { return pPi; } public void setpPi(ArrayList<Double> pPi) { this.pPi = pPi; } public ArrayList<ArrayList<ArrayList<Double>>> getpSigma() { return pSigma; } public void setpSigma(ArrayList<ArrayList<ArrayList<Double>>> pSigma) { this.pSigma = pSigma; } }
?
核心代码如下:
public class GMMAlgorithm { /** * * @Title: GMMCluster * @Description: GMM聚类算法的实现类,返回每条数据的类别(0~k-1) * @return int[] * @throws */ public int[] GMMCluster(ArrayList<ArrayList<Double>>dataSet, ArrayList<ArrayList<Double>> pMiu, int dataNum, int k, int dataDimen) { Parameter parameter = iniParameters(dataSet, dataNum, k, dataDimen); double Lpre = -1000000; // 上一次聚类的误差 double threshold = 0.0001; while(true) { ArrayList<ArrayList<Double>> px = computeProbablity(dataSet, pMiu, dataNum, k, dataDimen); double[][] pGama = new double[dataNum][k]; for(int i = 0; i < dataNum; i++) { for(int j = 0; j < k; j++) { pGama[i][j] = px.get(i).get(j) * parameter.getpPi().get(j); } } double[] sumPGama = GMMUtil.matrixSum(pGama, 2); for(int i = 0; i < dataNum; i++) { for(int j = 0; j < k; j++) { pGama[i][j] = pGama[i][j] / sumPGama[i]; } } double[] NK = GMMUtil.matrixSum(pGama, 1); // 第k个高斯生成每个样本的概率的和,所有Nk的总和为N // 更新pMiu double[] NKReciprocal = new double[NK.length]; for(int i = 0; i < NK.length; i++) { NKReciprocal[i] = 1 / NK[i]; } double[][] pMiuTmp = GMMUtil.matrixMultiply(GMMUtil.matrixMultiply(GMMUtil.diag(NKReciprocal), GMMUtil.matrixReverse(pGama)), GMMUtil.toArray(dataSet)); // 更新pPie double[][] pPie = new double[k][1]; for(int i = 0; i < NK.length; i++) { pPie[i][1] = NK[i] / dataNum; } // 更新k个pSigma double[][][] pSigmaTmp = new double[dataDimen][dataDimen][k]; for(int i = 0; i < k; i++) { double[][] xShift = new double[dataNum][dataDimen]; for(int j = 0; j < dataNum; j++) { for(int l = 0; l < dataDimen; l++) { xShift[j][l] = pMiuTmp[i][l]; } } double[] pGamaK = new double[dataNum]; // 第k条pGama值 for(int j = 0; j < dataNum; j++) { pGamaK[j] = pGama[j][i]; } double[][] diagPGamaK = GMMUtil.diag(pGamaK); double[][] pSigmaK = GMMUtil.matrixMultiply(GMMUtil.matrixReverse(xShift), (GMMUtil.matrixMultiply(diagPGamaK, xShift))); for(int j = 0; j < dataDimen; j++) { for(int l = 0; l < dataDimen; l++) { pSigmaTmp[j][l][k] = pSigmaK[j][l] / NK[i]; } } } // 判断是否迭代结束 double[][] a = GMMUtil.matrixMultiply(GMMUtil.toArray(px), pPie); for(int i = 0; i < dataNum; i++) { a[i][0] = Math.log(a[i][0]); } double L = GMMUtil.matrixSum(a, 1)[0]; if(L - Lpre < threshold) { break; } Lpre = L; } return null; } /** * * @Title: computeProbablity * @Description: 计算每个节点(共n个)属于每个分布(k个)的概率 * @return ArrayList<ArrayList<Double>> * @throws */ public ArrayList<ArrayList<Double>> computeProbablity(ArrayList<ArrayList<Double>>dataSet, ArrayList<ArrayList<Double>> pMiu, int dataNum, int k, in