前言

前面爬取了一些手机的尺寸信息,并将其制作成了数据集,接下来就准备将这些数据集进行分类。

python神经网络实现数据多分类 神经网络多分类问题_ci


通过数据分布我们还是能够比较明显地发现这些数据有比较明显的区别。而谈到分类问题解决方案也有很多,比如神经网络(NN)、感知机、决策树、支持向量机(SVM)等。

这里我首先使用了神经网络对数据集进行分类,但经过测试,在数据集很小(不足200)的情况下,神经网络很难拟合出足够好的参数,最终准确率不到90%。最后还是决定用传统方法,这里选用的就是支持向量机,随机抽取每类样本作为训练集的情况下,SVM基本上能够达到全对,除非偶尔几次边际几个点被选做了测试集。

接下来,讲一下神经网络(NN)和SVM的实现,但最后应用还是推荐SVM。

主要流程

准备工作

首先下载下面几个库,都下载最新的好了。不过tf2.0很多新特性还是很麻烦,看个人习惯了,我还是决定尝试一下新东西。

tensorflow
scikit-learn
pandas
matplotlib

神经网络(Neural Network)

主要使用tensorflow2.0里面的keras进行实现的。
这里是跟着tf官网学习了一下(见参考链接2),我们的数据一个是height、width都是数值类型的。主要步骤有:

  1. 分割训练、测试、验证集(验证集其实这么小的数据集没太大必要)
  2. 构建输入层feature_layer(包含height、width)
  3. 构建神经网络(一层输入层 + 两个卷积层 + softmax输出三类,keras实现)
  4. 进行训练(设置了一千轮)
def df_to_dataset(dataframe, shuffle=True, batch_size=32):
    dataframe = dataframe.copy()
    labels = dataframe.pop('label')
    ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(dataframe))
        ds = ds.batch(batch_size)
    return ds


if __name__ == '__main__':
    # data = read_data()
    dataframe = pd.read_csv('data.csv')

    train, test = train_test_split(dataframe, test_size=0.1)
    train, val = train_test_split(train, test_size=0.1)
    batch_size = 100
    train_ds = df_to_dataset(train, batch_size=batch_size)
    val_ds = df_to_dataset(val, batch_size=batch_size)
    test_ds = df_to_dataset(test, batch_size=batch_size)

    for feature_batch, label_batch in train_ds.take(1):
        print('Every feature:', list(feature_batch.keys()))
        print('A batch of height:', feature_batch['height'])
        print('A batch of width:', feature_batch['width'])
        print('A batch of label:', label_batch)

    feature_columns = []

    height = feature_column.numeric_column('height')
    width = feature_column.numeric_column('width')

    feature_columns.append(height)
    feature_columns.append(width)

    feature_layer = tf.keras.layers.DenseFeatures(feature_columns)

    model = tf.keras.Sequential([
        feature_layer,
        tf.keras.layers.Dense(4, activation='sigmoid'),
        tf.keras.layers.Dense(4, activation='sigmoid'),
        tf.keras.layers.Dense(3, activation='softmax')
    ])

    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy'],
                  run_eagerly=True)

    model.fit(train_ds,
              validation_data=val_ds,
              epochs=1000)

    loss, accuracy = model.evaluate(test_ds)
    print("Accuracy", accuracy)

完成神经网络后得出准确率大概不到90%,不管怎么添加训练轮数都提升不了,判断应该是已经收敛了不存在提升了。神经网络一定程度上很依赖数据驱动,从这个实验也可以看出,神经网络不适用小数据集。

支持向量机(Support Vector Machine,SVM)

既然最流行的神经网络不适用,我们的目光就转向了很多传统的方法。虽然传统方法在上限上无法与神经网络匹敌,但是在小数据集的情况下传统方法能够有更好的表现。最经典的传统方法当然是SVM了,这里就选取了SVM进行实现。

主要使用了scikit-learn库进行实现的。

测试SVM

scikit-learn对SVM封装得已经好了,我们基本上主要准备好合适的数据格式就能使用SVM了。
主要问题是选取核(kernel),我最后选的是默认的高斯核rbf,当然也可以选线性核linear、多项式核poly

def test_svm():
    # import data
    dataframe = pd.read_csv('data.csv')

    h = .02  # step size in the mesh
    train, test = train_test_split(dataframe, test_size=0.1)
    train_X = train.iloc[:, :-1].values * 1.1  # add 10% border
    train_Y = train.iloc[:, -1].values
    test_X = test.iloc[:, :-1].values * 1.1  # add 10% border
    test_Y = test.iloc[:, -1].values

    clf = svm.SVC(kernel='rbf')
    clf.fit(train_X, train_Y)
    judge_arr = clf.predict(test_X)-test_Y
    count = 0
    for res in judge_arr:
        if res != 0:
            count = count + 1
    print('accuracy:', (len(judge_arr)-count) / len(judge_arr))

保存模型

使用sklearn的joblib

# save the model
joblib.dump(clf, '../model/%s.pkl' % kernel)

调用模型预测

这边封装了一下,可以选取模型。

def predict_data(test_X, model='rbf'):
    clf = joblib.load('../model/%s.pkl' % model)
    result = clf.predict(test_X)
    return result

可视化*

可视化可以看参考链接4,效果大致如下:

python神经网络实现数据多分类 神经网络多分类问题_python神经网络实现数据多分类_02


可以看到基本上完成了划分。

参考链接

  1. https://www.tensorflow.org/tutorials/structured_data/feature_columns?hl=zh-cn#%E5%B0%86_dataframe_%E6%8B%86%E5%88%86%E4%B8%BA%E8%AE%AD%E7%BB%83%E3%80%81%E9%AA%8C%E8%AF%81%E5%92%8C%E6%B5%8B%E8%AF%95%E9%9B%86
  2. http://scikit-learn.sourceforge.net/stable/modules/generated/sklearn.svm.SVC.html
  3. http://scikit-learn.sourceforge.net/stable/auto_examples/svm/plot_custom_kernel.html