import time


import numpy as np

import tensorflow as tf

from tensorflow.python import pywrap_tensorflow


import lib.config.config as cfg

from lib.datasets import roidb as rdl_roidb

from lib.datasets.factory import get_imdb

from lib.datasets.imdb import imdb as imdb2

from lib.layer_utils.roi_data_layer import RoIDataLayer

from lib.nets.vgg16 import vgg16

from lib.utils.timer import Timer


try:

 import cPickle as pickle

except ImportError:

 import pickle

import os


def get_training_roidb(imdb):

   """Returns a roidb (Region of Interest database) for use in training."""

   if True:

       print('Appending horizontally-flipped training examples...')

       imdb.append_flipped_images()

       print('done')


   print('Preparing training data...')

   rdl_roidb.prepare_roidb(imdb)

   print('done')


   return imdb.roidb



def combined_roidb(imdb_names):

   """

   Combine multiple roidbs

   """


   def get_roidb(imdb_name):

       imdb = get_imdb(imdb_name)

       print('Loaded dataset `{:s}` for training'.format(imdb.name))

       imdb.set_proposal_method("gt")

       print('Set proposal method: {:s}'.format("gt"))

       roidb = get_training_roidb(imdb)

       return roidb


   roidbs = [get_roidb(s) for s in imdb_names.split('+')]

   roidb = roidbs[0]

   if len(roidbs) > 1:

       for r in roidbs[1:]:

           roidb.extend(r)

       tmp = get_imdb(imdb_names.split('+')[1])

       imdb = imdb2(imdb_names, tmp.classes)

   else:

       imdb = get_imdb(imdb_names)

   return imdb, roidb



class Train:

   def __init__(self):


       # Create network

       if cfg.FLAGS.network == 'vgg16':

           self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)

       else:

           raise NotImplementedError


       self.imdb, self.roidb = combined_roidb("voc_2007_trainval")


       self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)

       self.output_dir = cfg.get_output_dir(self.imdb, 'default')



   def train(self):


       # Create session

       tfconfig = tf.ConfigProto(allow_soft_placement=True)

       tfconfig.gpu_options.allow_growth = True

       sess = tf.Session(config=tfconfig)

       with sess.graph.as_default():


           tf.set_random_seed(cfg.FLAGS.rng_seed)

           layers = self.net.create_architecture(sess, "TRAIN", self.imdb.num_classes, tag='default')

           loss = layers['total_loss']

           lr = tf.Variable(cfg.FLAGS.learning_rate, trainable=False)

           momentum = cfg.FLAGS.momentum

           optimizer = tf.train.MomentumOptimizer(lr, momentum)


           gvs = optimizer.compute_gradients(loss)


           # Double bias

           # Double the gradient of the bias if set

           if cfg.FLAGS.double_bias:

               final_gvs = []

               with tf.variable_scope('Gradient_Mult'):

                   for grad, var in gvs:

                       scale = 1.

                       if cfg.FLAGS.double_bias and '/biases:' in var.name:

                           scale *= 2.

                       if not np.allclose(scale, 1.0):

                           grad = tf.multiply(grad, scale)

                       final_gvs.append((grad, var))

               train_op = optimizer.apply_gradients(final_gvs)

           else:

               train_op = optimizer.apply_gradients(gvs)


           # We will handle the snapshots ourselves

           self.saver = tf.train.Saver(max_to_keep=100000)

           # Write the train and validation information to tensorboard

           # writer = tf.summary.FileWriter(self.tbdir, sess.graph)

           # valwriter = tf.summary.FileWriter(self.tbvaldir)


       # Load weights

       # Fresh train directly from ImageNet weights

       print('Loading initial model weights from {:s}'.format(cfg.FLAGS.pretrained_model))

       variables = tf.global_variables()

       # Initialize all variables first

       sess.run(tf.variables_initializer(variables, name='init'))

       var_keep_dic = self.get_variables_in_checkpoint_file(cfg.FLAGS.pretrained_model)

       # Get the variables to restore, ignorizing the variables to fix

       variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)


       restorer = tf.train.Saver(variables_to_restore)

       restorer.restore(sess, cfg.FLAGS.pretrained_model)

       print('Loaded.')

       # Need to fix the variables before loading, so that the RGB weights are changed to BGR

       # For VGG16 it also changes the convolutional weights fc6 and fc7 to

       # fully connected weights

       self.net.fix_variables(sess, cfg.FLAGS.pretrained_model)

       print('Fixed.')

       sess.run(tf.assign(lr, cfg.FLAGS.learning_rate))

       last_snapshot_iter = 0


       timer = Timer()

       iter = last_snapshot_iter + 1

       last_summary_time = time.time()

       while iter < cfg.FLAGS.max_iters + 1:

           # Learning rate

           if iter == cfg.FLAGS.step_size + 1:

               # Add snapshot here before reducing the learning rate

               # self.snapshot(sess, iter)

               sess.run(tf.assign(lr, cfg.FLAGS.learning_rate * cfg.FLAGS.gamma))


           timer.tic()

           # Get training data, one batch at a time

           blobs = self.data_layer.forward()


           # Compute the graph without summary

           rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = self.net.train_step(sess, blobs, train_op)

           timer.toc()

           iter += 1


           # Display training information

           if iter % (cfg.FLAGS.display) == 0:

               print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '

                     '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n ' % \

                     (iter, cfg.FLAGS.max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box))

               print('speed: {:.3f}s / iter'.format(timer.average_time))


           if iter % cfg.FLAGS.snapshot_iterations == 0:

               self.snapshot(sess, iter )


   def get_variables_in_checkpoint_file(self, file_name):

       try:

           reader = pywrap_tensorflow.NewCheckpointReader(file_name)

           var_to_shape_map = reader.get_variable_to_shape_map()

           return var_to_shape_map

       except Exception as e:  # pylint: disable=broad-except

           print(str(e))

           if "corrupted compressed block contents" in str(e):

               print("It's likely that your checkpoint file has been compressed "

                     "with SNAPPY.")


   def snapshot(self, sess, iter):

       net = self.net


       if not os.path.exists(self.output_dir):

           os.makedirs(self.output_dir)


       # Store the model snapshot

       filename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.ckpt'

       filename = os.path.join(self.output_dir, filename)

       self.saver.save(sess, filename)

       print('Wrote snapshot to: {:s}'.format(filename))


       # Also store some meta information, random state, etc.

       nfilename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.pkl'

       nfilename = os.path.join(self.output_dir, nfilename)

       # current state of numpy random

       st0 = np.random.get_state()

       # current position in the database

       cur = self.data_layer._cur

       # current shuffled indeces of the database

       perm = self.data_layer._perm


       # Dump the meta info

       with open(nfilename, 'wb') as fid:

           pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)

           pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)

           pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)

           pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)


       return filename, nfilename



if __name__ == '__main__':

   train = Train()

   train.train()