下面我们来实现一个二进制消息协议的案例,这个案例也是我们后面自己实现RPC通讯案例的一部分。
我们现在实现一个RPC的服务接口定义,这个RPC调用可以完成除法操作。
实现本地调用的话,很容易,定义一个函数divide:
class InvalidOperation(Exception):
"""
自定义非法操作异常
"""
def __init__(self, message=None):
self.message = message or "Invalid operation."
def divide(num1, num2=1):
"""
除法
:param num1: int
:param num2: int 默认为 1
:return: float 商 或者 无效的异常值
"""
if num2 == 0:
raise InvalidOperation()
val = num1 / num2
return val
# 尝试在本地调用
try:
val = divide(200, 100)
except InvalidOperation as e:
print(e.message)
else:
print(val)
# print(type(val))
但是如果变成RPC调用的话,调用双方该以什么样的方式传递哪些消息数据呢?
通过上面的学习,我们已经知道采用二进制方式传递消息数据的话,效率更高,所以我么使用二进制方式来实现消息协议。为了突出消息协议本身,我们不再进行额外压缩处理。
我们将上面的过程抽象成接口:
float divide(1:int num1, 2:int num2=1) => InvalidOperation
消息协议分为两部分说明:
调用请求消息
- 方法名为divide
- 第1个调用参数为整型int,名为num1
- 第2个调用参数为整型int,名为num2,默认值为1
调用返回消息 - 正常返回float类型
- 错误会抛出InvalidOperation异常
struct模块
struct是Python标准库提供的二进制编码解码库,允许我们将各种不同类型的变量转换为bytes字节类型,或者将bytes字节类型转换为其他类型变量。通过struct我们可以方便的操作二进制字节。
1. 将其他类型转换为bytes类型
struct.pack(格式, 数据)
如
>>> struct.pack('!I', 6)
b'\x00\x00\x00\x06'
!表示适用于网络传输的字节顺序
I表示无符号4字节整数
struct支持的格式如下:
2. 将bytes类型转换为其他类型
struct.unpack(格式, 数据)
如
>>> a = b'\x00\x00\x00\x06'
>>> a
b'\x00\x00\x00\x06'
>>> struct.unpack('!I', a)
(6,)
注意unpack返回的是元组
代码实现
1. divide消息协议实现
class DivideProtocol(object):
"""
float divide(1: int num1, 2: int num2=1)
"""
conn = None
def _read_all(self, size):
"""
读取指定长度的字节
:param size: 长度
:return: 读取出的二进制数据
"""
if isinstance(self.conn, BytesIO):
# BytesIO 类型 用于演示
buff = b''
have = 0
while have < size:
chunk = self.conn.read(size - have)
have += len(chunk)
buff += chunk
return buff
else:
# socket 类型
buff = b''
have = 0
while have < size:
chunk = self.conn.recv(size - have)
have += len(chunk)
buff += chunk
# 客户端关闭了连接
if len(chunk) == 0:
raise EOFError()
return buff
def args_encode(self, num1, num2=1):
"""
对调用参数进行编码
:param num1: int
:param num2: int
:return: 编码后的二进制数据
"""
# 处理参数 num1, 4字节整型
buff = struct.pack("!B", 1)
buff += struct.pack("!i", num1)
# 处理参数 num2, 4 字节整型,如为 默认值 1 就不再放到消息中
if num2 != 1:
buff += struct.pack("!B", 2)
buff += struct.pack('!i', num2)
# 处理消息总长度 4 字节无符号整型
length = len(buff)
# 处理方法名 字符串类型
name = "divide"
# 字符串长度 4 字节无符号整型
msg = struct.pack('!I', len(name))
msg += name.encode()
msg += struct.pack('!I', length) + buff
return msg
def args_decode(self, connection):
"""
获取调用参数并进行解码
:param connection: 传输工具对象,如socket对象或者BytesIO对象,从中可以读取消息数据
:return: 解码后的参数字典
"""
# 保存到当前对象中,供_read_all方式使用
self.conn = connection
param_name_map = {
1: 'num1',
2: 'num2',
}
param_len_map = {
1: 4,
2: 4,
}
# 用于保存解码后的参数字典
args = dict()
# 读取消息总长度,4字无节符号整数
buff = self._read_all(4)
length = struct.unpack('!I', buff)[0]
# 记录已读取的长度
have = 0
# 读取第一个参数,4字节整型
buff = self._read_all(1)
have += 1
param_seq = struct.unpack('!B', buff)[0]
param_len = param_len_map[param_seq]
buff = self._read_all(param_len)
have += param_len
args[param_name_map[param_seq]] = struct.unpack('!i', buff)[0]
if have >= length:
return args
# 读取第二个参数,4字节整型
buff = self._read_all(1)
have += 1
param_seq = struct.unpack('!B', buff)[0]
param_len = param_len_map[param_seq]
buff = self._read_all(param_len)
have += param_len
args[param_name_map[param_seq]] = struct.unpack('!i', buff)[0]
return args
def result_encode(self, result):
"""
对调用的结果进行编码
:param result: float 或 InvalidOperation对象
:return: 编码后的二进制数据
"""
if isinstance(result, float):
# 没有异常,正常执行
# 处理结果类型,1字节无符号整数
buff = struct.pack('!B', 1)
# 处理结果值, 4字节float
buff += struct.pack('!f', result)
else:
# 发生了InvalidOperation异常
# 处理结果类型,1字节无符号整数
buff = struct.pack('!B', 2)
# 处理异常结果值, 字符串
# 处理字符串长度, 4字节无符号整数
buff += struct.pack('!I', len(result.message))
# 处理字符串内容
buff += result.message.encode()
return buff
def result_decode(self, connection):
"""
对调用结果进行解码
:param connection: 传输工具对象,如socket对象或者BytesIO对象,从中可以读取消息数据
:return: 结果数据
"""
self.conn = connection
# 取出结果类型, 1字节无符号整数
buff = self._read_all(1)
result_type = struct.unpack('!B', buff)[0]
if result_type == 1:
# float的结果值, 4字节float
buff = self._read_all(4)
result = struct.unpack('!f', buff)[0]
return result
else:
# InvalidOperation对象
# 取出字符串长度, 4字节无符号整数
buff = self._read_all(4)
str_len = struct.unpack('!I', buff)[0]
buff = self._read_all(str_len)
message = buff.decode()
return InvalidOperation(message)
2. 解析方法名实现
class MethodProtocol(object):
"""解析方法名的实现 """
def __init__(self, connection):
self.conn = connection
def _read_all(self, size):
"""
读取指定长度的字节
:param size: 长度
:return: 读取出的二进制数据
"""
if isinstance(self.conn, BytesIO):
# BytesIO类型,用于演示
buff = b''
have = 0
while have < size:
chunk = self.conn.read(size - have)
have += len(chunk)
buff += chunk
return buff
else:
# socket类型
buff = b''
have = 0
while have < size:
print('have=%d size=%d' % (have, size))
chunk = self.conn.recv(size - have)
have += len(chunk)
buff += chunk
if len(chunk) == 0:
raise EOFError()
return buff
def get_method_name(self):
# 获取方法名
# 读取字符串长度,4字节无符号整型
buff = self._read_all(4)
str_len = struct.unpack('!I', buff)[0]
# 读取字符串
buff = self._read_all(str_len)
name = buff.decode()
return name
测试
if __name__ == "__main__":
proto = DivideProtocol()
# 构造消息
buff = BytesIO()
# buff.write(proto.args_encode(100))
buff.write(proto.args_encode(100, 200))
# 解读消息
buff.seek(0)
name = MethodProtocol(buff).get_method_name()
print(name)
args = proto.args_decode(buff)
print(args)
buff.close()
完整代码
import struct
from io import BytesIO
class InvalidOperation(Exception):
"""
自定义非法操作异常
"""
def __init__(self, message=None):
self.message = message or "Invalid operation."
class DivideProtocol(object):
"""
float divide(1: int num1, 2: int num2=1)
"""
conn = None
def _read_all(self, size):
"""
读取指定长度的字节
:param size: 长度
:return: 读取出的二进制数据
"""
if isinstance(self.conn, BytesIO):
# BytesIO 类型 用于演示
buff = b''
have = 0
while have < size:
chunk = self.conn.read(size - have)
have += len(chunk)
buff += chunk
return buff
else:
# socket 类型
buff = b''
have = 0
while have < size:
chunk = self.conn.recv(size - have)
have += len(chunk)
buff += chunk
# 客户端关闭了连接
if len(chunk) == 0:
raise EOFError()
return buff
def args_encode(self, num1, num2=1):
"""
对调用参数进行编码
:param num1: int
:param num2: int
:return: 编码后的二进制数据
"""
# 处理参数 num1, 4字节整型
buff = struct.pack("!B", 1)
buff += struct.pack("!i", num1)
# 处理参数 num2, 4 字节整型,如为 默认值 1 就不再放到消息中
if num2 != 1:
buff += struct.pack("!B", 2)
buff += struct.pack('!i', num2)
# 处理消息总长度 4 字节无符号整型
length = len(buff)
# 处理方法名 字符串类型
name = "divide"
# 字符串长度 4 字节无符号整型
msg = struct.pack('!I', len(name))
msg += name.encode()
msg += struct.pack('!I', length) + buff
return msg
def args_decode(self, connection):
"""
获取调用参数并进行解码
:param connection: 传输工具对象,如socket对象或者BytesIO对象,从中可以读取消息数据
:return: 解码后的参数字典
"""
# 保存到当前对象中,供_read_all方式使用
self.conn = connection
param_name_map = {
1: 'num1',
2: 'num2',
}
param_len_map = {
1: 4,
2: 4,
}
# 用于保存解码后的参数字典
args = dict()
# 读取消息总长度,4字无节符号整数
buff = self._read_all(4)
length = struct.unpack('!I', buff)[0]
# 记录已读取的长度
have = 0
# 读取第一个参数,4字节整型
buff = self._read_all(1)
have += 1
param_seq = struct.unpack('!B', buff)[0]
param_len = param_len_map[param_seq]
buff = self._read_all(param_len)
have += param_len
args[param_name_map[param_seq]] = struct.unpack('!i', buff)[0]
if have >= length:
return args
# 读取第二个参数,4字节整型
buff = self._read_all(1)
have += 1
param_seq = struct.unpack('!B', buff)[0]
param_len = param_len_map[param_seq]
buff = self._read_all(param_len)
have += param_len
args[param_name_map[param_seq]] = struct.unpack('!i', buff)[0]
return args
def result_encode(self, result):
"""
对调用的结果进行编码
:param result: float 或 InvalidOperation对象
:return: 编码后的二进制数据
"""
if isinstance(result, float):
# 没有异常,正常执行
# 处理结果类型,1字节无符号整数
buff = struct.pack('!B', 1)
# 处理结果值, 4字节float
buff += struct.pack('!f', result)
else:
# 发生了InvalidOperation异常
# 处理结果类型,1字节无符号整数
buff = struct.pack('!B', 2)
# 处理异常结果值, 字符串
# 处理字符串长度, 4字节无符号整数
buff += struct.pack('!I', len(result.message))
# 处理字符串内容
buff += result.message.encode()
return buff
def result_decode(self, connection):
"""
对调用结果进行解码
:param connection: 传输工具对象,如socket对象或者BytesIO对象,从中可以读取消息数据
:return: 结果数据
"""
self.conn = connection
# 取出结果类型, 1字节无符号整数
buff = self._read_all(1)
result_type = struct.unpack('!B', buff)[0]
if result_type == 1:
# float的结果值, 4字节float
buff = self._read_all(4)
result = struct.unpack('!f', buff)[0]
return result
else:
# InvalidOperation对象
# 取出字符串长度, 4字节无符号整数
buff = self._read_all(4)
str_len = struct.unpack('!I', buff)[0]
buff = self._read_all(str_len)
message = buff.decode()
return InvalidOperation(message)
class MethodProtocol(object):
"""解析方法名的实现 """
def __init__(self, connection):
self.conn = connection
def _read_all(self, size):
"""
读取指定长度的字节
:param size: 长度
:return: 读取出的二进制数据
"""
if isinstance(self.conn, BytesIO):
# BytesIO类型,用于演示
buff = b''
have = 0
while have < size:
chunk = self.conn.read(size - have)
have += len(chunk)
buff += chunk
return buff
else:
# socket类型
buff = b''
have = 0
while have < size:
print('have=%d size=%d' % (have, size))
chunk = self.conn.recv(size - have)
have += len(chunk)
buff += chunk
if len(chunk) == 0:
raise EOFError()
return buff
def get_method_name(self):
# 获取方法名
# 读取字符串长度,4字节无符号整型
buff = self._read_all(4)
str_len = struct.unpack('!I', buff)[0]
# 读取字符串
buff = self._read_all(str_len)
name = buff.decode()
return name
if __name__ == "__main__":
proto = DivideProtocol()
# 构造消息
buff = BytesIO()
# buff.write(proto.args_encode(100))
buff.write(proto.args_encode(100, 200))
# 解读消息
buff.seek(0)
name = MethodProtocol(buff).get_method_name()
print(name)
args = proto.args_decode(buff)
print(args)
buff.close()