RNN知识+LSTM知识+encoder-decoder+ctc+基于pytorch的crnn网络结构

一.基础知识:

下图是一个循环神经网络实现语言模型的示例,可以看出其是基于当前的输入与过去的输入序列,预测序列的下一个字符.

序列特点就是某一步的输出不仅依赖于这一步的输入,还依赖于其他步的输入或输出.

其中n为批量大小,d为词向量大小

1.RNN:

xt不止与该时刻输入有关还与上一时刻的输出状态有关,而第t层的误差函数跟输出Ot直接相关,而Ot依赖于前面每一层的xi和si,故存在梯度消失或梯度爆炸的问题,对于长时序很难处理.所以可以进行改造让第t层的误差函数只跟该层{si,xi}有关.

RNN代码简单实现:

def one_hot(x, n_class, dtype=torch.float32):result = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)  # shape: (n, n_class)result.scatter_(1, x.long().view(-1, 1), 1)  # result[i, x[i, 0]] = 1return resultdef to_onehot(X, n_class):return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]def get_parameters(num_inputs, num_hiddens,num_outputs):def init_parameter(shape):param = torch.zeros(shape, device=device,dtype=torch.float32)nn.init.normal_(param, 0, 0.01)return torch.nn.Parameter(param)#权重参数w_xh = init_parameter((num_inputs, num_hiddens))w_hh = init_parameter((num_hiddens, num_hiddens))b_h = torch.nn.Parameter(torch.zeros(num_hiddens,device=device))#输出层参数w_hq = init_parameter((num_hiddens, num_outputs))b_q = torch.nn.Parameter(torch.zeros(num_outputs,device=device))return (w_xh, w_hh, b_h, w_hq, b_q)def rnn(inputs,state,params):w_xh, w_hh, b_h, w_hq, b_q = paramsH = stateoutputs = []for x in inputs:print('===x:', x)        #(batch_size,vocab_size) (vocab_size, num_hiddens)H = torch.tanh(torch.matmul(x, w_xh)+torch.matmul(H, w_hh)+b_h)# (batch_size,num_hiddens) (num_hiddens, num_hiddens)Y = torch.matmul(H, w_hq)+b_q# (batch_size,num_hiddens) (num_hiddens, num_outputs)outputs.append(Y)return outputs, Hdef init_rnn_state(batch_size, num_hiddens,device):return torch.zeros((batch_size, num_hiddens),device=device)def test_one_hot():X = torch.arange(10).view(2, 5)print('==X:', X)inputs = to_onehot(X, 10)print(len(inputs))print('==inputs:', inputs)# print('==inputs:', inputs[-1].shape)def test_rnn():X = torch.arange(5).view(1, 5)print('===X:', X)num_hiddens = 256vocab_size = 10#词典长度num_inputs, num_hiddens, num_outputs = vocab_size, num_hiddens, vocab_sizestate = init_rnn_state(X.shape[0], num_hiddens, device)inputs = to_onehot(X.to(device), vocab_size)print('===len(inputs), inputs', len(inputs), inputs)params = get_parameters(num_inputs, num_hiddens, num_outputs)outputs, state_new = rnn(inputs, state, params)print('==len(outputs), outputs[0].shape:', len(outputs), outputs[0].shape)print('==state.shape:', state.shape)print('==state_new.shape:', state_new.shape)if __name__ == '__main__':# test_one_hot()test_rnn()

2.LSTM:

传统RNN每个模块内只是一个简单的tanh层:

遗忘门:控制上一时间步的记忆细胞;

输入门:控制当前时间步的输入;
输出门:控制从记忆细胞到隐藏状态;
记忆细胞:⼀种特殊的隐藏状态的信息的流动,表示的是长期记忆;

h 是隐藏状态,表示的是短期记忆;

LSTM每个循环的模块内又有4层结构:3个sigmoid层,1个tanh层

细胞状态Ct,类似short cut信息流通畅顺,故可以解决梯度消失或爆炸的问题.

遗忘层,决定信息保留多少

更新层,这里要注意的是用了tanh,值域在-1,1,起到信息加强和减弱的作用.

