1、A*搜索算法介绍
A*搜寻算法,俗称A星算法,作为启发式搜索算法中的一种,这是一种在图形平面上,有多个节点的路径,求出最低通过成本的算法。常用于游戏中的NPC的移动计算,或线上游戏的BOT的移动计算上。
算法核心:
A*算法最为核心的部分,就在于它的一个估值函数的设计上:
f(n)=g(n)+h(n)
其中f(n)是每个可能试探点的估值,它有两部分组成:
- g(n):表示从起始搜索点到当前点的代价(通常用某结点在搜索树中的深度来表示)。
- h(n):表示启发式搜索中最为重要的一部分,即当前结点到目标结点的估值,h(n)设计的好坏,直接影响着具有此种启发式函数的启发式算法的是否能称为A*算法。
一种具有f(n)=g(n)+h(n)策略的启发式算法能成为A*算法的充分条件是:
1、搜索树上存在着从起始点到终了点的最优路径。
2、问题域是有限的。
3、所有结点的子结点的搜索代价值>0。
4、h(n)=<h*(n) (h*(n)为实际问题的代价值)。
当此四个条件都满足时,一个具有f(n)=g(n)+h(n)策略的启发式算法能成为A*算法,并一定能找到最优解。
算法流程:
首先将起始结点S放入OPEN表,CLOSE表置空,算法开始时:
1、如果OPEN表不为空,从表头取一个结点n,如果为空算法失败。
2、n是目标解吗?是,找到一个解(继续寻找,或终止算法)。
3、将n的所有后继结点展开,就是从n可以直接关联的结点(子结点),如果不在CLOSE表中,就将它们放入OPEN表,并把S放入CLOSE表,同时计算每一个后继结点的估价值f(n),将OPEN表按f(x)排序,最小的放在表头,重复算法,回到1
2、Python实现(以八数码为例)
首先定义初始数码状态与目标数码状态:
a = [[1, 3, 4],
[2, 6, 0],
[7, 5, 8]]
b = [[3, 0, 2],
[6, 1, 5],
[8, 7, 4]]
需要注意的是,初始状态与目标状态序列的逆序对数需要同奇偶,例如a的逆序对数为4,b为10,同为偶,这样可以保证有解。然后我们定义AStar类,在初始化函数中接受S0(初始状态)和G(目标状态),获取行列等信息,然后初始化g(n)、访问过的节点数nodes以及open表和close表。
class AStar:
def __init__(self, S0, G):
self.S0 = S0
self.G = G
self.max_row = len(S0)
self.max_col = len(S0[0])
self.gn = 0
self.nodes = 0
self.close_list = {'Sn': [S0],
'gn': [0],
'hn': [self.get_h(S0)],
'fn': [self.get_h(S0)]}
self.open_list = {'Sn': [], 'gn': [], 'hn': [], 'fn': []}
然后,我们定义一些常用的函数:
(1)get_loc: 获取某个值在数码表中的位置
def get_loc(array, num):
for i in array:
for j in i:
if j == num:
row = array.index(i) + 1
col = i.index(j) + 1
return row, col
return None
(2)get_h: 获取某个状态的启发函数h(n)值,这里h(n)为哈密尔顿距离
def get_h(self, Sn):
h = 0
for i in Sn:
for j in i:
if not j:
continue
temp = j
row_n, col_n = self.get_loc(Sn, temp)
row_g, col_g = self.get_loc(self.G, temp)
h = h + abs(row_n - row_g) + abs(col_n - col_g)
return h
(3)add_open_list: 将可走的某个状态写入open表中
(注:如果该状态已在close表中出现过,则跳过)
def add_open_list(self, temp_s):
if temp_s in self.close_list['Sn']:
return 0
else:
self.open_list['Sn'].append(temp_s)
self.open_list['gn'].append(self.gn)
self.open_list['hn'].append(self.get_h(temp_s))
self.open_list['fn'].append(self.gn + self.get_h(temp_s))
常用函数准备好之后,接下来就是算法的主要部分,状态的搜索及变化,这里定义move函数来实现这一功能。
- 在最开始,是获取当前的深度并+1,由于算法每次将在open表中挑选fn值最小的走法,深度的不一定是线性增长,所以每次需要获取最后放入close表中的状态所对应的深度。
- 其后则是向上下左右四个方向的移动过程,deepcopy用来复制一个新的列表,现有的方法比如b = a[:]在复制后修改b仍会使a的值发生变化,复制的是指针而不是列表值。
from copy import deepcopy as dp
def move(self, Sn):
# restore gn
self.gn = self.close_list['gn'][-1]
self.gn += 1
row_0, col_0 = self.get_loc(Sn, 0)
# up
if row_0 > 1:
temp_n = Sn[row_0 - 2][col_0 - 1]
temp_s = dp(Sn)
temp_s[row_0 - 2][col_0 - 1] = 0
temp_s[row_0 - 1][col_0 - 1] = temp_n
self.add_open_list(temp_s)
# down
if row_0 < self.max_row:
temp_n = Sn[row_0][col_0 - 1]
temp_s = dp(Sn)
temp_s[row_0][col_0 - 1] = 0
temp_s[row_0 - 1][col_0 - 1] = temp_n
self.add_open_list(temp_s)
# left
if col_0 > 1:
temp_n = Sn[row_0 - 1][col_0 - 2]
temp_s = dp(Sn)
temp_s[row_0 - 1][col_0 - 2] = 0
temp_s[row_0 - 1][col_0 - 1] = temp_n
self.add_open_list(temp_s)
# right
if col_0 < self.max_col:
temp_n = Sn[row_0 - 1][col_0]
temp_s = dp(Sn)
temp_s[row_0 - 1][col_0] = 0
temp_s[row_0 - 1][col_0 - 1] = temp_n
self.add_open_list(temp_s)
在可走的状态都放入open表中后,我们需要:
- 判断f(n)值最小的走法,找到最佳状态
- 将最佳状态写入close表的最后
- 从open表中删除这一状态,以免重复走同一步
- 返回移动后的新状态
# get best move
fns = self.open_list['fn']
best_idx = fns.index(min(fns))
new_s = self.open_list['Sn'][best_idx]
# update list
self.close_list['Sn'].append(new_s)
self.close_list['gn'].append(self.open_list['gn'][best_idx])
self.close_list['hn'].append(self.open_list['hn'][best_idx])
self.close_list['fn'].append(fns[best_idx])
for key in self.open_list.keys():
self.open_list[key].pop(best_idx)
return new_s
以上定义的move函数实现了单步的判断和行走,我们需要用一个主函数来调用他走完所有的状态,最终达到目标状态。其中搜索停止的条件是h(n)值为0,即目前状态与目标状态一致。
def run(self):
Sn = self.S0
start = time.clock()
while self.close_list['hn'][-1]:
Sn = self.move(Sn)
self.nodes += 1
elapsed = time.clock() - start
print(f'步数: {self.gn}, 访问节点数: {self.nodes}, 耗时: {elapsed}')
return self.close_list
最后实例化AStar:
if __name__ == '__main__':
method = AStar(a, b)
method.run()
运行得到结果如下,最佳搜索步数为22步,访问过的总结点数为983。
步数: 22, 访问节点数: 983, 耗时: 0.283306
以上就是A*搜索算法的Python实现,最后附上完整代码
# A-star searching method
from copy import deepcopy as dp
import time
a = [[1, 3, 4],
[2, 6, 0],
[7, 5, 8]]
b = [[3, 0, 2],
[6, 1, 5],
[8, 7, 4]]
class AStar:
def __init__(self, S0, G):
self.S0 = S0
self.G = G
self.max_row = len(S0)
self.max_col = len(S0[0])
self.gn = 0
self.nodes = 0
self.close_list = {'Sn': [S0],
'gn': [0],
'hn': [self.get_h(S0)],
'fn': [self.get_h(S0)]}
self.open_list = {'Sn': [], 'gn': [], 'hn': [], 'fn': []}
def run(self):
Sn = self.S0
start = time.clock()
while self.close_list['hn'][-1]:
Sn = self.move(Sn)
self.nodes += 1
elapsed = time.clock() - start
print(f'步数: {self.gn}, 访问节点数: {self.nodes}, 耗时: {elapsed}')
return self.close_list
def move(self, Sn):
# restore gn
self.gn = self.close_list['gn'][-1]
self.gn += 1
row_0, col_0 = self.get_loc(Sn, 0)
# up
if row_0 > 1:
temp_n = Sn[row_0 - 2][col_0 - 1]
temp_s = dp(Sn)
temp_s[row_0 - 2][col_0 - 1] = 0
temp_s[row_0 - 1][col_0 - 1] = temp_n
self.add_open_list(temp_s)
# down
if row_0 < self.max_row:
temp_n = Sn[row_0][col_0 - 1]
temp_s = dp(Sn)
temp_s[row_0][col_0 - 1] = 0
temp_s[row_0 - 1][col_0 - 1] = temp_n
self.add_open_list(temp_s)
# left
if col_0 > 1:
temp_n = Sn[row_0 - 1][col_0 - 2]
temp_s = dp(Sn)
temp_s[row_0 - 1][col_0 - 2] = 0
temp_s[row_0 - 1][col_0 - 1] = temp_n
self.add_open_list(temp_s)
# right
if col_0 < self.max_col:
temp_n = Sn[row_0 - 1][col_0]
temp_s = dp(Sn)
temp_s[row_0 - 1][col_0] = 0
temp_s[row_0 - 1][col_0 - 1] = temp_n
self.add_open_list(temp_s)
# get best move
fns = self.open_list['fn']
best_idx = fns.index(min(fns))
new_s = self.open_list['Sn'][best_idx]
# update list
self.close_list['Sn'].append(new_s)
self.close_list['gn'].append(self.open_list['gn'][best_idx])
self.close_list['hn'].append(self.open_list['hn'][best_idx])
self.close_list['fn'].append(fns[best_idx])
for key in self.open_list.keys():
self.open_list[key].pop(best_idx)
return new_s
def add_open_list(self, temp_s):
if temp_s in self.close_list['Sn']:
return 0
else:
self.open_list['Sn'].append(temp_s)
self.open_list['gn'].append(self.gn)
self.open_list['hn'].append(self.get_h(temp_s))
self.open_list['fn'].append(self.gn + self.get_h(temp_s))
def get_h(self, Sn):
h = 0
for i in Sn:
for j in i:
if not j:
continue
temp = j
row_n, col_n = self.get_loc(Sn, temp)
row_g, col_g = self.get_loc(self.G, temp)
h = h + abs(row_n - row_g) + abs(col_n - col_g)
return h
@staticmethod
def get_loc(array, num):
for i in array:
for j in i:
if j == num:
row = array.index(i) + 1
col = i.index(j) + 1
return row, col
return None
if __name__ == '__main__':
method = AStar(a, b)
method.run()