Pytorch高阶API示范——线性回归模型

本文与《20天吃透Pytorch》有所不同,《20天吃透Pytorch》中是继承之前的模型进行拟合,本文是单独建立网络进行拟合。

代码实现:

import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,TensorDataset"""
1.准备数据
"""
n=800   #样本数量#生成测试用的数据集
X = 10*torch.rand([n,2])-5.0    #torch.rand是均匀分布
w0 = torch.tensor([[2.0],[-3.0]])
b0 = torch.tensor([10.0])
Y = X@w0 + b0 + torch.normal(0.0,2.0,size=[n,1])    ## @表示矩阵乘法,增加正态扰动#数据可视化
plt.figure(figsize= (12,5))
ax1 = plt.subplot(121)
ax1.scatter(X[:,0],Y[:,0],c = 'b',label = 'samples')
ax1.legend()    #图例
plt.xlabel("x1")
plt.ylabel("y",rotation = 0)
ax2 = plt.subplot(122)
ax2.scatter(X[:,1],Y[:,0],c = 'g',label = 'samples')
ax2.legend()
plt.xlabel('x2')
plt.ylabel('y',rotation = 0)
plt.show()"""
构建通道
"""ds = TensorDataset(X,Y)
ds_train,ds_valid = torch.utils.data.random_split(ds,[int (n*0.7),n-int(n*0.7)])  #选取总样本的70%为训练数据
dl_train = DataLoader(ds_train,batch_size=10,shuffle=True)
dl_valid = DataLoader(ds_valid,batch_size=10,shuffle=True)"""
2.定义模型
"""class LinearRegression(torch.nn.Module):def __init__(self):super(LinearRegression, self).__init__()self.fc = nn.Linear(2,1)def forward(self,x):x = self.fc(x)return xnet = LinearRegression()
"""
3.训练模型
"""
loss_func = torch.nn.MSELoss()
optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)eporchs = 10
log_step_freq = 20for eporch in range(1,eporchs+1):net.train()loss_sum = 0.0metric_sum = 0.0step = 1for step,(features,labels) in enumerate(dl_train,1):predictions = net(features)loss = loss_func(predictions,labels)optimizer.zero_grad()loss.backward()optimizer.step()w = net.state_dict()["fc.weight"]b = net.state_dict()["fc.bias"]print("step =", step, "loss = ", loss)print("w =", w)print("b =", b)loss_sum += loss.item()"""
结果可视化
"""
w,b = net.state_dict()["fc.weight"],net.state_dict()["fc.bias"]plt.figure(figsize = (12,5))
ax1 = plt.subplot(121)
ax1.scatter(X[:,0],Y[:,0], c = "b",label = "samples")
ax1.plot(X[:,0],w[0,0]*X[:,0]+b[0],"-r",linewidth = 5.0,label = "model")
ax1.legend()
plt.xlabel("x1")
plt.ylabel("y",rotation = 0)ax2 = plt.subplot(122)
ax2.scatter(X[:,1],Y[:,0], c = "g",label = "samples")
ax2.plot(X[:,1],w[0,1]*X[:,1]+b[0],"-r",linewidth = 5.0,label = "model")
ax2.legend()
plt.xlabel("x2")
plt.ylabel("y",rotation = 0)plt.show()

结果展示:

数据部分:

在这里插入图片描述

线性回归结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389381.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue 上传图片限制大小和格式

<div class"upload-box clear"><span class"fl">上传图片</span><div class"artistDet-logo-box fl"><el-upload :action"this.baseServerUrl/fileUpload/uploadPic?filepathartwork" list-type"pic…

作业要求 20181023-3 每周例行报告

本周要求参见&#xff1a;https://edu.cnblogs.com/campus/nenu/2018fall/homework/2282 1、本周PSP 总计&#xff1a;927min 2、本周进度条 代码行数 博文字数 用到的软件工程知识点 217 757 PSP、版本控制 3、累积进度图 &#xff08;1&#xff09;累积代码折线图 &…

算命数据_未来的数据科学家或算命精神向导

算命数据Real Estate Sale Prices, Regression, and Classification: Data Science is the Future of Fortune Telling房地产销售价格&#xff0c;回归和分类&#xff1a;数据科学是算命的未来 As we all know, I am unusually blessed with totally-real psychic abilities.众…

openai-gpt_为什么到处都看到GPT-3?

openai-gptDisclaimer: My opinions are informed by my experience maintaining Cortex, an open source platform for machine learning engineering.免责声明&#xff1a;我的看法是基于我维护 机器学习工程的开源平台 Cortex的 经验而 得出 的。 If you frequent any part…

Pytorch高阶API示范——DNN二分类模型

代码部分&#xff1a; import numpy as np import pandas as pd from matplotlib import pyplot as plt import torch from torch import nn import torch.nn.functional as F from torch.utils.data import Dataset,DataLoader,TensorDataset""" 准备数据 &qu…

OO期末总结