输出层,上述两层的信息相加流通到这里以后,经过tanh函数得到输出值候选项,而候选项中的哪些部分最终会被输出由一个sigmoid层来决定.这时就得到了输出状态和输出值,下一时刻也是如此.

 

LSTM简单实现代码:

def one_hot(x, n_class, dtype=torch.float32):result = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)  # shape: (n, n_class)result.scatter_(1, x.long().view(-1, 1), 1)  # result[i, x[i, 0]] = 1return resultdef to_onehot(X, n_class):return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]def get_parameters(num_inputs, num_hiddens,num_outputs):def init_parameter(shape):param = torch.zeros(shape, device=device,dtype=torch.float32)nn.init.normal_(param, 0, 0.01)return torch.nn.Parameter(param)def final_init_parameter():return (init_parameter((num_inputs, num_hiddens)),init_parameter((num_hiddens, num_hiddens)),torch.nn.Parameter(torch.zeros(num_hiddens,device=device,dtype=torch.float32,requires_grad=True)))w_xf, w_hf, b_f = final_init_parameter()#遗忘门参数w_xi, w_hi, b_i = final_init_parameter()#输入门参数w_xo, w_ho, b_o = final_init_parameter()#输出门参数w_xc, w_hc, b_c = final_init_parameter()#记忆门参数w_hq = init_parameter((num_hiddens, num_outputs))#输出层参数b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32, requires_grad=True))return nn.ParameterList([w_xi, w_hi, b_i, w_xf, w_hf, b_f, w_xo, w_ho, b_o, w_xc, w_hc, b_c, w_hq, b_q])def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))def lstm(inputs, states, params):[w_xi, w_hi, b_i, w_xf, w_hf, b_f, w_xo, w_ho, b_o, w_xc, w_hc, b_c, w_hq, b_q] = params[H, C] = statesoutputs = []for x in inputs:print('===x:',x)I = torch.sigmoid(torch.matmul(x, w_xi) + torch.matmul(H, w_hi) + b_i)#输入门数据F = torch.sigmoid(torch.matmul(x, w_xf) + torch.matmul(H, w_hf) + b_f)#遗忘门数据O = torch.sigmoid(torch.matmul(x, w_xo) + torch.matmul(H, w_ho) + b_o)#输出门数据C_tila = torch.tanh(torch.matmul(x, w_xc) + torch.matmul(H, w_hc) + b_c)#C冒数据C = F*C + I*C_tilaH = torch.tanh(C)*O# print('H.shape', H.shape)# print('w_hq.shape', w_hq.shape)# print('b_q.shape:', b_q.shape)Y = torch.matmul(H, w_hq)+b_qoutputs.append(Y)return outputs, (H,C)def test_lstm():batch_size = 1X = torch.arange(5).view(batch_size, 5)print('===X:', X)num_hiddens = 256vocab_size = 10  # 词典长度inputs = to_onehot(X.to(device), vocab_size)print('===len(inputs), inputs', len(inputs), inputs)num_inputs, num_hiddens, num_outputs = vocab_size, num_hiddens, vocab_sizestates = init_lstm_state(batch_size, num_hiddens, device='cpu')params = get_parameters(num_inputs, num_hiddens, num_outputs)outputs, new_states = lstm(inputs, states, params)H, C = new_statesprint('===H.shape', H.shape)print('===C.shape', C.shape)print('===len(outputs), outputs[0].shape:', len(outputs), outputs[0].shape)
if __name__ == '__main__':# test_one_hot()# test_rnn()test_lstm()

 

3.Seq2seq模型在于,encoder层,由双层lstm实现隐藏状态编码信息,decoder层由双层lstm将encode层隐藏状态编码信息解码出来,这样也造成了decoder依赖最终时间步的隐藏状态,且RNN机制实际中存在长程梯度消失的问题,对于较长的句子,所以随着所需翻译句子的长度的增加,这种结构的效果会显著下降,也就引入后面的attention。与此同时,解码的目标词语可能只与原输入的部分词语有关,而并不是与所有的输入有关。  例如,当把“Hello world”翻译成“Bonjour le monde”时,“Hello”映射成“Bonjour”,“world”映射成“monde”。 # 在seq2seq模型中, 解码器只能隐式地从编码器的最终状态中选择相应的信息。然而,注意力机制可以将这种选择过程显式地建模。

