机器学习的练功方式(十一)——逻辑回归

文章目录

    • 致谢
  • 11 逻辑回归
    • 11.1 引入
    • 11.2 激活函数
    • 11.3 损失函数
    • 11.4 梯度下降
    • 11.5 案例:癌症分类预测

致谢

逻辑回归为什么用Sigmoid? - 知乎 (zhihu.com)

逻辑回归中的损失函数的解释_yidiLi的博客-CSDN博客_逻辑回归损失函数

11 逻辑回归

逻辑回归也被称为逻辑斯蒂回归(Logistic Regression),虽被称为回归,但是其实际上是统计学习中经典的分类方法。

逻辑回归常常被用于二分类问题,比较常见的有:

  1. 判断一封电子邮件是否是垃圾邮件;
  2. 判断一次金融交易是否是欺诈;
  3. 区别一个肿瘤是恶性的还是良性的

我们将因变量可能属于的两个类分别称为负向类和正向类,则因变量y∈0,1。其中0表示负向类,1表示正向类。实际上,哪个是正向类哪个是负向类其实并不重要。虽然没有特别界限,但是我们通常把0归结为没有,而1归结为具有我们要寻找的东西。

11.1 引入

对于二分类问题,y取值非0即1;但如果你使用的是线性回归,那么其输出的y^\hat yy^可能远远大于1或者为负数。所以如果我们用线性回归来做分类,效果实际上是很差的。

逻辑回归是怎么把输出变成0到1之间的呢?实际上,它拿线性回归的输出作为输入,然后用函数g(x)=11+e−xg(x) = \frac {1}{1+e^{-x}}g(x)=1+ex1输出,其输出值可以把线性回归的结果投射到0和1的区间上。

11.2 激活函数

上面我们说到将线性回归的输出输入到一个函数上再输出,这样能够使结果投射到0-1的区间上。在深度学习中,我们常常把这类可以将线性回归的结果投射到某个区间函数叫做激活函数,而上面提到的我们通常称为sigmoid激活函数,在深度学习中,我们会陆续见到其他的激活函数。

实际上,为何逻辑回归要用sigmoid函数来投射线性回归的结果是有原理的,如果感兴趣可以去网上搜索相关推导,如果我不懒后续也会手推,但是你现在需要知道的一点是——当因变量服从伯努利分布(0-1分布)时,广义线性模型就为逻辑回归。

实际上,逻辑回归用于二分类问题,而其实际上是softmax回归的特殊情况,softmax回归常常用来解决多分类问题,在后续深度学习的文章中,我们会持续了解它们。

让我们看一下sigmoid激活函数长什么样吧!

import matplotlib.pyplot as plt
import math
import numpy as np
import matplotlibmatplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False# 1 准备数据
def func(x):y = 1 / (1 + math.exp(-x))return yX = np.linspace(-100, 100, 200)
Y = [func(x) for x in X]
# 2 创建画布
plt.figure()plt.grid(True, linestyle='--', alpha=0.5)# 3 画图
plt.plot(X, Y)# 4 设置标题
plt.title("sigmoid函数图像")# 5 释放图像
plt.show()

out:

image-20220321125705319

实际使用中,我们默认分类的界限是0.5,也就是说经由sigmoid函数投射值如果大于0.5则直接分为1类,否则分为0类。

11.3 损失函数

记得我们前面谈到线性回归的损失函数吗,由于线性回归属于回归问题,所以输出是某一个预测值,我们拿预测值和真实值进行比对来衡量其误差,这个比对方式,我们用的是平方损失函数。但是在逻辑回归中,我们并不能继续这么做了,因为逻辑回归非0即1,你还拿平方损失函数比对其误差,如果没有误差,那不就是0吗,这还搞啥子,我们后面还要采用梯度下降呢,参数不动这模型还怎么优化。

