1.什么是Apriori算法?

Apriori算法是一种发现频繁项集的基本算法,通过Apriori算法得出频繁项集,以此来产生强关联规则。

2.Apriori的具体实现

通过扫描数据库,累计每个项的计数,并收集满足最小支持度计数的项,找出频繁1项集的集合。该集合记为L₁,然后通过L₁

找出频繁2项集的集合L₂,使用L₂找出L₃知道不能找到频繁K项集。因为找到每个频繁K项集都需要扫面一遍数据库,为了提高效率,我们使用一种先验性质(Apriori property):频繁K项集的所有非空子集一定是频繁的。

具体如何通过LK-1找到LK步骤:

测试类:

package edu.bjut.jzl.apriori;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;

public class Test {

	public static void main(String[] args) throws IOException {
		String filePath ="C:\\Users\\ji\\Desktop\\data.txt";
		ArrayList<ArrayList<String>> fcs = new ArrayList<>();
		ArrayList<ArrayList<String>> ks = new ArrayList<>();
		ArrayList<ArrayList<String>> conns = new ArrayList<>();
		ArrayList<ArrayList<String>> cs = new ArrayList<>();
		ReadData rs =new ReadData();
		Apriori apri = new Apriori();
		//得到整个数据集 格式为[[T100, I1, I2, I5], .....]
		ArrayList<ArrayList<String>> data = rs.DataAll(filePath);
		//得到初始C候选
		cs = rs.CandidateSets(data);
		//候选1项集扫描计数
		ks = apri.scanData(cs, data,fcs,0);
		//产生频繁i+1项集
		for(int i = 1;i<3;i++){
			
			//连接步
			conns = apri.connection(ks,i);
			//剪枝步
			cs = apri.pruning(conns, fcs);
			//候选i+1项集扫描计数
			ks = apri.scanData(cs, data,fcs,i);
		}
		//产生关联规则
		apri.AssociationRules(ks,data);
		
	}

}

数据读取类:

package edu.bjut.jzl.apriori;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map.Entry;

import org.omg.CORBA.INTERNAL;
@SuppressWarnings("resource")
public class ReadData {
	