Seq2seq代码案例,batch为4,单词长度为7,每个单词对应的embedding向量为8,lstm为两层

import torch.nn as nn
import d2l
import torch
import math#由于依赖最终时间步的隐藏状态,RNN机制实际中存在长程梯度消失的问题,对于较长的句子,
# 我们很难寄希望于将输入的序列转化为定长的向量而保存所有的有效信息,
# 所以随着所需翻译句子的长度的增加,这种结构的效果会显著下降。
#与此同时,解码的目标词语可能只与原输入的部分词语有关,而并不是与所有的输入有关。
# 例如,当把“Hello world”翻译成“Bonjour le monde”时,“Hello”映射成“Bonjour”,“world”映射成“monde”。
# 在seq2seq模型中,
# 解码器只能隐式地从编码器的最终状态中选择相应的信息。然而,注意力机制可以将这种选择过程显式地建模。#双层lstm实现隐藏层编码信息encode
class Seq2SeqEncoder(d2l.Encoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqEncoder, self).__init__(**kwargs)self.num_hiddens = num_hiddensself.num_layers = num_layersself.embedding = nn.Embedding(vocab_size, embed_size)#每个字符编码成一个向量self.rnn = nn.LSTM(embed_size, num_hiddens, num_layers, dropout=dropout, batch_first=False)def begin_state(self, batch_size, device):#(H, C)return [torch.zeros(size=(self.num_layers, batch_size, self.num_hiddens), device=device),torch.zeros(size=(self.num_layers, batch_size, self.num_hiddens), device=device)]def forward(self, X, *args):X = self.embedding(X)  # X shape: (batch_size, seq_len, embed_size)print('===encode X.shape', X.shape)X = X.transpose(0, 1)  # (seq_len, batch_size, embed_size)print('===encode X.shape', X.shape)state = self.begin_state(X.shape[1], device=X.device)out, state = self.rnn(X,state)print('===encode out.shape:', out.shape)#(seq_len, batch_size, num_hiddens)H, C = stateprint('===encode H.shape:', H.shape)#(num_layers, batch_size, num_hiddens)print('===encode C.shape:', C.shape)#(num_layers, batch_size, num_hiddens)return out, state#双层lstm将encode层隐藏层信息解码出来
class Seq2SeqDecoder(d2l.Decoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqDecoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.LSTM(embed_size, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, *args):return enc_outputs[1]def forward(self, X, state):X = self.embedding(X).transpose(0, 1)print('==decode X.shape', X.shape)# (seq_len, batch_size, embed_size)out, state = self.rnn(X, state)print('==decode out.shape:', out.shape)# (seq_len, batch_size, num_hiddens)H, C = stateprint('==decode H.shape:', H.shape)  # (num_layers, batch_size, num_hiddens)print('==decode C.shape:', C.shape)  # (num_layers, batch_size, num_hiddens)# Make the batch to be the first dimension to simplify loss computation.out = self.dense(out).transpose(0, 1)# (batch_size, seq_len, vocab_size)print('==decode final out.shape', out.shape)return out, statedef SequenceMask(X, X_len,value=0):print(X)print(X_len)print(X_len.device)maxlen = X.size(1)print('==torch.arange(maxlen)[None, :]:', torch.arange(maxlen)[None, :])print('==X_len[:, None]:', X_len[:, None])mask = torch.arange(maxlen)[None, :] < X_len[:, None]print(mask)X[~mask] = valueprint('X:', X)return Xdef masked_softmax(X, valid_length):# X: 3-D tensor, valid_length: 1-D or 2-D tensorsoftmax = nn.Softmax(dim=-1)if valid_length is None:return softmax(X)else:shape = X.shapeif valid_length.dim() == 1:try:valid_length = torch.FloatTensor(valid_length.numpy().repeat(shape[1], axis=0))  # [2,2,3,3]except:valid_length = torch.FloatTensor(valid_length.cpu().numpy().repeat(shape[1], axis=0))  # [2,2,3,3]else:valid_length = valid_length.reshape((-1,))# fill masked elements with a large negative, whose exp is 0X = SequenceMask(X.reshape((-1, shape[-1])), valid_length)return softmax(X).reshape(shape)
class MLPAttention(nn.Module):def __init__(self, ipt_dim, units, dropout, **kwargs):super(MLPAttention, self).__init__(**kwargs)# Use flatten=True to keep query's and key's 3-D shapes.self.W_k = nn.Linear(ipt_dim, units, bias=False)self.W_q = nn.Linear(ipt_dim, units, bias=False)self.v = nn.Linear(units, 1, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, query, key, value, valid_length):query, key = self.W_k(query), self.W_q(key)print("==query.size, key.size::", query.size(), key.size())# expand query to (batch_size, #querys, 1, units), and key to# (batch_size, 1, #kv_pairs, units). Then plus them with broadcast.print('query.unsqueeze(2).shape', query.unsqueeze(2).shape)print('key.unsqueeze(1).shape', key.unsqueeze(1).shape)features = query.unsqueeze(2) + key.unsqueeze(1)#print("features:",features.size())  #--------------开启scores = self.v(features).squeeze(-1)print('===scores:', scores.shape)attention_weights = self.dropout(masked_softmax(scores, valid_length))return torch.bmm(attention_weights, value)def test_encoder():encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size, seq_len)output, state = encoder(X)def test_decoder():X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size, seq_len)encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)state = decoder.init_state(encoder(X))out, state = decoder(X, state)def test_loss():X = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])SequenceMask(X, torch.FloatTensor([2, 3]))def test_dot():keys = torch.ones((2, 10, 2), dtype=torch.float)values = torch.arange((40), dtype=torch.float).view(1, 10, 4).repeat(2, 1, 1)print('==values.shape:', values.shape)# print(values)atten = MLPAttention(ipt_dim=2, units=8, dropout=0)atten(torch.ones((2, 1, 2), dtype=torch.float), keys, values, torch.FloatTensor([2, 6]))if __name__ == '__main__':test_encoder()# test_decoder()

encode输出: 

decode输出:

二.基于pytorch的crnn网络结构

地址:https://github.com/zonghaofan/crnn_pytorch

1.网络图:

首先卷积提取特征以后再用两层双向lstm提取时序特征

2.代码实现

import torch.nn as nn
import torch.nn.functional as F
import torchclass BiLSTM(nn.Module):def __init__(self,nIn,nHidden,nOut):super(BiLSTM,self).__init__()self.lstm=nn.LSTM(input_size=nIn,hidden_size=nHidden,bidirectional=True)self.embdding=nn.Linear(nHidden*2,nOut)#Sequence batch channels (W,b,c)def forward(self, input):recurrent,_=self.lstm(input)S,b,h=recurrent.size()S_line = recurrent.view(S*b,h)output=self.embdding(S_line)#[S*b,nout]output=output.view(S,b,-1)return outputclass CRNN(nn.Module):def __init__(self,imgH,imgC,nclass,nhidden):assert imgH==32super(CRNN,self).__init__()cnn = nn.Sequential()cnn.add_module('conv{}'.format(0), nn.Conv2d(imgC, 64, 3, 1, 1))cnn.add_module('relu{}'.format(0), nn.ReLU(True))cnn.add_module('pooling{}'.format(0),nn.MaxPool2d(2,2))cnn.add_module('conv{}'.format(1), nn.Conv2d(64, 128, 3, 1, 1))cnn.add_module('relu{}'.format(1), nn.ReLU(True))cnn.add_module('pooling{}'.format(1), nn.MaxPool2d(2, 2))cnn.add_module('conv{}'.format(2), nn.Conv2d(128, 256, 3, 1, 1))cnn.add_module('relu{}'.format(2), nn.ReLU(True))cnn.add_module('conv{}'.format(3), nn.Conv2d(256, 256, 3, 1, 1))cnn.add_module('relu{}'.format(3), nn.ReLU(True))cnn.add_module('pooling{}'.format(3), nn.MaxPool2d((1,2), 2))cnn.add_module('conv{}'.format(4), nn.Conv2d(256, 512, 3, 1, 1))cnn.add_module('relu{}'.format(4), nn.ReLU(True))cnn.add_module('BN{}'.format(4), nn.BatchNorm2d(512))cnn.add_module('conv{}'.format(5), nn.Conv2d(512, 512, 3, 1, 1))cnn.add_module('relu{}'.format(5), nn.ReLU(True))cnn.add_module('BN{}'.format(5), nn.BatchNorm2d(512))cnn.add_module('pooling{}'.format(5), nn.MaxPool2d((1, 2), 2))cnn.add_module('conv{}'.format(6), nn.Conv2d(512, 512, 2, 1, 0))cnn.add_module('relu{}'.format(6), nn.ReLU(True))self.cnn=cnnself.rnn=nn.Sequential(BiLSTM(512,nhidden,nhidden),BiLSTM(nhidden, nhidden, nclass))def forward(self,input):conv = self.cnn(input)print('conv.size():',conv.size())b,c,h,w=conv.size()assert h==1conv=conv.squeeze(2)#b ,c wconv=conv.permute(2,0,1) #w,b,crnn_out=self.rnn(conv)print('rnn_out.size():',rnn_out.size())out=F.log_softmax(rnn_out,dim=2)print('out.size():',out.size())return out
def lstm_test():print('===================LSTM===========================')model = BiLSTM(512, 256, 5600)print(model)x = torch.rand((41, 32, 512))print('input:', x.size())out = model(x)print(out.size())
def crnn_test():print('===================CRNN===========================')model = CRNN(32, 1, 3600, 256)print(model)x = torch.rand((32, 1, 32, 200))  # b c h wprint('input:', x.size())out = model(x)print(out.size())
if __name__ == '__main__':lstm_test()crnn_test()

lstm输出:

#crnn输出

3.提特征输入ctc过程

上面代码可看成,输入为(32,1,32,200)cnn提取特征过后,每张图片11个特征向量,每个特征向量长度为512,在LSTM中一个时间步就传入一个特征向量进行分类。一个特征向量就相当于原图中的一个小矩形区域,RNN的目标就是预测这个矩形区域为哪个字符,即根据输入的特征向量,进行预测,得到所有字符的softmax概率分布,这是一个长度为字符类别数的向量,作为CTC层的输入。如下图所示就是输入ctc的示例图

 

4.ctc loss

首先思考ctc解决什么问题,一般分类就是一张图片对应一类,那样拉一个全连接进行softmax即可分类,对于这种一张图片有好几个字符,上述就解决不了,故有一种思路是将输入图片的字符切割出来在进行分类,那这样的问题是分割不准怎么办?所以面临这种输入类别不定长的时候,就可以利用ctc进行解决,ctc的思想就是将输入图片提取特征变成时序步长,给出输入时序步长X的所有可能结果Y的输出分布。那么根据这个分布,我们可以输出最可能的结果。大胆猜测对于有一张个头差不多的猫猫与狗狗水平挨着的图片,ctc可能也能解决分类问题.

ctc的损失函数可以对CNN和RNN进行端到端的联合训练。

RNN这里有去冗余操作,例如,上图中RNN中有5个时间步,但最终输出两个字符,理想情况下 t0, t1, t2时刻都应映射为“a”,t3, t4 时刻都应映射为“b”,然后连接起来得到“aaabb”,那么合并结果为“ab”。但是在识别book这类字符会有问题。最后以“-”符号代表blank,RNN 输出序列时,在文本标签中的重复的字符之间插入一个“-”,比如输出序列为“bbooo-ookk”,则最后将被映射为“book”,即有blank字符隔开的话,连续相同字符就不进行合并。即对字符序列先删除连续重复字符,然后从路径中删除所有“-”字符,这个称为解码过程,而编码则是由神经网络来实现。引入blank机制,我们就可以很好地解决重复字符的问题。

4.1训练过程

其中t0,t1代表两个时间步长,黑色线代表a字符的路径,虚线代表空文本路径。

例如:对于时序步长为2的识别,有两个时间步长(t0,t1)和三个可能的字符为“a”,“b”和“-”,我们得到两个概率分布向量,如果采取最大概率路径解码的方法,则“--”的概率最大,即真实字符为空的概率为0.6*0.6=0.36。

但是为字符“a”的情况有多种组合,“aa”, “a-“和“-a”都是代表“a”,所以,输出“a”的概率应该为三种之和:

0.4*0.4+0.4*0.6+0.4*0.6=0.64 ,故a的概率最高。如果标签文本为“a”,则通过计算图像中为“a”的所有可能的对齐组合(或者路径)的分数之和来计算损失函数。

对于RNN给定输入概率分布矩阵为y={y1,y2,...,yT},T是序列长度,最后映射为标签文本l的总概率为:

其中B(π)代表从序列到序列的映射函数B变换后是文本l的所有路径集合,而π则是其中的一条路径。每条路径的概率为各个时间步中对应字符的分数的乘积。然后训练网络使得这个概率值最大化,类似于普通的分类,CTC的损失函数定义为概率的负最大似然函数,为了计算方便,对似然函数取对数。

然后通过对损失函数的计算,就可以对之前的神经网络进行反向传播,神经网络的参数根据所使用的优化器进行更新,从而找到最可能的像素区域对应的字符。这种通过映射变换和所有可能路径概率之和的方式使得CTC不需要对原始的输入字符序列进行准确的切分。

4.2推理过程

推理阶段,过程与训练阶段有所不同,我们用训练好的神经网络来识别新的文本图像。如果我们像上面一样将每种可能文本的所有路径计算出来,对于很长的时间步和很长的字符序列来说,计算量是非常庞大。

由于RNN在每一个时间步的输出为所有字符类别的概率分布,所以,我们取其中最大概率的字符作为该时间步的输出字符,然后将所有时间步得到一个字符进行拼接得到一个序列路径,即最大概率路径,再根据上面介绍的合并序列方法得到最终的预测文本结果。

如上图5个时间步长,输出结果为a->a->a->blank>b,合并去重结果就为ab。要注意的是字符之间有间距,需要添加blank。

4.3ctc loss代码示例

import torch
from torch import nnT = 50      #时序步长
C = 20      #类别数 排除blank
N = 2      # Batch size
S = 30      #一个batch中的label的最大时序步长
S_min = 10  #一个batch中label的最小字符个数# Initialize random batch of input vectors, for *size = (T,N,C)
#rnn输出结果
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
print('==input.shape:', input.shape)
#字符对应的label idx
target = torch.randint(low=1, high=C+1, size=(N, S), dtype=torch.long)
print('==target:', target)
#序列长度的值
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
print('==input_lengths.shape:', input_lengths.shape)
print('=input_lengths:', input_lengths)
#字符的长度
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
print('==target_lengths.shape:', target_lengths.shape)
print('==target_lengths:', target_lengths)ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)

