PyFlink 记载Keras 模型,进行实时性参数预测
# -*- coding: utf-8 -*
import logging
import os
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.table import StreamTableEnvironment, EnvironmentSettings, DataTypes
from pyflink.table.udf import ScalarFunction, udf
"""
加载Mysql中的重连数据进行模型训练、验证
"""
# ########################### 初始化流处理环境 ###########################
# 创建 Blink 流处理环境,注意此处需要指定 StreamExecutionEnvironment,否则无法导入 java 函数
env = StreamExecutionEnvironment.get_execution_environment()
env_settings = EnvironmentSettings.new_instance().in_streaming_mode().use_blink_planner().build()
env.set_parallelism(1)
t_env = StreamTableEnvironment.create(env, environment_settings=env_settings)
# 设置该参数以使用 UDF
t_env.get_config().get_configuration().set_boolean("python.fn-execution.4memory.managed", True)
t_env.get_config().get_configuration().set_string("taskmanager.memory.task.off-heap.size", "80m")
# ########################### 指定 jar 依赖 ###########################
# dir_kafka_sql_connect = os.path.join(os.path.abspath(os.path.dirname(__file__)),
# 'flink-sql-connector-kafka_2.11-1.11.2.jar')
# t_env.get_config().get_configuration().set_string("pipeline.jars", 'file:///' + dir_kafka_sql_connect)
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'ram.log')
logging.basicConfig(filename=filename, level=logging.INFO)
# ########################### 注册 UDF ###########################
class myKerasMLP(ScalarFunction):
def __init__(self):
print("Model __init__方法")
# 加载模型
self.model_name = 'Parameter_Predict_Net'
self.weights = 'Parameter_Predict_Net_weights'
self.redis_params = dict(host='localhost', password='123456', port=6379, db=1)
self.model = None
# y 的定义域
self.classes = list(range(10))
# 自定义的 4 类指标,用于评估模型和样本,指标值将暴露给外部系统以便于实时监控模型的状况
self.metric_counter = None # 从作业开始至今的所有样本数量
self.metric_predict_acc = 0 # 模型预测的准确率(用过去 10 条样本来评估)
self.metric_distribution_y = None # 标签 y 的分布
self.metric_total_10_sec = None # 过去 10 秒内训练过的样本数量
self.metric_right_10_sec = None # 过去 10 秒内的预测正确的样本数
def open(self, function_context):
"""
访问指标系统,并注册指标,以便于在 webui (localhost:8081) 实时查看算法的运行情况。
:param function_context:
:return:
"""
if self.model:
print("模型已加载..")
else:
self.model = self.load_model()
self.model.summary()
print("Model open方法")
# 访问指标系统,并定义 Metric Group 名称为 online_ml 以便于在 webui 查找
# Metric Group + Metric Name 是 Metric 的唯一标识
metric_group = function_context.get_metric_group().add_group("online_ml")
# 目前 PyFlink 1.11.2 版本支持 4 种指标:计数器 Counters,量表 Gauges,分布 Distribution 和仪表 Meters 。
# 目前这些指标都只能是整数
# 1、计数器 Counter,用于计算某个东西出现的次数,可以通过 inc()/inc(n:int) 或 dec()/dec(n:int) 来增加或减少值
self.metric_counter = metric_group.counter('sample_count') # 训练过的样本数量
# 2、量表 Gauge,用于根据业务计算指标,可以比较灵活地使用
# 目前 pyflink 只支持 Gauge 为整数值
metric_group.gauge("prediction_acc", lambda: int(self.metric_predict_acc * 100))
# 3、分布 Distribution,用于报告某个值的分布信息(总和,计数,最小,最大和平均值)的指标,可以通过 update(n: int) 来更新值
# 目前 pyflink 只支持 Distribution 为整数值
self.metric_distribution_y = metric_group.distribution("metric_distribution_y")
# 4、仪表 Meters,用于汇报平均吞吐量,可以通过 mark_event(n: int) 函数来更新事件数。
# 统计过去 10 秒内的样本量、预测正确的样本量
self.metric_total_10_sec = metric_group.meter("total_10_sec", time_span_in_seconds=10)
self.metric_right_10_sec = metric_group.meter("right_10_sec", time_span_in_seconds=10)
def eval(self, *args):
"""
模型预测
:param args: 参数集合
:return:
"""
from sklearn.preprocessing import StandardScaler
import numpy as np
import redis
import pickle
redis = redis.StrictRedis(**self.redis_params)
# 加载训练好的StandardScaler,应用于单条记录的归一化
x_sc = pickle.loads(redis.get("x_sc"))
y_sc = pickle.loads(redis.get("y_sc"))
# 拼接参数
a = []
for u in args:
a.append(u)
# shape :(7,1)
print("a:", np.array(a))
# shape :(1,7)
b = np.transpose(np.array(a).reshape(-1, 1))
# 数据归一化
data = x_sc.transform(b)
y_pred = self.model.predict(data)
# 反归一化
trueY = y_sc.inverse_transform(y_pred)
# 返回预测结果
return trueY[0][0], trueY[0][1]
def load_model(self):
"""
加载模型,如果 redis 里存在模型,则优先从 redis 加载,否则初始化一个新模型
:return:
"""
import redis
import pickle
import logging
from keras.models import model_from_json
from redis import StrictRedis
logging.info('载入模型!')
redis = redis.StrictRedis(**self.redis_params)
model = None
try:
# 从redis中获取模型、应用pickle.loads加载模型
print(redis.get("NT_Parameter_Predict_Net"))
model = model_from_json(redis.get("NT_Parameter_Predict_Net"))
model.set_weights(pickle.loads(redis.get("NT_Parameter_Predict_Net_weights")))
model.summary()
except TypeError:
logging.error('Redis 内没有指定名称的模型,请先训练模型保存至Redis')
return model
##############################################
# 特征输入:
# hotime, before_ta, before_rssi, after_ta, after_rssil, nb_tath, nb_rssith
# 训练输出:
# nbrta nbrssithrd
#
##############################################
myKerasMLP = udf(myKerasMLP(), input_types=[DataTypes.STRING(), DataTypes.STRING(), DataTypes.STRING(), DataTypes.STRING(),
DataTypes.STRING(), DataTypes.STRING(), DataTypes.STRING()],
result_type=DataTypes.ARRAY(DataTypes.FLOAT()))
print('UDF 模型加载完成!')
# t_env.create_temporary_system_function('train_and_predict', myKerasMLP)
t_env.register_function('train_and_predict', myKerasMLP)
print('UDF 注册成功')
# ########################### 创建源表(source) ###########################
# 使用 MySQL-CDC 连接器从 MySQL 的 binlog 里提取更改。
# 该连接器非官方连接器,写法请参照扩展阅读 2。
t_env.execute_sql("""
CREATE TABLE source (
hotime STRING ,
before_ta STRING ,
before_rssi STRING ,
after_ta STRING ,
after_rssil STRING ,
nb_tath STRING ,
nb_rssith STRING ,
nbr_rssi STRING ,
nbr_ta STRING
) WITH (
'connector' = 'jdbc',
'url' = 'jdbc:mysql://localhost:3306/hadoop',
'table-name' = 'nt_data',
'username' = 'root',
'password' = '123456'
)
""")
t_env.execute_sql("""
CREATE TABLE print_table (
hotime STRING ,
before_ta STRING ,
before_rssi STRING ,
after_ta STRING ,
after_rssil STRING ,
nb_tath STRING ,
nb_rssith STRING ,
predict ARRAY<FLOAT >
) WITH (
'connector' = 'print'
)
""")
# ########################### ###########################
t_env.sql_query("""
SELECT
hotime ,
before_ta ,
before_rssi ,
after_ta ,
after_rssil ,
nb_tath ,
nb_rssith ,
train_and_predict(hotime, before_ta, before_rssi, after_ta, after_rssil, nb_tath, nb_rssith) predict
FROM
source
""").insert_into("print_table")
t_env.execute('NT重连预测参数')