Pytorch高阶API示范——DNN二分类模型

代码部分:

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,TensorDataset"""
准备数据
"""#正负样本数量
n_positive,n_negative = 2000,2000#生成正样本, 小圆环分布
r_p = 5.0 + torch.normal(0.0,1.0,size = [n_positive,1])
theta_p = 2*np.pi*torch.rand([n_positive,1])
Xp = torch.cat([r_p*torch.cos(theta_p),r_p*torch.sin(theta_p)],axis = 1)
Yp = torch.ones_like(r_p)#生成负样本, 大圆环分布
r_n = 8.0 + torch.normal(0.0,1.0,size = [n_negative,1])
theta_n = 2*np.pi*torch.rand([n_negative,1])
Xn = torch.cat([r_n*torch.cos(theta_n),r_n*torch.sin(theta_n)],axis = 1)
Yn = torch.zeros_like(r_n)#汇总样本
X = torch.cat([Xp,Xn],axis = 0)
Y = torch.cat([Yp,Yn],axis = 0)#可视化
plt.figure(figsize = (6,6))
plt.scatter(Xp[:,0],Xp[:,1],c = "r")
plt.scatter(Xn[:,0],Xn[:,1],c = "g")
plt.legend(["positive","negative"])plt.show()"""
#构建输入数据管道
"""
ds = TensorDataset(X,Y)
dl = DataLoader(ds,batch_size = 10,shuffle=True)"""
2, 定义模型
"""
class DNNModel(nn.Module):def __init__(self):super(DNNModel, self).__init__()self.fc1 = nn.Linear(2,4)self.fc2 = nn.Linear(4,8)self.fc3 = nn.Linear(8,1)def forward(self,x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))y = nn.Sigmoid()(self.fc3(x))return ydef loss_func(self,y_pred,y_true):return nn.BCELoss()(y_pred,y_true)def metric_func(self,y_pred,y_true):y_pred = torch.where(y_pred > 0.5, torch.ones_like(y_pred, dtype=torch.float32),torch.zeros_like(y_pred, dtype=torch.float32))acc = torch.mean(1 - torch.abs(y_true - y_pred))return acc@propertydef optimizer(self):return torch.optim.Adam(self.parameters(), lr=0.001)model = DNNModel()# 测试模型结构
(features,labels) = next(iter(dl))
predictions = model(features)loss = model.loss_func(predictions,labels)
metric = model.metric_func(predictions,labels)print("init loss:",loss.item())
print("init metric:",metric.item())"""
3,训练模型
"""
def train_step(model, features, labels):# 正向传播求损失predictions = model(features)loss = model.loss_func(predictions,labels)metric = model.metric_func(predictions,labels)# 反向传播求梯度loss.backward()# 更新模型参数model.optimizer.step()model.optimizer.zero_grad()return loss.item(),metric.item()# 测试train_step效果
features,labels = next(iter(dl))    #非for循环就用next
train_step(model,features,labels)def train_model(model,epochs):for epoch in range(1,epochs+1):loss_list,metric_list = [],[]for features, labels in dl:lossi,metrici = train_step(model,features,labels)loss_list.append(lossi)metric_list.append(metrici)loss = np.mean(loss_list)metric = np.mean(metric_list)if epoch%100==0:print("epoch =",epoch,"loss = ",loss,"metric = ",metric)train_model(model,epochs = 300)# 结果可视化
fig, (ax1,ax2) = plt.subplots(nrows=1,ncols=2,figsize = (12,5))
ax1.scatter(Xp[:,0],Xp[:,1], c="r")
ax1.scatter(Xn[:,0],Xn[:,1],c = "g")
ax1.legend(["positive","negative"])
ax1.set_title("y_true")Xp_pred = X[torch.squeeze(model.forward(X)>=0.5)]
Xn_pred = X[torch.squeeze(model.forward(X)<0.5)]ax2.scatter(Xp_pred[:,0],Xp_pred[:,1],c = "r")
ax2.scatter(Xn_pred[:,0],Xn_pred[:,1],c = "g")
ax2.legend(["positive","negative"])
ax2.set_title("y_pred")plt.show()

结果展示:

数据部分:

在这里插入图片描述

结果分类:

在这里插入图片描述

思考:

本文中的DNN模型,将loss(损失),metric(准确率),optimizer(优化器)的定义放在了DNN网络中,这也产生了一系列的问题。首先在调用这些函数时,需要用网络名+“.”来调用。例如:loss = model.loss_func(predictions,labels)
但是这里最重要的一点是 def optimizer(self):在optimizer函数上面必须有:@property,若没有将会出现AttributeError: 'function' object has no attribute 'step'的报错。
@property装饰器的作用:我们可以使用@property装饰器来创建只读属性,@property装饰器会将方法转换为相同名称的只读属性,可以与所定义的属性配合使用,这样可以防止属性被修改。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389376.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OO期末总结

$0 写在前面 善始善终&#xff0c;临近期末&#xff0c;为一学期的收获和努力画一个圆满的句号。 $1 测试与正确性论证的比较 $1-0 什么是测试&#xff1f; 测试是使用人工操作或者程序自动运行的方式来检验它是否满足规定的需求或弄清预期结果与实际结果之间的差别的过程。 它…

数据可视化及其重要性:Python

Data visualization is an important skill to possess for anyone trying to extract and communicate insights from data. In the field of machine learning, visualization plays a key role throughout the entire process of analysis.对于任何试图从数据中提取和传达见…

【洛谷算法题】P1046-[NOIP2005 普及组] 陶陶摘苹果【入门2分支结构】Java题解

