pytorch利用rnn通过sin预测cos 利用lstm预测手写数字

一.利用rnn通过sin预测cos

1.首先可视化一下数据

import numpy as np
from matplotlib import pyplot as plt
def show(sin_np,cos_np):plt.figure()plt.title('Sin and Cos', fontsize='18')plt.plot(steps, sin_np, 'r-', label='sin')plt.plot(steps, cos_np, 'b-', label='cos')plt.legend(loc='best')plt.show()
if __name__ == '__main__':steps = np.linspace(0, np.pi*2, 256, dtype=np.float32)sin_np = np.sin(steps)cos_np = np.cos(steps)#debug to showshow(sin_np, cos_np)

2.构建rnn模型,其中输入数据可以看成是一个batch,步长自定义,维度为1的数据

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
# import matplotlib.animation
import math, randomclass RNN(nn.Module):def __init__(self):super(RNN,self).__init__()self.rnn=nn.Sequential()  #batch,sequence,input_sizeself.rnn.add_module('rnn1',nn.RNN(input_size=1,hidden_size=64,num_layers=1,batch_first=True))self.linear=nn.Sequential()self.linear.add_module('linear1', nn.Linear(64,1))def forward(self, x):y,_=self.rnn(x)# print('y.shape:',y.shape)outs=[]for time_step in range(y.size(1)):outs.append(self.linear(y[:,time_step,:]))return torch.stack(outs,dim=1)DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=RNN().to(DEVICE)
print('model:',model)
optimzer = optim.Adam(model.parameters(),lr=0.0001,weight_decay=0.00001)
criterion = nn.MSELoss()model.load_state_dict(torch.load('model_params.pth',map_location='cpu'))def train():# model=model.cuda()model.train()for epoch in range(10000):start, end = epoch*np.pi, (epoch+2)*np.pisteps = np.linspace(start,end,Times_step,dtype=np.float32)sin_x = np.sin(steps)cos_x = np.cos(steps)# print('sin_x.shape',sin_x.shape)#batch,sequence,input_size (1,256,1)if torch.cuda.is_available():sinx_input = torch.from_numpy(sin_x[np.newaxis,:,np.newaxis]).cuda()# print('sinx_input.shape:',sinx_input.shape)cosx_lable = torch.from_numpy(cos_x[np.newaxis, :, np.newaxis]).cuda()else:sinx_input = torch.from_numpy(sin_x[np.newaxis, :, np.newaxis])# print('sinx_input.shape:',sinx_input.shape)cosx_lable = torch.from_numpy(cos_x[np.newaxis, :, np.newaxis])y_pre = model(sinx_input)# print('y_pre.shape:',y_pre.shape)loss = criterion(y_pre,cosx_lable)optimzer.zero_grad()loss.backward()optimzer.step()if epoch%100==0:print('epoch,loss',epoch,loss)# plt.plot(steps, sinx_lable.cpu().data.numpy().flatten(),color='r')# plt.plot(steps, sinx_input.cpu().data.numpy().flatten(), color='b')# plt.show()torch.save(model.state_dict(), 'model_params.pth')  # save only the parameters
def eval():model.eval()start, end =0 * np.pi, (0+2) * np.pisteps = np.linspace(start, end, Times_step, dtype=np.float32)sin_x = np.sin(steps)print('sin_x:', sin_x)cos_x = np.cos(steps)# print('sin_x.shape',sin_x.shape)# batch,sequence,input_size (1,256,1)sinx_input = torch.from_numpy(sin_x[np.newaxis, :, np.newaxis])model.load_state_dict(torch.load('model_params.pth',map_location='cpu'))with torch.no_grad():y_pre=model(sinx_input)# print('sinx_input.shape:',sinx_input.shape)cosx_lable = torch.from_numpy(cos_x[np.newaxis, :, np.newaxis])plt.plot(steps, cosx_lable.data.numpy().flatten(), color='r',label='cosx_lable')plt.plot(steps, y_pre.data.numpy().flatten(), color='b',label='prediction')plt.legend(loc='best')plt.show()
if __name__ == '__main__':# train()eval()

    

最后一个epoch的结果

二.利用lstm预测手写数字

