KNN 和 SVM 图片分类 任务 代码及细节分享

使用KNN (K-最近邻) 方法进行图像分类也是一个常见的选择。以下是

使用sklearnKNeighborsClassifier进行图像分类的Python脚本:

import os
import cv2
import numpy as np
import logging
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix# 设置日志
logging.basicConfig(filename='training_log.txt', level=logging.INFO, format='%(asctime)s - %(message)s')# 读取图像数据和标签
def load_images_from_folder(folder):images = []labels = []label = 0for subdir in os.listdir(folder):subpath = os.path.join(folder, subdir)if os.path.isdir(subpath):for filename in os.listdir(subpath):if filename.endswith(".jpg"):img_path = os.path.join(subpath, filename)img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)img_normalized = cv2.resize(img, (256, 256))  # 归一化图像大小为256x256images.append(img_normalized.flatten())labels.append(label)label += 1return images, labels# 主函数
def main():# train_folder = "YOUR_TRAIN_DATASET_FOLDER_PATH"  # 替换为你的训练集文件夹路径# test_folder = "YOUR_TEST_DATASET_FOLDER_PATH"    # 替换为你的测试集文件夹路径logging.info("Loading training data from %s", train_folder)X_train, y_train = load_images_from_folder(train_folder)logging.info("Loaded %d training samples", len(X_train))logging.info("Loading test data from %s", test_folder)X_test, y_test = load_images_from_folder(test_folder)logging.info("Loaded %d test samples", len(X_test))logging.info("Training KNeighborsClassifier...")knn = KNeighborsClassifier(n_neighbors=3)  # 使用3个邻居knn.fit(X_train, y_train)logging.info("Training completed.")y_pred = knn.predict(X_test)accuracy = accuracy_score(y_test, y_pred)logging.info("Test Accuracy: %f", accuracy)cm = confusion_matrix(y_test, y_pred)cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]logging.info("Normalized Confusion Matrix:")for row in cm_normalized:logging.info(" - ".join(map(lambda x: "{:.2f}".format(x), row)))if __name__ == "__main__":main()

KNeighborsClassifier来进行训练和预测。默认情况下,我设置了n_neighbors=3,这意味着使用3个最近的邻居进行投票。你可以根据需要调整这个参数。

选择合适的n_neighbors值是很重要的。一个较小的值(如1或3)可能会使模型对噪声更敏感,而一个较大的值可能会使模型更加平滑。通常,选择一个奇数的值可以避免在投票中出现平局的情况。

为了确定最佳的n_neighbors值,你可以使用交叉验证来评估不同值的性能,然后选择性能最好的那个值。

对于多类分类问题,选择合适的n_neighbors值是很重要的。以下是一些建议:

  1. 避免平局:为了避免在投票中出现平局的情况,通常建议选择一个奇数的n_neighbors值。但是,由于你有21个类,这个规则不再适用,因为即使是偶数的邻居数也不太可能导致平局。

  2. 默认值KNeighborsClassifier的默认n_neighbors值是5。这是一个常用的起始值,但可能不是最优的。

  3. 根据类别数量:一个常见的策略是选择一个接近总类别数量的n_neighbors值。在你的情况下,你有21个类,所以你可以考虑从这个数字开始,然后向上或向下调整。

  4. 交叉验证:最佳的方法是使用交叉验证来确定最佳的n_neighbors值。这意味着你会尝试多个不同的n_neighbors值,然后选择在验证集上性能最好的那个。

使用支持向量机(SVM)进行图像分类是一个强大的方法。sklearn提供了SVC(支持向量分类)类来实现SVM。

以下是使用sklearnSVC进行图像分类的Python脚本:

import os
import cv2
import numpy as np
import logging
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix# 设置日志
logging.basicConfig(filename='training_log_svm.txt', level=logging.INFO, format='%(asctime)s - %(message)s')# 读取图像数据和标签
def load_images_from_folder(folder):images = []labels = []label = 0for subdir in os.listdir(folder):subpath = os.path.join(folder, subdir)if os.path.isdir(subpath):for filename in os.listdir(subpath):if filename.endswith(".jpg"):img_path = os.path.join(subpath, filename)img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)img_normalized = cv2.resize(img, (256, 256))  # 归一化图像大小为256x256images.append(img_normalized.flatten())labels.append(label)label += 1return images, labels# 主函数
def main():# train_folder = "YOUR_TRAIN_DATASET_FOLDER_PATH"  # 替换为你的训练集文件夹路径# test_folder = "YOUR_TEST_DATASET_FOLDER_PATH"    # 替换为你的测试集文件夹路径logging.info("Loading training data from %s", train_folder)X_train, y_train = load_images_from_folder(train_folder)logging.info("Loaded %d training samples", len(X_train))logging.info("Loading test data from %s", test_folder)X_test, y_test = load_images_from_folder(test_folder)logging.info("Loaded %d test samples", len(X_test))logging.info("Training SVM...")svm = SVC(kernel='linear', C=1)  # 使用线性核和C=1svm.fit(X_train, y_train)logging.info("Training completed.")y_pred = svm.predict(X_test)accuracy = accuracy_score(y_test, y_pred)logging.info("Test Accuracy: %f", accuracy)cm = confusion_matrix(y_test, y_pred)cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]logging.info("Normalized Confusion Matrix:")for row in cm_normalized:logging.info(" - ".join(map(lambda x: "{:.2f}".format(x), row)))if __name__ == "__main__":main()

