【人工智能基础】RNN实验

一、RNN特性

权重共享

wordi · weight + bais

持久记忆单元

wordi · weightword + baisword + hi · weighth + baish

二、公式化表达

ht = f(ht - 1, xt)
ht = tanh(Whhht - 1 + Wxhxt)
yt = Whyht

三、RNN网络正弦波波形预测

环境准备

import numpy as np
import torch
from torch import nn,optim
from matplotlib import pyplot as plt# 时间轴采样数
num_time_steps = 50
input_size = 1
hidden_size = 16
output_size = 1
lr = 0.01

RNN类

class Net(nn.Module):def __init__(self,):super(Net, self).__init__()self.rnn = nn.RNN(input_size = input_size, hidden_size = hidden_size, num_layers = 1,# 格式为[batch, seq, feature]batch_first = True)for p in self.rnn.parameters():nn.init.normal_(p,mean=0.0, std=0.001)self.linear = nn.Linear(hidden_size, output_size)def forward(self, x, hidden_prev):out, hidden_prev = self.rnn(x, hidden_prev)# [1, seq, h] => [seq, h]out = out.view(-1,hidden_size)# [seq, h] => [seq, 1]out = self.linear(out)# [seq, 1] => [1, seq, 1], 需要和y做均方差out = out.unsqueeze(dim=0)return out, hidden_prev.clone()

正弦数据构建函数

def create_image():start = np.random.randint(3, size=1)[0]time_steps = np.linspace(start, start + 10, num_time_steps)data = np.sin(time_steps)data = data.reshape(num_time_steps, 1)x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)return time_steps,x, y

训练模型


model = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr)hidden_prev = torch.zeros(1,1, hidden_size)
for iter in range(6000):time_steps,x, y = create_image()output, hidden_prev = model(x, hidden_prev)hidden_prev = hidden_prev.detach()loss = criterion(output,y)model.zero_grad()loss.backward()for p in model.parameters():torch.nn.utils.clip_grad_norm_(p,10)optimizer.step()if iter % 1000 == 0:plt.plot(time_steps[:-1], x.ravel(), c = 'b')plt.plot(time_steps[:-1], y.ravel(), c= 'r')plt.plot(time_steps[:-1], output.detach().numpy().ravel(), c= 'g')plt.show()print('Iteration:{} loss {}'.format(iter, loss.item()))

可以看到第二次绘制图像的时候,输出曲线基本拟合了目标曲线

未训练图像

训练后图像

图像预测

time_steps,x, y = create_image()predictions = []
# input = x[:, 0, :]
for i in range(x.shape[1]):input = x[:, i, :].view(1, 1, 1)(pred, hiden_prev) = model(input, hidden_prev)input = predpredictions.append(pred.detach().numpy().ravel()[0])x = x.data.numpy().ravel()y = y.data.numpy()
plt.scatter(time_steps[:-1], x.ravel(), s=90)
plt.plot(time_steps[:-1], x.ravel())plt.scatter(time_steps[1:],predictions)
plt.show()

输出的预测曲线基本与目标曲线相同

预测结果
p.s. 最后的实验应该是输入一个点,通过这个点来预测出整个正弦曲线,但是我尝试了很多次都失败了,只能修改成根据正弦函数的上一个点来预测下一个点
失败的图像

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/8617.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode-hot100题解—Day7

原题链接:力扣热题-HOT100 我把刷题的顺序调整了一下,所以可以根据题号进行参考,题号和力扣上时对应的,那么接下来就开始刷题之旅吧~ 1-8题见LeetCode-hot100题解—Day1 9-16题见LeetCode-hot100题解—Day2 17-24题见LeetCode-hot…

面试 Java 基础八股文十问十答第二十九期

面试 Java 基础八股文十问十答第二十九期 作者:程序员小白条,个人博客 相信看了本文后,对你的面试是有一定帮助的!关注专栏后就能收到持续更新! ⭐点赞⭐收藏⭐不迷路!⭐ 1)类加载过程 类加载…

OpenNJet产品体验-手把手在Ubuntu20.04系统从零部署到应用OpenNJet

目录 一、引言 二、OpenNJet产品安装 2.1下载OpenNJet安装包 2.2安装OpenNJet V2.0.1 ​2.3快速启动并测试OpenNJet 三、OpenNJet产品应用体验 3.1配置OpenNJet 3.2 部署 Web 应用程序 3.3启动 NJet 3.4访问 Web 应用程序 四、总结 一、引言 OpenNJet应用引擎是高性…

[iOS]从拾遗到Runtime(上)

[iOS]从拾遗到Runtime(上) 文章目录 [iOS]从拾遗到Runtime(上)写在前面名词介绍instance 实例对象class 类对象meta-class 元类对象为什么要有元类? runtimeMethod(objc_method)SEL(objc_selector)IMP 类缓存(objc_cache)Category(objc_category) 消息传递消息传递的…

Vue 3 中的 h() 与 mergeProps() API 详解

前言 在 Vue 3 中,随着 Composition API 的引入,我们有了更多的灵活性和控制权来构建我们的组件。其中,h() 函数和 mergeProps() 是在构建渲染函数或 JSX/TSX 时经常使用的两个工具。下面,我将对这两个 API 进行详细的解释。 h()…

