博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow学习笔记(8):基于MNIST数据的循环神经网络RNN
阅读量:6435 次
发布时间:2019-06-23

本文共 6950 字,大约阅读时间需要 23 分钟。

前言

本文输入数据是MNIST,全称是Modified National Institute of Standards and Technology,是一组由这个机构搜集的手写数字扫描文件和每个文件对应标签的数据集,经过一定的修改使其适合机器学习算法读取。这个数据集可以从牛的不行的Yann LeCun教授的获取。

本系列的其他文章已经根据TensorFlow的官方教程基于MNIST数据集采用了softmax regression和CNN进行建模。为了完整性,本文对MNIST数据应用RNN模型求解,具体使用的RNN为LSTM。

关于RNN/LSTM的理论知识,可以参考

代码

# coding: utf-8# @author: 陈水平# @date:2017-02-14# # In[1]:import tensorflow as tfimport numpy as np# In[2]:sess = tf.InteractiveSession()# In[3]:from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('mnist/', one_hot=True)# In[4]:learning_rate = 0.001batch_size = 128n_input = 28n_steps = 28n_hidden = 128n_classes = 10x = tf.placeholder(tf.float32, [None, n_steps, n_input])y = tf.placeholder(tf.float32, [None, n_classes])# In[5]:def RNN(x, weight, biases):    # x shape: (batch_size, n_steps, n_input)    # desired shape: list of n_steps with element shape (batch_size, n_input)    x = tf.transpose(x, [1, 0, 2])    x = tf.reshape(x, [-1, n_input])    x = tf.split(0, n_steps, x)    outputs = list()    lstm = tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)    state = (tf.zeros([n_steps, n_hidden]),)*2    sess.run(state)    with tf.variable_scope("myrnn2") as scope:        for i in range(n_steps-1):            if i > 0:                scope.reuse_variables()            output, state = lstm(x[i], state)            outputs.append(output)    final = tf.matmul(outputs[-1], weight) + biases    return final# In[6]:def RNN(x, n_steps, n_input, n_hidden, n_classes):    # Parameters:    # Input gate: input, previous output, and bias    ix = tf.Variable(tf.truncated_normal([n_input, n_hidden], -0.1, 0.1))    im = tf.Variable(tf.truncated_normal([n_hidden, n_hidden], -0.1, 0.1))    ib = tf.Variable(tf.zeros([1, n_hidden]))    # Forget gate: input, previous output, and bias    fx = tf.Variable(tf.truncated_normal([n_input, n_hidden], -0.1, 0.1))    fm = tf.Variable(tf.truncated_normal([n_hidden, n_hidden], -0.1, 0.1))    fb = tf.Variable(tf.zeros([1, n_hidden]))    # Memory cell: input, state, and bias    cx = tf.Variable(tf.truncated_normal([n_input, n_hidden], -0.1, 0.1))    cm = tf.Variable(tf.truncated_normal([n_hidden, n_hidden], -0.1, 0.1))    cb = tf.Variable(tf.zeros([1, n_hidden]))    # Output gate: input, previous output, and bias    ox = tf.Variable(tf.truncated_normal([n_input, n_hidden], -0.1, 0.1))    om = tf.Variable(tf.truncated_normal([n_hidden, n_hidden], -0.1, 0.1))    ob = tf.Variable(tf.zeros([1, n_hidden]))    # Classifier weights and biases    w = tf.Variable(tf.truncated_normal([n_hidden, n_classes]))    b = tf.Variable(tf.zeros([n_classes]))    # Definition of the cell computation    def lstm_cell(i, o, state):        input_gate = tf.sigmoid(tf.matmul(i, ix) + tf.matmul(o, im) + ib)        forget_gate = tf.sigmoid(tf.matmul(i, fx) + tf.matmul(o, fm) + fb)        update = tf.tanh(tf.matmul(i, cx) + tf.matmul(o, cm) + cb)        state = forget_gate * state + input_gate * update        output_gate = tf.sigmoid(tf.matmul(i, ox) +  tf.matmul(o, om) + ob)        return output_gate * tf.tanh(state), state        # Unrolled LSTM loop    outputs = list()    state = tf.Variable(tf.zeros([batch_size, n_hidden]))    output = tf.Variable(tf.zeros([batch_size, n_hidden]))        # x shape: (batch_size, n_steps, n_input)    # desired shape: list of n_steps with element shape (batch_size, n_input)    x = tf.transpose(x, [1, 0, 2])    x = tf.reshape(x, [-1, n_input])    x = tf.split(0, n_steps, x)    for i in x:        output, state = lstm_cell(i, output, state)        outputs.append(output)    logits =tf.matmul(outputs[-1], w) + b    return logits# In[7]:pred = RNN(x, n_steps, n_input, n_hidden, n_classes)cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))# Initializing the variablesinit = tf.global_variables_initializer()# In[8]:# Launch the graphsess.run(init)for step in range(20000):    batch_x, batch_y = mnist.train.next_batch(batch_size)    batch_x = batch_x.reshape((batch_size, n_steps, n_input))    sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})    if step % 50 == 0:        acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})        loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})        print "Iter " + str(step) + ", Minibatch Loss= " +               "{:.6f}".format(loss) + ", Training Accuracy= " +               "{:.5f}".format(acc)print "Optimization Finished!"# In[9]:# Calculate accuracy for 128 mnist test imagestest_len = batch_sizetest_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))test_label = mnist.test.labels[:test_len]print "Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label})