	public ReadData() {
		super();
		
	}
	//全部事物数据
	public ArrayList<ArrayList<String>> DataAll(String filePath)throws IOException{
		BufferedReader br = new BufferedReader(new FileReader(filePath));
		ArrayList<ArrayList<String>> list = new ArrayList<>();
		ArrayList<String> line = new ArrayList<String>();
		String str = null;
		String s1 =null;
		String s =null;
		int length;
		String[] c;
		while(!(str=br.readLine()).equals("")){
			s =str.split(" ")[0];
			line.add(s);
			list.add(line);
			s1 = str.split(" ")[1];
			c= s1.split(",");
			length = s1.split(",").length;
			for(int i = 0 ; i<length; i++){
				line.add(c[i]);
			}
			//System.out.println(list);
			line = new ArrayList<>();
		}
		br.close();
		System.out.println("全部事务数据:"+list);
		return list;
	}
	//产生候选C
	public  ArrayList<ArrayList<String>> CandidateSets(ArrayList<ArrayList<String>> list){
		//产生候选项
		ArrayList<String> line = new ArrayList<>();
		ArrayList<String> tLine = new ArrayList<>();
		ArrayList<ArrayList<String>> cs = new ArrayList<>();
		Iterator<ArrayList<String>> it = list.iterator();
		ArrayList<ArrayList<String>> css = new ArrayList<>();
		String temp = null;
		int sign =0;//sign==0表示读入一条事务的ID
		while(it.hasNext()){
			line = it.next();
			Iterator<String> it1 = line.iterator();
			while(it1.hasNext()){
				if(sign==0){
					temp = it1.next();
					sign=1;
					continue;
				}
				temp = it1.next();
				tLine.add(temp);
				if(!cs.contains(tLine)){
					cs.add(tLine);
				}
				tLine = new ArrayList<>();
			}
			sign =0;
			line = new ArrayList<>();
		}
		css = this.SortCandidateSetsIn(cs);
		cs = new ArrayList<>();
		cs = this.SortCandidateSetsOut(css);
		System.out.println("候选1项集:"+cs);
		return cs;
	}		
	//对候选项进行排序
	//1.候选项内排序
	public ArrayList<ArrayList<String>> SortCandidateSetsIn(ArrayList<ArrayList<String>> list){
		ArrayList<String> line = new ArrayList<>();
		ArrayList<String> tempLine = new ArrayList<>();
		Iterator<ArrayList<String>> it = list.iterator();
		ArrayList<ArrayList<String>> css = new ArrayList<>();
		String s1 = null;
		String s2 = null;
		while(it.hasNext()){
			line = it.next();
			//伪冒泡排序
			int length = line.size();
			//length!=0表示并没有全部遍历完
			while(length!=0){
				Iterator<String> iti = line.iterator();
				while(iti.hasNext()){
					s1 = iti.next();
					Iterator<String> itj = line.iterator();
					while(itj.hasNext()){
						s2 =itj.next();
						if(s1.compareTo(s2)>=0){s1 = s2;}
					}
					tempLine.add(s1);
					break;
				}
				Iterator<String> itf = line.iterator();
				while(itf.hasNext()){
					if(itf.next().equals(s1)){
						itf.remove();
						length--;
						break;
					}
				}	
			}
			css.add(tempLine);
			tempLine = new ArrayList<>();
		}
		
		return css;
	}
	//对候选项进行排序
	//2.候选项外排序
	public ArrayList<ArrayList<String>> SortCandidateSetsOut(ArrayList<ArrayList<String>> list){
		ArrayList<String> line = new ArrayList<>();
		ArrayList<String> tempLine = new ArrayList<>();
		ArrayList<ArrayList<String>> css = new ArrayList<>();
		int length = list.size();
		//length!=0表示并没有全部遍历完
		while(length!=0){
			Iterator<ArrayList<String>> iti = list.iterator();
			while(iti.hasNext()){
				line = iti.next();
				Iterator<ArrayList<String>> itj = list.iterator();
				while(itj.hasNext()){
					tempLine = itj.next();
					Iterator<String> it1 = line.iterator();
					Iterator<String> it2 = tempLine.iterator();
					while(it1.hasNext()){
						if(it1.next().compareTo(it2.next())>0){	
							line = tempLine;
							break;
						}
					}
				}
			}
			Iterator<ArrayList<String>> itj = list.iterator();
			while(itj.hasNext()){
				if(itj.next().equals(line)){
					css.add(line);
					itj.remove();
					break;
				}		
			}
			line = new ArrayList<>();
			tempLine = new ArrayList<>();
			length = length-1;
		}
		return css;
	}	
}

先验类:

package edu.bjut.jzl.apriori;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

public class Apriori {
	//最小支持度计数
	final int minCount = 2;
	//最小置信度阈值
	final double minConfidence =0.7; 
	