5.一些可能改动的点

最后两层pooling设置为h=1,w=2的矩形,是因为文本大多数是高小而宽长,这样就可以不丢失宽度信息,利于区分i和L. 

如果数字过小,那就可以让横向长度不变,pool可以换成如下,这样横向长度基本不变,纵向减少两倍。

pool2 = nn.MaxPool2d((2, 2), (2, 1), (0, 1))
x=torch.rand((32,1,32,100))
print('=========input========')
print(x.shape)
print('=========output========')
pool = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))
y = pool(x)
print(y.shape)
#                   (h-2)/2+1          (w-1)/1+1
pool = nn.MaxPool2d(kernel_size=(2,1),stride=(2,1))
y = pool(x)
print(y.shape)#                   (h-2)+2*p/2+1          (w-2)+2*p/1+1
pool = nn.MaxPool2d(kernel_size=(2,2),stride=(2,1),padding=(1,0))
y = pool(x)
print(y.shape)

6.finetune新加字符

由于原先数据集不一定找得到,对于新加的字符,对除了最后一层的全连接进行冻结,例如原先最后一层是(512,5000),现在新加10个字符,变为(512,5010),则将原先的那一层权重矩阵平移过来.只需要训练(512,10)的矩阵.

参考:

http://colah.github.io/posts/2015-08-Understanding-LSTMs/