import torch
from torch import nn
from torch.autograd import Variable
import torchvision.datasets as dsets
import torch.utils.data as Data
import matplotlib.pyplot as plt
import torchvisiontorch.manual_seed(1)epochs = 5
BATCH_SIZE = 8
TIME_STEP = 28
INPUT_SIZE = 28
LR = 0.01
DOWNLOAD_MNIST = Falsetrain_data = dsets.MNIST(root='./mnist',train=True,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD_MNIST,
)test_data = torchvision.datasets.MNIST(root='./mnist', train=False)train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)# print(test_data.data.size())
test_x = Variable(test_data.data).type(torch.FloatTensor) / 255.
test_y = test_data.targets
print('==========test data shape========')
print(test_x.size())
print(test_y.size())class RNN(nn.Module):def __init__(self):super(RNN, self).__init__()self.rnn = nn.LSTM(input_size=INPUT_SIZE,hidden_size=64,num_layers=1,batch_first=True,)self.out = nn.Linear(64, 10)def forward(self, x):r_out, (h_n, h_c) = self.rnn(x)# print('r_out',r_out.size())# 取出最后一次循环的r_out传递到全连接层out = self.out(r_out[:, -1, :])return outrnn = RNN()
print('========ini rnn================')
print(rnn)optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()print('========start train model===============')
for epoch in range(epochs):for step, (x, y) in enumerate(train_loader):if step < 10:# print(x.size())#[batch,28,28]input = Variable(x.squeeze())#[batch]label = Variable(y)# print('============train data input shape=========')# print(input.size())# print(label.size())output = rnn(input)loss = loss_func(output, label)optimizer.zero_grad()loss.backward()optimizer.step()else:breaktest_output = rnn(test_x.squeeze())# print(test_output.size())pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()# print(pred_y.shape)accuracy = sum(pred_y == test_y.numpy()) / float(test_y.size(0))print('Epoch: ', epoch, '| train loss:%.4f' % loss.item(), '| test accuracy:%.2f' % accuracy)print('=================start test=======================')
test_output = rnn(test_x[:10].squeeze())
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10].numpy(), 'real number')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493252.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高德纳咨询公司(Gartner)预测:2019年七大人工智能科技趋势

来源&#xff1a;创新研究摘要&#xff1a;人工智能技术对我们的工作环境、工作种类等等正在产生日益深刻的影响&#xff0c;其结果或好或坏都有可能。为应对这种改变&#xff0c;特别是负面的变化&#xff0c;高德纳咨询公司&#xff08;Gartner&#xff09;于2018年12月13日发…

Java进阶 创建和销毁对象

最近准备写点Javase的东西&#xff0c;希望可以帮助大家写出更好的代码。 1、给不可实例化的类提供私有构造器 比如&#xff1a;每个项目中都有很多工具类&#xff0c;提供了很多static类型的方法供大家使用&#xff0c;谁也不希望看到下面的代码&#xff1a; TextUtils textUt…

美爆!《自然》公布2018年19张最震撼的科学图片

来源&#xff1a;前瞻网 摘要&#xff1a;2018年注定将载入科学史册&#xff1a;气候上&#xff0c;从加利福尼亚烧到开普敦的致命野火和极端干旱、历史罕见;医学上&#xff0c;克隆和成像技术的进步既带来希望&#xff0c;也产生了争议;生物上&#xff0c;一系列事件让人们意识…

python实现Trie 树+朴素匹配字符串+RK算法匹配字符串+kmp算法匹配字符串

一.trie树应用&#xff1a; 相应leetcode 常用于搜索提示&#xff0c;如当输入一个网址&#xff0c;可以自动搜索出可能的选择。当没有完全匹配的搜索结果&#xff0c;可以返回前缀最相似的可能。 例如三个单词app, apple, add,我们按照以下规则创建了一颗Trie树.对于从树的根…

天才也勤奋!DeepMind哈萨比斯自述:领导400名博士向前,每天工作至凌晨4点

来源&#xff1a;量子位你见过凌晨4点的伦敦吗&#xff1f;哈萨比斯天天见。这位DeepMind创始人、AlphaGo之父&#xff0c;一直是全球赞颂的当世天才&#xff0c;但每天要到凌晨4点&#xff0c;才能睡下。这是哈萨比斯最新采访中透露的作息时间&#xff0c;他告诉《星期日泰晤士…

如何学好PPT设计

一般来说&#xff0c;学习PPT大致要经过三个阶段&#xff1a; 1、基础技能&#xff1a;基本的PPT软件学习、基础设计理论的学习、相关辅助软件的掌握(例如图片处理的photoshop&#xff0c;绘制矢量图形的coreldraw、illustrator等)&#xff1b; 2、讲故事的能力&#xff08;就是…

RNN知识+LSTM知识+encoder-decoder+ctc+基于pytorch的crnn网络结构

一&#xff0e;基础知识&#xff1a; 下图是一个循环神经网络实现语言模型的示例&#xff0c;可以看出其是基于当前的输入与过去的输入序列&#xff0c;预测序列的下一个字符&#xff0e; 序列特点就是某一步的输出不仅依赖于这一步的输入&#xff0c;还依赖于其他步的输入或输…

编程模式如何结束未响应的程序

有时要编程结束一个程序的运行。比如说 hWnd 是你要操作的那个窗口的句柄。如果是一般的情况::PostMessage(hWnd,WM_CLOSE,0,0);就可以了。&#xff08;注意不要发送 WM_DESTROY消息。这两者有什么区别呢&#xff1f;WM_CLOSE&#xff0c;会正常关闭程序&#xff0c;比如说&am…

