100字范文,内容丰富有趣,生活中的好帮手!
100字范文 > GAN生成对抗网络-text to image原理与基本实现-文字转图像-11

GAN生成对抗网络-text to image原理与基本实现-文字转图像-11

时间:2022-04-03 08:05:52

相关推荐

GAN生成对抗网络-text to image原理与基本实现-文字转图像-11

实质上这是一个RNN的词语向量化模型 + 条件GAN

首先用一个RNN网络来将文字转换为向量,然后将生成的文

本向量加入到G和D网络中。

与普通GAN不同的是,这里多了一种错误情况,即看上去挺

真的,但是对应的描述与图不符合,也要给与惩罚。

如果不加的话,那么D所能获得的信息仅仅是G的生成图,失

去了判断图与描述是否符合的判断能力。

为什么还需要噪声输入?

这是因为一般情况下很多时候一句话就是描述内容(花的样

子)的,而不会描述style(style主要是包括背景和姿态)。

那么这种情况下我们就希望噪声能起到这种加入style的作用

,从而生成更加真实多样化的图片。

另外,通过特征可视化的方式,让z具有specific的style加入

功能,从而解决文本描述本身不对style进行任何阐述的问题,

随机化的z可以加入不同的style,从而增加生成样本的真实

性与多样性。

三个重点部分:

一、对于文本的处理,如何提取文本信息,作为我们生成器

的条件?

首先是如何文本的向量化

然后提取文本信息

二、对于图片的处理,需要添加负面的训练:

即:输入的文本和图片不对应的时候,要给出惩罚。

做出输入的队列:

正确的图片 + 正确的文本

错误的图片 + 错误的文本

三、创建输入队列

保证文本和图片对应

import tensorflow as tffrom gensim.models import word2vecfrom gensim.models import Word2Vecimport pandas as pdimport globimport numpy as npimport osimport matplotlib.pyplot as plt%matplotlib inline

from IPython import display

os.listdir('../input/gan-text-to-image-102flowers-rieyuguanghua')

n_input = 100n_hidden = 128image_height = 64image_width = 64image_depth = 3noise_dim = 100maxlength = 250NUM_EPOCHS = 100batch_size = 64

if not os.path.exists('102flowers'):!mkdir 102flowers!tar zxvf ../input/102flowersdataset/102flowers.tgz -C ./102flowers/display.clear_output()

all_text_filename = glob.glob('../input/cvpr/cvpr_flowers/text_c10/class_*/image_*.txt')all_text_filename.sort(key=lambda x:x.split('/')[-1])

all_image_filename = glob.glob('./102flowers/jpg/*.jpg')all_image_filename.sort()

all_text_filename = np.array(all_text_filename)all_image_filename = np.array(all_image_filename)wrong_image_filename = all_image_filename[np.random.permutation(len(all_image_filename))]

dataset_image = tf.data.Dataset.from_tensor_slices((all_image_filename, wrong_image_filename))

if not os.path.exists('../input/gan-text-to-image-102flowers-rieyuguanghua/all_text.txt'):with open('all_text.txt', 'at') as f:for a_text in all_text_filename:f.write(open(a_text).read().replace('\n', '') + '\n')if not os.path.exists('../input/gan-text-to-image-102flowers-rieyuguanghua/word_model'):sentences = word2vec.Text8Corpus('all_text.txt')model = word2vec.Word2Vec(sentences, size=100)model.save('word_model')else:model = Word2Vec.load('../input/gan-text-to-image-102flowers-rieyuguanghua/word_model')!cp ../input/gan-text-to-image-102flowers-rieyuguanghua/all_text.txt ./!cp ../input/gan-text-to-image-102flowers-rieyuguanghua/word_model ./word_vectors = model.wv

maxlength = max([len(open(a_text).read().split()) for a_text in all_text_filename])

n_steps = maxlength

def pad(x, maxlength=200):x1 = np.zeros((maxlength,100))x1[:len(x)] = xreturn x1

def text_vec(text_filenames):vec = []for a_text in text_filenames:all_word = open(a_text).read().split()all_vec = [word_vectors[w] for w in all_word if w in word_vectors]vec.append(all_vec)data = pd.Series(vec)data = data.apply(pad, maxlength=maxlength)data_ = np.concatenate(data).reshape(len(data),maxlength,100)return data_