&#x1f468;‍&#x1f4bb;博客主页&#xff1a;花无缺 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! 本文由 花无缺 原创 收录于专栏 【洛谷算法题】 文章目录 【洛谷算法题】P1046-[NOIP2005 普及组] 陶陶摘苹果【入门2分支结构】Java题解&#x1f30f;题目…

python多项式回归_如何在Python中实现多项式回归模型

python多项式回归Let’s start with an example. We want to predict the Price of a home based on the Area and Age. The function below was used to generate Home Prices and we can pretend this is “real-world data” and our “job” is to create a model which wi…

充分利用UC berkeleys数据科学专业

By Kyra Wong and Kendall Kikkawa黄凯拉(Kyra Wong)和菊川健多 ( Kendall Kikkawa) 什么是“数据科学”&#xff1f; (What is ‘Data Science’?) Data collection, an important aspect of “data science”, is not a new idea. Before the tech boom, every industry al…

02-web框架

1 while True:print(server is waiting...)conn, addr server.accept()data conn.recv(1024) print(data:, data)# 1.得到请求的url路径# ------------dict/obj d["path":"/login"]# d.get(”path“)# 按着http请求协议解析数据# 专注于web业…

ai驱动数据安全治理_AI驱动的Web数据收集解决方案的新起点

ai驱动数据安全治理Data gathering consists of many time-consuming and complex activities. These include proxy management, data parsing, infrastructure management, overcoming fingerprinting anti-measures, rendering JavaScript-heavy websites at scale, and muc…

铁拳nat映射_铁拳如何重塑我的数据可视化设计流程

铁拳nat映射It’s been a full year since I’ve become an independent data visualization designer. When I first started, projects that came to me didn’t relate to my interests or skills. Over the past eight months, it’s become very clear to me that when cl…

DengAI —如何应对数据科学竞赛? (EDA)

了解机器学习 (Understanding ML) This article is based on my entry into DengAI competition on the DrivenData platform. I’ve managed to score within 0.2% (14/9069 as on 02 Jun 2020). Some of the ideas presented here are strictly designed for competitions li…

java.net.SocketException: Software caused connection abort: socket write erro

场景&#xff1a;接口测试 编辑器&#xff1a;eclipse 版本&#xff1a;Version: 2018-09 (4.9.0) testng版本&#xff1a;TestNG version 6.14.0 执行testng.xml时报错信息&#xff1a; 出现此报错原因之一&#xff1a;网上有人说是testng版本与eclipse版本不一致造成的&#…

使用K-Means对美因河畔法兰克福的社区进行聚类

介绍 (Introduction) This blog post summarizes the results of the Capstone Project in the IBM Data Science Specialization on Coursera. Within the project, the districts of Frankfurt am Main in Germany shall be clustered according to their venue data using t…

样本均值的抽样分布_抽样分布样本均值

样本均值的抽样分布One of the most important concepts discussed in the context of inferential data analysis is the idea of sampling distributions. Understanding sampling distributions helps us better comprehend and interpret results from our descriptive as …

玩转ceph性能测试---对象存储(一)

笔者最近在工作中需要测试ceph的rgw&#xff0c;于是边测试边学习。首先工具采用的intel的一个开源工具cosbench&#xff0c;这也是业界主流的对象存储测试工具。 1、cosbench的安装&#xff0c;启动下载最新的cosbench包wget https://github.com/intel-cloud/cosbench/release…

因果关系和相关关系 大数据_数据科学中的相关性与因果关系

因果关系和相关关系 大数据Let’s jump into it right away.让我们马上进入。 相关性 (Correlation) Correlation means relationship and association to another variable. For example, a movement in one variable associates with the movement in another variable. For…

vue取数据第一个数据_我作为数据科学家的第一个月

vue取数据第一个数据A lot.很多。 I landed my first job as a Data Scientist at the beginning of August, and like any new job, there’s a lot of information to take in at once.我于8月初找到了数据科学家的第一份工作&#xff0c;并且像任何新工作一样&#xff0c;一…

STL-开篇

基本概念 STL&#xff1a; Standard Template Library&#xff0c;标准模板库 定义&#xff1a; c引入的一个标准类库 特点&#xff1a;1&#xff09;数据结构和算法的 c实现&#xff08; 采用模板类和模板函数&#xff09;2&#xff09;数据的存储和算法的分离3&#xff09;高…

rcp rapido_为什么气流非常适合Rapido

rcp rapidoBack in 2019, when we were building our data platform, we started building the data platform with Hadoop 2.8 and Apache Hive, managing our own HDFS. The need for managing workflows whether it’s data pipelines, i.e. ETL’s, machine learning predi…

Mysql5.7开启远程

2019独角兽企业重金招聘Python工程师标准>>> 1.注掉bind-address #bind-address 127.0.0.1 2.开启远程访问权限 grant all privileges on *.* to root"xxx.xxx.xxx.xxx" identified by "密码"; 或 grant all privileges on *.* to root"%…

分类结果可视化python_可视化分类结果的另一种方法

分类结果可视化pythonI love good data visualizations. Back in the days when I did my PhD in particle physics, I was stunned by the histograms my colleagues built and how much information was accumulated in one single plot.我喜欢出色的数据可视化。 早在我获得…

算法组合 优化算法_算法交易简化了风险价值和投资组合优化

算法组合 优化算法Photo by Markus Spiske (left) and Jamie Street (right) on UnsplashMarkus Spiske (左)和Jamie Street(右)在Unsplash上的照片 In the last post, we saw how actual algorithms are developed and tested. In this post, we will figure out the level of…