sklearn线性回归详解

图片若未能正常显示,点击下面链接:
http://ihoge.cn/2018/Logistic-regression.html

在线性回归中,我们想要建立一个模型,来拟合一个因变量 y 与一个或多个独立自变量(预测变量) x 之间的关系。

给定:

数据集

{(x(1),y(1)),...,(x(m),y(m))}{(x(1),y(1)),...,(x(m),y(m))}

xixi是d-维向量Xi=(x(i)1,...,x(i)d)Xi=(x1(i),...,xd(i))

y(i)y(i)是一个目标变量,它是一个标量

线性回归模型可以理解为一个非常简单的神经网络:

它有一个实值加权向量w=(w(i),...,w(d))w=(w(i),...,w(d))
它有一个实值偏置量 b
它使用恒等函数作为其激活函数

线性回归模型可以使用以下方法进行训练

a) 梯度下降法

b) 正态方程(封闭形式解)w=(XTX)1XTyw=(XTX)−1XTy

其中 X 是一个矩阵,其形式为(m,nfeatures)(m,nfeatures),包含所有训练样本的维度信息。

而正态方程需要计算(XTX)(XTX)的转置。这个操作的计算复杂度介于O(n2.4features)O(nfeatures2.4)O(n3features)O(nfeatures3)之间,而这取决于所选择的实现方法。因此,如果训练集中数据的特征数量很大,那么使用正态方程训练的过程将变得非常缓慢。

线性回归模型的训练过程有不同的步骤。首先(在步骤 0 中),模型的参数将被初始化。在达到指定训练次数或参数收敛前,重复以下其他步骤。

第 0 步:

用0 (或小的随机值)来初始化权重向量和偏置量,或者直接使用正态方程计算模型参数

第 1 步(只有在使用梯度下降法训练时需要):

计算输入的特征与权重值的线性组合,这可以通过矢量化和矢量传播来对所有训练样本进行处理:
y˙=Xw+by˙=X⋅w+b

其中 X 是所有训练样本的维度矩阵,其形式为(m,nfeatures)(m,nfeatures);这里我用· 表示

第 2 步(只有在使用梯度下降法训练时需要):

用均方误差计算训练集上的损失:J(w,b)=1mmi=1(y˙(i)y(i))2J(w,b)=1m∑i=1m(y˙(i)−y(i))2

第 3 步(只有在使用梯度下降法训练时需要):

对每个参数,计算其对损失函数的偏导数:

Jwj=2mmi=1(y˙(i)y(i))x(i)j∂J∂wj=2m∑i=1m(y˙(i)−y(i))xj(i)

Jb=2mmi=1(y˙(i)y(i))∂J∂b=2m∑i=1m(y˙(i)−y(i))

所有偏导数的梯度计算如下:

ΔwJ=2mXT(y˙y)ΔwJ=2mXT(y˙−y)

ΔbJ=2m(y˙y)ΔbJ=2m(y˙−y)

第 4 步(只有在使用梯度下降法训练时需要):

更新权重向量和偏置量:

w=wηΔwJw=w−ηΔwJ

ΔbJ=2m(y˙y)ΔbJ=2m(y˙−y)

其中η表示学习率

代码实现

数据集

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
np.random.seed(123)X = 2 * np.random.rand(500, 1)
y = 5 + 3 * X + np.random.randn(500, 1)
fig = plt.figure(figsize=(8,6))
plt.scatter(X, y)
plt.title("Dataset")
plt.xlabel("First feature")
plt.ylabel("Second feature")
plt.show()

png

X_train, X_test, y_train, y_test = train_test_split(X, y)
print(f'Shape X_train: {X_train.shape}')
print(f'Shape y_train: {y_train.shape}')
print(f'Shape X_test: {X_test.shape}')
print(f'Shape y_test: {y_test.shape}')
Shape X_train: (375, 1)
Shape y_train: (375, 1)
Shape X_test: (125, 1)
Shape y_test: (125, 1)

线性回归分类 源码编译

 class LinearRegression:def __init__(self):passdef train_gradient_descent(self, X, y, learning_rate=0.01, n_iters=100):"""Trains a linear regression model using gradient descent"""# Step 0: Initialize the parametersn_samples, n_features = X.shapeself.weights = np.zeros(shape=(n_features,1))self.bias = 0costs = []for i in range(n_iters):# Step 1: Compute a linear combination of the input features and weightsy_predict = np.dot(X, self.weights) + self.bias# Step 2: Compute cost over training setcost = (1 / n_samples) * np.sum((y_predict - y)**2)costs.append(cost)if i % 100 == 0:print(f"Cost at iteration {i}: {cost}")# Step 3: Compute the gradientsdJ_dw = (2 / n_samples) * np.dot(X.T, (y_predict - y))dJ_db = (2 / n_samples) * np.sum((y_predict - y)) # Step 4: Update the parametersself.weights = self.weights - learning_rate * dJ_dwself.bias = self.bias - learning_rate * dJ_dbreturn self.weights, self.bias, costsdef train_normal_equation(self, X, y):"""Trains a linear regression model using the normal equation"""self.weights = np.dot(np.dot(np.linalg.inv(np.dot(X.T, X)), X.T), y)self.bias = 0return self.weights, self.biasdef predict(self, X):return np.dot(X, self.weights) + self.bias