使用了SVC类来进行训练和预测。我选择了线性核(kernel='linear')和正则化参数C=1(C=1)。这些参数可能需要根据你的数据进行调整。

SVM有许多可调参数,如核函数、C值和其他参数。

科学地调整SVM的参数:

  1. 交叉验证:使用交叉验证是评估不同参数组合性能的关键。例如,你可以使用sklearnGridSearchCV来自动进行参数搜索和交叉验证。

  2. 核函数选择

    • 线性核:当特征数量很大或数据线性可分时使用。
    • RBF核:当数据有非线性边界时使用。它有一个参数gamma需要调整。
    • 多项式核:当数据的边界是多项式形式时使用。它有degreecoef0两个参数需要调整。
    • Sigmoid核:在某些特定的数据集上可能有效,但不常用。
  3. C值:C是SVM的正则化参数。较小的C值会导致较大的间隔,但可能允许一些误分类。较大的C值会尝试最大化训练数据的正确分类,但可能导致过拟合。

  4. gamma值:仅对RBF核和多项式核有效。它定义了单个训练样本的影响范围。较小的值意味着更大的影响范围,较大的值意味着更小的影响范围。

  5. 使用GridSearchCV:这是一个自动化的参数搜索方法,它会尝试所有给定的参数组合,并使用交叉验证来评估每种组合的性能。

from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC# 定义参数网格
param_grid = {'C': [0.1, 1, 10, 100],'gamma': [1, 0.1, 0.01, 0.001],'kernel': ['linear', 'rbf', 'poly', 'sigmoid']
}# 使用GridSearchCV进行参数搜索
grid_search = GridSearchCV(SVC(), param_grid, refit=True, verbose=3, cv=5)
grid_search.fit(X_train, y_train)# 打印最佳参数
print("Best parameters found: ", grid_search.best_params_)
  1. 考虑数据的规模和分布:如果数据集很大,可能需要使用随机样本或考虑使用LinearSVC,它是专门为大数据集优化的。

  2. 特征缩放:确保所有特征都在相同的尺度上。SVM对特征的尺度非常敏感,所以通常使用StandardScaler来缩放数据。

  3. 评估指标:确保你使用了合适的评估指标来评估模型的性能。例如,对于不平衡的数据集,考虑使用F1分数或AUC-ROC曲线,而不仅仅是准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/116432.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MQ常见的问题(kafka保证消息不丢失)

MQ常见的问题 1,mq如何避免消息堆积问题。 消息堆积:生产者的生产速率远远大于消费者的消费速率,使消息大批量的堆积在消息队列。 解决方案:1,提升消费者的消费速率(增加消费者集群) 2&…

强化学习代码实战(2) --- 多臂赌博机

目录 前言 1.Python基础 2.Numpy基础 3.多臂赌博机 参考文献 前言 本文内容来自于南京大学郭宪老师在博文视点学院录制的视频,课程仅9元地址,配套书籍为深入浅出强化学习 编程实战 郭宪地址。 1.Python基础 1. print() 可以用该语句查看当前数据的情…

2023版 STM32实战11 SPI总线读写W25Q

SPI全称 英文全称:Serial peripheral Interface 串行外设接口 SPI特点 -1- 串行(逐bit传输) -2- 同步(共用时钟线) -3- 全双工(收发可同时进行) -4- 通信只能由主机发起(一主,多从机) 开发使用习惯和理解 -1- CS片选一般配置为软件控制 -2- 片选低电平有效,从…

2023.10.22 关于 定时器(Timer) 详解

目录 引言 标准库定时器使用 自己实现定时器的代码 模拟实现的两大方面 核心思路 重点理解 自己实现的定时器代码最终代码版本 引言 定时器用于在 预定的时间间隔之后 执行特定的任务或操作 实例理解: 在服务器开发中,客户端向服务器发送请求&#…

Spring Cloud 之 GateWay简介及简单DEMO的搭建

(1)Filter(过滤器): 和Zuul的过滤器在概念上类似,可以使用它拦截和修改请求,并且对上游的响应,进行二次处理。过滤器为org.springframework.cloud.gateway.filter.GatewayFilter类的…

微信小程序WeUI项目weui-miniprogram如何运行起来?

微信小程序WeUI项目weui-miniprogram如何运行起来? 解决方法: 1、下载 https://github.com/wechat-miniprogram/weui-miniprogram 2、在项目根目录weui-miniprogram-master执行以下命令安装依赖: npm install 3、继续执行编译命令: npm r…

Unity3D 基础——鼠标悬停更改物体颜色,移走恢复

方法介绍 【unity学习笔记】OnMouseEnter、OnMouseOver、OnMouseExit_unity onmouseover_一白梦人的博客-CSDN博客https://blog.csdn.net/a1208498468/article/details/117856445 GetComponent()详解_getcomponet<> 动态名称-CSDN博客https://blog.csdn.net/kaixindrag…

C# 写入文件比较

数据长度&#xff1a;128188个long BinaryWriter每次写一个long 耗时14.7828ms StreamWriter每次写一个long 耗时44.0934 ms FileStream每次写一个long 耗时20.5142 ms FileStream固定chunk写入&#xff0c;循环操作数组&#xff0c;耗时13.4126 ms byte[] chunk new byte[d…

国际物流报关流程

国际物流报关流程如下&#xff1a; 准备资料&#xff1a;根据国家和地区的要求&#xff0c;准备相关的报关资料&#xff0c;包括进出口货物清单、运输合同、发票、装箱单等文件。 海关申报&#xff1a;将准备好的报关资料递交给海关申报&#xff0c;海关会根据报关单上的信息对…

1. 前缀码判定

题目 前缀码&#xff1a;任何一个字符的编码都不是同一字符集中另一个字符的编码的前缀。 请编写一个程序&#xff0c;判断输入的n个由1和0组成的编码是否为前缀码。如果这n个编码是前缀码&#xff0c;则输出"YES”&#xff1b;否则输出第一个与前面编码发生矛盾的编码。…

牛客网刷题-(2)

&#x1f308;write in front&#x1f308; &#x1f9f8;大家好&#xff0c;我是Aileen&#x1f9f8;.希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流. &#x1f194;本文由Aileen_0v0&#x1f9f8; 原创 CSDN首发&#x1f412; 如…

jvm垃圾回收算法有哪些及原理

目录 垃圾回收器1 Serial收集器2 Parallel收集器3 ParNew收集器4 CMS收集器5 G1回收器三色标记算法标记算法的过程三色标记算法缺陷多标漏标 垃圾回收器 垃圾回收机制&#xff0c;我们已经知道什么样的对象会成为垃圾。对象回收经历了什么——垃圾回收算法。那么谁来负责回收垃…

华为云 CodeArts Snap 智能编程助手 PyCharm 插件安装与使用指南

1 插件安装下载 1.1 搜索插件 打开 PyCharm&#xff0c;选择 File&#xff0c;点击 Settings。 选择 Plugins&#xff0c;点击 Marketplace&#xff0c;并在搜索框中输入 Huawei Cloud CodeArts Snap。 1.2 安装插件 如上图所示&#xff0c;点击 Install 按钮安装 Huawei Cl…

2. 计算WPL

题目 Huffman编码是通信系统中常用的一种不等长编码&#xff0c;它的特点是&#xff1a;能够使编码之后的电文长度最短。 更多关于Huffman编码的内容参考教材第十章。 输入&#xff1a; 第一行为要编码的符号数量n 第二行&#xff5e;第n1行为每个符号出现的频率 输…

npm install报错 缺少python

报错信息&#xff1a; Building:E:tolsnvmnodesnodeexe : ode emos ant-desig-we-eos odemodules node-gypbintnode-gp.s rebld -verbose -Libsass_ext --Libsas_cflags- lags --libsass_librarygyp info it worked if it ends with ok gyp verb cli [ gyp verb cliE: toolsnv…

一步一步认知机器学习

1&#xff0c;前言 之前学习并且实操了一些算法框架用来探索相关方向的可能性&#xff0c;但是总不了解相关的步骤。因为一步一步按照别人给出的步骤去操作&#xff0c;解决一些操作时出现的问题&#xff0c;基本可以达到目的。但是这个也基本限制了在那个框架而已。对于算法还…

Socket 是什么? 总结+详解

文章摘要&#xff1a;Socket 套接字 编程接口 netstat-ano 创建 建立连接 断开 删除 1.Socket 是什么 Socket &#xff1a;套接字&#xff08;socket&#xff09;是一个抽象层&#xff0c;应用程序可以通过它发送或接收数据&#xff0c;可对其进行像对文件一样的打开、读写和…

SpringBoot的日志系统(日志分组、文件输出、滚动归档)

[toc](目录) > SpringBoot3需要jdk17 # 1. 简介 1. Spring5及以后Spring自己实现了commons-logging&#xff0c;来作为内部的日志。日志的jar包是org.springframework:spring-jcl:6.0.10。查看org.apache.commons.logging.LogAdapter Java package org.apache.commons.log…

如何把Elasticsearch中的数据导出为CSV格式的文件

前言| 本文结合用户实际需求用按照数据量从小到大的提供三种方式从ES中将数据导出成CSV形式。本文将重点介Kibana/Elasticsearch高效导出的插件、工具集&#xff0c;通过本文你可以了解如下信息&#xff1a; 1&#xff0c;从kibana导出数据到csv文件 2&#xff0c;logstash导…

VMware vCenter Server 6.7安装过程记录

0、前言 最近由于一些原因需要安装测试VMware ESXi&#xff0c;无奈所有服务器都是十几年前的&#xff0c;配置低也不支持。后来通过VMware兼容性列表查询&#xff0c;快要放弃的时候发现唯一一台Dell R420&#xff0c;如获至宝。通过查询得知最高支持到6.5 U3&#xff0c;好在…