2024OD机试卷-求最多可以派出多少支团队 (java\python\c++)

题目:求最多可以派出多少支团队 题目描述 用数组代表每个人的能力,一个比赛活动要求参赛团队的最低能力值为N,每个团队可以由1人或者2人组成,且1个人只能参加1个团队,计算出最多可以派出多少只符合要求的团队。 输入描述 第一行代表总人数,范围1-500000 第二行数组代…

大模型微调之 在亚马逊AWS上实战LlaMA案例(四)

大模型微调之 在亚马逊AWS上实战LlaMA案例(四) 在 Amazon SageMaker JumpStart 上微调 Llama 2 以生成文本 Meta 能够使用Amazon SageMaker JumpStart微调 Llama 2 模型。 Llama 2 系列大型语言模型 (LLM) 是预先训练和微调的生成文本模型的集合&#x…

C++string续

一.find_first_of与find 相同:都是从string里面找字符,传参格式一样(都可以从某个位置开始找) 不同:find_first_of只能找字符,find可以找字符串 find_first_of参数里面的string与char*是每个字符的集合,指找出string…

普通组件的注册-局部注册和全局注册

目录 一、局部注册和全局注册-概述 二、局部注册的使用示例 三、全局注册的使用示例 一、局部注册和全局注册-概述 组件注册有两种方式: 局部注册:只能在注册的组件内使用。使用方法:创建.vue文件,在使用的组件内导入并注册。…

QX-mini51单片机学习-----(3)流水灯

目录 1宏定义 2函数的定义 3延时函数 4标准库函数中的循环移位函数 5循环移位函数与左移和右移运算符的区别 6实例 7keil中DeBug的用法 1宏定义 是预处理语句不需要分号 #define uchar unsigned char//此时uchar代替unsigned char typedef是关键字 后面是接分号…

Python实战开发及案例分析(5)—— 贪心算法

贪心算法是一种在每一步选择中都采取当前状态下最好或最优(即最有利)的选择,从而希望导致结果是全局最好或最优的算法。贪心算法不能保证得到最优解,但在某些问题中非常有效,并容易实现。 案例分析:找零问…

【Web前端】JavaScript—01

1.Javascript简介 简称JS,是当前最流行、应用最广泛的客户端脚本语言,用来在网页中添加一些动态效果与交互功能。在web开发领域有着举足轻重的地位。 2.JavaScript包含内容 核心ECMAScript(es):提供语言的语法和基本对象(数据类型、运算符、流程控制等…

Android Studio 之颜色

在Android中,颜色值由透明度alpha和RGB(红、绿、蓝)三原色定义,有八位十六进制数与六位十六进制数两种编码,例如八位编码FFEEDDCC,FF表示透明度,EE表示红色的浓度,DD表示绿色的浓度&…

Java特性之设计模式【代理模式】

一、代理模式 概述 在代理模式(Proxy Pattern)中,一个类代表另一个类的功能。这种类型的设计模式属于结构型模式 在代理模式中,我们创建具有现有对象的对象,以便向外界提供功能接口 主要解决: 在直接访问…

设计模式——外观模式(Facade)

外观模式(Facade Pattern) 是一种结构型设计模式,它为一个子系统中的一组接口提供一个统一的高层接口,使得子系统更加容易使用。这种类型的设计模式属于结构型模式,它向客户端提供了一个接口,隐藏了子系统的…

项目管理-项目资源管理2/2

项目管理:每天进步一点点~ 活到老,学到老 ヾ(◍∇◍)ノ゙ 何时学习都不晚,加油 资源管理:6个过程“硅谷火箭管控” ①规划资源管理: 写计划 ②估算活动资源:估算团队资源&…

【代码随想录37期】Day01 二分查找 + 移除元素

二分查找 力扣704 贴一下之前的笔记: 没想到一下子写不出来,忘记什么是二分法了,这里回顾一下: 「二分查找 binary search」是一种基于分治策略的高效搜索算法。 它利用数据的有序性,每轮减少一半搜索范围&#xff…

Kafak 消费异常:The coordinator is not available.

Kafak 消费异常:The coordinator is not available. 1. 问题描述2. 问题排查2.1 Topic 状态异常2.2 `__consumer_offsets` 简介1. 问题描述 在新环境部署 Kafak 时,发现可以正常产生消息,但是无法正常消费消息,消费消息的异常日志如下: 11:59:53.315 [main] DEBUG org.a…

PPP点对点协议

概述 Point-to-Point Protocol,点到点协议,工作于数据链路层,在链路层上传输网络层协议前验证链路的对端,主要用于在全双工的同异步链路上进行点到点的数据传输。 PPP主要是用来通过拨号或专线方式在两个网络节点之间建立连接、…

docker-本地私有仓库、harbor私有仓库部署与管理

一、本地私有仓库: 1、本地私有仓库简介: docker本地仓库,存放镜像,本地的机器上传和下载,pull/push。 使用私有仓库有许多优点: 节省网络带宽,针对于每个镜像不用每个人都去中央仓库上面去下…