pytorch利用rnn通过sin预测cos 利用lstm预测手写数字

一.利用rnn通过sin预测cos

1.首先可视化一下数据

import numpy as np
from matplotlib import pyplot as plt
def show(sin_np,cos_np):plt.figure()plt.title('Sin and Cos', fontsize='18')plt.plot(steps, sin_np, 'r-', label='sin')plt.plot(steps, cos_np, 'b-', label='cos')plt.legend(loc='best')plt.show()
if __name__ == '__main__':steps = np.linspace(0, np.pi*2, 256, dtype=np.float32)sin_np = np.sin(steps)cos_np = np.cos(steps)#debug to showshow(sin_np, cos_np)

2.构建rnn模型,其中输入数据可以看成是一个batch,步长自定义,维度为1的数据

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
# import matplotlib.animation
import math, randomclass RNN(nn.Module):def __init__(self):super(RNN,self).__init__()self.rnn=nn.Sequential()  #batch,sequence,input_sizeself.rnn.add_module('rnn1',nn.RNN(input_size=1,hidden_size=64,num_layers=1,batch_first=True))self.linear=nn.Sequential()self.linear.add_module('linear1', nn.Linear(64,1))def forward(self, x):y,_=self.rnn(x)# print('y.shape:',y.shape)outs=[]for time_step in range(y.size(1)):outs.append(self.linear(y[:,time_step,:]))return torch.stack(outs,dim=1)DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=RNN().to(DEVICE)
print('model:',model)
optimzer = optim.Adam(model.parameters(),lr=0.0001,weight_decay=0.00001)
criterion = nn.MSELoss()model.load_state_dict(torch.load('model_params.pth',map_location='cpu'))def train():# model=model.cuda()model.train()for epoch in range(10000):start, end = epoch*np.pi, (epoch+2)*np.pisteps = np.linspace(start,end,Times_step,dtype=np.float32)sin_x = np.sin(steps)cos_x = np.cos(steps)# print('sin_x.shape',sin_x.shape)#batch,sequence,input_size (1,256,1)if torch.cuda.is_available():sinx_input = torch.from_numpy(sin_x[np.newaxis,:,np.newaxis]).cuda()# print('sinx_input.shape:',sinx_input.shape)cosx_lable = torch.from_numpy(cos_x[np.newaxis, :, np.newaxis]).cuda()else:sinx_input = torch.from_numpy(sin_x[np.newaxis, :, np.newaxis])# print('sinx_input.shape:',sinx_input.shape)cosx_lable = torch.from_numpy(cos_x[np.newaxis, :, np.newaxis])y_pre = model(sinx_input)# print('y_pre.shape:',y_pre.shape)loss = criterion(y_pre,cosx_lable)optimzer.zero_grad()loss.backward()optimzer.step()if epoch%100==0:print('epoch,loss',epoch,loss)# plt.plot(steps, sinx_lable.cpu().data.numpy().flatten(),color='r')# plt.plot(steps, sinx_input.cpu().data.numpy().flatten(), color='b')# plt.show()torch.save(model.state_dict(), 'model_params.pth')  # save only the parameters
def eval():model.eval()start, end =0 * np.pi, (0+2) * np.pisteps = np.linspace(start, end, Times_step, dtype=np.float32)sin_x = np.sin(steps)print('sin_x:', sin_x)cos_x = np.cos(steps)# print('sin_x.shape',sin_x.shape)# batch,sequence,input_size (1,256,1)sinx_input = torch.from_numpy(sin_x[np.newaxis, :, np.newaxis])model.load_state_dict(torch.load('model_params.pth',map_location='cpu'))with torch.no_grad():y_pre=model(sinx_input)# print('sinx_input.shape:',sinx_input.shape)cosx_lable = torch.from_numpy(cos_x[np.newaxis, :, np.newaxis])plt.plot(steps, cosx_lable.data.numpy().flatten(), color='r',label='cosx_lable')plt.plot(steps, y_pre.data.numpy().flatten(), color='b',label='prediction')plt.legend(loc='best')plt.show()
if __name__ == '__main__':# train()eval()

    

最后一个epoch的结果

二.利用lstm预测手写数字

