使用Pytorch实现linear_regression

使用Pytorch实现线性回归

# import necessary packages
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# Set necessary Hyper-parameters.
input_size = 1
output_size = 1
num_epochs = 60
learning_rate = 0.001
# Define a Toy dataset.
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168], [9.779], [6.182], [7.59], [2.167], [7.042], [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573], [3.366], [2.596], [2.53], [1.221], [2.827], [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)# Confirm the data shape.
print(x_train.shape, y_train.shape)
(15, 1) (15, 1)
# Linear regression model
model = nn.Linear(input_size, output_size)
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, )
# Train the model
for epoch in range(num_epochs):# Convert numpy arrays to torch tensorsinputs = torch.from_numpy(x_train)targets = torch.from_numpy(y_train)# Forward passoutputs = model(inputs)loss = criterion(outputs, targets)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()# Set an output counterif (epoch+1) % 5 == 0:print('Epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))# Plot the graph
predicted = model(torch.from_numpy(x_train)).detach().numpy()
plt.plot(x_train, y_train, 'ro', label='Original data')
plt.plot(x_train, predicted, label='Fitted line')
plt.legend()
plt.show()
Epoch [5/60], loss: 7.1598
Epoch [10/60], loss: 3.0717
Epoch [15/60], loss: 1.4154
Epoch [20/60], loss: 0.7443
Epoch [25/60], loss: 0.4722
Epoch [30/60], loss: 0.3618
Epoch [35/60], loss: 0.3169
Epoch [40/60], loss: 0.2985
Epoch [45/60], loss: 0.2909
Epoch [50/60], loss: 0.2876
Epoch [55/60], loss: 0.2861
Epoch [60/60], loss: 0.2853

在这里插入图片描述

# Save the model checkpoint
torch.save(model.state_dict(), 'model_param.ckpt')
torch.save(model, 'model.ckpt')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/157506.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

操作系统 应用题 例题+参考答案(考研真题)

1.(考研真题)一个多道批处理系统中仅有P1和P2两个作业,P2比P1晚5ms到达,它们的计算和I/O操作顺序如下。 P1:计算60ms,I/O 80ms,计算20ms。 P2:计算120ms,I/O 40ms&…

<Linux>权限管理|权限分类|权限设置|权限掩码|粘滞位

文章目录 Linux权限的概念Linux权限管理a. 文件访问者的分类b. 文件类型和访问权限c. 文件权限表示方法d. 文件权限的设置权限掩码file指令粘滞位 权限总结权限作业 Linux权限的概念 Linux下有两种用户:超级用户(root)和普通用户。 超级用户:可以在Lin…

学生党的福利!移动云重磅升级存储产品体系

如今,随着科学技术不断发展进步,电子产品的生产技术也变得越来越成熟。一方面,电子产品的功能越来越强大,质量越来越可靠;另一方面,产品价格越来越便宜,在人们生活中越来越普及。大学生群体可以…

基于纳什博弈的多微网主体电热双层共享策略(matlab代码)

目录 ​1 主要内容 2 部分代码 3 程序结果 4 下载链接 ​1 主要内容 该程序复现《Multi-Micro-Grid Main Body Electric Heating Double-Layer Sharing Strategy Based on Nash Game》模型,主要做的是构建基于纳什博弈的多微网主体电热双层共享模型,…

[笔记] 错排问题 #错排

