admin管理员组文章数量:1122907
LSTM
封装成类
import tensorflow_pass_warning_wgs
import tensorflow as tf, numpy
import tensorflow.contrib as contf.set_random_seed(1)class model():def __init__(self, sentence):self.sentence = sentenceself.sequnec_length = 10 # 序列长度self.X_data, self.Y_data = self.get_data(sentence)self.hidden_size = 50 # 隐藏神经元数量self.num_classes = len(self.idx2char)self.batch_size = len(self.X_data)def get_data(self, sentence):# 词集向量self.idx2char = list(set(sentence))# 转字典char_dict = {w: i for i, w in enumerate(self.idx2char)}# 构造数据集x_data, y_data = [], []for i in range(0, len(sentence) - self.sequnec_length):x_str = sentence[i: i + self.sequnec_length] # 措开取y_str = sentence[i + 1: i + self.sequnec_length + 1]print(i, x_str, '->', y_str)# 词袋模型x, y = [char_dict[c] for c in x_str], [char_dict[c] for c in y_str]x_data.append(x)y_data.append(y)return x_data, y_data# 训练def train(self):X, Y = tf.placeholder(tf.int32, [None, self.sequnec_length]), tf.placeholder(tf.int32, [None, self.sequnec_length])X_oneHot = tf.one_hot(X, self.num_classes)cells = [con.rnn.LSTMCell(num_units=self.hidden_size) for _ in range(2)] # 深层RNN,多个RNN基础单元mul_cells = con.rnn.MultiRNNCell(cells) # 堆叠RNN基础单元outputs, state = tf.nn.dynamic_rnn(mul_cells, X_oneHot, dtype=tf.float32)outputs = tf.reshape(outputs, shape=[-1, self.hidden_size])logits = con.layers.fully_connected(outputs, self.num_classes, activation_fn=None) # 全连接logits = tf.reshape(logits, shape=[self.batch_size, self.sequnec_length, self.num_classes])weight = tf.ones(shape=[self.batch_size, self.sequnec_length])cost = tf.reduce_mean(con.seq2seq.sequence_loss(logits=logits, targets=Y, weights=weight))optimizer = tf.train.AdamOptimizer(0.1).minimize(cost)with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(500):c_, lo_, _ = sess.run([cost, logits, optimizer], feed_dict={X: self.X_data, Y: self.Y_data})for j, res in enumerate(lo_):index = numpy.argmax(res, 1)if j == 0:ret = ''.join([self.idx2char[c] for c in index])else:ret += self.idx2char[index[-1]]print(i, c_, ret)if ret == self.sentence[1:]:breaksentence = ("if you want to build a ship, don't drum up people together to ""collect wood and don't assign them tasks and work, but rather ""teach them to long for the endless immensity of the sea.")
m = model(sentence)
m.train()
普通方法
from __future__ import print_functionimport tensorflow as tf
import numpy as np
from tensorflow.contrib import rnn
import tensorflow.contrib as rnnn
from tensorflow.python.ops.rnn import dynamic_rnntf.set_random_seed(777) # reproducibilitysentence = ("if you want to build a ship, don't drum up people together to ""collect wood and don't assign them tasks and work, but rather ""teach them to long for the endless immensity of the sea.")char_set = list(set(sentence)) # print(len(char_set)) #25
char_dic = {w: i for i, w in enumerate(char_set)}
hidden_size = 50 # len(char_set) #25sequence_length = 10 # Any arbitrary number
data_dim = len(char_set) # 25
num_classes = len(char_set) # 25
learning_rate = 0.1# 构造数据集
dataX = []
dataY = []
for i in range(0, len(sentence) - sequence_length):x_str = sentence[i:i + sequence_length]y_str = sentence[i + 1: i + sequence_length + 1]print(i, x_str, '->', y_str)x = [char_dic[c] for c in x_str] # x str to index 字符转数字y = [char_dic[c] for c in y_str] # y str to indexdataX.append(x)dataY.append(y)batch_size = len(dataX) # 170
print(batch_size)
X = tf.placeholder(tf.int32, [None, sequence_length])
Y = tf.placeholder(tf.int32, [None, sequence_length])X_one_hot = tf.one_hot(X, num_classes) # 独热编码 #print(X_one_hot) (?, 10, 25)# 建一个有隐藏单元的LSTM,Make a lstm cell with hidden_size (each unit output vector size)
def cell():# cell = rnn.BasicLSTMCell(hidden_size, state_is_tuple=True)# cell = rnn.GRUCell(hidden_size)cell = rnn.LSTMCell(hidden_size, state_is_tuple=True)return cellmulti_cells = rnn.MultiRNNCell([cell() for _ in range(2)], state_is_tuple=True)# outputs:展开隐藏层 unfolding size x hidden size, state = hidden size
outputs, _states = dynamic_rnn(multi_cells, X_one_hot, dtype=tf.float32)
# 全连接层FC layer
X_for_fc = tf.reshape(outputs, [-1, hidden_size])
outputs = tf.contrib.layers.fully_connected(X_for_fc, num_classes, activation_fn=None)
# outputs = rnnn.layers.fully_connected(X_for_fc,num_classes,activation_fn=None)
# print(outputs.shape) #(?,25)
# 改变维度准备计算序列损失reshape out for sequence_loss
outputs = tf.reshape(outputs, [batch_size, sequence_length, num_classes]) # (170, 10, 25)
weights = tf.ones([batch_size, sequence_length]) # 所有的权重都是1 All weights are 1 (equal weights)
# 计算损失值
sequence_loss = tf.contrib.seq2seq.sequence_loss(logits=outputs, targets=Y, weights=weights)
# sequence_loss = rnnn.seq2seq.sequence_loss(logits=outputs,targets=Y,weights=weights)
mean_loss = tf.reduce_mean(sequence_loss)
train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(mean_loss)sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(500): # (500)_, lossval, results = sess.run([train_op, mean_loss, outputs], feed_dict={X: dataX, Y: dataY})# print(results.shape) (170,10,25)# if i == 49:# for j, result in enumerate(results): #j:[0,170) result:(10,25)# index = np.argmax(result, axis=1)# print(i, j, ''.join([char_set[t] for t in index]), l)results = sess.run(outputs, feed_dict={X: dataX}) # (170,10,25)for j, result in enumerate(results):index = np.argmax(result, axis=1)# print('----------------------------------------')# print(index)if j is 0: # 第一个结果10个字符组成一个句子 print all for the first result to make a sentenceret = ''.join([char_set[t] for t in index])else: # 其它取最后一个字符ret = ret + char_set[index[-1]]print(i, lossval, ret)if ret == sentence[1:]:break# #输出每个结果的最后一个字符检测效果 Let's print the last char of each result to check it works
# results = sess.run(outputs, feed_dict={X: dataX}) #(170,10,25)
# for j, result in enumerate(results):
# index = np.argmax(result, axis=1)
# if j is 0: #第一个结果10个字符组成一个句子 print all for the first result to make a sentence
# # print(''.join([char_set[t] for t in index]), end='')
# ret =''.join([char_set[t] for t in index])
# else: #其它取最后一个字符
# # print(char_set[index[-1]], end='')
# ret = ret + char_set[index[-1]]
本文标签: LSTM
版权声明:本文标题:LSTM 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1693795607a243291.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论