直接上代码:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# 序列转为稀疏矩阵
# 输入:序列
# 输出:indices非零坐标点,values数据值,shape稀疏矩阵大小
import numpy as np
def sparse_tuple_from(sequences, dtype=np.int32):
indices = []
values = []
for n, seq in enumerate(sequences):
# 标记有数据的位置 seq表示每一个[] n表示行数
#print('n:',n)
#zip()表示迭代后压缩
indices.extend(zip([n] * len(seq), range(len(seq))))
#print('n:',[n] * len(seq))
#输出[0,0,0]
#c=zip([n] * len(seq))
#d=range(len(seq))
#print('C:',c)
#print('D:', d)
#print('indices:',indices)
# 记录数据的真实值
values.extend(seq)
print('values:',values)
#numpy.asarray(arr, dtype=None, order=None)为数据转换形式
indices = np.asarray(indices, dtype=np.int64)
values = np.asarray(values, dtype=dtype)
#计算矩阵的大小
shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)
#遍历indices,取出最大的列+1
print('max:',np.asarray(indices).max(0)[1] + 1)
#4 6
print('总共:',len(sequences), np.asarray(indices).max(0)[1] + 1)
return indices, values, shape
batch_label = []
batch_label.append([56, 45, 2347])
batch_label.append([1, 6, 7, 13, 98])
batch_label.append([2, 6, 4, 32, 12, 78])
batch_label.append([15, 3])
print('batch_label:',batch_label)
batch_label = sparse_tuple_from(batch_label)
print(batch_label)
代码可运行。
扩展:可以读取csv文件,生成稀疏矩阵