data_text_emb = text_vec(all_text_filename)

def read_image(image_filename):image = tf.read_file(image_filename)image = tf.image.decode_jpeg(image, channels=3)image = tf.image.resize_image_with_crop_or_pad(image, 512, 512)image = tf.image.resize_images(image, (256, 256))#image = tf.image.convert_image_dtype(image, tf.float32)image = (image - tf.reduce_min(image))/(tf.reduce_max(image) - tf.reduce_min(image))return image

def _pre_func(real_image_name, wrong_image_name):wrong_image = read_image(wrong_image_name)real_image = read_image(real_image_name)return real_image, wrong_image

dataset_image = dataset_image.map(_pre_func)

dataset_image = dataset_image.batch(batch_size)

iterator = tf.data.Iterator.from_structure(dataset_image.output_types, dataset_image.output_shapes)real_image_batch, wrong_image_batch = iterator.get_next()

input_text = tf.placeholder(tf.float32, [None, n_steps, n_input])inputs_noise = tf.placeholder(tf.float32, [None, noise_dim], name='inputs_noise')

def length(shuru):return tf.reduce_sum(tf.sign(tf.reduce_max(tf.abs(shuru),reduction_indices=2)),reduction_indices=1)

def text_rnn(input_text, batch_size=64, reuse= tf.AUTO_REUSE):cell = tf.contrib.rnn.GRUCell(n_hidden,kernel_initializer = tf.truncated_normal_initializer(stddev=0.0001),bias_initializer = tf.truncated_normal_initializer(stddev=0.0001),reuse=reuse)output, _ = tf.nn.dynamic_rnn(cell,input_text,dtype=tf.float32,sequence_length = length(input_text))index = tf.range(0,batch_size)*n_steps + (tf.cast(length(input_text),tf.int32) - 1)flat = tf.reshape(output,[-1,int(output.get_shape()[2])])last = tf.gather(flat,index)return last

def get_generator(noise_img, image_depth, condition_label, is_train=True, alpha=0.2):with tf.variable_scope("generator", reuse= tf.AUTO_REUSE):# 100 x 1 to 4 x 4 x 512# 全连接层noise_img = tf.to_float(noise_img)noise_img = tf.layers.dense(noise_img, n_hidden)noise_img = tf.maximum(alpha * noise_img, noise_img)noise_img_ = tf.concat([noise_img, condition_label], 1)layer1 = tf.layers.dense(noise_img_, 4*4*512)layer1 = tf.reshape(layer1, [-1, 4, 4, 512])layer1 = tf.layers.batch_normalization(layer1, training=is_train)layer1 = tf.nn.relu(layer1)# batch normalization#layer1 = tf.layers.batch_normalization(layer1, training=is_train)# ReLU#layer1 = tf.nn.relu(layer1)# dropoutlayer1 = tf.nn.dropout(layer1, keep_prob=0.8)# 4 x 4 x 512 to 8 x 8 x 256layer2 = tf.layers.conv2d_transpose(layer1, 256, 3, strides=2, padding='same')layer2 = tf.layers.batch_normalization(layer2, training=is_train)layer2 = tf.nn.relu(layer2)layer2 = tf.nn.dropout(layer2, keep_prob=0.8)# 8 x 8 256 to 16x 16 x 128layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')layer3 = tf.layers.batch_normalization(layer3, training=is_train)layer3 = tf.nn.relu(layer3)layer3 = tf.nn.dropout(layer3, keep_prob=0.8)# 16 x 16 x 128 to 32 x 32 x 64layer4 = tf.layers.conv2d_transpose(layer3, 64, 3, strides=2, padding='same')layer4 = tf.layers.batch_normalization(layer4, training=is_train)layer4 = tf.nn.relu(layer4)# 64 x 64 x 32layer5 = tf.layers.conv2d_transpose(layer4, 32, 3, strides=2, padding='same')layer5 = tf.layers.batch_normalization(layer5, training=is_train)layer5 = tf.nn.relu(layer5)# 128 x 128 x 16layer6 = tf.layers.conv2d_transpose(layer5, 16, 3, strides=2, padding='same')layer6 = tf.layers.batch_normalization(layer6, training=is_train)layer6 = tf.nn.relu(layer6) # 256 x 256 x 3logits = tf.layers.conv2d_transpose(layer6, image_depth, 3, strides=2, padding='same')outputs = tf.tanh(logits)outputs = (outputs/2) + 0.5outputs = tf.clip_by_value(outputs, 0.0, 1.0)return outputs