为此,我们急需有一个新的损失函数来取代之前的平方损失函数。在逻辑回归中,我们常常使用对数似然损失函数来衡量损失。其公式为:
cost(h0(x),y)={−log⁡(h0(x))ify=1−log⁡(1−h0(x))ify=0cost(h_0(x),y) = \left\{ \begin{aligned} -\log(h_0(x)) && if &&y = 1\\ -\log(1-h_0(x)) && if && y = 0 \end{aligned} \right. cost(h0(x),y)={log(h0(x))log(1h0(x))ifify=1y=0
当y = 1时,我们可以观察其函数图像,如下所示:

image-20220321132008218

明显地,如果h0(x)h_0(x)h0(x)越接近1,而我们y = 1,那么说明我们分类地越准确。

同理,当y = 0时,我们也可以观察其图像,如下所示:

image-20220321132318970

h0(x)h_0(x)h0(x)越接近于1,而我们y = 0,说明我们分类的很差,损失很大。

为了后续计算梯度下降方便,我们必须对上述的分段函数做一个简化。为此,我们将其形式改写为:
cost(h0(x),y)=∑i=1m−yilog(h0(x))−(1−yi)log(1−h0(x))cost(h_0(x),y) = \sum^m_{i = 1}-y_ilog(h_0(x))-(1-y_i)log(1-h_0(x)) cost(h0(x),y)=i=1myilog(h0(x))(1yi)log(1h0(x))
这个公式的好处在于,当你的y取1,那么−yilog(h0(x))-y_ilog(h_0(x))yilog(h0(x))会保留,后面一部分会为0,这样就等同于上面的分段函数,当你的y取0,同理,−yilog(h0(x))-y_ilog(h_0(x))yilog(h0(x))会为0,后面一部分保留。

11.4 梯度下降

同样地,我们在打基础阶段还是照样使用梯度下降来优化损失函数,从而找到逻辑回归中的权重参数。

除了梯度下降算法以外,还有一些常被用来令代价函数最小的算法,这些算法更加复杂和优越,而且通常不需要人工选择学习率,通常比梯度下降算法要更加快速。这些算法有:共轭梯度(Conjugate Gradient)局部优化法(Broyden fletcher goldfarb shann,BFGS)有限内存局部优化法(LBFGS)

11.5 案例:癌症分类预测

让我们来看看sklearn为我们提供的API。

sklearn.linear_model.LogisticRegression(solver = ‘liblinear’,penalty = ‘l2’,C = 1.0)

  • solver:优化求解方式。默认使用开源的liblinear库实现优化,内部使用了坐标轴下降法来迭代优化损失函数
  • penalty:正则化种类
  • C:正则化力度

在之前的学习中,我们一直用着sklearn自带的数据集。是的,它们是好用,这仅仅只是因为人家帮你预处理好了,他在帮你偷懒!所以,我想你是时候学着处理一些东西了。

我们先去UCI把我们要用到的数据集下载下来,或者你也可以不下载,直接利用pandas读取网站即可:索引 /ml/机器学习数据库/乳腺癌-威斯康星州 (uci.edu)

其中数据放于http://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.data

而数据介绍放于http://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.names

我们先来读取数据。由于原始数据是没有标签的,所以我们顺便给数据打上标签。

import pandas as pd
import numpy as np# 1 读取数据
path = "http://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.data"
colum_name = ['Sample code number', 'Clump Thickness','Uniformity of Cell Size', 'Uniformity of Cell Shape','Marginal Adhesion', 'Single Epithelial Cell Size', 'Bare Nuclei','Bland Chromatin', 'Normal Nucleoli', 'Mitoses', 'Class']
data = pd.read_csv(path, names=colum_name)

从原数据集上看,数据中有?,我们在数据预处理中通常先将其替换为NaN,然后进行删除。如果你喜欢填充往?部分填充均值也是可以的。

# 2 缺失值处理
data = data.replace(to_replace = "?",value = np.nan)
data.dropna(inplace = True)

处理完成后,我们可以通过直接查看数据看看是否处理成功,也可以通过ifnull()函数查看是否还有空值。

pd.set_option('display.max_rows', None) # 显示所有行
data
data.isnull().any()

看到没有空值后,我们下一步要对原始数据集进行特征和分类标签的分离。

# 3 划分数据集
from sklearn.model_selection import train_test_split# 筛选特征值和目标值
x = data.iloc[:,1:-1]
y = data["Class"]
x.head()

然后进行切割数据集。

# 划分数据集
x_train,x_test,y_train,y_test = train_test_split(x,y)
x_train.head()

逻辑回归本质上属于广义线性模型,所以我们还要进行一下标准化。

# 4 标准化
from sklearn.preprocessing import StandardScaler
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)
x_train

标准化完成后,我们调用逻辑回归预估器进行模型的训练:

# 5 逻辑回归
from sklearn.linear_model import LogisticRegression
estimator = LogisticRegression()
estimator.fit(x_train,y_train)

我们可以看一下训练后的回归系数和偏置。

# 回归系数和偏置
estimator.coef_
estimator.intercept_

最后我们对模型进行评估。

# 6 模型评估
y_predict = estimator.predict(x_test)
print("y_predict:\n",y_predict)
print("直接比对真实值和与预测值:\n",y_test == y_predict)score = estimator.score(x_test,y_test)
print("准确率为:\n",score)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/398617.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ODB——基于c++的ORM映射框架尝试(安装)

这篇博客应该是和之前的重拾cgi一起的。当时为了模仿java的web框架,从页面的模板,到数据库的ORM,都找个对应的库来进行尝试。数据库用的就是ODB,官方网站是http://www.codesynthesis.com/products/odb/。 1、安装 odb是直接提供源…

【百度地图API】如何制作一张魔兽地图!!——CS地图也可以,哈哈哈

【百度地图API】如何制作一张魔兽地图!!——CS地图也可以,哈哈哈 原文:【百度地图API】如何制作一张魔兽地图!!——CS地图也可以,哈哈哈摘要: 你玩魔兽不?你知道如何做一张魔兽地图不…

Linux OpenGL 实践篇-2 创建一个窗口

OpenGL 作为一个图形接口,并没有包含窗口的相关内容,但OpenGL使用必须依赖窗口,即必须在窗口中绘制。这就要求我们必须了解一种窗口系统,但不同的操作系统提供的创建窗口的API都不相同,如果我们在学习OpenGL时要去学习…

C++从0到1的入门级教学(一)——C++初识

文章目录1 C初识1.1 入门1.1.1 简介1.1.2 输入和输出1.1.3 头文件名1.1.5 名称空间1.2 注释1.3 变量1.4 常量1.4.1 C定义常量两种方式1.5 关键字1.6 标识符命名规则1 C初识 1.1 入门 1.1.1 简介 既然是第一次学习,我们就使用大家初学任何编程语言都会用的"h…

linux系统分两种更普遍的包,rpm和tar,这两种安装包如何解压与安装

2019独角兽企业重金招聘Python工程师标准>>> RPM软件包管理器&#xff1a;一种用于互联网下载包的打包及安装工具&#xff0c;它包含在某些Linux分发版中。它生成具有.RPM扩展名的文件。rpm -ivh xxxx.rpm <-安装rpm包 -i install的意思 -v view 查看更详细的…

C++类的数组元素查找最大值问题

找出一个整型数组中的元素的最大值。 1 /*找出一个整型数组中的元素的最大值。*/2 3 #include <iostream>4 using namespace std;5 6 class ArrayMax //创建一个类7 {8 public :9 void set_value(); 10 void max_value(); 11 void sh…

ABNFBNF 巴克斯范式

https://www.cnblogs.com/qook/p/5957436.html转载于:https://www.cnblogs.com/ArcherHuang/p/8479897.html

C++从0到1的入门级教学(二)——数据类型

文章目录2 数据类型2.1 简单变量2.2 基本数据类型2.2.1 整型2.2.2 实型&#xff08;浮点型&#xff09;2.2.3 字符型2.2.4 布尔类型2.3 sizeof关键字2.4 类型转换2.5 转义字符2.6 重新谈及变量2.6.1 字面值常量2.6.2 变量2.6.3 列表初始化2.7 数据的输入2 数据类型 2.1 简单变…

大数乘法

很久没手写过大数运算了&#xff0c;以前也都是直接贴模板的&#xff0c;今晚的模拟笔试最后一道大数乘法就没调好&#xff0c;gg…… #include <iostream> #include <string> #include <cstdio> #include <cstring> using namespace std;string num1,…

获取class的名字

ele str.get_attribute(class)&#xff08;获取class的名字&#xff09;转载于:https://www.cnblogs.com/zero-77/p/8482362.html

为什么下了android 4.1 的SDK后在本地用浏览器看api说明文档时,浏览器打开api的html文件很慢?试了好几款浏览器都一样。为什么?...

http://www.oschina.net/question/436724_61401 http://www.google.com/jsapi 他惹的祸 注释掉就可以了&#xff5e; <!-- <script src"http://www.google.com/jsapi" type"text/javascript"></script> --> 很多页面都有&#xff0c;…

深度学习修炼(三)——自动求导机制

文章目录致谢3 自动求导机制3.1 传播机制与计算图3.1.1 前向传播3.1.2 反向传播3.2 自动求导3.3 再来做一次3.4 线性回归3.4.1 回归3.4.2 线性回归的基本元素3.4.3 线性模型3.4.4 线性回归的实现3.4.4.1 获取数据集3.4.4.2 模型搭建3.4.4.3 损失函数3.4.4.4 训练模型3.5 后记致…

5、android使用意图传递数据之全局变量传递

实例&#xff1a; 1、layout的代码 activity_main.xml     <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android" xmlns:tools"http://schemas.android.com/tools" android:layout_width"match_parent" android:lay…

安装rf所需要的库

1. RF 在两个Python中安装 robotframework 执行命令 pip install robotframework 2. seleniumlibrary 在两个Python中安装 seleniumlibrary 执行命令 pip install --pre --upgrade robotframework-seleniumlibrary 3. RIDE 在Python2中安装 RIDE 执行命令 pip install robot…

CNN for Visual Recognition (assignment1_Q1)

参考&#xff1a;http://cs231n.github.io/assignment1/ Q1: k-Nearest Neighbor classifier (30 points) 1 import numpy as np2 from matplotlib.cbook import todate3 4 class KNearestNeighbor:5 """ a kNN classifier with L2 distance ""&quo…

深度学习修炼(四)——补充知识

文章目录致谢4 补充知识4.1 微积分4.1.1 导数和微分4.1.2 偏导数4.1.3 梯度4.1.4 链式求导4.2 Hub模块致谢 导数与微分到底有什么区别&#xff1f; - 知乎 (zhihu.com) 4 补充知识 在这一小节的学习中&#xff0c;我们会对上一小节的知识点做一个补充&#xff0c;并且拓展一个…

java使用POI jar包读写xls文件

主要使用poi jar来操作excel文件。代码中用到的数据库表信息见ORACLE之表。使用public ArrayList<Person> getPersonAllRecords()获得所有的记录。 1 public class PersonXLS {2 3 public static void main(String[] args) throws IOException {4 5 …

U-boot 打补丁,编译,设置环境变量,

&#xff08;1&#xff09;U-boot 的最终目的是&#xff1a; 启动内核 U-boot 从Flash上读取内核&#xff0c;把内核放到SDRAM上&#xff0c;运行内核 设置环境变量 print  显示出环境变量 set bootdelay 10 save reset  重启转载于:https://www.cnblogs.com/bkyysd/p/42…

深度学习修炼(五)——基于pytorch神经网络模型进行气温预测

文章目录5 基于pytorch神经网络模型进行气温预测5.1 实现前的知识补充5.1.1 神经网络的表示5.1.2 隐藏层5.1.3 线性模型出错5.1.4 在网络中加入隐藏层5.1.5 激活函数5.1.6 小批量随机梯度下降5.2 实现的过程5.2.1 预处理5.2.2 搭建网络模型5.3 简化实现5.4 评估模型5 基于pytor…

Android 应用程序集成FaceBook 登录及二次封装

1、首先在Facebook 开发者平台注册一个账号 https://developers.facebook.com/ 开发者后台 https://developers.facebook.com/apps 2、创建账号并且获得 APP ID 图一 图二 图三 图四 图五 3、获取app签名的Key Hashes 值&#xff08;两种方式&#xff09; 3.1方法1&#xff1…