目的

之前写了一篇文章​​多模式匹配AC算法Java(kotlin)实现,可建模中文​​,里面通过建模char(unicode)来实现跳转,使用的是map。但是通过私下的实验,其实这样做性能并不高,而且代码复杂难懂。更通用的做法是将unicode字符串转换为bytes,每个byte256种情况,也就是为每个节点维护一个256维的数组表示子节点。通过实验,建模bytes的方法比建模char的性能高不少。本文贴出建模byte的kotlin代码。

要点

  1. 根节点比较特殊,其子节点除了模式中的首byte节点之外,剩下的bytes的节点的跳转表要指向根节点。
  2. 失败跳转表构建思路比较绕,其实也比较简单。假如当前节点在byte_x失配了,那么就查其父节点的失配节点是否有byte_x的子节点,如果有那么恭喜,直接到那个节点,继续在树中跳转;如果没有,则查其父节点的失配节点的失配节点是否有byte_x的子节点,如果有,直接跳转过去。由于最差也会跳转到根节点,所以算法收敛。

代码

package com.davezhao.utils

val BYTE_SIZE = Byte.MAX_VALUE - Byte.MIN_VALUE + 1

data class NodeByte(
var finish: Boolean = false, // 当前节点是否为某一模式终止节点
var label: Int = 0, // 当前节点编号,默认为0,即根节点,其余的节点大于0
var pattern: String = "", // 如果当前节点为某一模式的终止节点,则该字段保存终止的模式字符串
val transitionTable: MutableList<Int> = MutableList(BYTE_SIZE, { -1 }) // 当前节点的子节点编号,-1表示不是子节点
)


class AcMatchByte {
private val startNode = NodeByte() // trie的根节点
private var labelCount = 1 //内部有效的节点个数
private val nodes = mutableListOf<NodeByte>(startNode) // 保存所有的节点,位置为节点的编号
private var fail: MutableList<Int> = mutableListOf() // 失败跳转表

/**
* 添加模式,构建trie树,可一次放入多个模式字符串,或者多次调用该函数放入。
* @param patterns
* @return
*/
fun addPatterns(vararg patterns: String) {
var latestLabel: Int = labelCount
for (pattern in patterns) {
var pNode = startNode //从根节点开始构建
for (b in pattern.toByteArray()) { //将当前模式字符串转换为byte
val i = b - Byte.MIN_VALUE // 取当前byte的位置
var nxtNodeLabel = pNode.transitionTable[i] //查看当前节点是否包含i的子节点
if (nxtNodeLabel == -1) { //如果不包含,则需要为其创建子节点
val nxtNode = NodeByte()
nxtNode.label = latestLabel
nodes.add(nxtNode)
pNode.transitionTable[i] = latestLabel // 为i创建指向子节点的跳转表
nxtNodeLabel = latestLabel++ // 全局的节点编号要增加1
}
pNode = nodes[nxtNodeLabel] // 令pNode指向i的子节点,继续下一个byte的构建
}
pNode.finish = true // 一个模式完成后,为模式的最后一个节点设置finish=true
pNode.pattern = pattern // 一个模式完成后,为模式的最后一个节点pattern设为当前pattern
}

labelCount = latestLabel // 构建完所有的模式后,将最新编号赋值给labelCount,以备下次构建
}

fun build() {
// 在构建失败跳转规则之前,需要有一个保底的设置。根节点有256个子节点,哪些非匹配模式的节点需要跳转到根节点本身,以便自动机跳转
for (i in (0 until BYTE_SIZE)) {
if (startNode.transitionTable[i] == -1) {
startNode.transitionTable[i] = 0
}
}

val q = mutableListOf<Int>() // 创建一个队列,用于存储待创建其子节点失败跳转表的节点
fail = MutableList(labelCount, { -1 }) // 失败跳转表是每个节点都有跳转,所以size为state_count个。
startNode.transitionTable.filter { it > 0 }.forEach {
// 将startNode节点中非指向根节点的节点挑出来,设置它们的失败跳转为根节点,并且加入队列,以便创建其子节点的失败跳转表
fail[it] = 0
q.add(it)
}

while (!q.isEmpty()) { // 如果队列为空,则说明所有节点失败跳转构建完毕,退出
val known = q.removeAt(0) // 从队列中取出队头的节点
(0 until BYTE_SIZE).filter { nodes[known].transitionTable[it] > 0 }.forEach { i ->
// 取出当前节点known的所有模式子节点
val nxt = nodes[known].transitionTable[i] // 对于nxt子节点
var p = fail[known] // 首先先得到其父节点的跳转节点
while (!(p != -1 && nodes[p].transitionTable[i] != -1)) { // 然后判断,如果nxt节点父节点known跳转节点p存在,且p也有子节点i,则退出跳转,否则继续寻找p的跳转节点赋值给p继续判断。由于我们有根节点的保底设置,所以最差也会到根节点。
p = fail[p]
}
fail[nxt] = nodes[p].transitionTable[i] // 子节点nxt的失败跳转节点即为其广义父节点的N(大于等于1)次跳转节点的同字符i的子节点
q.add(nxt) // 将nxt放入q,为其子节点设置失败跳转
}
}
}

fun match(str: String): List<String> {
val strB = str.toByteArray()
var pNode = startNode
var i = 0 // 遍历带搜索字符串str的下标
val res = mutableListOf<String>() // 保存搜索到的字符串列表
while (i < strB.size) {
val trans = strB[i] - Byte.MIN_VALUE
if (pNode.transitionTable[trans] != -1) { // 如果当前字符跳转成功,则跳转到当前字符的下一个字符
pNode = nodes[pNode.transitionTable[trans]]
} else { // 否则回退下标,并且使用失败跳转转到下一个节点继续
--i
pNode = nodes[fail[pNode.label]]
}
if (pNode.finish) { // 如果当前节点已经是某一模式终点,则保存该模式
res.add(pNode.pattern)
}
++i // 正常的下标增长
}

return res
}
}

fun main(args: Array<String>) {
val ac = AcMatchByte()
val patterns = arrayOf("his", "hers", "she", "he", "中国")
ac.addPatterns(patterns = *patterns)
ac.addPatterns("国中", "中国中")
ac.build()

val str = "hishers中国人民中国中国"

val res = ac.match(str)
println(res)
}