def get_discriminator(inputs_img, condition_label, reuse= tf.AUTO_REUSE, alpha=0.2):with tf.variable_scope("discriminator", reuse=reuse):# 256 x 256 x 3 to 128 x 128 x 16# 第一层不加入BNlayer1 = tf.layers.conv2d(inputs_img, 16, 3, strides=2, padding='same')layer1 = tf.maximum(alpha * layer1, layer1)layer1 = tf.nn.dropout(layer1, keep_prob=0.8)# 128 x 128 x 16 to 64 x 64 x 32layer2 = tf.layers.conv2d(layer1, 32, 3, strides=2, padding='same')layer2 = tf.layers.batch_normalization(layer2, training=True)layer2 = tf.maximum(alpha * layer2, layer2)layer2 = tf.nn.dropout(layer2, keep_prob=0.8)# 32 x 32 x 64layer3 = tf.layers.conv2d(layer2, 64, 3, strides=2, padding='same')layer3 = tf.layers.batch_normalization(layer3, training=True)layer3 = tf.maximum(alpha * layer3, layer3)layer3 = tf.nn.dropout(layer3, keep_prob=0.8)# 16*16*128layer4 = tf.layers.conv2d(layer3, 128, 3, strides=2, padding='same')layer4 = tf.layers.batch_normalization(layer4, training=True)layer4 = tf.maximum(alpha * layer4, layer4)# 8*8*256layer5 = tf.layers.conv2d(layer4, 256, 3, strides=2, padding='same')layer5 = tf.layers.batch_normalization(layer5, training=True)layer5 = tf.maximum(alpha * layer5, layer5)# 4*4*512layer6 = tf.layers.conv2d(layer5, 512, 3, strides=2, padding='same')layer6 = tf.layers.batch_normalization(layer6, training=True)layer6 = tf.maximum(alpha * layer6, layer6)text_emb = tf.layers.dense(condition_label, 512)text_emb = tf.maximum(alpha * text_emb, text_emb)text_emb = tf.expand_dims(text_emb, 1)text_emb = tf.expand_dims(text_emb, 2)text_emb = tf.tile(text_emb, [1,4,4,1])layer_concat = tf.concat([layer6, text_emb], 3)layer7 = tf.layers.conv2d(layer_concat, 512, 1, strides=1, padding='same')layer7 = tf.layers.batch_normalization(layer7, training=True)layer7 = tf.maximum(alpha * layer7, layer7)flatten = tf.reshape(layer7, (-1, 4*4*512))logits = tf.layers.dense(flatten, 1)outputs = tf.sigmoid(logits)return logits, outputs

def get_loss(inputs_image, wrong_image, inputs_noise, condition_label, image_depth, smooth=0.1):g_outputs = get_generator(inputs_noise, image_depth, condition_label, is_train=True)d_logits_real, d_outputs_real = get_discriminator(inputs_image, condition_label)d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, condition_label, reuse=tf.AUTO_REUSE)d_logits_wrong, d_outputs_wrong = get_discriminator(wrong_image, condition_label, reuse=tf.AUTO_REUSE)print(inputs_image.get_shape(), condition_label.get_shape())# 计算Lossg_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_outputs_fake)*(1-smooth)))#g_loss_l1 = tf.reduce_mean(tf.abs(g_outputs - inputs_image))#g_loss = g_loss_ + g_loss_l1d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,labels=tf.ones_like(d_outputs_real)*(1-smooth)))d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.ones_like(d_outputs_fake)*smooth))d_loss_wrong = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_wrong,labels=tf.ones_like(d_outputs_wrong)*smooth))d_loss = d_loss_real + d_loss_fake + d_loss_wrongreturn g_loss, d_loss

def get_optimizer(g_loss, d_loss, beta1=0.4, learning_rate=0.001):train_vars = tf.trainable_variables()g_vars = [var for var in train_vars if var.name.startswith("generator")]d_vars = [var for var in train_vars if var.name.startswith("discriminator")]# Optimizerwith tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):g_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)d_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)return g_opt, d_opt

