Java教程

两秒了解基础RNN模型

本文主要是介绍两秒了解基础RNN模型,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

RNN是一种序列模型,所谓的序列模型就是序列中包含信息。

序列模型的严格定义是:输入或输出中包含序列数据的模型叫做序列模型。

其有两大特点:

  1. 输入(输出)元素之间是具有顺序关系,不同的顺序,得到的结果应该是不同的,比如‘不睡觉’和‘睡觉不’这两个短语的意思是不同的。
  2. 输入输出不定长。比如两聊天机器人,聊天的对话长度是不定的。

普通神经网络模型

       在说RNN的模型结果前,我们先看看简单的普通神经网络的结构(该图就是个向量机模型,其是神经网络的积木,神经网络就是由多个向量机搭建的):

然后再来看看RNN结构和普通神经网络的区别:

        RNN和普通神经网络的区别就是:上一个数据的输出会作为这一个的输入,其和普通神经网络的区别就是在输入层会多N(中间层有几个神经单元,就多几个,比如本图中中间层只有1个神经元,那么它就比普通神经网络在输入层多1个)个。

        还需要注意的是:RNN的输入是一组数据,一般在一组数据中,只有最后一个才有实际数据给到下一层或者作为结果给出。但是如果RNN层的下一层还是RNN结构,那么就需要每一个数据都要输出结果到下一层(在tensorflow中用参数return_sequence控制)。

现在再来看看我们怎样使用Tensorflow搭建一个简单的RNN模型:

首先创建模型(即说明你的神经网络有几层,每一层有几个神经元,神经元见怎样连接):

然后我们可以输出模型结构:

 

 然后编译模型(即说明你以什么作为损失函数,用那种方式来进行优化参数):

 然后就是训练模型:

 

最后就是使用模型进行预测了。

我在github上提交了一个使用RNN进行股票预测的简单模型,其中包含了用于股票预测的股票历史数据(注:本模型不提供任何实际买股建议),其网址为:rotten-meng/Stock_Price_Predict_Simple_RNN: 使用20天股票的高、开、低、收、上影线、下影线、实体相对于昨日价格的涨跌幅来预测后5天最高收盘价的涨跌幅。模型是使用keras中的SimpleRNN搭建,使用训练数据集中的后20%用做测试集。 (github.com)icon-default.png?t=M276https://github.com/rotten-meng/Stock_Price_Predict_Simple_RNN

这篇关于两秒了解基础RNN模型的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!