下面我们来实现一个二进制消息协议的案例,这个案例也是我们后面自己实现RPC通讯案例的一部分。

我们现在实现一个RPC的服务接口定义,这个RPC调用可以完成除法操作。

实现本地调用的话,很容易,定义一个函数divide:

class InvalidOperation(Exception):
    """
    自定义非法操作异常
    """
    def __init__(self, message=None):
        self.message = message or "Invalid operation."


def divide(num1, num2=1):
    """
    除法
    :param num1:  int
    :param num2:  int 默认为 1
    :return:  float 商 或者 无效的异常值
    """
    if num2 == 0:
        raise InvalidOperation()
    val = num1 / num2
    return val


# 尝试在本地调用
try:
    val = divide(200, 100)
except InvalidOperation as e:
    print(e.message)
else:
    print(val)
    # print(type(val))

但是如果变成RPC调用的话,调用双方该以什么样的方式传递哪些消息数据呢?

通过上面的学习,我们已经知道采用二进制方式传递消息数据的话,效率更高,所以我么使用二进制方式来实现消息协议。为了突出消息协议本身,我们不再进行额外压缩处理。

我们将上面的过程抽象成接口:

float divide(1:int num1, 2:int num2=1) => InvalidOperation

消息协议分为两部分说明:
调用请求消息

  • 方法名为divide
  • 第1个调用参数为整型int,名为num1
  • 第2个调用参数为整型int,名为num2,默认值为1

    调用返回消息
  • 正常返回float类型
  • 错误会抛出InvalidOperation异常

struct模块

struct是Python标准库提供的二进制编码解码库,允许我们将各种不同类型的变量转换为bytes字节类型,或者将bytes字节类型转换为其他类型变量。通过struct我们可以方便的操作二进制字节。
1. 将其他类型转换为bytes类型

struct.pack(格式, 数据)

>>> struct.pack('!I', 6)
b'\x00\x00\x00\x06'

!表示适用于网络传输的字节顺序

I表示无符号4字节整数

struct支持的格式如下:

python aria2 rpc python aria2 rpc调用_数据


2. 将bytes类型转换为其他类型

struct.unpack(格式, 数据)

>>> a = b'\x00\x00\x00\x06'
>>> a
b'\x00\x00\x00\x06'
>>> struct.unpack('!I', a)
(6,)

注意unpack返回的是元组

代码实现

1. divide消息协议实现
class DivideProtocol(object):
    """
    float divide(1: int num1, 2: int num2=1)
    """
    conn = None

    def _read_all(self, size):
        """
        读取指定长度的字节
        :param size: 长度
        :return: 读取出的二进制数据
        """
        if isinstance(self.conn, BytesIO):
            # BytesIO 类型 用于演示
            buff = b''
            have = 0
            while have < size:
                chunk = self.conn.read(size - have)
                have += len(chunk)
                buff += chunk
            return buff
        else:
            # socket 类型
            buff = b''
            have = 0
            while have < size:
                chunk = self.conn.recv(size - have)
                have += len(chunk)
                buff += chunk
                # 客户端关闭了连接
                if len(chunk) == 0:
                    raise EOFError()
            return buff

    def args_encode(self, num1, num2=1):
        """
        对调用参数进行编码
        :param num1:  int
        :param num2:  int
        :return: 编码后的二进制数据
        """
        # 处理参数 num1, 4字节整型
        buff = struct.pack("!B", 1)
        buff += struct.pack("!i", num1)

        # 处理参数 num2, 4 字节整型,如为 默认值 1 就不再放到消息中
        if num2 != 1:
            buff += struct.pack("!B", 2)
            buff += struct.pack('!i', num2)

        # 处理消息总长度 4 字节无符号整型
        length = len(buff)

        # 处理方法名 字符串类型
        name = "divide"
        # 字符串长度 4 字节无符号整型
        msg = struct.pack('!I', len(name))
        msg += name.encode()

        msg += struct.pack('!I', length) + buff

        return msg

    def args_decode(self, connection):
        """
        获取调用参数并进行解码
        :param connection: 传输工具对象,如socket对象或者BytesIO对象,从中可以读取消息数据
        :return: 解码后的参数字典
        """
        # 保存到当前对象中,供_read_all方式使用
        self.conn = connection
        param_name_map = {
            1: 'num1',
            2: 'num2',
        }
        param_len_map = {
            1: 4,
            2: 4,
        }
        # 用于保存解码后的参数字典
        args = dict()

        # 读取消息总长度,4字无节符号整数
        buff = self._read_all(4)
        length = struct.unpack('!I', buff)[0]

        # 记录已读取的长度
        have = 0

        # 读取第一个参数,4字节整型
        buff = self._read_all(1)
        have += 1
        param_seq = struct.unpack('!B', buff)[0]
        param_len = param_len_map[param_seq]
        buff = self._read_all(param_len)
        have += param_len
        args[param_name_map[param_seq]] = struct.unpack('!i', buff)[0]

        if have >= length:
            return args

        # 读取第二个参数,4字节整型
        buff = self._read_all(1)
        have += 1
        param_seq = struct.unpack('!B', buff)[0]
        param_len = param_len_map[param_seq]
        buff = self._read_all(param_len)
        have += param_len
        args[param_name_map[param_seq]] = struct.unpack('!i', buff)[0]

        return args

    def result_encode(self, result):
        """
        对调用的结果进行编码
        :param result: float 或 InvalidOperation对象
        :return: 编码后的二进制数据
        """
        if isinstance(result, float):
            # 没有异常,正常执行
            # 处理结果类型,1字节无符号整数
            buff = struct.pack('!B', 1)

            # 处理结果值, 4字节float
            buff += struct.pack('!f', result)
        else:
            # 发生了InvalidOperation异常
            # 处理结果类型,1字节无符号整数
            buff = struct.pack('!B', 2)

            # 处理异常结果值, 字符串
            # 处理字符串长度, 4字节无符号整数
            buff += struct.pack('!I', len(result.message))
            # 处理字符串内容
            buff += result.message.encode()

        return buff

    def result_decode(self, connection):
        """
        对调用结果进行解码
        :param connection: 传输工具对象,如socket对象或者BytesIO对象,从中可以读取消息数据
        :return: 结果数据
        """
        self.conn = connection

        # 取出结果类型, 1字节无符号整数
        buff = self._read_all(1)
        result_type = struct.unpack('!B', buff)[0]
        if result_type == 1:
            # float的结果值, 4字节float
            buff = self._read_all(4)
            result = struct.unpack('!f', buff)[0]
            return result
        else:
            # InvalidOperation对象
            # 取出字符串长度, 4字节无符号整数
            buff = self._read_all(4)
            str_len = struct.unpack('!I', buff)[0]
            buff = self._read_all(str_len)
            message = buff.decode()
            return InvalidOperation(message)
