import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, LSTM, RepeatVector, TimeDistributed
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
from keras.utils import to_categorical
import numpy as np
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255.0

# TODO: fill in this area
def build_model():
    return None

model = build_model()

for epoch in range(10):
    model.fit(np.expand_dims(np.expand_dims(y_train, axis=-1), axis=-1), x_train, epochs=1)

    for digit in range(10):
        img = model.predict(np.expand_dims(np.array([[digit]]), axis=-1))[0]
        plt.imsave(str(epoch) + '_' + str(digit) + '.png', img)
        plt.clf()