$0 写在前面 善始善终&#xff0c;临近期末&#xff0c;为一学期的收获和努力画一个圆满的句号。 $1 测试与正确性论证的比较 $1-0 什么是测试&#xff1f; 测试是使用人工操作或者程序自动运行的方式来检验它是否满足规定的需求或弄清预期结果与实际结果之间的差别的过程。 它…

puppet puppet模块、file模块

转载&#xff1a;http://blog.51cto.com/ywzhou/1577356 作用&#xff1a;通过puppet模块自动控制客户端的puppet配置&#xff0c;当需要修改客户端的puppet配置时不用在客户端一一设置。 1、服务端配置puppet模块 &#xff08;1&#xff09;模块清单 [rootpuppet ~]# tree /et…

数据可视化及其重要性:Python

Data visualization is an important skill to possess for anyone trying to extract and communicate insights from data. In the field of machine learning, visualization plays a key role throughout the entire process of analysis.对于任何试图从数据中提取和传达见…

熊猫数据集_熊猫迈向数据科学的第三部分

熊猫数据集Data is almost never perfect. Data Scientist spend more time in preprocessing dataset than in creating a model. Often we come across scenario where we find some missing data in data set. Such data points are represented with NaN or Not a Number i…

Pytorch有关张量的各种操作

一&#xff0c;创建张量 1. 生成float格式的张量: a torch.tensor([1,2,3],dtype torch.float)2. 生成从1到10&#xff0c;间隔是2的张量: b torch.arange(1,10,step 2)3. 随机生成从0.0到6.28的10个张量 注意&#xff1a; (1).生成的10个张量中包含0.0和6.28&#xff…

mongodb安装失败与解决方法(附安装教程)

安装mongodb遇到的一些坑 浪费了大量的时间 在此记录一下 主要是电脑系统win10企业版自带的防火墙 当然还有其他的一些坑 一般的问题在第6步骤都可以解决&#xff0c;本教程的安装步骤不够详细的话 请自行百度或谷歌 安装教程很多 我是基于node.js使用mongodb结合Robo 3T数…

【洛谷算法题】P1046-[NOIP2005 普及组] 陶陶摘苹果【入门2分支结构】Java题解

&#x1f468;‍&#x1f4bb;博客主页&#xff1a;花无缺 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! 本文由 花无缺 原创 收录于专栏 【洛谷算法题】 文章目录 【洛谷算法题】P1046-[NOIP2005 普及组] 陶陶摘苹果【入门2分支结构】Java题解&#x1f30f;题目…

web性能优化(理论)

什么是性能优化&#xff1f; 就是让用户感觉你的网站加载速度很快。。。哈哈哈。 分析 让我们来分析一下从用户按下回车键到网站呈现出来经历了哪些和前端相关的过程。 缓存 首先看本地是否有缓存&#xff0c;如果有符合使用条件的缓存则不需要向服务器发送请求了。DNS查询建立…

python多项式回归_如何在Python中实现多项式回归模型

python多项式回归Let’s start with an example. We want to predict the Price of a home based on the Area and Age. The function below was used to generate Home Prices and we can pretend this is “real-world data” and our “job” is to create a model which wi…

充分利用UC berkeleys数据科学专业

By Kyra Wong and Kendall Kikkawa黄凯拉(Kyra Wong)和菊川健多 ( Kendall Kikkawa) 什么是“数据科学”&#xff1f; (What is ‘Data Science’?) Data collection, an important aspect of “data science”, is not a new idea. Before the tech boom, every industry al…

文本二叉树折半查询及其截取值

using System;using System.ComponentModel;using System.Data;using System.Drawing;using System.Text;using System.Windows.Forms;using System.Collections;using System.IO;namespace CS_ScanSample1{ /// <summary> /// Logic 的摘要说明。 /// </summary> …

nn.functional 和 nn.Module入门讲解

本文来自《20天吃透Pytorch》 一&#xff0c;nn.functional 和 nn.Module 前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API。 利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数&#xff0c;模型层&#xff0c;损失函数)。 Pytorch和神经网络…

10.30PMP试题每日一题

SC>0&#xff0c;CPI<1&#xff0c;说明项目截止到当前&#xff1a;A、进度超前&#xff0c;成本超值B、进度落后&#xff0c;成本结余C、进度超前&#xff0c;成本结余D、无法判断 答案将于明天和新题一起揭晓&#xff01; 10.29试题答案&#xff1a;A转载于:https://bl…

02-web框架

1 while True:print(server is waiting...)conn, addr server.accept()data conn.recv(1024) print(data:, data)# 1.得到请求的url路径# ------------dict/obj d["path":"/login"]# d.get(”path“)# 按着http请求协议解析数据# 专注于web业…

ai驱动数据安全治理_AI驱动的Web数据收集解决方案的新起点

ai驱动数据安全治理Data gathering consists of many time-consuming and complex activities. These include proxy management, data parsing, infrastructure management, overcoming fingerprinting anti-measures, rendering JavaScript-heavy websites at scale, and muc…