https://www.cnblogs.com/zhangchaoyang/articles/6684906.html

https://aijishu.com/a/1060000000135614

https://www.cnblogs.com/ydcode/p/11038064.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493245.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

利用flask写的接口(base64, 二进制, 上传视频流)+异步+gunicorn部署Flask服务+多gpu卡部署

一.flask写的接口 1.1 manage.py启动服务(发送图片base64版) 这里要注意的是用docker的话,记得端口映射 #coding:utf-8 import base64 import io import logging import picklefrom flask import Flask, jsonify, request from PIL import Image from sklearn import metric…

2018中国自动驾驶市场专题分析

来源&#xff1a;智车科技未来智能实验室是人工智能学家与科学院相关机构联合成立的人工智能&#xff0c;互联网和脑科学交叉研究机构。未来智能实验室的主要工作包括&#xff1a;建立AI智能系统智商评测体系&#xff0c;开展世界人工智能智商评测&#xff1b;开展互联网&#…

python写日志

需要再加入按照日期生成日志 #coding:utf-8 import logging import logging.handlers class Logger:logFile def __init__(self, logFile):self.logFile logFileself.logger logging.getLogger(mylogger)self.logger.setLevel(logging.INFO)rf_handler logging.handlers.…

MIT科学家Dimitri P. Bertsekas最新2019出版《强化学习与最优控制》(附书稿PDF讲义)...