	public Apriori() {
		super();
	}
	//扫描数据集,得出支持度以及项集
	public ArrayList<ArrayList<String>> scanData(ArrayList<ArrayList<String>> list,ArrayList<ArrayList<String>> data,ArrayList<ArrayList<String>> fcs,int n){
		ArrayList<String> line = new ArrayList<>();
		ArrayList<String> tempLine = new ArrayList<>();
		Iterator<ArrayList<String>> it = list.iterator();
		ArrayList<ArrayList<String>> css = new ArrayList<>();
		String str = null;
		Integer count = 0;
		boolean sign = false;
		while(it.hasNext()){
			line = it.next();
			Iterator<ArrayList<String>> it1 = data.iterator();
			while(it1.hasNext()){
				tempLine = it1.next();
				Iterator<String> iti =line.iterator();
				while(iti.hasNext()){
					str = iti.next();
					if(tempLine.contains(str)){
						sign = true;	
					}else{
						sign = false;
						break;
					}	
				}
				if(sign){
						count++;
				}
				tempLine = new ArrayList<>();
			}
			if(count>=this.minCount){
				
				line.add(count.toString());
				css.add(line);
				count = 0;
				}else{
				line.add(count.toString());
				fcs.add(line);
				count = 0;
			}
			line = new ArrayList<>();
			tempLine = new ArrayList<>();
		}
		System.out.println("频繁"+(n+1)+"项集:"+css);
		System.out.println("非频繁"+(n+1)+"项集:"+fcs);
 		return css;
	}
	//连接步
	public ArrayList<ArrayList<String>> connection(ArrayList<ArrayList<String>> list,int n){
		Iterator<ArrayList<String>> it = list.iterator();
		ArrayList<ArrayList<String>> css = new ArrayList<>();
		ArrayList<String> line1 = new ArrayList<>();
		ArrayList<String> line2 = new ArrayList<>();
		ArrayList<String> tempLine = new ArrayList<>();
		
		String str1 = null;
		String str2 = null;
		int i= 0;
		int length = list.size();
		boolean sign = true;
		while(it.hasNext()){				
			line1 = it.next();
			int j = i+1;
			i++;
			//求频繁2项集时
			if(n==1){
				while(j<length){
					line2 = list.get(j);
					str1 = line1.get(0);
					str2 = line2.get(0);
					tempLine.add(str1);
					tempLine.add(str2);
					css.add(tempLine);
					j++;
					tempLine = new ArrayList<>();
					line2 = new ArrayList<>();
				}
				line1 = new ArrayList<>();
			}else{
				while(j<length){
					line2 = list.get(j);
					//比较前n-1项是否相等
					for(int m = 0 ;m<n-1;m++){
						if(!line1.get(m).equals(line2.get(m))){
							sign = false;
							break;
						}
					}
					if(sign){
						for(int m = 0 ;m<line1.size()-1;m++){
							str1 = line1.get(m);
							tempLine.add(str1);
						}
						str2 = line2.get(n-1);
						tempLine.add(str2);
						css.add(tempLine);	
					}
					j++;
					sign = true;
					tempLine = new ArrayList<>();
					line2 = new ArrayList<>();
				}
				line1 = new ArrayList<>();
			}
		}
		System.out.println("候选"+(n+1)+"项集"+css);
		return css;
	}
	//剪枝步
	public ArrayList<ArrayList<String>> pruning(ArrayList<ArrayList<String>> list,ArrayList<ArrayList<String>> fcs){
		ArrayList<ArrayList<String>> css = new ArrayList<>();
		ArrayList<String> line = new ArrayList<>();
		ArrayList<String> line1 = new ArrayList<>();
		Iterator<ArrayList<String>> itf = fcs.iterator();
		String str = null;
		int length;
		int count = 0;
		while(itf.hasNext()){
			//取出一项非频繁项
			line = itf.next();
			length = line.size();
			
			Iterator<ArrayList<String>> itl = list.iterator();
			//使用取出的非频繁项比较所有候选项是否真包含该非频繁项
			while(itl.hasNext()){
				line1 = itl.next();
				//比较每一项
				Iterator<String> it = line.iterator();
				while(it.hasNext()){
					str = it.next();
					if(line1.contains(str)){
						count++;	
					}
					if(count == length)
						break;
				}
				if(count==length){
					itl.remove();
					count = 0;		
				}
				line1 =new ArrayList<>();
			}
			line = new ArrayList<>();
			count = 0;
		}
		css = list;
		fcs.clear();
		System.out.println("剪枝后:"+css);
		return css;
	}
	//产生关联规则
	public void AssociationRules(ArrayList<ArrayList<String>> list,ArrayList<ArrayList<String>> data){
		double confidence;
		Iterator<ArrayList<String>> it = list.iterator();
		ArrayList<ArrayList<String>> css = new ArrayList<>();
		ArrayList<String> line = new ArrayList<>();
		ArrayList<String> line1 = new ArrayList<>();
		ArrayList<String> line2 = new ArrayList<>();
		ArrayList<String> tempLine = new ArrayList<>();
		double supportCountAnd = 0;
		double supportCountAlone = 0;
		String str =null;
		String str1 = null;
		String str2 = null;
		int length = 0;
		int lengthLine = 0;
		int count;
		int value = 0;
		Integer temp ;
		boolean sign = true;
		//求集合的所有子集(不包括空集,全集)
		while(it.hasNext()){
			line = it.next();
			line1 = line;
			Iterator<String> itl = line.iterator();
			count = 0;
			lengthLine = line.size();
			while(itl.hasNext()){
				if((count+1)==lengthLine){
					str1 = itl.next();
					temp =new Integer(str1);
					supportCountAnd = temp.doubleValue();
					itl.remove();
				}else{
					itl.next();
					count++;
				}
			}
			Set<Set<String>> result = new HashSet<Set<String>>();	//用来存放子集的集合	
			length = line.size() ;
			int num = length==0 ? 0 : 1<<(length);	//2的n次方,若集合set为空,num为0;若集合set有4个元素,那么num为16.
			//从0到2^n-1([00...00]到[11...11])
			for(int i = 0; i < num; i++){		
				Set<String> subSet = new HashSet<String>();
				int index = i;
				for(int j = 0; j < length; j++){
					if((index & 1) == 1){		//每次判断index最低位是否为1,为1则把集合set的第j个元素放到子集中
						subSet.add(line.get(j));
					}
					index >>= 1;		//右移一位
				}
				if(subSet.size()!=0&&subSet.size()!=length)
				result.add(subSet);		//把子集存储起来
			}
			line= new ArrayList<>();
			Iterator<Set<String>> its = result.iterator();
			
			while(its.hasNext()){
				Set<String> tempSet = its.next();
				Iterator<String> itss = tempSet.iterator();
				while(itss.hasNext()){
					str = itss.next();
					line.add(str);
				}
				css.add(line);
				line = new ArrayList<>();
			}
			//求得子集为css,通过子集进行操作
			//求置信度
			//求子集的计数
			Iterator<ArrayList<String>> csit = css.iterator();
			while(csit.hasNext()){
				line = csit.next();
				Iterator<ArrayList<String>> dit= data.iterator(); 
				while(dit.hasNext()){
					tempLine = dit.next();
					Iterator<String> cslineIt = line.iterator();
					while(cslineIt.hasNext()){
						str = cslineIt.next();
						if(!tempLine.contains(str)){
							sign = false;
							break;
						}
					}
					if(sign){value++;}						
					sign = true;
				}
				temp =new Integer(value);
				supportCountAlone = temp.doubleValue();
				if((supportCountAnd/supportCountAlone)>=this.minConfidence){
					
					Iterator<String> itste = line1.iterator();
					while(itste.hasNext()){
						str2 = itste.next(); 
						if(!line.contains(str2))
							line2.add(str2);
					}
					str2 = null;
					System.out.print(line);
					line = new ArrayList<>();
					System.out.print("->");
					System.out.print(line2);
					line2 = new ArrayList<>();
					System.out.print("  Confidence:");
					System.out.println(supportCountAnd/supportCountAlone);
				}
				line.add(temp.toString());
				value = 0;		
			}
			css = new ArrayList<>();
		}
	}
}



运行结果:


Arima java算法库 java apriori算法_System

数据集:

T100 I1,I2,I5
T200 I2,I4
T300 I2,I3
T400 I1,I2,I4
T500 I1,I3
T600 I2,I3
T700 I1,I3
T800 I1,I2,I3,I5
T900 I1,I2,I3