import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import tensorflow as tf
from tensorflow import keras
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
"""plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()"""
print(train_images.shape)
x = tf.placeholder(tf.float32,[None,28,28])
y_ = tf.placeholder(tf.float32,[None,28,28])
xs = tf.reshape(x,shape=[-1,28*28])
y_shape = tf.reshape(y_,shape=[-1,28*28])
with tf.variable_scope("encoder") as scope1:
w1 = tf.get_variable("w1",initializer=tf.random_normal([784,128],stddev=1))
w2 = tf.get_variable("w2",initializer=tf.random_normal([128,16],stddev=1))
b1 = tf.get_variable("b1",initializer=tf.zeros([1,128])+0.01)
b2 = tf.get_variable("b2",initializer=tf.zeros([1,16])+0.01)
l1 = tf.nn.sigmoid(tf.matmul(xs,w1)+b1)
l2 = tf.nn.sigmoid(tf.matmul(l1,w2)+b2)
with tf.variable_scope("decoder") as scope2:
w3 = tf.get_variable("w3",initializer=tf.random_normal([16,128],stddev=1))
w4 = tf.get_variable("w4",initializer=tf.random_normal([128,784],stddev=1))
b3 = tf.get_variable("b3",initializer=tf.zeros([1,128])+0.01)
b4 = tf.get_variable("b4",initializer=tf.zeros([1,784])+0.01)
l3 = tf.nn.sigmoid(tf.matmul(l2,w3)+b3)
y = tf.nn.sigmoid(tf.matmul(l3,w4)+b4)
loss = tf.reduce_mean(tf.square(y-y_shape))
opt = tf.train.AdamOptimizer(0.05).minimize(loss)
with tf.Session() as sess:
init = tf.global_variables_initializer()
srun = sess.run
srun(init)
for e in range(3001):
ts = e*100%60000
loss_val,_ = srun([loss,opt],{x:train_images[ts:ts+100],y_:train_images[ts:ts+100]})
if(e%100==0):
print(e,loss_val)
y_out = srun(y,{x:test_images[:1]})
y_show = tf.reshape(y_out,shape=[-1,28,28])
y_show = y_show.eval()
print(type(y_show))
print((y_show.shape))
print((train_images[0].shape))
plt.figure()
plt.imshow(y_show[0])
#plt.imshow(test_images[0])
plt.colorbar()
plt.grid(False)
plt.show()
plt.figure()
plt.imshow(test_images[0])
plt.colorbar()
plt.grid(False)
plt.show()