来源&#xff1a;专知摘要&#xff1a;MIT科学家Dimitri P. Bertsekas今日发布了一份2019即将出版的《强化学习与最优控制》书稿及讲义&#xff0c;该专著目的在于探索这人工智能与最优控制的共同边界&#xff0c;形成一个可以在任一领域具有背景的人员都可以访问的桥梁。REINF…

yolov3 anchors用kmeans聚类出先验框+anchor宽高比分析

一&#xff0e;yolov v3聚类出框 # -*- coding: utf-8 -*- import numpy as np import random import argparse import os# # 参数名称 # parser argparse.ArgumentParser(description使用该脚本生成YOLO-V3的anchor boxes\n) # parser.add_argument(--input_annotation_txt…

Geoff Hinton:全新的想法将比微小的改进更有影响力

来源&#xff1a;AI科技评论摘要&#xff1a;日前&#xff0c;WIRED 对 Hinton 进行了一次专访&#xff0c;在访谈中&#xff0c;WIRED 针对人工智能带来的道德挑战和面临的挑战等问题进行了提问&#xff0c;以下为谈话内容。“作为一名谷歌高管&#xff0c;我认为在公开场合抱…

修改TOMCAT服务器图标为应用LOGO

在tomcat下部署应用程序&#xff0c;运行后&#xff0c;发现在地址栏中会显示tomcat的小猫咪图标。有时候&#xff0c;我们自己不想显示这个图标&#xff0c;想换成自己定义的的图标&#xff0c;那么按如下方法操作即可&#xff1a; 参考网上的解决方案&#xff1a;1、将$TOMCA…