利用flask写的接口(base64, 二进制, 上传视频流)+异步+gunicorn部署Flask服务+多gpu卡部署

一.flask写的接口 1.1 manage.py启动服务(发送图片base64版) 这里要注意的是用docker的话,记得端口映射 #coding:utf-8 import base64 import io import logging import picklefrom flask import Flask, jsonify, request from PIL import Image from sklearn import metric…

2018中国自动驾驶市场专题分析

来源&#xff1a;智车科技未来智能实验室是人工智能学家与科学院相关机构联合成立的人工智能&#xff0c;互联网和脑科学交叉研究机构。未来智能实验室的主要工作包括&#xff1a;建立AI智能系统智商评测体系&#xff0c;开展世界人工智能智商评测&#xff1b;开展互联网&#…

Ant步步为营(4)ant启动tomcat

前序&#xff1a; 最近产品要release&#xff0c;一直忙着测试&#xff0c;没有时间学习ant了&#xff0c;今天终于没什么事了赶紧写点东西。这个启动tomcat是好些天之前写的了。在这里跟大家分享一下。 build.xml <?xml version"1.0"?> <project name…

python写日志

需要再加入按照日期生成日志 #coding:utf-8 import logging import logging.handlers class Logger:logFile def __init__(self, logFile):self.logFile logFileself.logger logging.getLogger(mylogger)self.logger.setLevel(logging.INFO)rf_handler logging.handlers.…

MIT科学家Dimitri P. Bertsekas最新2019出版《强化学习与最优控制》(附书稿PDF讲义)...

来源&#xff1a;专知摘要&#xff1a;MIT科学家Dimitri P. Bertsekas今日发布了一份2019即将出版的《强化学习与最优控制》书稿及讲义&#xff0c;该专著目的在于探索这人工智能与最优控制的共同边界&#xff0c;形成一个可以在任一领域具有背景的人员都可以访问的桥梁。REINF…

PHP中的中文截取乱码问题_gb2312_utf-8

一、字符串编码为gb2312&#xff0c;一个中文占俩字节 public static function chinesesubstr($str, $start, $len) { // $str指字符串,$start指字符串的起始位置&#xff0c;$len指字符串长度$strlen $start $len; // 用$strlen存储字符串的总长度&#xff0c;即从字符串的起…

yolov3 anchors用kmeans聚类出先验框+anchor宽高比分析

一&#xff0e;yolov v3聚类出框 # -*- coding: utf-8 -*- import numpy as np import random import argparse import os# # 参数名称 # parser argparse.ArgumentParser(description使用该脚本生成YOLO-V3的anchor boxes\n) # parser.add_argument(--input_annotation_txt…

Geoff Hinton:全新的想法将比微小的改进更有影响力

来源&#xff1a;AI科技评论摘要&#xff1a;日前&#xff0c;WIRED 对 Hinton 进行了一次专访&#xff0c;在访谈中&#xff0c;WIRED 针对人工智能带来的道德挑战和面临的挑战等问题进行了提问&#xff0c;以下为谈话内容。“作为一名谷歌高管&#xff0c;我认为在公开场合抱…

修改TOMCAT服务器图标为应用LOGO

在tomcat下部署应用程序&#xff0c;运行后&#xff0c;发现在地址栏中会显示tomcat的小猫咪图标。有时候&#xff0c;我们自己不想显示这个图标&#xff0c;想换成自己定义的的图标&#xff0c;那么按如下方法操作即可&#xff1a; 参考网上的解决方案&#xff1a;1、将$TOMCA…

python连接mysql的一些基础知识+安装Navicat可视化数据库+flask_sqlalchemy写数据库

一&#xff0e;mysql基础知识 &#xff11;&#xff0e;connect连接数据库 import pymysqldef get_conn():conn pymysql.connect(hostxxx.xxx.xxx.xxx, port3306, userroot, passwd, dbnewspaper_rest) # db:表示数据库名称return conn &#xff12;&#xff0e;创建表 im…

工业互联网平台创新发展白皮书(2018)

来源&#xff1a;走向智能论坛摘要&#xff1a;近日&#xff0c;在“2018年产业互联网与数据经济大会——首届工业互联网平台创新发展暨两化融合推进会”上&#xff0c;国家工业信息安全发展研究中心尹丽波主任发布并解读了《工业互联网平台创新发展白皮书&#xff08;2018&…

迭代器模式和组合模式混用

迭代器模式和组合模式混用 前言 园子里说设计模式的文章算得上是海量了&#xff0c;所以本篇文章所用到的迭代器设计模式和组合模式不提供原理解析&#xff0c;有兴趣的朋友可以到一些前辈的设计模式文章上学学&#xff0c;很多很有意思的。在Head First 设计模式这本书中&…