一、代码实现
# KNN import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_iris # data = load_iris() url = 'https://www.gairuo.com/file/data/dataset/iris.data' data = pd.read_csv(url) data ["species"] = data["species"].map({"setosa":1,"virginica":0,"versicolor":2}) # 删除重复的数据 data = data.drop_duplicates() # 查看各个类别的鸢尾花 data["species"].value_counts() # KNN class KNN: """ 使用Python实现K近邻算法 实现分类 """ def __init__(self,k): """ 初始化方法 :parameter :param k:int 邻居的个数 """ self.k = k def fit(self, X, y): """ 训练方法 :parameter :param X: 类数组类型 ,形状为:[样本数量,特征数量] 待训练的样本特征 :param y: 类数组类型,形状为:[样本数量] 每个样本的目标值 (标签) :return: None """ # 将X,y转换成ndarray数组 self.X = np.asarray(X) self.y = np.asarray(y) def predict(self,X): """ 根据参数传递的样本,对样本进行预测。 :parameter :param X: 类数组类型 ,形状为:[样本数量,特征数量] 待训练的样本特征 :return: 数组的类型 (预测的结果) """ X = np.asarray(X) result = [] # 对ndarray数组进行遍历,每次取数组的一行。 for x in X: # 对于测试集的每个样本依次对训练集中的所有样本求距离 dis = np.sqrt(np.sum((x - self.X) ** 2,axis=1)) # 返回数组排序后,每个元素在原数组中的索引 index = dis.argsort() # 进行截断,只取k个元素[取距离最近的k个元素的索引] index = index[:self.k] # 返回数组中每个元素出现的次数。元素必须是非负的整数 count = np.bincount(self.y[index]) # 返回ndarray数组中最大的元素对应的索引。该索引是我们判定的类别。 # 最大元素就是出现次数最多的元素 result.append(count.argmax()) pass return np.asarray(result) pass # 提取出每个类别的鸢尾花数据 t0 = data[data["species"] == 0] t1 = data[data["species"] == 1] t2 = data[data["species"] == 2] # 打乱顺序 每次打乱顺序都是一样的 t0 = t0.sample(len(t0),random_state=0) t1 = t1.sample(len(t1),random_state=0) t2 = t2.sample(len(t2),random_state=0) # 构建训练集和测试集 concat:拼接数组 axis = 0 纵向拼接 train_X = pd.concat([t0.iloc[:40,:-1],t1.iloc[:40,:-1],t2.iloc[:40,:-1]],axis=0) train_y = pd.concat([t0.iloc[:40,-1],t1.iloc[:40,-1],t2.iloc[:40,-1]],axis=0) test_X = pd.concat([t0.iloc[40:,:-1],t1.iloc[40:,:-1],t2.iloc[40:,:-1]],axis=0) test_y = pd.concat([t0.iloc[40:,-1],t1.iloc[40:,-1],t2.iloc[40:,-1]],axis=0) # 创建KNN对象,进行训练与测试 knn = KNN(k=3) knn.fit(train_X,train_y) # 进行测试 result = knn.predict(test_X) # 预测正确率 print("预测正确率 =",np.sum(result == test_y)/len(result))
二、代码结果
预测正确率 = 0.9629629629629629