python连接mysql的一些基础知识+安装Navicat可视化数据库+flask_sqlalchemy写数据库

一&#xff0e;mysql基础知识 &#xff11;&#xff0e;connect连接数据库 import pymysqldef get_conn():conn pymysql.connect(hostxxx.xxx.xxx.xxx, port3306, userroot, passwd, dbnewspaper_rest) # db:表示数据库名称return conn &#xff12;&#xff0e;创建表 im…

工业互联网平台创新发展白皮书(2018)

来源&#xff1a;走向智能论坛摘要&#xff1a;近日&#xff0c;在“2018年产业互联网与数据经济大会——首届工业互联网平台创新发展暨两化融合推进会”上&#xff0c;国家工业信息安全发展研究中心尹丽波主任发布并解读了《工业互联网平台创新发展白皮书&#xff08;2018&…

迭代器模式和组合模式混用

迭代器模式和组合模式混用 前言 园子里说设计模式的文章算得上是海量了&#xff0c;所以本篇文章所用到的迭代器设计模式和组合模式不提供原理解析&#xff0c;有兴趣的朋友可以到一些前辈的设计模式文章上学学&#xff0c;很多很有意思的。在Head First 设计模式这本书中&…

python实现可扩容队列

#coding:utf-8 """ fzh created on 2019/10/15 构建一个队列 """ import datetimeclass LoopQueue(object):def __init__(self, n10):self.arr [None] * (n1) # 由于特意浪费了一个空间&#xff0c;所以arr的实际大小应该是用户传入的容量1sel…

