频繁项集挖掘Apriori算法及其Python实现
Apriori算法是通过限制候选产生发现频繁项集。
Apriori算法使用一种称为逐层搜索的迭代方法,其中k项集用于探索(k+1)项集。首先,通过扫描数据库,累计每个项的计数,并收集满足最小支持度的项,找出频繁1项集的集合,记为L1。然后,使用L1找出频繁2项集的集合L2,使用L2找出L3,如此下去,直到不能再找到频繁k项集。
为了提高频繁项集逐层产生的效率,一种称为先验性质(Apriori property)的重要性质用于压缩搜索空间。
先验性质:频繁项集的所有非空子集也一定是频繁的。
其实,算法的描述理解起来是困难的,有例子帮助理解最好不过。恩,盗图一张。
Python实现:
我们的期望:
输入:数据库中数据和最小支持度
输出:频繁项集
例:
输入:数据,[['A','B','C','D'],['B','C','E'],['A','B','C','E'],['B','D','E'],['A','B','C','D']];最小支持度:0.7
输出:[['B'], ['C'], ['B', 'C']]
写一个方法 def apriori(D, minSup)
,参数D
就是输入的数据库数据,minSup
是最小支持度。
def apriori(D, minSup):
'''频繁项集用keys表示,
key表示项集中的某一项,
cutKeys表示经过剪枝步的某k项集。
C表示某k项集的每一项在事务数据库D中的支持计数
'''
#先求出1项集的集合及其支持计数,注意此处C1是字典,key为项集,value是计数,与下不同,下的C是列表只有计数
C1 = {}
for T in D:
for I in T:
if I in C1:
C1[I] += 1
else:
C1[I] = 1
_keys1 = C1.keys()
#此处对keys的存储格式进行处理,为了方便后边由此得出k+1项集的集合
keys1 = []
for i in _keys1:
keys1.append([i])
n = len(D)
cutKeys1 = []
#对keys1(1项集)进行剪枝步
for k in keys1[:]:
if C1[k[0]]*1.0/n >= minSup:
cutKeys1.append(k)
cutKeys1.sort()
总之,对于1项集要进行特殊处理,然后再用迭代的方法求k+1项集。
好,迭代来了:
all_keys = []
while keys != []:
C = getC(D, keys)
cutKeys = getCutKeys(keys, C, minSup, D)
for key in cutKeys:
all_keys.append(key)
keys = aproiri_gen(cutKeys)
return all_keys
注意,all_keys
是全局变量,存储所有通过剪枝步的k项集。
函数getC(D, keys)
是对keys
中的每一个key进行计数,函数getCutKeys(keys, C, minSup, D)
是剪枝步的实现,函数aproiri_gen(cutKeys)
是由k项集获得k+1项集(连接步)。
这样,算法Apriori就实现了,输入输出试一下:
D = [['A','B','C','D'],['B','C','E'],['A','B','C','E'],['B','D','E'],['A','B','C','D']]
F = apriori(D, 0.7)
print '\nfrequent itemset:\n', F
把最小支持度改成0.5
试试:
完整代码:
# coding:utf8
import sys
def apriori(D, minSup):
"""
频繁项集用keys表示,
key表示项集中的某一项,
cutKeys表示经过剪枝步的某k项集。
C表示某k项集的每一项在事务数据库D中的支持计数
:param D:
:param minSup:
:return:
"""
C1 = {}
for T in D:
for I in T:
if I in C1:
C1[I] += 1
else:
C1[I] = 1
print C1
_keys1 = C1.keys()
keys1 = []
for i in _keys1:
keys1.append([i])
n = len(D)
cutKeys1 = []
for k in keys1[:]:
if C1[k[0]] * 1.0 / n >= minSup:
cutKeys1.append(k)
cutKeys1.sort()
keys = cutKeys1
all_keys = []
while keys != []:
C = getC(D, keys)
cutKeys = getCutKeys(keys, C, minSup, len(D))
for key in cutKeys:
all_keys.append(key)
keys = aproiri_gen(cutKeys)
return all_keys
def getC(D, keys):
"""
对keys中的每一个key进行计数
:param D:
:param keys:
:return:
"""
C = []
for key in keys:
c = 0
for T in D:
have = True
for k in key:
if k not in T:
have = False
if have:
c += 1
C.append(c)
return C
def getCutKeys(keys, C, minSup, length):
"""
剪枝步
:param keys:
:param C:
:param minSup:
:param length:
:return:
"""
for i, key in enumerate(keys):
if float(C[i]) / length < minSup:
keys.remove(key)
return keys
def keyInT(key, T):
"""
判断项key是否在数据库中某一元组T中
:param key:
:param T:
:return:
"""
for k in key:
if k not in T:
return False
return True
def aproiri_gen(keys1):
"""
连接步
:param keys1:
:return:
"""
keys2 = []
for k1 in keys1:
for k2 in keys1:
if k1 != k2:
key = []
for k in k1:
if k not in key:
key.append(k)
for k in k2:
if k not in key:
key.append(k)
key.sort()
if key not in keys2:
keys2.append(key)
return keys2
D = [['A', 'C', 'D'], ['B', 'C', 'E'], ['A', 'B', 'C', 'E'], ['B', 'E']]
F = apriori(D, 0.5)
print '\nfrequent itemset:\n', F
# [['B'], ['C'], ['E'], ['B', 'E'], ['C', 'E']]