def plot_images(samples):#samples = (samples+1)/2fig, axes = plt.subplots(nrows=1, ncols=10, sharex=True, sharey=True, figsize=(20,2))for img, ax in zip(samples, axes):ax.imshow(img.reshape((256, 256, 3)))ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)fig.tight_layout(pad=0)

def show_generator_output(sess, n_images, inputs_noise, output_dim, test_text_vec):# condition_text = tf.to_float(condition_text)# last, b_size = sess.run(text_vec(condition_text, batch_size=n_images, reuse=tf.AUTO_REUSE))samples = sess.run(get_generator(inputs_noise, output_dim, test_text_vec, is_train=False))return samples

# 定义参数n_samples = 10learning_rate = 0.0002beta1 = 0.5

# 存储losslosses = []step = 0last = text_rnn(input_text)g_loss, d_loss = get_loss(real_image_batch, wrong_image_batch, inputs_noise, last, image_depth, smooth=0.1)g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, beta1, learning_rate)saver = tf.train.Saver()with tf.Session() as sess:#sess.run(tf.global_variables_initializer())model_file=tf.train.latest_checkpoint('../input/gan-text-to-image-102flowers-rieyuguanghua')saver.restore(sess, model_file)for epoch in range(791, 831):index = np.random.permutation(len(all_image_filename))data_text_emb = data_text_emb[index]all_image_filename = all_image_filename[index]wrong_image_filename = all_image_filename[np.random.permutation(len(all_image_filename))] dataset_image = tf.data.Dataset.from_tensor_slices((all_image_filename, wrong_image_filename))dataset_image = dataset_image.map(_pre_func)dataset_image = dataset_image.repeat(1)dataset_image = dataset_image.batch(batch_size)dataset_image_op = iterator.make_initializer(dataset_image)sess.run(dataset_image_op)i = 0while True: try:batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_dim))text_emb_batch = data_text_emb[i: i + batch_size]i = i + batch_size_ = sess.run([g_train_opt, d_train_opt], feed_dict={input_text: text_emb_batch,inputs_noise: batch_noise})#if step % 50 == 0:# saver.save(sess, "./model10.ckpt")# train_loss_d = d_loss.eval({input_text: text_emb_batch,# inputs_noise: batch_noise})# train_loss_g = g_loss.eval({input_text: text_emb_batch,# inputs_noise: batch_noise})# # losses.append((train_loss_d, train_loss_g))# print("Step {}....".format(step+1), #"Discriminator Loss: {:.4f}....".format(train_loss_d),#"Generator Loss: {:.4f}....". format(train_loss_g))# 显示图片step += 1 #except tf.errors.OutOfRangeError as e:except:#saver.save(sess, "./model10.ckpt")print('epoch', epoch, 'step', step)#print(e)#try:# sess.run(real_image_batch)#except Exception as e:# print(e)breakif epoch%10 == 0:#saver.save(sess, "./model10.ckpt")n_samples = 10condition_text = data_text_emb[:n_samples]test_noise = np.random.uniform(-1, 1, size=[n_samples, noise_dim])last_test = text_rnn(input_text, batch_size=n_samples, reuse=tf.AUTO_REUSE)test_text_vec = sess.run(last_test, feed_dict={input_text: condition_text})samples = show_generator_output(sess, n_samples, test_noise, 3, test_text_vec)plot_images(samples)saver.save(sess, "./model11.ckpt")

tf.reset_default_graph()vec = []test_word = """the petals on this flower are yellow with a red center,the petals on this flower are yellow with a red center"""all_vec = [word_vectors[w] for w in test_word if w in word_vectors]vec.append(all_vec)data = pd.Series(vec)data = data.apply(pad, maxlength=maxlength)data_ = np.concatenate(data).reshape(len(data),maxlength,100)test_text_vec = data_test_text_vec = test_text_vec.astype(np.float32)losses = []step = 0n_samples = 10test_noise = np.random.uniform(-1, 1, size=[n_samples, noise_dim])last_test = text_rnn(test_text_vec, batch_size=n_samples, reuse=tf.AUTO_REUSE)new_image = get_generator(test_noise, image_depth, last_test)saver = tf.train.Saver()with tf.Session() as sess:model_file=tf.train.latest_checkpoint('../input/gan-text-to-image-102flowers-rieyuguanghua')saver.restore(sess, model_file)samples = show_generator_output(sess, n_samples, test_noise, 3, last_test)plot_images(samples)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。