import torch
from torch import nn
from torch.autograd import Variable
import torchvision.datasets as dsets
import torch.utils.data as Data
import matplotlib.pyplot as plt
import torchvisiontorch.manual_seed(1)epochs = 5
BATCH_SIZE = 8
TIME_STEP = 28
INPUT_SIZE = 28
LR = 0.01
DOWNLOAD_MNIST = Falsetrain_data = dsets.MNIST(root='./mnist',train=True,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD_MNIST,
)test_data = torchvision.datasets.MNIST(root='./mnist', train=False)train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)# print(test_data.data.size())
test_x = Variable(test_data.data).type(torch.FloatTensor) / 255.
test_y = test_data.targets
print('==========test data shape========')
print(test_x.size())
print(test_y.size())class RNN(nn.Module):def __init__(self):super(RNN, self).__init__()self.rnn = nn.LSTM(input_size=INPUT_SIZE,hidden_size=64,num_layers=1,batch_first=True,)self.out = nn.Linear(64, 10)def forward(self, x):r_out, (h_n, h_c) = self.rnn(x)# print('r_out',r_out.size())# 取出最后一次循环的r_out传递到全连接层out = self.out(r_out[:, -1, :])return outrnn = RNN()
print('========ini rnn================')
print(rnn)optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()print('========start train model===============')
for epoch in range(epochs):for step, (x, y) in enumerate(train_loader):if step < 10:# print(x.size())#[batch,28,28]input = Variable(x.squeeze())#[batch]label = Variable(y)# print('============train data input shape=========')# print(input.size())# print(label.size())output = rnn(input)loss = loss_func(output, label)optimizer.zero_grad()loss.backward()optimizer.step()else:breaktest_output = rnn(test_x.squeeze())# print(test_output.size())pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()# print(pred_y.shape)accuracy = sum(pred_y == test_y.numpy()) / float(test_y.size(0))print('Epoch: ', epoch, '| train loss:%.4f' % loss.item(), '| test accuracy:%.2f' % accuracy)print('=================start test=======================')
test_output = rnn(test_x[:10].squeeze())
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10].numpy(), 'real number')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493252.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高德纳咨询公司(Gartner)预测:2019年七大人工智能科技趋势

来源&#xff1a;创新研究摘要&#xff1a;人工智能技术对我们的工作环境、工作种类等等正在产生日益深刻的影响&#xff0c;其结果或好或坏都有可能。为应对这种改变&#xff0c;特别是负面的变化&#xff0c;高德纳咨询公司&#xff08;Gartner&#xff09;于2018年12月13日发…

美爆!《自然》公布2018年19张最震撼的科学图片

来源&#xff1a;前瞻网 摘要&#xff1a;2018年注定将载入科学史册&#xff1a;气候上&#xff0c;从加利福尼亚烧到开普敦的致命野火和极端干旱、历史罕见;医学上&#xff0c;克隆和成像技术的进步既带来希望&#xff0c;也产生了争议;生物上&#xff0c;一系列事件让人们意识…

python实现Trie 树+朴素匹配字符串+RK算法匹配字符串+kmp算法匹配字符串

一.trie树应用&#xff1a; 相应leetcode 常用于搜索提示&#xff0c;如当输入一个网址&#xff0c;可以自动搜索出可能的选择。当没有完全匹配的搜索结果&#xff0c;可以返回前缀最相似的可能。 例如三个单词app, apple, add,我们按照以下规则创建了一颗Trie树.对于从树的根…

天才也勤奋!DeepMind哈萨比斯自述:领导400名博士向前,每天工作至凌晨4点

来源&#xff1a;量子位你见过凌晨4点的伦敦吗&#xff1f;哈萨比斯天天见。这位DeepMind创始人、AlphaGo之父&#xff0c;一直是全球赞颂的当世天才&#xff0c;但每天要到凌晨4点&#xff0c;才能睡下。这是哈萨比斯最新采访中透露的作息时间&#xff0c;他告诉《星期日泰晤士…

RNN知识+LSTM知识+encoder-decoder+ctc+基于pytorch的crnn网络结构

一&#xff0e;基础知识&#xff1a; 下图是一个循环神经网络实现语言模型的示例&#xff0c;可以看出其是基于当前的输入与过去的输入序列&#xff0c;预测序列的下一个字符&#xff0e; 序列特点就是某一步的输出不仅依赖于这一步的输入&#xff0c;还依赖于其他步的输入或输…

利用flask写的接口(base64, 二进制, 上传视频流)+异步+gunicorn部署Flask服务+多gpu卡部署

一.flask写的接口 1.1 manage.py启动服务(发送图片base64版) 这里要注意的是用docker的话,记得端口映射 #coding:utf-8 import base64 import io import logging import picklefrom flask import Flask, jsonify, request from PIL import Image from sklearn import metric…

2018中国自动驾驶市场专题分析

来源&#xff1a;智车科技未来智能实验室是人工智能学家与科学院相关机构联合成立的人工智能&#xff0c;互联网和脑科学交叉研究机构。未来智能实验室的主要工作包括&#xff1a;建立AI智能系统智商评测体系&#xff0c;开展世界人工智能智商评测&#xff1b;开展互联网&#…

python写日志

需要再加入按照日期生成日志 #coding:utf-8 import logging import logging.handlers class Logger:logFile def __init__(self, logFile):self.logFile logFileself.logger logging.getLogger(mylogger)self.logger.setLevel(logging.INFO)rf_handler logging.handlers.…

