1.什么是Apriori算法?
Apriori算法是一种发现频繁项集的基本算法,通过Apriori算法得出频繁项集,以此来产生强关联规则。
2.Apriori的具体实现
通过扫描数据库,累计每个项的计数,并收集满足最小支持度计数的项,找出频繁1项集的集合。该集合记为L₁,然后通过L₁
找出频繁2项集的集合L₂,使用L₂找出L₃知道不能找到频繁K项集。因为找到每个频繁K项集都需要扫面一遍数据库,为了提高效率,我们使用一种先验性质(Apriori property):频繁K项集的所有非空子集一定是频繁的。
具体如何通过LK-1找到LK步骤:
测试类:
package edu.bjut.jzl.apriori;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
public class Test {
public static void main(String[] args) throws IOException {
String filePath ="C:\\Users\\ji\\Desktop\\data.txt";
ArrayList<ArrayList<String>> fcs = new ArrayList<>();
ArrayList<ArrayList<String>> ks = new ArrayList<>();
ArrayList<ArrayList<String>> conns = new ArrayList<>();
ArrayList<ArrayList<String>> cs = new ArrayList<>();
ReadData rs =new ReadData();
Apriori apri = new Apriori();
//得到整个数据集 格式为[[T100, I1, I2, I5], .....]
ArrayList<ArrayList<String>> data = rs.DataAll(filePath);
//得到初始C候选
cs = rs.CandidateSets(data);
//候选1项集扫描计数
ks = apri.scanData(cs, data,fcs,0);
//产生频繁i+1项集
for(int i = 1;i<3;i++){
//连接步
conns = apri.connection(ks,i);
//剪枝步
cs = apri.pruning(conns, fcs);
//候选i+1项集扫描计数
ks = apri.scanData(cs, data,fcs,i);
}
//产生关联规则
apri.AssociationRules(ks,data);
}
}
数据读取类:
package edu.bjut.jzl.apriori;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map.Entry;
import org.omg.CORBA.INTERNAL;
@SuppressWarnings("resource")
public class ReadData {
public ReadData() {
super();
}
//全部事物数据
public ArrayList<ArrayList<String>> DataAll(String filePath)throws IOException{
BufferedReader br = new BufferedReader(new FileReader(filePath));
ArrayList<ArrayList<String>> list = new ArrayList<>();
ArrayList<String> line = new ArrayList<String>();
String str = null;
String s1 =null;
String s =null;
int length;
String[] c;
while(!(str=br.readLine()).equals("")){
s =str.split(" ")[0];
line.add(s);
list.add(line);
s1 = str.split(" ")[1];
c= s1.split(",");
length = s1.split(",").length;
for(int i = 0 ; i<length; i++){
line.add(c[i]);
}
//System.out.println(list);
line = new ArrayList<>();
}
br.close();
System.out.println("全部事务数据:"+list);
return list;
}
//产生候选C
public ArrayList<ArrayList<String>> CandidateSets(ArrayList<ArrayList<String>> list){
//产生候选项
ArrayList<String> line = new ArrayList<>();
ArrayList<String> tLine = new ArrayList<>();
ArrayList<ArrayList<String>> cs = new ArrayList<>();
Iterator<ArrayList<String>> it = list.iterator();
ArrayList<ArrayList<String>> css = new ArrayList<>();
String temp = null;
int sign =0;//sign==0表示读入一条事务的ID
while(it.hasNext()){
line = it.next();
Iterator<String> it1 = line.iterator();
while(it1.hasNext()){
if(sign==0){
temp = it1.next();
sign=1;
continue;
}
temp = it1.next();
tLine.add(temp);
if(!cs.contains(tLine)){
cs.add(tLine);
}
tLine = new ArrayList<>();
}
sign =0;
line = new ArrayList<>();
}
css = this.SortCandidateSetsIn(cs);
cs = new ArrayList<>();
cs = this.SortCandidateSetsOut(css);
System.out.println("候选1项集:"+cs);
return cs;
}
//对候选项进行排序
//1.候选项内排序
public ArrayList<ArrayList<String>> SortCandidateSetsIn(ArrayList<ArrayList<String>> list){
ArrayList<String> line = new ArrayList<>();
ArrayList<String> tempLine = new ArrayList<>();
Iterator<ArrayList<String>> it = list.iterator();
ArrayList<ArrayList<String>> css = new ArrayList<>();
String s1 = null;
String s2 = null;
while(it.hasNext()){
line = it.next();
//伪冒泡排序
int length = line.size();
//length!=0表示并没有全部遍历完
while(length!=0){
Iterator<String> iti = line.iterator();
while(iti.hasNext()){
s1 = iti.next();
Iterator<String> itj = line.iterator();
while(itj.hasNext()){
s2 =itj.next();
if(s1.compareTo(s2)>=0){s1 = s2;}
}
tempLine.add(s1);
break;
}
Iterator<String> itf = line.iterator();
while(itf.hasNext()){
if(itf.next().equals(s1)){
itf.remove();
length--;
break;
}
}
}
css.add(tempLine);
tempLine = new ArrayList<>();
}
return css;
}
//对候选项进行排序
//2.候选项外排序
public ArrayList<ArrayList<String>> SortCandidateSetsOut(ArrayList<ArrayList<String>> list){
ArrayList<String> line = new ArrayList<>();
ArrayList<String> tempLine = new ArrayList<>();
ArrayList<ArrayList<String>> css = new ArrayList<>();
int length = list.size();
//length!=0表示并没有全部遍历完
while(length!=0){
Iterator<ArrayList<String>> iti = list.iterator();
while(iti.hasNext()){
line = iti.next();
Iterator<ArrayList<String>> itj = list.iterator();
while(itj.hasNext()){
tempLine = itj.next();
Iterator<String> it1 = line.iterator();
Iterator<String> it2 = tempLine.iterator();
while(it1.hasNext()){
if(it1.next().compareTo(it2.next())>0){
line = tempLine;
break;
}
}
}
}
Iterator<ArrayList<String>> itj = list.iterator();
while(itj.hasNext()){
if(itj.next().equals(line)){
css.add(line);
itj.remove();
break;
}
}
line = new ArrayList<>();
tempLine = new ArrayList<>();
length = length-1;
}
return css;
}
}
先验类:
package edu.bjut.jzl.apriori;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
public class Apriori {
//最小支持度计数
final int minCount = 2;
//最小置信度阈值
final double minConfidence =0.7;
public Apriori() {
super();
}
//扫描数据集,得出支持度以及项集
public ArrayList<ArrayList<String>> scanData(ArrayList<ArrayList<String>> list,ArrayList<ArrayList<String>> data,ArrayList<ArrayList<String>> fcs,int n){
ArrayList<String> line = new ArrayList<>();
ArrayList<String> tempLine = new ArrayList<>();
Iterator<ArrayList<String>> it = list.iterator();
ArrayList<ArrayList<String>> css = new ArrayList<>();
String str = null;
Integer count = 0;
boolean sign = false;
while(it.hasNext()){
line = it.next();
Iterator<ArrayList<String>> it1 = data.iterator();
while(it1.hasNext()){
tempLine = it1.next();
Iterator<String> iti =line.iterator();
while(iti.hasNext()){
str = iti.next();
if(tempLine.contains(str)){
sign = true;
}else{
sign = false;
break;
}
}
if(sign){
count++;
}
tempLine = new ArrayList<>();
}
if(count>=this.minCount){
line.add(count.toString());
css.add(line);
count = 0;
}else{
line.add(count.toString());
fcs.add(line);
count = 0;
}
line = new ArrayList<>();
tempLine = new ArrayList<>();
}
System.out.println("频繁"+(n+1)+"项集:"+css);
System.out.println("非频繁"+(n+1)+"项集:"+fcs);
return css;
}
//连接步
public ArrayList<ArrayList<String>> connection(ArrayList<ArrayList<String>> list,int n){
Iterator<ArrayList<String>> it = list.iterator();
ArrayList<ArrayList<String>> css = new ArrayList<>();
ArrayList<String> line1 = new ArrayList<>();
ArrayList<String> line2 = new ArrayList<>();
ArrayList<String> tempLine = new ArrayList<>();
String str1 = null;
String str2 = null;
int i= 0;
int length = list.size();
boolean sign = true;
while(it.hasNext()){
line1 = it.next();
int j = i+1;
i++;
//求频繁2项集时
if(n==1){
while(j<length){
line2 = list.get(j);
str1 = line1.get(0);
str2 = line2.get(0);
tempLine.add(str1);
tempLine.add(str2);
css.add(tempLine);
j++;
tempLine = new ArrayList<>();
line2 = new ArrayList<>();
}
line1 = new ArrayList<>();
}else{
while(j<length){
line2 = list.get(j);
//比较前n-1项是否相等
for(int m = 0 ;m<n-1;m++){
if(!line1.get(m).equals(line2.get(m))){
sign = false;
break;
}
}
if(sign){
for(int m = 0 ;m<line1.size()-1;m++){
str1 = line1.get(m);
tempLine.add(str1);
}
str2 = line2.get(n-1);
tempLine.add(str2);
css.add(tempLine);
}
j++;
sign = true;
tempLine = new ArrayList<>();
line2 = new ArrayList<>();
}
line1 = new ArrayList<>();
}
}
System.out.println("候选"+(n+1)+"项集"+css);
return css;
}
//剪枝步
public ArrayList<ArrayList<String>> pruning(ArrayList<ArrayList<String>> list,ArrayList<ArrayList<String>> fcs){
ArrayList<ArrayList<String>> css = new ArrayList<>();
ArrayList<String> line = new ArrayList<>();
ArrayList<String> line1 = new ArrayList<>();
Iterator<ArrayList<String>> itf = fcs.iterator();
String str = null;
int length;
int count = 0;
while(itf.hasNext()){
//取出一项非频繁项
line = itf.next();
length = line.size();
Iterator<ArrayList<String>> itl = list.iterator();
//使用取出的非频繁项比较所有候选项是否真包含该非频繁项
while(itl.hasNext()){
line1 = itl.next();
//比较每一项
Iterator<String> it = line.iterator();
while(it.hasNext()){
str = it.next();
if(line1.contains(str)){
count++;
}
if(count == length)
break;
}
if(count==length){
itl.remove();
count = 0;
}
line1 =new ArrayList<>();
}
line = new ArrayList<>();
count = 0;
}
css = list;
fcs.clear();
System.out.println("剪枝后:"+css);
return css;
}
//产生关联规则
public void AssociationRules(ArrayList<ArrayList<String>> list,ArrayList<ArrayList<String>> data){
double confidence;
Iterator<ArrayList<String>> it = list.iterator();
ArrayList<ArrayList<String>> css = new ArrayList<>();
ArrayList<String> line = new ArrayList<>();
ArrayList<String> line1 = new ArrayList<>();
ArrayList<String> line2 = new ArrayList<>();
ArrayList<String> tempLine = new ArrayList<>();
double supportCountAnd = 0;
double supportCountAlone = 0;
String str =null;
String str1 = null;
String str2 = null;
int length = 0;
int lengthLine = 0;
int count;
int value = 0;
Integer temp ;
boolean sign = true;
//求集合的所有子集(不包括空集,全集)
while(it.hasNext()){
line = it.next();
line1 = line;
Iterator<String> itl = line.iterator();
count = 0;
lengthLine = line.size();
while(itl.hasNext()){
if((count+1)==lengthLine){
str1 = itl.next();
temp =new Integer(str1);
supportCountAnd = temp.doubleValue();
itl.remove();
}else{
itl.next();
count++;
}
}
Set<Set<String>> result = new HashSet<Set<String>>(); //用来存放子集的集合
length = line.size() ;
int num = length==0 ? 0 : 1<<(length); //2的n次方,若集合set为空,num为0;若集合set有4个元素,那么num为16.
//从0到2^n-1([00...00]到[11...11])
for(int i = 0; i < num; i++){
Set<String> subSet = new HashSet<String>();
int index = i;
for(int j = 0; j < length; j++){
if((index & 1) == 1){ //每次判断index最低位是否为1,为1则把集合set的第j个元素放到子集中
subSet.add(line.get(j));
}
index >>= 1; //右移一位
}
if(subSet.size()!=0&&subSet.size()!=length)
result.add(subSet); //把子集存储起来
}
line= new ArrayList<>();
Iterator<Set<String>> its = result.iterator();
while(its.hasNext()){
Set<String> tempSet = its.next();
Iterator<String> itss = tempSet.iterator();
while(itss.hasNext()){
str = itss.next();
line.add(str);
}
css.add(line);
line = new ArrayList<>();
}
//求得子集为css,通过子集进行操作
//求置信度
//求子集的计数
Iterator<ArrayList<String>> csit = css.iterator();
while(csit.hasNext()){
line = csit.next();
Iterator<ArrayList<String>> dit= data.iterator();
while(dit.hasNext()){
tempLine = dit.next();
Iterator<String> cslineIt = line.iterator();
while(cslineIt.hasNext()){
str = cslineIt.next();
if(!tempLine.contains(str)){
sign = false;
break;
}
}
if(sign){value++;}
sign = true;
}
temp =new Integer(value);
supportCountAlone = temp.doubleValue();
if((supportCountAnd/supportCountAlone)>=this.minConfidence){
Iterator<String> itste = line1.iterator();
while(itste.hasNext()){
str2 = itste.next();
if(!line.contains(str2))
line2.add(str2);
}
str2 = null;
System.out.print(line);
line = new ArrayList<>();
System.out.print("->");
System.out.print(line2);
line2 = new ArrayList<>();
System.out.print(" Confidence:");
System.out.println(supportCountAnd/supportCountAlone);
}
line.add(temp.toString());
value = 0;
}
css = new ArrayList<>();
}
}
}
运行结果:
数据集:
T100 I1,I2,I5
T200 I2,I4
T300 I2,I3
T400 I1,I2,I4
T500 I1,I3
T600 I2,I3
T700 I1,I3
T800 I1,I2,I3,I5
T900 I1,I2,I3