2. 解析方法名实现
class MethodProtocol(object):
    """解析方法名的实现 """
    def __init__(self, connection):
        self.conn = connection

    def _read_all(self, size):
        """
        读取指定长度的字节
        :param size: 长度
        :return: 读取出的二进制数据
        """
        if isinstance(self.conn, BytesIO):
            # BytesIO类型,用于演示
            buff = b''
            have = 0
            while have < size:
                chunk = self.conn.read(size - have)
                have += len(chunk)
                buff += chunk

            return buff

        else:
            # socket类型
            buff = b''
            have = 0
            while have < size:
                print('have=%d size=%d' % (have, size))
                chunk = self.conn.recv(size - have)
                have += len(chunk)
                buff += chunk

                if len(chunk) == 0:
                    raise EOFError()

            return buff

    def get_method_name(self):
        # 获取方法名
        # 读取字符串长度,4字节无符号整型
        buff = self._read_all(4)
        str_len = struct.unpack('!I', buff)[0]

        # 读取字符串
        buff = self._read_all(str_len)
        name = buff.decode()
        return name
测试
if __name__ == "__main__":
    proto = DivideProtocol()
    # 构造消息
    buff = BytesIO()
    # buff.write(proto.args_encode(100))
    buff.write(proto.args_encode(100, 200))

    # 解读消息
    buff.seek(0)
    name = MethodProtocol(buff).get_method_name()
    print(name)
    args = proto.args_decode(buff)
    print(args)
    buff.close()
完整代码
import struct

from io import BytesIO


class InvalidOperation(Exception):
    """
    自定义非法操作异常
    """
    def __init__(self, message=None):
        self.message = message or "Invalid operation."


