当前位置: 移动技术网 > 科技>人工智能>机器学习 > 【机器学习基础】KNN实现手写数字识别

【机器学习基础】KNN实现手写数字识别

2020年09月29日  | 移动技术网科技  | 我要评论
KNN(K-近邻算法)实现手写数字识别K-近邻算法是一种监督机器学习分类算法。它的思想很简单,计算测试点与样本集合之间的欧几里得距离(直线距离),找到测试点与样本集合中距离最近的k个样本集,这k个样本集合中对应的最多的分类就可以作为测试点的分类。本文使用的数据集, 解压后的testDigits文件夹为测试文件,trainingDigits为训练文件# 手写数字识别import numpy as npimport osclass DigitRecoginze(): def _

KNN(K-近邻算法)实现手写数字识别

K-近邻算法是一种监督机器学习分类算法。它的思想很简单,计算测试点与样本集合之间的欧几里得距离(直线距离),找到测试点与样本集合中距离最近的k个样本集,这k个样本集合中对应的最多的分类就可以作为测试点的分类。
本文使用的数据集, 解压后的testDigits文件夹为测试文件,trainingDigits为训练文件
文中的数据为像素图像,保存在txt文件中。图像的像素为32*32。
数字0图像像素矩阵

# 手写数字识别 import numpy as np import os class DigitRecoginze(): def __init__(self): self.label = None self.train_set = None self.test_set = None def img2vector(self, filename): # 将像素图像转化为向量 # 图像像素为32*32 image_vector = np.zeros((1, 1024)) f = open(filename, 'r') for i in range(32): line = f.readline() for j in range(32): image_vector[0, 32*i + j] = int(line[j]) return image_vector def import_data(self,filepath): # 导入数据 data_list = os.listdir(filepath) data_list_number = len(data_list) # 导入label数据 return_label = np.zeros((data_list_number, 1)) for i in range(data_list_number): return_label[i] = (data_list[i].strip().split('_'))[0] # 导入data数据 return_data_set = np.zeros((data_list_number, 1024)) for i in range(data_list_number): return_data_set[i] = self.img2vector(os.path.join(filepath, data_list[i])) return return_data_set, return_label def train_set_normalize(self, train_set): # 归一化 data_range = np.max(train_set) - np.min(train_set) return (train_set - np.min(train_set)) / data_range def single_train(self, train_set, testcase_x, train_label, k = 5): # 计算距离 train_set_size = train_set.shape[0] diff_mat = np.tile(testcase_x, (train_set_size, 1)) - train_set
        distances = (diff_mat**2).sum(axis=1)**0.5 # print(distances) # 排序,这里排序结果表示他的排序位置 distances_sorted = distances.argsort() class_result = {} # 找出k个点 for i in range(k): now_label = int(train_label[distances_sorted[i]][0]) # print(now_label) class_result[now_label] = class_result.get(now_label, 0)+ 1 # 找出最近最多的点 max_num = 0 result_label = 0 for single_result in class_result: if class_result[single_result] > max_num: max_num = class_result[single_result] result_label = single_result return result_label # 训练 def test(self, train_set_filepath, test_set_filepath, k = 5): # 导入数据 train_set, train_label = self.import_data(train_set_filepath) train_set = self.train_set_normalize(train_set) test_set, test_label = self.import_data(test_set_filepath) test_set = self.train_set_normalize(test_set) error_number = 0 all_number = test_set.shape[0] # 对于每一个测试样本进行测试 for i in range(all_number): result_label = self.single_train(train_set, test_set[i,:], train_label, k) if result_label != int(test_label[i][0]): error_number = error_number + 1 print("testcase %d: knn send back %d, the real class is %d" %(i, result_label, int(test_label[i][0]))) print("error ratio = %f" %(float(error_number)/float(all_number))) # 数据位置修改为自己的 FILE_PATH_TEST = r'2020\ML\ML_action\1.KNN\data\digit\testDigits' FILE_PATH_TRAIN = r'2020\ML\ML_action\1.KNN\data\digit\trainingDigits' _dr = DigitRecoginze() _dr.test(FILE_PATH_TRAIN,FILE_PATH_TEST) 

参考文献:
https://github.com/apachecn/AiLearning/blob/master/docs/ml/2.k-%E8%BF%91%E9%82%BB%E7%AE%97%E6%B3%95.md

本文地址:https://blog.csdn.net/qq_37753409/article/details/108863897

如您对本文有疑问或者有任何想说的,请点击进行留言回复,万千网友为您解惑!

相关文章:

验证码:
移动技术网