参考:刷题笔记-错排问题总结 错排问题: 一个n个元素的排列,若一个排列中所有的元素都不在自己原来的位置上,那么这样的一个排列就称为原排列的一个错排。而研究一个排列的错排个数的问题,就称为错排问题(或…

java项目之木里风景文化管理平台(ssm+vue)

项目简介 木里风景文化管理平台实现了以下功能: 前台功能:用户进入系统可以实现首页,旅游公告,景区,景区商品,景区美食,旅游交通工具,红黑榜,个人中心,后台…

squid代理服务器(传统代理、透明代理、反向代理、ACL、日志分析)

一、Squid 代理服务器 (一)代理的工作机制 1、代替客户机向网站请求数据,从而可以隐藏用户的真实IP地址。 2、将获得的网页数据(静态 Web 元素)保存到缓存中并发送给客户机,以便下次请求相同的数据时快速…

【洛谷 B2001】入门测试题目 题解(模拟算法+顺序结构)

入门测试题目 题目描述 求两个整数的和。 输入格式 一行,两个用空格隔开的整数。 输出格式 两个整数的和。 样例 #1 样例输入 #1 1 2样例输出 #1 3样例 #2 样例输入 #2 10230 21312样例输出 #2 31542提示 对于 100 % 100\% 100% 的数据,输…

MySQL 插件之 连接控制插件(Connection-Control)

插件介绍 MySQL 5.7.17 以后提供了Connection-Control插件用来控制客户端在登录操作连续失败一定次数后的响应的延迟。该插件可有效的防止客户端暴力登录的风险(攻击)。该插件包含以下2个组件 CONNECTION_CONTROL:用来控制登录失败的次数及延迟响应时间CONNECTION_C…

Stable Diffusion XL网络结构-超详细原创

强烈推荐先看本人的这篇 Stable Diffusion1.5网络结构-超详细原创-CSDN博客 1 Unet 1.1 详细整体结构 1.2 缩小版整体结构 以生成图像1024x1024为例,与SD1.5的3个CrossAttnDownBlock2D和CrossAttnUpBlock2D相比,SDXL只有2个,但SDXL的Cros…

Rust语言精讲:数据类型全解析

大家好!我是lincyang。 今天,我们将深入探讨Rust语言中的数据类型,这是理解和掌握Rust的基础。 Rust语言数据类型概览 Rust是静态类型语言,所有变量类型在编译时确定。Rust的数据类型分为两类:标量类型和复合类型。…

动态神经网络时间序列预测

大家好,我是带我去滑雪! 神经网络投照是否存在反锁与记忆可以分为静态神经网络与动态神经网络。动态神经网络是指神经网络带有反做与记忆功能,无论是局部反馈还是全局反锁。通过反馈与记忆,神经网络能将前一时刻的数据保留&#x…

【MySQL】InnoDB中的索引

这里写目录标题 索引底层的数据结构:B树B树与B树的区别InnoDB与MyISAM在B树使用索引结构的不同? 聚簇索引非聚簇索引联合索引 B树索引适用的条件查询全值匹配匹配左边的列匹配列前缀匹配范围的值精确匹配某一列并范围匹配另外一列避免使用隐式转换 排序必…

【ARM AMBA AXI 入门 15 -- AXI-Lite 详细介绍】

请阅读【ARM AMBA AXI 总线 文章专栏导读】 文章目录 AXI LiteAXI-Full 介绍AXI Stream 介绍AXI Lite 介绍AXI Full 与 AIX Lite 差异总结AXI Lite AMBA AXI4 规范中包含三种不同的协议接口,分别是: AXI4-FullAXI4-LiteAXI4-Stream 上图中的 AXI FULL 和 AIX-Lite 我们都把…

centos安装神通数据库

1、安装 wget工具 yum install -y wget2、安装rar解压工具 wget --no-check-certificate http://www.rarlab.com/rar/rarlinux-x64-5.3.0.tar.gz tar zxvf rarlinux-x64-5.3.0.tar.gz && cd rar/ && make install3、下载oscar神通数据库(linux 64…

【GUI】-- 12 贪吃蛇小游戏之让小蛇动起来

GUI编程 04 贪吃蛇小游戏 4.3 第三步:让小蛇动起来(键盘控制) 首先,在构造器中要获取焦点事件、键盘监听事件并加入定时器(定时器定义需要实现ActionListener接口并重写actionPerformed方法): //构造器public GamePanel() {init();this.s…

jbase仪器接口设计

jbase的计划有借助虚拟M来实现连仪器,之前陆续写了些TCP逻辑,今天终于整理完成了仪器设计。首先用java的cs程序测试TCP的服务和客户端。 javafx的示例加强 package sample;import javafx.application.Application; import javafx.event.EventHandler; …

【Python第三方库】Requests

1.Requests主要作用是什么? 发送HTTP请求,获取响应数据 建议在学习本模块前,先大致了解下HTTP协议【网络基础】HTTP_记录测试点滴的博客-CSDN博客 2. 如何搭建requests环境 下载requests模块导入requests模块 3.Requests发送请求 (常用)1.get请求 语法格式: request…

什么是死锁,如何避免死锁

死锁的定义: 死锁是指在一个多线程或多进程的系统中,两个或多个进程(线程)被永久阻塞,无法向前推进。这是由于每个进程都在等待系统中的其他进程释放资源,而这些资源又只能由其他进程释放。这样&#xff0c…

Web 自动化神器 TestCafe—页面基本操作篇

前 言 Testcafe是基于node.js的框架,以操作简洁著称,是web自动化的神器 今天主要给大家介绍一下testcafe这个框架和页面元素交互的方法。 一、互动要求 使用 TestCafe 与元素进行交互操作,元素需满足以下条件:☟ 元素在 body 页…