支持向量机也是一种既可以处理分类问题,也可以处理回归问题的算法。
关于支持向量机在回归问题上的应用,请参考:TODO
支持向量机分类广泛应用于图像识别、文本分类、生物信息学(例如基因分类)、手写数字识别等领域。
支持向量机的主要思想是找到一个超平面,将不同类别的样本最大化地分隔开。
超平面的位置由支持向量决定,它们是离分隔边界最近的数据点。
对于二分类问题,SVM寻找一个超平面,使得正例和支持向量到超平面的距离之和等于反例和支持向量到超平面的距离之和。
如果这个等式不成立,SVM将寻找一个更远离等式中不利样本的超平面。
下面的示例,演示了支持向量机分类算法在图像识别上的应用。
这次的样本使用的是scikit-learn
自带的手写数字数据集。
import matplotlib.pyplot as plt from sklearn import datasets # 加载手写数据集 data = datasets.load_digits() _, axes = plt.subplots(nrows=2, ncols=4, figsize=(10, 6)) for ax, image, label in zip(np.append(axes[0], axes[1]), data.images, data.target): ax.set_axis_off() ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest") ax.set_title("目标值: {}".format(label))
这里显示了其中的几个手写数字,这个数据集总共有大约1700多个手写数字。
样本数据中,手写数字的图片存储为一个 8x8
的二维数组。
比如:
data.images[0] # 运行结果 array([[ 0., 0., 5., 13., 9., 1., 0., 0.], [ 0., 0., 13., 15., 10., 15., 5., 0.], [ 0., 3., 15., 2., 0., 11., 8., 0.], [ 0., 4., 12., 0., 0., 8., 8., 0.], [ 0., 5., 8., 0., 0., 9., 8., 0.], [ 0., 4., 11., 0., 1., 12., 7., 0.], [ 0., 2., 14., 5., 10., 12., 0., 0.], [ 0., 0., 6., 13., 10., 0., 0., 0.]])
所以,在分割训练集和测试集之前,我们需要先将手写数字的的存储格式从 8x8
的二维数组转换为 64x1
的一维数组。
from sklearn.model_selection import train_test_split n_samples = len(data.images) X = data.images.reshape((n_samples, -1)) y = data.target # 分割训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)
按照9:1的比例来划分训练集和测试集。
然后用scikit-learn
中的SVC
模型来训练样本:
from sklearn.svm import SVC # 定义 reg = SVC() # 训练模型 reg.fit(X_train, y_train)
模型的训练效果:
# 在测试集上进行预测 y_pred = reg.predict(X_test) correct_pred = np.sum(y_pred == y_test) print("预测正确率:{:.2f}%".format(correct_pred / len(y_pred) * 100)) # 运行效果 预测正确率:98.89%
正确率非常高,下面我们看看没识别出来的手写数字是哪些。
wrong_pred = [] for i in range(len(y_pred)): if y_pred[i] != y_test[i]: wrong_pred.append(i) print(wrong_pred) # 运行效果 [156, 158]
在测试集中,只有两个手写数字识别错了。
我面看看识别错的2个手写数字是什么样的:
_, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 3)) for i in range(2): idx = wrong_pred[i] image = X_test[idx].reshape(8, 8) axes[i].set_axis_off() axes[i].imshow(image, cmap=plt.cm.gray_r, interpolation="nearest") axes[i].set_title("预测值({}) 目标值({})".format(y_pred[idx], y_test[idx]))
可以看出,即使人眼去识别,这两个手写数字也不太容易识别。
支持向量机分类算法的优势有:
它的劣势主要有: