量化交易之机器学习篇 - 实现K近邻模型的两种方式


# 导入相关模块import numpy as npfrom collections import Counter
import matplotlib.pyplot as pltfrom sklearn import datasets
from sklearn.utils import shuffledef load_data():iris = datasets.load_iris()# 打乱数据后的数据和标签X, y = shuffle(iris.data, iris.target, random_state=13)# 数据转换为 flout32 格式X = X.astype(np.float32)# 简单划分训练集和测试集, 训练样本 - 测试样本比例为 7:3offset = int(X.shape[0] * 0.7)X_train, y_train = X[:offset], y[:offset]X_test, y_test = X[offset:], y[offset:]# 将标签转换为竖向量y_train = y_train.reshape((-1, 1))y_test = y_test.reshape((-1, 1))return X_train, X_test, y_train, y_testdef compute_distances(X, X_train):"""定义欧氏距离函数X: 测试样本实例矩阵X_train: 训练样本实例矩阵"""# 测试实例样本num_test = X.shape[0]# 训练实例样本量num_train = X_train.shape[0]# 基于训练和测试维度的欧氏距离初始化dists = np.zeros((num_test, num_train))# 测试样本鱼训练样本的矩阵点乘M = np.dot(X, X_train.T)# 测试样本矩阵平方te = np.square(X).sum(axis=1)# 训练样本矩阵平方tr = np.square(X_train).sum(axis=1)# 计算欧式距离dists = np.sqrt(-2 * M + tr + np.matrix(te).T)return distsdef predict_labels(y_train, dists, k=1):"""定义预测函数:param y_train: 训练集标签:param dists: 测试集与训练集之间的欧氏距离矩阵:param k: k值:return: 测试集预测结果"""# 测试样本量num_test = dists.shape[0]# 初始化测试集预测结果y_pred = np.zeros(num_test)# 遍历for i in range(num_test):# 初始化最近邻列表closest_y = []# 按 欧式距离矩阵排序后取索引, 并用训练集标签按排序后的索引取值# 最后展平列表# 注意 np.argsort 函数的用法labels = y_train[np.argsort(dists[i, :])].flatten()# 取最近的k个值closest_y = labels[0:k]# 对最近的k个值进行计数统计# 这里注意 collections 模块中的计数器 Counter 的用法c = Counter(closest_y)# 取计数最多的那个类别y_pred[i] = c.most_common(1)[0][0]return y_predif __name__ == '__main__':# 导入 sklearn iris 数据集X_train, X_test, y_train, y_test = load_data()dists = compute_distances(X=X_test, X_train=X_train)y_test_pred = predict_labels(y_train=y_train, dists=dists, k=1)y_test_pred = y_test_pred.reshape((-1, 1))# 找出预测正确的实例num_correct = np.sum(y_test_pred == y_test)# 计算分类准确率accuracy = float(num_correct) / X_test.shape[0]print('KNN Accuracy based on NumPy: ' + str(accuracy))# 用五折交叉验证寻找最优 k值# 五折num_folds = 5# 候选 k 值k_choices = [1, 3, 5, 8, 10, 12, 15, 20, 50, 100]X_train_folds = []y_train_folds = []# 训练数据划分X_train_folds = np.array_split(X_train, num_folds)# 训练标签划分y_train_folds = np.array_split(y_train, num_folds)k_to_accuracies = {}# 表里所有候选k值for k in k_choices:# 五折遍历for fold in range(num_folds):# 为 传入的训练集单独划分出一个验证集作为测试集validation_X_test = X_train_folds[fold]validation_y_test = y_train_folds[fold]temp_X_train = np.concatenate(X_train_folds[:fold] + X_train_folds[fold+1:])temp_y_train = np.concatenate(y_train_folds[:fold] + y_train_folds[fold+1:])# 计算距离temp_dists = compute_distances(X=validation_X_test, X_train=temp_X_train)temp_y_test_pred = predict_labels(temp_y_train, temp_dists, k=k)temp_y_test_pred = temp_y_test_pred.reshape((-1, 1))# 查看分类准确率num_correct = np.sum(temp_y_test_pred == validation_y_test)num_test = validation_X_test.shape[0]accuracy = float(num_correct) / num_testk_to_accuracies[k] = k_to_accuracies.get(k, []) + [accuracy]for k in sorted(k_to_accuracies):for accuracy in k_to_accuracies[k]:print(f'k = {k}, accuracy = {accuracy}')# 打印不同k值, 不同折数下的分类准确率for k in k_choices:# 取出第k个k值的分类准确率accuracies = k_to_accuracies[k]# 绘制不同k值下分类准确率的散点图plt.scatter([k] * len(accuracies), accuracies)# 计算分类准确率均值并排序accuracies_mean = np.array([np.mean(v) for k, v in sorted(k_to_accuracies.items())])# 计算分类准确率标准差并排序accuracies_std = np.array([np.std(v) for k, v in sorted(k_to_accuracies.items())])# 绘制有质询区间的误差棒图plt.errorbar(k_choices, accuracies_mean, yerr=accuracies_std)# 绘图标题plt.title('Cross-validation on k')# x轴标签plt.xlabel('k')# y轴标签plt.ylabel('Cross-validation accuracy')plt.show()

if __name__ == '__main__':# 导入 KNeighborsClassifier 模块from sklearn.neighbors import KNeighborsClassifier# 创建 k近邻实例neigh = KNeighborsClassifier(n_neighbors=10)# k 近邻模型拟合neigh.fit(X_train, y_train)# k 近邻模型预测y_pred = neigh.predict(X_test)# 预测结果数组重塑y_pred = y_pred.reshape((-1, 1))# 统计预测正确的个数num_correct = np.sum(y_pred == y_test)# 计算分类准确率accuracy = float(num_correct) / X_test.shape[0]print(f'KNN Accuracy based on sklearn: {accuracy}.')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/38130.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习】在【Pycharm】中的应用:【线性回归模型】进行【房价预测】

专栏:机器学习笔记 pycharm专业版免费激活教程见资源,私信我给你发 python相关库的安装:pandas,numpy,matplotlib,statsmodels 1. 引言 线性回归(Linear Regression)是一种常见的统计方法和机器学习算法&a…

弹性力学讲义

弹性力学讲义 1. 基本假设和一些概念2. 应力3. 二维应力状态与摩尔库伦屈服准则 1. 基本假设和一些概念 力学:变形体力学–固体力学和流体力学(连续介质力学) 刚体力学–理论力学(一般力学) 物理受理后:要…

Facebook的投流技巧有哪些?

相信大家都知道Facebook拥有着巨大的用户群体和高转化率,在国外社交推广中的影响不言而喻。但随着Facebook广告的竞争越来越激烈,在Facebook广告上获得高投资回报率也变得越来越困难。IPIDEA代理IP今天就教大家如何在Facebook上投放广告的技巧&#xff0…

python–基础篇–正则表达式–是什么

文章目录 定义一:正则表达式就是记录文本规则的代码定义一:正则表达式是一个特殊的字符序列,用于判断一个字符串是否与我们所设定的字符序列是否匹配,也就是说检查一个字符串是否与某种模式匹配。初识 Python 正则表达式 定义一&a…

D : 合适的顺序

Description 给定 8 个数,如果将它们排成一列,每个数的权值是它与相邻的数之积,求一个排列方式,所有数的权值之和最大,输出该权值和. 例如 13242315 的权值和为 1∗33∗1∗22∗3∗44∗2∗22∗4∗33∗2∗11∗3∗55∗1…

新工具:轻松可视化基因组,部分功能超IGV~

本次分享一个Python基因组数据可视化工具figeno。 figeno擅长可视化三代long reads、跨区域基因组断点视图(multi-regions across genomic breakpoints)、表观组数据(HiC、ATAC-seq和ChIP-seq等)可视化、WGS中的CNV和SV可视化等。…

第四周——单词记忆

deploy 部署 attorney 律师 discrimination 歧视,区别 implicit 含蓄的 disposition 性格,倾向 entail 牵涉 retail 零售 imposing 印象深刻的 壮观的 implication 含义 entrenched 根深蒂固的 perplex 使复杂化 comply 遵守 composed 沉着…

小米平板6系列对比

小米平板6系列目前有4款,分别为6、6 Pro、6 Max、6S Pro。具体对比如下表所示。 小米平板型号66 Pro6 Max6S Pro实物图发布时间2023年4月21日2023年4月21日2023年8月14日2024年2月22 日屏幕大小11英寸11英寸14英寸12.4英寸分辨率2.8K2.8K2.8K3K刷新率144Hz144Hz120…

43 - 部门工资前三高的所有员工(高频 SQL 50 题基础版)

43 - 部门工资前三高的所有员工 # dense_rank 排名selectDepartment,Employee,Salary from(selectd.name as Department,e.name as Employee,e.salary as Salary,(dense_rank() over (partition by d.name order by e.salary desc)) as rankingfrom Employee e left join Depar…

数据库-存储过程,函数与触发器

创建存储过程:create procedure 存储过程名(参数) eg: CREATE PROCEDURE proc1() BEGIN SELECT * FROM user; END; 执行存储过程:call 存储过程名 创建带有参数的存储过程 存储过程的参数有三种: IN:输入参数,也是…

18 学渣的逆袭之路

在小学阶段(本篇特指五年级,一到四年级随便学学就可以逆袭90分,六年级难度飙升),无论你的分数怎么低,只要有一颗上进的心,就绝对可以逆袭95! 在本篇文章,我将会讲解“对于…

【前端那些事】Node.js的安装并配置镜像源

1、官网下载地址 Download Node.js 一步一步点击安装即可,可自定义安装目录 2、配置镜像源 # 设置淘宝镜像源 npm config set registry https://registry.npmmirror.com# 查看使用的镜像源 npm config get registry 如果需要恢复为npm默认的官方源&#xf…

intellij idea安装R包ggplot2报错问题求解

1、intellij idea安装R包ggplot2问题 在我上次解决图形显示问题后,发现安装ggplot2包时出现了问题,这在之前高版本中并没有出现问题, install.packages(ggplot2) ERROR: lazy loading failed for package lifecycle * removing C:/Users/V…

java:aocache的单实例缓存(一)

上一篇博客《java:aocache:基于aspectJ实现的方法缓存工具》介绍了aocache的基本使用, 介绍AoCacheable注解时说过,AoCacheable可以定义在构造方法上,定义在构造方法,该构建方法就成了单实例模式。 也就是说,只要构建…

Java实现按高度或宽度等比压缩图片尺寸

要实现一个能够按高度或宽度等比压缩图片并返回InputStream的Java方法,你需要先计算图片的原始宽高比,然后根据目标尺寸(宽度或高度)计算出等比缩放后的另一个维度。以下是一个示例代码: import javax.imageio.ImageI…

理解cpu对地址的操作

在硬件编程中,外设寄存器被映射到内存或输入/输出(I/O)的地址空间上,这种映射使得CPU能够通过读写这些地址来控制外设。这种机制通常被称为内存映射I/O(MMIO)或端口映射I/O(PMIO),其中MMIO更为常见。以GPIO(通用输入输出)为例,下面是这个过程的一般性描述: MMIO(…

【云原生】MiniKube部署Kubernetes最小化集群

MiniKube安装Kubernetes集群(一步到位) 文章目录 MiniKube安装Kubernetes集群(一步到位)资源列表基础环境一、环境配置1.1、更新系统1.2、安装Docker1.3、配置Docker加速器 二、部署MiniKube2.1、安装kubectl2.2、安装MiniKube2.2…

深入了解Qt 控件:Display Widgets部件(1) 以及 QT自定义控件(电池)

QT自定义控件(电池) 在线调色板Qt之CSS专栏Chapter1 QT自定义控件(电池)Chapter2 Qt教程 — 3.5 深入了解Qt 控件:Display Widgets部件(1)1 Display Widgets简介2 如何使用Display Widgets部件 Chapter3 Qt自定义控件电池组件使用前言一、最基…

“论大数据处理架构及其应用”高分范文,软考高级,系统架构设计师

论文真题 大数据处理架构是专门用于处理和分析巨量复杂数据集的软件架构。它通常包括数据收集、存储、处理、分析和可视化等多个层面,旨在从海量、多样化的数据中提取有价值的信息。Lambda架构是大数据平台里最成熟、最稳定的架构,它是一种将批处理和流…

springboot 3.x相比之前版本有什么区别

Spring Boot 3.x相比之前的版本(尤其是Spring Boot 2.x),主要存在以下几个显著的区别和新特性: Java版本要求: Spring Boot 3.x要求至少使用Java 17作为最低版本,同时已经通过了Java 19的测试,…