5G 产业链重要投资节点

来源&#xff1a;兴业证券 ▌5G:大通信容量及超低延时&#xff0c;未来多项应用的基础5G:高工作频率以及频谱带宽带来高通信容量5G(5thgeneration)是指第五代移动电话通信标准。3GPP(第三代合作伙伴计划&#xff0c;电信标准化机构)将5G标准分为了NSA(非独立组网)和SA(独立组网…

Kneser猜想与相关推广

本文本来是想放在Borsuk-Ulam定理的应用这篇文章当中。但是这个文章实在是太长&#xff0c;导致有喧宾夺主之嫌&#xff0c;从而独立出为一篇文章&#xff0c;仅供参考。$\newcommand{\di}{\mathrm{dist}}$ &#xff08;图1&#xff1a;Kneser叙述他的猜想原文手稿&#xff09;…

python .py文件变为.so文件进行加密

&#xff11;.mytest.py 需要加密的内容 #coding:utf-8 import datetimeclass Today():def get_time(self):print(datetime.datetime.now())def say(self):print("hello word!")today Today() today.say() today.get_time() 2.执行setup.py 也就是加密脚本 from…

从技术上解读大数据的应用现状和开源未来

来源&#xff1a;网络大数据作者 | 韩锐、 Lizy Kurian John、詹剑锋摘要&#xff1a;近年来&#xff0c;随着大数据系统的快速发展&#xff0c;各式各样的开源基准测试集被开发出来&#xff0c;以评测和分析大数据系统并促进其技术改进。然而&#xff0c;迄今为止&#xff0c;…

十八岁华裔天才携手「量子计算先驱」再次颠覆量子计算

来源&#xff1a;机器之心编译参与&#xff1a;刘晓坤、李泽南摘要&#xff1a;量子计算再一次「被打败了」。今年 8 月&#xff0c;刚刚年满 18 岁的 Ewin Tang 证明了经典算法能以和量子计算机相近的速度解决推荐问题&#xff0c;这位天才少女&#xff08;更正&#xff1a;不…

resnet系列+mobilenet v2+pytorch代码实现

一.resnet系列backbone import torch.nn as nn import math import torch.utils.model_zoo as model_zooBatchNorm2d nn.BatchNorm2d__all__ [ResNet, resnet18, resnet34, resnet50, resnet101, deformable_resnet18, deformable_resnet50,resnet152]model_urls {resnet18:…

广度优先搜索(BFS)与深度优先搜索(DFS)

一.广度优先搜索&#xff08;BFS&#xff09; 1.二叉树代码 # 实现一个二叉树 class TreeNode:def __init__(self, x):self.val xself.left Noneself.right Noneself.nexts []root_node TreeNode(1) node_2 TreeNode(2) node_3 TreeNode(3) node_4 TreeNode(4) node_…

骁龙855在AI性能上真的秒杀麒麟980?噱头而已

来源&#xff1a;网易智能摘要&#xff1a;前段时间的高通发布会上&#xff0c;有关骁龙855 AI性能达到友商竞品两倍的言论可谓是赚足了眼球。高通指出&#xff0c;骁龙855针对CPU、GPU、DSP都进行了AI计算优化&#xff0c;结合第四代AI引擎可以实现每秒超过7万亿次运算&#x…

MySQL主从复制(Master-Slave)与读写分离(MySQL-Proxy)实践 转载

http://heylinux.com/archives/1004.html MySQL主从复制&#xff08;Master-Slave&#xff09;与读写分离&#xff08;MySQL-Proxy&#xff09;实践 Mysql作为目前世界上使用最广泛的免费数据库&#xff0c;相信所有从事系统运维的工程师都一定接触过。但在实际的生产环境中&am…