class DivideProtocol(object):
    """
    float divide(1: int num1, 2: int num2=1)
    """
    conn = None

    def _read_all(self, size):
        """
        读取指定长度的字节
        :param size: 长度
        :return: 读取出的二进制数据
        """
        if isinstance(self.conn, BytesIO):
            # BytesIO 类型 用于演示
            buff = b''
            have = 0
            while have < size:
                chunk = self.conn.read(size - have)
                have += len(chunk)
                buff += chunk
            return buff
        else:
            # socket 类型
            buff = b''
            have = 0
            while have < size:
                chunk = self.conn.recv(size - have)
                have += len(chunk)
                buff += chunk
                # 客户端关闭了连接
                if len(chunk) == 0:
                    raise EOFError()
            return buff

    def args_encode(self, num1, num2=1):
        """
        对调用参数进行编码
        :param num1:  int
        :param num2:  int
        :return: 编码后的二进制数据
        """
        # 处理参数 num1, 4字节整型
        buff = struct.pack("!B", 1)
        buff += struct.pack("!i", num1)

        # 处理参数 num2, 4 字节整型,如为 默认值 1 就不再放到消息中
        if num2 != 1:
            buff += struct.pack("!B", 2)
            buff += struct.pack('!i', num2)

        # 处理消息总长度 4 字节无符号整型
        length = len(buff)

        # 处理方法名 字符串类型
        name = "divide"
        # 字符串长度 4 字节无符号整型
        msg = struct.pack('!I', len(name))
        msg += name.encode()

        msg += struct.pack('!I', length) + buff

        return msg

    def args_decode(self, connection):
        """
        获取调用参数并进行解码
        :param connection: 传输工具对象,如socket对象或者BytesIO对象,从中可以读取消息数据
        :return: 解码后的参数字典
        """
        # 保存到当前对象中,供_read_all方式使用
        self.conn = connection
        param_name_map = {
            1: 'num1',
            2: 'num2',
        }
        param_len_map = {
            1: 4,
            2: 4,
        }
        # 用于保存解码后的参数字典
        args = dict()

        # 读取消息总长度,4字无节符号整数
        buff = self._read_all(4)
        length = struct.unpack('!I', buff)[0]

        # 记录已读取的长度
        have = 0

        # 读取第一个参数,4字节整型
        buff = self._read_all(1)
        have += 1
        param_seq = struct.unpack('!B', buff)[0]
        param_len = param_len_map[param_seq]
        buff = self._read_all(param_len)
        have += param_len
        args[param_name_map[param_seq]] = struct.unpack('!i', buff)[0]

        if have >= length:
            return args

        # 读取第二个参数,4字节整型
        buff = self._read_all(1)
        have += 1
        param_seq = struct.unpack('!B', buff)[0]
        param_len = param_len_map[param_seq]
        buff = self._read_all(param_len)
        have += param_len
        args[param_name_map[param_seq]] = struct.unpack('!i', buff)[0]

        return args

    def result_encode(self, result):
        """
        对调用的结果进行编码
        :param result: float 或 InvalidOperation对象
        :return: 编码后的二进制数据
        """
        if isinstance(result, float):
            # 没有异常,正常执行
            # 处理结果类型,1字节无符号整数
            buff = struct.pack('!B', 1)

            # 处理结果值, 4字节float
            buff += struct.pack('!f', result)
        else:
            # 发生了InvalidOperation异常
            # 处理结果类型,1字节无符号整数
            buff = struct.pack('!B', 2)

            # 处理异常结果值, 字符串
            # 处理字符串长度, 4字节无符号整数
            buff += struct.pack('!I', len(result.message))
            # 处理字符串内容
            buff += result.message.encode()

        return buff

    def result_decode(self, connection):
        """
        对调用结果进行解码
        :param connection: 传输工具对象,如socket对象或者BytesIO对象,从中可以读取消息数据
        :return: 结果数据
        """
        self.conn = connection

        # 取出结果类型, 1字节无符号整数
        buff = self._read_all(1)
        result_type = struct.unpack('!B', buff)[0]
        if result_type == 1:
            # float的结果值, 4字节float
            buff = self._read_all(4)
            result = struct.unpack('!f', buff)[0]
            return result
        else:
            # InvalidOperation对象
            # 取出字符串长度, 4字节无符号整数
            buff = self._read_all(4)
            str_len = struct.unpack('!I', buff)[0]
            buff = self._read_all(str_len)
            message = buff.decode()
            return InvalidOperation(message)


class MethodProtocol(object):
    """解析方法名的实现 """
    def __init__(self, connection):
        self.conn = connection

    def _read_all(self, size):
        """
        读取指定长度的字节
        :param size: 长度
        :return: 读取出的二进制数据
        """
        if isinstance(self.conn, BytesIO):
            # BytesIO类型,用于演示
            buff = b''
            have = 0
            while have < size:
                chunk = self.conn.read(size - have)
                have += len(chunk)
                buff += chunk

            return buff

        else:
            # socket类型
            buff = b''
            have = 0
            while have < size:
                print('have=%d size=%d' % (have, size))
                chunk = self.conn.recv(size - have)
                have += len(chunk)
                buff += chunk

                if len(chunk) == 0:
                    raise EOFError()

            return buff

    def get_method_name(self):
        # 获取方法名
        # 读取字符串长度,4字节无符号整型
        buff = self._read_all(4)
        str_len = struct.unpack('!I', buff)[0]

        # 读取字符串
        buff = self._read_all(str_len)
        name = buff.decode()
        return name


if __name__ == "__main__":
    proto = DivideProtocol()
    # 构造消息
    buff = BytesIO()
    # buff.write(proto.args_encode(100))
    buff.write(proto.args_encode(100, 200))

    # 解读消息
    buff.seek(0)
    name = MethodProtocol(buff).get_method_name()
    print(name)
    args = proto.args_decode(buff)
    print(args)
    buff.close()