MIT科学家Dimitri P. Bertsekas最新2019出版《强化学习与最优控制》(附书稿PDF讲义)...

来源&#xff1a;专知摘要&#xff1a;MIT科学家Dimitri P. Bertsekas今日发布了一份2019即将出版的《强化学习与最优控制》书稿及讲义&#xff0c;该专著目的在于探索这人工智能与最优控制的共同边界&#xff0c;形成一个可以在任一领域具有背景的人员都可以访问的桥梁。REINF…

yolov3 anchors用kmeans聚类出先验框+anchor宽高比分析

一&#xff0e;yolov v3聚类出框 # -*- coding: utf-8 -*- import numpy as np import random import argparse import os# # 参数名称 # parser argparse.ArgumentParser(description使用该脚本生成YOLO-V3的anchor boxes\n) # parser.add_argument(--input_annotation_txt…

Geoff Hinton:全新的想法将比微小的改进更有影响力

来源&#xff1a;AI科技评论摘要&#xff1a;日前&#xff0c;WIRED 对 Hinton 进行了一次专访&#xff0c;在访谈中&#xff0c;WIRED 针对人工智能带来的道德挑战和面临的挑战等问题进行了提问&#xff0c;以下为谈话内容。“作为一名谷歌高管&#xff0c;我认为在公开场合抱…

修改TOMCAT服务器图标为应用LOGO

在tomcat下部署应用程序&#xff0c;运行后&#xff0c;发现在地址栏中会显示tomcat的小猫咪图标。有时候&#xff0c;我们自己不想显示这个图标&#xff0c;想换成自己定义的的图标&#xff0c;那么按如下方法操作即可&#xff1a; 参考网上的解决方案&#xff1a;1、将$TOMCA…

python连接mysql的一些基础知识+安装Navicat可视化数据库+flask_sqlalchemy写数据库

一&#xff0e;mysql基础知识 &#xff11;&#xff0e;connect连接数据库 import pymysqldef get_conn():conn pymysql.connect(hostxxx.xxx.xxx.xxx, port3306, userroot, passwd, dbnewspaper_rest) # db:表示数据库名称return conn &#xff12;&#xff0e;创建表 im…

工业互联网平台创新发展白皮书(2018)

来源&#xff1a;走向智能论坛摘要&#xff1a;近日&#xff0c;在“2018年产业互联网与数据经济大会——首届工业互联网平台创新发展暨两化融合推进会”上&#xff0c;国家工业信息安全发展研究中心尹丽波主任发布并解读了《工业互联网平台创新发展白皮书&#xff08;2018&…

迭代器模式和组合模式混用

迭代器模式和组合模式混用 前言 园子里说设计模式的文章算得上是海量了&#xff0c;所以本篇文章所用到的迭代器设计模式和组合模式不提供原理解析&#xff0c;有兴趣的朋友可以到一些前辈的设计模式文章上学学&#xff0c;很多很有意思的。在Head First 设计模式这本书中&…

python实现可扩容队列

#coding:utf-8 """ fzh created on 2019/10/15 构建一个队列 """ import datetimeclass LoopQueue(object):def __init__(self, n10):self.arr [None] * (n1) # 由于特意浪费了一个空间&#xff0c;所以arr的实际大小应该是用户传入的容量1sel…

5G 产业链重要投资节点

来源&#xff1a;兴业证券 ▌5G:大通信容量及超低延时&#xff0c;未来多项应用的基础5G:高工作频率以及频谱带宽带来高通信容量5G(5thgeneration)是指第五代移动电话通信标准。3GPP(第三代合作伙伴计划&#xff0c;电信标准化机构)将5G标准分为了NSA(非独立组网)和SA(独立组网…

Kneser猜想与相关推广

本文本来是想放在Borsuk-Ulam定理的应用这篇文章当中。但是这个文章实在是太长&#xff0c;导致有喧宾夺主之嫌&#xff0c;从而独立出为一篇文章&#xff0c;仅供参考。$\newcommand{\di}{\mathrm{dist}}$ &#xff08;图1&#xff1a;Kneser叙述他的猜想原文手稿&#xff09;…

python .py文件变为.so文件进行加密

&#xff11;.mytest.py 需要加密的内容 #coding:utf-8 import datetimeclass Today():def get_time(self):print(datetime.datetime.now())def say(self):print("hello word!")today Today() today.say() today.get_time() 2.执行setup.py 也就是加密脚本 from…

从技术上解读大数据的应用现状和开源未来

来源&#xff1a;网络大数据作者 | 韩锐、 Lizy Kurian John、詹剑锋摘要&#xff1a;近年来&#xff0c;随着大数据系统的快速发展&#xff0c;各式各样的开源基准测试集被开发出来&#xff0c;以评测和分析大数据系统并促进其技术改进。然而&#xff0c;迄今为止&#xff0c;…