日期:2014-05-16  浏览次数:20495 次

java实现GMM算法

GMM算发步骤:

1. 初始化参数,包括Gauss分布个数、均值、协方差;

2. 计算每个节点属于每个分布的概率;

3. 计算每个分布产生每个节点的概率;

4. 更新每个分布的权值,均值和它们的协方差。

基本参数类:

public class Parameter {
	private ArrayList<ArrayList<Double>> pMiu; // 均值参数k个分布的中心点,每个中心点d维
	private ArrayList<Double> pPi; // k个GMM的权值
	private ArrayList<ArrayList<ArrayList<Double>>> pSigma; // k类GMM的协方差矩阵,d*d*k
	
	public ArrayList<ArrayList<Double>> getpMiu() {
		return pMiu;
	}
	public void setpMiu(ArrayList<ArrayList<Double>> pMiu) {
		this.pMiu = pMiu;
	}
	public ArrayList<Double> getpPi() {
		return pPi;
	}
	public void setpPi(ArrayList<Double> pPi) {
		this.pPi = pPi;
	}
	public ArrayList<ArrayList<ArrayList<Double>>> getpSigma() {
		return pSigma;
	}
	public void setpSigma(ArrayList<ArrayList<ArrayList<Double>>> pSigma) {
		this.pSigma = pSigma;
	}
}

?

核心代码如下:

public class GMMAlgorithm {
	
	/**
	 * 
	* @Title: GMMCluster 
	* @Description: GMM聚类算法的实现类,返回每条数据的类别(0~k-1)
	* @return int[]
	* @throws
	 */
	public int[] GMMCluster(ArrayList<ArrayList<Double>>dataSet, ArrayList<ArrayList<Double>> pMiu, int dataNum, int k, int dataDimen) {
		Parameter parameter = iniParameters(dataSet, dataNum, k, dataDimen);
		double Lpre = -1000000; // 上一次聚类的误差
		double threshold = 0.0001;
		while(true) {
			ArrayList<ArrayList<Double>> px = computeProbablity(dataSet, pMiu, dataNum, k, dataDimen);
			double[][] pGama = new double[dataNum][k];
			for(int i = 0; i < dataNum; i++) {
				for(int j = 0; j < k; j++) {
					pGama[i][j] = px.get(i).get(j) * parameter.getpPi().get(j);
				}
			}
			
			double[] sumPGama = GMMUtil.matrixSum(pGama, 2);
			for(int i = 0; i < dataNum; i++) {
				for(int j = 0; j < k; j++) {
					pGama[i][j] = pGama[i][j] / sumPGama[i];
				}
			}
			
			double[] NK = GMMUtil.matrixSum(pGama, 1); // 第k个高斯生成每个样本的概率的和,所有Nk的总和为N
			
			// 更新pMiu
			double[] NKReciprocal = new double[NK.length];
			for(int i = 0; i < NK.length; i++) {
				NKReciprocal[i] = 1 / NK[i];
			}
			double[][] pMiuTmp = GMMUtil.matrixMultiply(GMMUtil.matrixMultiply(GMMUtil.diag(NKReciprocal), GMMUtil.matrixReverse(pGama)), GMMUtil.toArray(dataSet));
			
			// 更新pPie
			double[][] pPie = new double[k][1];
			for(int i = 0; i < NK.length; i++) {
				pPie[i][1] = NK[i] / dataNum;
			}
			
			// 更新k个pSigma
			double[][][] pSigmaTmp = new double[dataDimen][dataDimen][k];
			for(int i = 0; i < k; i++) {
				double[][] xShift = new double[dataNum][dataDimen];
				for(int j = 0; j < dataNum; j++) {
					for(int l = 0; l < dataDimen; l++) {
						xShift[j][l] = pMiuTmp[i][l];
					}
				}
				
				double[] pGamaK = new double[dataNum]; // 第k条pGama值
				for(int j = 0; j < dataNum; j++) {
					pGamaK[j] = pGama[j][i];
				}
				double[][] diagPGamaK = GMMUtil.diag(pGamaK);
				
				double[][] pSigmaK = GMMUtil.matrixMultiply(GMMUtil.matrixReverse(xShift), (GMMUtil.matrixMultiply(diagPGamaK, xShift)));
				for(int j = 0; j < dataDimen; j++) {
					for(int l = 0; l < dataDimen; l++) {
						pSigmaTmp[j][l][k] = pSigmaK[j][l] / NK[i];
					}
				}
			}
			
			// 判断是否迭代结束
			double[][] a = GMMUtil.matrixMultiply(GMMUtil.toArray(px), pPie);
			for(int i = 0; i < dataNum; i++) {
				a[i][0] = Math.log(a[i][0]);
			}
			double L = GMMUtil.matrixSum(a, 1)[0];
			
			if(L - Lpre < threshold) {
				break;
			}
			Lpre = L;
		}
		return null;
	}
	
	/**
	 * 
	* @Title: computeProbablity 
	* @Description: 计算每个节点(共n个)属于每个分布(k个)的概率
	* @return ArrayList<ArrayList<Double>>
	* @throws
	 */
	public ArrayList<ArrayList<Double>> computeProbablity(ArrayList<ArrayList<Double>>dataSet, ArrayList<ArrayList<Double>> pMiu, int dataNum, int k, in