Java教程

使用kNN近邻算法识别水果种类(学习笔记)

本文主要是介绍使用kNN近邻算法识别水果种类(学习笔记),对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

1. 查看数据集,了解大概信息

import pandas as pd
fruits_df = pd.read_table('fruit_data_width_colors.txt')
print(fruits_df.head(10))
print('样本个数:',len(fruits_df))

输出信息为:

    fruit_label fruit_name     fruit_subtype  mass  width  height  color_score
0             1      apple      granny_smith   192    8.4     7.3         0.55
1             1      apple      granny_smith   180    8.0     6.8         0.59
2             1      apple      granny_smith   176    7.4     7.2         0.60
3             2   mandarin          mandarin    86    6.2     4.7         0.80
4             2   mandarin          mandarin    84    6.0     4.6         0.79
5             2   mandarin          mandarin    80    5.8     4.3         0.77
6             2   mandarin          mandarin    80    5.9     4.3         0.81
7             2   mandarin          mandarin    76    5.8     4.0         0.81
8             1      apple          braeburn   178    7.1     7.8         0.92
9             1      apple          braeburn   172    7.4     7.0         0.89
样本个数: 59

我们可以观察到数据分别为水果标签,水果名字,水果种类,水果质量,宽度,高度,颜色分数

我们继续通过做图来了解各个水果的样本量

import seaborn as sns
sns.countplot(fruits_df['fruit_name'],label="Count")
plt.show()

可以画出各个水果的数据量

image.png

image.png

2. 数据预处理

在预处理阶段我们要将数据划分为训练数据集和测试数据集,同时为了之后方便对比预测标签和真实标签我们要将水果标签和水果名字进行配对。

from sklearn.model_selection import train_test_split
#标签配对
fruit_name_dict = dict(zip(fruits_df['fruit_label'],fruits_df['fruit_name']))
print(fruit_name_dict)
#划分训练集和测试集
X = fruits_df[['mass','width','height','color_score']]
y = fruits_df['fruit_label']

X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=1/4,random_state=0)
print('数据集样本数:{},训练样本数:{},测试集样本数:{}'.format(len(X),len(X_train),len(X_test)))

返回输出结果:

{1: 'apple', 2: 'mandarin', 3: 'orange', 4: 'lemon'}
数据集样本数:59,训练样本数:44,测试集样本数:15

这里呢我们也可以做一下变量关系的可视化为下一步做铺垫:

sns.pairplot(data=fruits_df,hue='fruit_name',vars=['mass','width','height','color_score'],diag_kind='hist')
plt.show()

可以画出各个变量之间的关系

image.png

image.png

3. 建立模型,训练模型

我们这里使用的是sklearn中的kNN近邻算法模型,k-近邻算法是一种基于样本的算法(无参模型),算法的步骤大致为:

  1. 计算出测试样本和所有训练样本的距离
  2. 为测试样本选择k个与其距离最小的训练样本
  3. 统计出k个训练样本中大多数样本所属的分类
  4. 这个分类就是待分类数据所属的分类

那对于这个算法模型我们需要人为地给出一个k值,也就是通过几个最近的数据点来进行判断,这个就可以根据我们之前做出的可视化图形来进行判断,那如何得到最优邻点个数呢?
我们可以通过交叉验证的方式

一方面我们可以通过可视化图形来确定,一方面我们可以多尝试几个k值,来看看预测准确率

我们先将k值设置为5

from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=5) #建立模型
knn.fit(X_train,y_train) #训练模型

我们来测试一下模型:

from sklearn.metrics import accuracy_score
y_pred = knn.predict(X_test)
print('预测标签:',y_pred)
print('真实标签:', y_test.values)
acc = accurancy_score(y_test,y_pred)
print('准确率:',acc)

输出结果:

预测标签: [3 1 4 4 1 1 3 3 1 4 2 1 3 1 4]
真实标签: [3 3 4 3 1 1 3 4 3 1 2 1 3 3 3]
准确率: 0.5333333333333333

4. 调整模型

我们来测试一下每个k值所对应的测试准确率,看看k值取为多少的时候准确率最高:

k_range = range(1,20)
acc_scores = []

for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_train,y_train)
    acc_scores.append((knn.score(X_test,y_test)))

plt.figure()
plt.xlabel('k')
plt.ylabel('accuracy')
plt.plot(k_range,acc_scores,marker='o')
plt.xticks([0,5,11,15,21])
plt.show()

返回图片为:

image.png

image.png

可以看出当k取值为6的时候准确率最高,为0.6

通过这个实践我们可以了解到k-近邻算法
需要注意的问题:

  1. 相似性度量
  2. 近邻点个数,通过交叉验证得到最优近邻点个数

kNN的优缺点:
优点:算法简单直观,易于实现,且不需要额外的数据,只依靠数据本身
缺点:计算量较大,分类速度慢,需要预先指定k值

数据文件与代码(百度网盘)
提取码:br30

这篇关于使用kNN近邻算法识别水果种类(学习笔记)的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!