使用梯度下降进行训练

regressor = LinearRegression()
w_trained, b_trained, costs = regressor.train_gradient_descent(X_train, y_train, learning_rate=0.005, n_iters=600)
fig = plt.figure(figsize=(8,6))
plt.plot(np.arange(600), costs)
plt.title("Development of cost during training")
plt.xlabel("Number of iterations")
plt.ylabel("Cost")
plt.show()
Cost at iteration 0: 66.45256981003433
Cost at iteration 100: 2.208434614609594
Cost at iteration 200: 1.2797812854182806
Cost at iteration 300: 1.2042189195356685
Cost at iteration 400: 1.1564867816573
Cost at iteration 500: 1.121391041394467Text(0,0.5,'Cost')

png

测试(梯度下降模型)

n_samples, _ = X_train.shape
n_samples_test, _ = X_test.shapey_p_train = regressor.predict(X_train)
y_p_test = regressor.predict(X_test)error_train =  (1 / n_samples) * np.sum((y_p_train - y_train) ** 2)
error_test =  (1 / n_samples_test) * np.sum((y_p_test - y_test) ** 2)print(f"Error on training set: {np.round(error_train, 4)}")
print(f"Error on test set: {np.round(error_test)}")
Error on training set: 1.0955
Error on test set: 1.0

使用正规方程(normal equation)训练

X_b_train = np.c_[np.ones((n_samples)), X_train]
X_b_test = np.c_[np.ones((n_samples_test)), X_test]reg_normal = LinearRegression()
w_trained = reg_normal.train_normal_equation(X_b_train, y_train)

测试(正规方程模型)

y_p_train = reg_normal.predict(X_b_train)
y_p_test = reg_normal.predict(X_b_test)error_train =  (1 / n_samples) * np.sum((y_p_train - y_train) ** 2)
error_test =  (1 / n_samples_test) * np.sum((y_p_test - y_test) ** 2)print(f"Error on training set: {np.round(error_train, 4)}")
print(f"Error on test set: {np.round(error_test, 4)}")
Error on training set: 1.0228
Error on test set: 1.0432

可视化测试预测

fig = plt.figure(figsize=(8,6))
plt.scatter(X_train, y_train)
plt.scatter(X_test, y_p_test)
plt.xlabel("First feature")
plt.ylabel("Second feature")
plt.show()
Text(0,0.5,'Second feature')

png

转载注明出处:
http://ihoge.cn/2018/Logistic-regression.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/293086.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux之more命令

