[PyTorch][chapter 6][李宏毅深度学习][Logistic Regression]

前言:

         logistic回归又称logistic回归分析,是一种广义的线性回归分析模型,常用于数据挖掘,疾病自动诊断,经济预测等领域。 逻辑回归根据给定的自变量数据集来估计事件的发生概率,由于结果是一个概率,因此因变量的范围在 0 和 1 之间。 [3]例如,探讨引发疾病的危险因素,并根据危险因素预测疾病发生的概率等。

         训练样本特别小的时候用 Generative  Model会有较好的效果,大的样本使用Discriminative Model,Discriminative Model里面常用的二分类模型sigmoid ,多分类模型softmax


sigmoid 简介(Discriminative Model

    二分类模型

    1.1   模型定义

           使用了sigmoid 函数作为激活函数

           f(x)=\sigma(z)=\frac{1}{1+e^{-z}}

           z=wx+b=\sum_i w_ix_i+b

           输出 (0,1)

    1.2  损失函数

           假设有N个二分类样本

          

           \left\{\begin{matrix} \hat{y}=1\, \, \, \, ,if \, c_1 \\ \hat{y}=0\, \, \, \, \, , if \, c_2 \end{matrix}\right.

            损失函数定义为

            L(w,b)=f(x^1)f(x^2)(1-f(x^3))..

           我们要找到参数w,b使得上面概率最大

            w^{*},b^{*}=argmax_{w,b}L(w,b)

            根据交叉熵原理:我们对式子取对数。因为是求式子的最大值,可以转换成式子乘以负1,之后求最小值

            w^*,b^*=argmin_{w,b} -lnL(w,b)

            L(w,b)=-\sum_{i}^{N}-\begin{Bmatrix} \hat{y^i}ln f(x^i)+(1-\hat{y^i})ln (1-f(x^i)) \end{Bmatrix}

    1.3 梯度

           对w的求导分为两部分

          \frac{\partial lnf}{\partial f}\frac{\partial f}{\partial z}\frac{\partial z}{\partial w}=\frac{1}{f}f(1-f)x

                               =(1-f)x

           \frac{\partial ln1-f}{\partial f}\frac{\partial f}{\partial z}\frac{\partial z}{\partial w}=\frac{-1}{1-f}f(1-f)x

                                   =-fx

          合并起来

               \frac{\partial L}{\partial w}=\sum_i -\begin{Bmatrix} \hat{y^{i}}(1-f)x^{i}-(1-\hat{y^{i}})fx^{i} \end{Bmatrix}

                     =\sum_{i}-(\hat{y^{i}}-f)x^{i}

                    =\sum_{i}(f-\hat{y^{i}})x^{i}

     1.4 跟Linear 区别



二  Multi-class Classification(softmax)

      多分类模型

    2.1  模型定义

          使用了 softmax 作为激活函数 

           y=\sigma(z_i)=\frac{e^{z_i}}{\sum_{j=1}^{K}e^{z_j}}  

 

  2.2 损失函数

            使cross Entropy

             标签是一个one-hot 向量,非零项代表其类别

            L_{w,b}(\hat{y},y)=\sum_{i=1}^{K}\hat{y_i}logy_i

 

2.3 梯度

       

          y_i=softmax(z_i)=\frac{e^{z_i}}{\sum_j e^{z_j}}

         \left\{\begin{matrix} \frac{\partial y_i}{\partial z_j}=y_i*(1-y_j),j=i\\ \frac{\partial y_i}{\partial z_j}=-y_iy_j ,j\neq i \end{matrix}\right.

     损失函数为

       L=-\sum_{i}^{K}\hat{y_k}logy_k

      只跟其中的非零项有关系,假设非零项为y_i

          \frac{\partial L}{\partial z_j}=\left\{\begin{matrix} \frac{-1}{y_i}*y_i*(1-y_i)=y_i-1,j=i\\ \frac{-1}{y_i}*(-y_i y_j)=y_j-0,j\neq i \end{matrix}\right.

           因为标签值是one-hot

             \left\{\begin{matrix} \hat{y_j}=1,i=j\\ \hat{y_j}=0,i \neq j \end{matrix}\right.

             所以

              \frac{\partial L}{\partial z_j}=y_j-\hat{y_j}


三 代码

  

任务:

         给定的个人资料,预测此人的年收入是否大于50k

数据集说明:
                共有32561训练集数据,16281 测试集数据
(8140 in private test set and 8141 in public test set)

数据集情况:共14个feature 

 ?  代表不确定性
1 age 年龄: continuous.
2 workclass 工作性质: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
3 fnlwgt: continuous. *The number of people the census takers believe that observation represents.人口普查员认为这一观察结果所代表的人数。

4 education 教育水平: 
   Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.

5 education-num: continuous.

6 marital-status 婚姻状况: 
    Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.

7 occupation 工作: 
   Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.

8 relationship 关系: 
    Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.

9 race 种族: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
10 sex 性别: Female, Male.
11 capital-gain 资本收益: continuous.
12 capital-loss资本损失: continuous.
13 hours-per-week 每周工作时长: continuous.
14  native-country原国际: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.

 针对非数值型的属性,采用了one-hot 编码

分为两个文件:

dataLoader.py: csv文件读取,特征工程

lr.py:  模型训练  y=xw

          其中

                   x=[x,1]增广矩阵,

                   w =[b,w]增广矩阵

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 12 14:51:45 2023@author: chengxf2
"""import numpy as np
import pandas as pd
from random import shuffle
from math import floor, logdef sample(X, Y):                                 #X and Y are np.arrayrandomize = np.arange(X.shape[0])np.random.shuffle(randomize)return (X[randomize], Y[randomize])def split_valid_set(X, Y, percentage):m = X.shape[0]valid_size = int(floor(m * percentage))X, Y = sample(X, Y)X_valid, Y_valid = X[ : valid_size], Y[ : valid_size]X_train, Y_train = X[valid_size:], Y[valid_size:]return X_train, Y_train, X_valid, Y_validdef dataProcess_Y(rawData):df_y = rawData['income']y = pd.DataFrame((df_y==' >50K').astype("int64"), columns=["income"])print('\n y',y.shape)return ydef dataProcess_X(rawData):#axis=1, 删除列 axis=0 删除 indexif "income" in rawData.columns:Data = rawData.drop(["sex", 'income'], axis=1)#(32561, 13) else:Data = rawData.drop(["sex"], axis=1)#读取非数字的columnlistObjectColumn = [col for col in Data.columns if Data[col].dtypes == "object"] #数字的columnlistNonObjedtColumn = [x for x in list(Data) if x not in listObjectColumn] ObjectData = Data[listObjectColumn]NonObjectData = Data[listNonObjedtColumn]#insert set into nonobject data with male = 0 and female = 1NonObjectData.insert(0 ,"sex", (rawData["sex"] == " Female").astype(int))#set every element in object rows as an attribute,相当于one-hot 编码ObjectData = pd.get_dummies(ObjectData)Data = pd.concat([NonObjectData, ObjectData], axis=1)Data_x = Data.astype("int64")# Data_y = (rawData["income"] == " <=50K").astype(np.int)print("\n data_x: ",Data_x.shape)#normalizeData_x = (Data_x - Data_x.mean()) / Data_x.std()return Data_xdef data_loader():trainData =  pd.read_csv("data/train.csv")testData =  pd.read_csv("data/test.csv")test_label = pd.read_csv("data/correct_answer.csv")# here is one more attribute in trainDatax_train = dataProcess_X(trainData).drop(['native_country_ Holand-Netherlands'], axis=1).valuesx_test = dataProcess_X(testData).valuesy_train = dataProcess_Y(trainData).valuesy_test =  test_label['label'].values#x=>x[1,x]x_train = np.concatenate((np.ones((x_train.shape[0], 1)), x_train), axis=1)x_test = np.concatenate((np.ones((x_test.shape[0], 1)), x_test), axis=1)valid_set_percentage = 0.1X_train, Y_train, X_valid, Y_valid = split_valid_set(x_train, y_train, valid_set_percentage)return X_train, Y_train, X_valid, Y_valid ,x_test,y_test

import numpy as npfrom numpy.linalg import inv
import matplotlib.pyplot as plt
from dataLoader import data_loader
from dataLoader import sample
import os
from math import floor, log
import pandas as pdoutput_dir = "output/"def sigmoid(z):res = 1 / (1.0 + np.exp(-z))return np.clip(res, 1e-8, (1-(1e-8)))def valid(X, Y, w):a = np.dot(w,X.T)y = sigmoid(a)y_ = np.around(y)result = (np.squeeze(Y) == y_)print('Valid acc = %f' % (float(result.sum()) / result.shape[0]))return y_def train(X_train, Y_train):n= len(X_train[0])print("\n n ",n)w = np.zeros(n)l_rate = 0.001batch_size = 32m = len(X_train)step_num = int(floor(m / batch_size))epoch_num = 30list_cost = []total_loss = 0.0for epoch in range(1, epoch_num):total_loss = 0.0X_train, Y_train = sample(X_train, Y_train)for idx in range(1, step_num):X = X_train[idx*batch_size:(idx+1)*batch_size]Y = Y_train[idx*batch_size:(idx+1)*batch_size]s_grad = np.zeros(len(X[0]))z = np.dot(X, w)y = sigmoid(z)#squeeze 即把shape中为1的维度去掉loss = y - np.squeeze(Y)cross_entropy = -1 * (np.dot(np.squeeze(Y.T), np.log(y)) + np.dot((1 - np.squeeze(Y.T)), np.log(1 - y)))/ len(Y)total_loss += cross_entropygrad = np.sum( X * (y-np.squeeze(Y)).reshape((batch_size, 1)), axis=0)# grad = np.dot(X.T, loss)w = w - l_rate * grad#print("\n epoch :%d, total_loss: %7.3f"%(epoch, total_loss/batch_size))list_cost.append(total_loss)# valid(X_valid, Y_valid, w)plt.plot(np.arange(len(list_cost)), list_cost)plt.title("Train Process")plt.xlabel("epoch_num")plt.ylabel("Cost Function (Cross Entropy)")plt.savefig(os.path.join(os.path.dirname(output_dir), "TrainProcess"))plt.show()return wif __name__ == "__main__":X_train, Y_train, X_valid, Y_valid,x_test,y_test  = data_loader()w_train = train(X_train, Y_train)valid(X_valid, Y_valid, w_train)print("\n x_test",x_test.shape, "\t y_test ",y_test.shape,"\t w",w_train.shape)valid(x_test, y_test, w_train)df = pd.DataFrame({"id": np.arange(1, 16282), "label": y_test})if not os.path.exists(output_dir):os.mkdir(output_dir)df.to_csv(os.path.join(output_dir + 'lr_output.csv'), sep='\t', index=False)

https://github.com/maplezzz/ML2017S_Hung-yi-Lee_HW
动手学深度学习——softmax回归(原理解释+代码详解)-CSDN博客

https://www.cnblogs.com/hider/p/15431858.html 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/216445.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【docker 】 安装docker(centOS7)

官网 docker官网 github源码 官网 在CentOS上安装Docker引擎 官网 在Debian上安装Docker引擎 官网 在 Fedora上安装Docker引擎 官网 在ubuntu上安装Docker引擎 官网 在RHEL (s390x)上安装Docker引擎 官网 在SLES上安装Docker引擎 最完善的资料都在官网。 卸载旧版本 …

异地现场工控设备,如何实现远程配置、调试?

南京某企业专注于工业物联领域&#xff0c;在相关项目中往往会在各个点位部署基于Linux系统的中控主机&#xff0c;实现各类物联设备信息的采集、汇总。但是&#xff0c;由于各点位分散多地&#xff0c;且数量达到了上百个&#xff0c;虽然中控主机具备4G物联网接入能力&#x…

Vue3-07-样式绑定-style绑定的写法总结

style 绑定的方式 1.html中直接一个属性一个属性的写&#xff1b; 2.直接绑定一个对象&#xff1b; 3.绑定一个包含多个样式对象的数组。style绑定样式的注意点 推荐使用 驼峰命名 规则来编写样式的名称&#xff0c;如 &#xff1a; fontSize:12px; 如果使用 中线分割的规则时…

医美行业-上游厂商的营销规模分析与测算

一、医美行业整体发展趋势&#xff1a;轻医美逐步占领市场&#xff0c;占比逐年增加&#xff0c;规模增速远超手术类医美 从2019年开始医美行业扩张速度放缓&#xff0c;2020年受疫情影响中国美容用户的医美行为有所减少&#xff0c;增速放缓至9.9%&#xff0c;随着疫情的好转及…

计算机组成原理-堆栈寻址

文章目录 堆栈寻址软堆栈vs硬堆栈小结 堆栈寻址 栈结构后进后出 软堆栈vs硬堆栈 硬堆栈用寄存器实现 软堆栈就是用内存实现 小结 入栈和出栈即栈顶元素位置的变化不同

【计算机网络】滑动窗口 流量控制 拥塞控制 概念概述

参考资料&#xff1a;计算机网络第八版-视频课程

LabVIEW实时建模检测癌细胞的异常

LabVIEW实时建模检测癌细胞的异常 癌症是全球健康的主要挑战之一&#xff0c;每年导致许多人死亡。世界卫生组织指出&#xff0c;不健康的生活方式和日益严重的环境污染是癌症发生的主要原因之一。癌症的发生通常与基因突变有关&#xff0c;这些突变导致细胞失去正常的增长和分…

java--Map集合的遍历方式

1.Map集合的遍历方式之一&#xff1a;需要用的Map的如下方法 2.Map集合的遍历方式二&#xff1a;键值对 3.Map集合的遍历方式三&#xff1a;Lambda 需要用的Map的如下方法

mybatis-plus使用达梦数据库处理枚举类型报错的问题

使用mybatis-plus连接达梦数据库&#xff0c;枚举类型无法读取 枚举类&#xff1a; 实体&#xff1a; 数据库字段&#xff1a; mybatis-plus枚举包配置&#xff1a; 调用查询方法&#xff1a; List<QualityRuleTemplate> qualityRuleTemplates ruleTemplateServic…

可视化监控云平台/智能监控EasyCVR如何使用脚本创建ramdisk挂载并在ramdisk中临时运行

视频云存储/安防监控EasyCVR视频汇聚平台基于云边端智能协同&#xff0c;支持海量视频的轻量化接入与汇聚、转码与处理、全网智能分发、视频集中存储等。安防管理视频平台EasyCVR拓展性强&#xff0c;视频能力丰富&#xff0c;具体可实现视频监控直播、视频轮播、视频录像、云存…

Datawhale 12月组队学习 leetcode基础 day1 枚举

这是一个新的专栏&#xff0c;主要是一些算法的基础&#xff0c;对想要刷leedcode的同学会有一定的帮助&#xff0c;如果在算法学习中遇到了问题&#xff0c;也可以直接评论或者私信博主&#xff0c;一定倾囊相助 进入正题&#xff0c;今天咱们要说的枚举算法&#xff0c;这是个…

高压功率放大器的作用是什么

高压功率放大器是一种电子设备&#xff0c;其作用是将低电平的信号增强到高功率水平&#xff0c;以驱动要求高电压和电流的负载。它在各种应用中起着至关重要的作用&#xff0c;包括无线通信、医疗仪器、雷达系统、工业控制等领域。 高压功率放大器在无线通信中具有重要意义。在…

[common c/c++] 为什么使用 semaphore 的生产者消费者模型需要两个信号量

正文&#xff1a; 信号量没有触及上限则阻塞post的原语&#xff0c;同时信号量除了系统限制的信号量最大值之外并没有接口可以用来设置上限。因此在一个信号量场景下&#xff0c;生产者在 post 信号的时候是没有束缚的&#xff0c;如果不控制生产量的话&#xff0c;会导致系统…

小红书产品评测怎么做?商家必看

以小红书为代表的社交电商平台&#xff0c;产品评测成为了消费者决策的重要参考。一篇高质量的产品评测&#xff0c;不仅能够帮助消费者全面了解产品也能提升商家品牌的知名度和口碑。因此&#xff0c;小红书产品评测的重要性不言而喻。 本文旨在为商家提供一份详尽的小红书产…

基于Qt的Live2D模型显示以及控制

基本说明 Live2D官方提供有控制Live2D模型的SDK,而且还提供了一个基于OpenGL的C项目Example,我们可以基于该项目改成Qt的项目&#xff0c;做一个桌面端的Live2D桌宠程序。 官方例子 经过改造效果如下图所示。 官方项目配置 下载官方提供的SDK例程,&#xff0c;选择Cubism …

网上很火的记事软件有哪些?可以分类记事的工具选哪个

日常记事在生活及工作方面都是非常重要&#xff0c;选择好用的记事软件可以督促各项任务的按时完成&#xff0c;。随着科技的发展&#xff0c;越来越多的记事软件涌现出来&#xff0c;让人眼花缭乱。那么&#xff0c;网上很火的记事软件有哪些&#xff1f;可以分类记事的工具应…

TestCase与TransactionTestCase的区别

目录 一、概述 二、区别 1、事务管理方式 2、性能影响 3、适用场景 三、示例代码 TestCase示例代码 TransactionTestCase示例代码 四、总结 TestCase与TransactionTestCase是Django框架中两个重要的测试类&#xff0c;用于对数据库操作进行测试。在编写测试用例时&…

【PHP编程实战】手把手教你如何下载文件,实例代码详解!

本文将向大家详细介绍PHP文件下载实例代码&#xff0c;具有一定的参考价值。对于一个网站而言&#xff0c;文件下载功能几乎是必备的。因此&#xff0c;了解如何使用PHP实现文件下载是非常必要的。在接下来的内容中&#xff0c;我们将一起探讨PHP文件下载的实现方法。 无控制类…

版本控制:让你的代码有迹可循

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

软件测试之持续集成详解

什么是持续集成&#xff1f; 持续集成是一种 DevOps 软件开发实践。采用持续集成时&#xff0c;开发人员会定期将代码变更合并到一个中央存储库中&#xff0c;之后系统会自动运行构建和测试操作。持续集成通常是指软件发布流程的构建或集成阶段&#xff0c;需要用到自动化组件…