输出如下:

Iter 0, Minibatch Loss= 2.540429, Training Accuracy= 0.07812Iter 50, Minibatch Loss= 2.423611, Training Accuracy= 0.06250Iter 100, Minibatch Loss= 2.318830, Training Accuracy= 0.13281Iter 150, Minibatch Loss= 2.276640, Training Accuracy= 0.13281Iter 200, Minibatch Loss= 2.276727, Training Accuracy= 0.12500Iter 250, Minibatch Loss= 2.267064, Training Accuracy= 0.16406Iter 300, Minibatch Loss= 2.234139, Training Accuracy= 0.19531Iter 350, Minibatch Loss= 2.295060, Training Accuracy= 0.12500Iter 400, Minibatch Loss= 2.261856, Training Accuracy= 0.16406Iter 450, Minibatch Loss= 2.220284, Training Accuracy= 0.17969Iter 500, Minibatch Loss= 2.276015, Training Accuracy= 0.13281Iter 550, Minibatch Loss= 2.220499, Training Accuracy= 0.14062Iter 600, Minibatch Loss= 2.219574, Training Accuracy= 0.11719Iter 650, Minibatch Loss= 2.189177, Training Accuracy= 0.25781Iter 700, Minibatch Loss= 2.195167, Training Accuracy= 0.19531Iter 750, Minibatch Loss= 2.226459, Training Accuracy= 0.18750Iter 800, Minibatch Loss= 2.148620, Training Accuracy= 0.23438Iter 850, Minibatch Loss= 2.122925, Training Accuracy= 0.21875Iter 900, Minibatch Loss= 2.065122, Training Accuracy= 0.24219...Iter 19350, Minibatch Loss= 0.001304, Training Accuracy= 1.00000Iter 19400, Minibatch Loss= 0.000144, Training Accuracy= 1.00000Iter 19450, Minibatch Loss= 0.000907, Training Accuracy= 1.00000Iter 19500, Minibatch Loss= 0.002555, Training Accuracy= 1.00000Iter 19550, Minibatch Loss= 0.002018, Training Accuracy= 1.00000Iter 19600, Minibatch Loss= 0.000853, Training Accuracy= 1.00000Iter 19650, Minibatch Loss= 0.001035, Training Accuracy= 1.00000Iter 19700, Minibatch Loss= 0.007034, Training Accuracy= 0.99219Iter 19750, Minibatch Loss= 0.000608, Training Accuracy= 1.00000Iter 19800, Minibatch Loss= 0.002913, Training Accuracy= 1.00000Iter 19850, Minibatch Loss= 0.003484, Training Accuracy= 1.00000Iter 19900, Minibatch Loss= 0.005693, Training Accuracy= 1.00000Iter 19950, Minibatch Loss= 0.001904, Training Accuracy= 1.00000Optimization Finished!Testing Accuracy: 0.992188

转载地址:http://ilhga.baihongyu.com/

你可能感兴趣的文章
Javascript数组对象的方法和属性
查看>>
oracle数据库的启动和停止
查看>>
《LoadRunner没有告诉你的》之七——使用 LoadRunner 连续长时间执行测试,如何保证参数化的数据足够又不会重复?...
查看>>
python easy_install django 安装
查看>>
读《图解HTTP》总结--第六章
查看>>
毕业就能拿到上万薪资的程序员他们都做了啥?
查看>>
最小的k个数
查看>>
iOS技巧之获取本机通讯录中的内容,解析通讯录源代码
查看>>
程序员从零到月薪15K的转变,python200G资料分享
查看>>
DNS域名解析的知识了解
查看>>
部署社交网站
查看>>
CentOS下如何修改主机名
查看>>
“机器人商店”是什么?卖机器人的吗?
查看>>
SVN的代码正确提交方法
查看>>
js框架 vue
查看>>
tomcat关闭时进程未退出
查看>>
Git分支管理策略
查看>>
kali安装软件遇到的问题&解决
查看>>
Azure系列2.1.10 —— CloudBlobClient
查看>>
【04-20】httpclient处理302重定向问题
查看>>