more命令,功能类似 cat ,cat命令是整个文件的内容从上到下显示在屏幕上。 more会以一页一页的显示方便使用者逐页阅读,而最基本的指令就是按空白键(space)就往下一页显示,按 b 键就会往回(back&…

DateOnly和TimeOnly类型居然不能序列化!!! .Net 6下实现自定义JSON序列化

前言.Net 6引入了DateOnly和TimeOnly结构,可以存储日期和时间。但在实际使用时,发现一个很尴尬的问题,DateOnly和TimeOnly居然不能被序列化:var builder WebApplication.CreateBuilder(args);var app builder.Build();app.MapGe…

PHP面向对象之旅:抽象类继承抽象类(转)

可以理解为对抽象类的扩展 抽象类继承另外一个抽象类时,不用重写其中的抽象方法。抽象类中,不能重写抽象父类的抽象方法。这样的用法,可以理解为对抽象类的扩展。 下面的例子,演示了一个抽象类继承自另外一个抽象类时,…

Scala编程指南

1.scala简介 2004年,martin ordersky发明,javac的编译器,后来spark,kafka应用广泛,twitter应用推广。它具备面向对象和函数式编程的特点。 官网:www.scala-lang.org,最近版本2.12.5,我们用的是2.10.42.环境…

win7 64位下如何安装配置mysql-5.7.4-m14-winx64(安装记录)

1. mysql-5.7.4-m14-winx64.zip下载 官方网站下载地址:http://dev.mysql.com/get/Downloads/MySQL-5.6/mysql-5.6.17-winx64.zip 2、解压到D:\mysql.(路径自己指定) 3、在D:\mysql\mysql-5.7.4-m14-winx64下新建my.ini配置文件 内容如下&am…

Android之ndk之gdb调试

https://code.google.com/p/android/issues/detail?id152832

使用插件创建 .NET Core 应用程序

使用插件创建 .NET Core 应用程序本教程展示了如何创建自定义的 AssemblyLoadContext 来加载插件。AssemblyDependencyResolver 用于解析插件的依赖项。该教程正确地将插件依赖项与主机应用程序隔离开来。将了解如何执行以下操作:构建支持插件的项目。创建自定义…

支持向量机SVC

原文: http://ihoge.cn/2018/SVWSVC.html 支持向量机(support vector machine)是一种分类算法,但是也可以做回归,根据输入的数据不同可做不同的模型(若输入标签为连续值则做回归,若输入标签为分类值则用SVC()做分类&…

Shell 控制并发

方法1: #!/bin/bash c0 for i in seq -w 18 31;dowhile [ $c -ge 3 ];doc$(jobs -p |wc -w)sleep 1sdonebash run_cal_us_tmp.sh 201407$i &#echo "sleep 5shaha" &c$(jobs -p |wc -w) done优点:实现简单 缺点:若sleep 时间较短&…

Android之严苛模式(StrictMode)

Android 2.3提供一个称为严苛模式(StrictMode)的调试特性,Google称该特性已经使数百个Android上的Google应用程序受益。那它都做什么呢?它将报告与线程及虚拟机相关的策略违例。一旦检测到策略违例(policy violation&a…

Shell 脚本——测试命令

********************************************一、测试命令简介二、测试结构三、整数比较运算符四、字符串运算符五、文件操作符六、逻辑运算符********************************************一、测试命令简介Shell中存在一组测试命令,该组测试命令用于测试某种条件…

如何通过 C# 将文本变为声音 ?

咨询区 user2110292我的项目有一个需求需要将可以将 文本 转化为 声音,请问大家是否有开源的 C# 库 来解决这件事情?回答区 HABJAN最近 Google 发布了一个开源的 Google Cloud Text To Speech 包,.NET版本的github链接:https://gi…

sklearn集合算法预测泰坦尼克号幸存者

原文: http://ihoge.cn/2018/sklearn-ensemble.html 随机森林分类预测泰坦尼尼克号幸存者 import pandas as pd import numpy as npdef read_dataset(fname):data pd.read_csv(fname, index_col0)data.drop([Name, Ticket, Cabin], axis1, inplaceTrue)lables …

Oracle数据库案例整理-Oracle系统执行时故障-Shared Pool内存不足导致数据库响应缓慢...

1.1 现象描写叙述 数据库节点响应缓慢,部分用户业务受到影响。 查看数据库告警日志,開始显示ORA-07445错误,然后是大量的ORA-04031错误和ORA-00600错误。 检查数据库日志,数据库仍处于活动状态的信息例如以下: S…

C#读写txt文件的两种方法介绍

1.添加命名空间 System.IO; System.Text; 2.文件的读取 (1).使用FileStream类进行文件的读取,并将它转换成char数组,然后输出。 byte[] byData new byte[100];char[] charData new char[1000];public void Read(){try{FileStream file new FileStream…

Beetlex官网迁移完成

由于beetlex.io域名无法指向国内,使用国内的服务器很多时候有抽风情况出现,所以把网站迁回国内;新的域名也申请完成并且申请备案通过,现在可以通过https://beetlex-io.com来访问Beetlex的官网.接下把涉及的费用和部署情况也说一下…

linux之tail 命令

tail 命令从指定点开始将文件写到标准输出.使用tail命令的-f选项可以方便的查阅正在改变的日志文件,tail -f filename会把filename里最尾部的内容显示在屏幕上,并且不但刷新,使你看到最新的文件内容. 1.命令格式; tail[必要参数][选择参数][文件] 2.…

SVM支持向量机绘图

原文: http://ihoge.cn/2018/SVM绘图.html %matplotlib inline import matplotlib.pyplot as plt import numpy as np class1 np.array([[1, 1], [1, 3], [2, 1], [1, 2], [2, 2]]) class2 np.array([[4, 4], [5, 5], [5, 4], [5, 3], [4, 5], [6, 4]]) plt.f…

Hibernate学习——建立一个简单的Hibernate项目

最近老师让做个web小应用,大三的时候学习过一点J2EE的东西,也做过一些web相关的XXX管理系统,都是用servlet,jsp这些完成的,虽然勉强能够完成任务,但其中各种代码掺杂在一起,不好看而且维护起来也…