# -*- coding: utf-8 -*-
"""Final model maker script

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1QYXBbLJs3Cqp5RFt2fENc01WvBNV2gCE
"""

#This file has been included in case the user has no program to run the .ipynb file (e.g. Google Colab or Jupyter Notebooks)


# Commented out IPython magic to ensure Python compatibility.
# Load the TensorBoard notebook extension
# %load_ext tensorboard

# Install Model Maker
!pip install -q tflite-model-maker
!pip install tf_explain

# Import required libraries
import os
import datetime

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
assert tf.__version__.startswith('2')

from tensorflow_hub.tools.make_image_classifier import make_image_classifier_lib as hub_lib
from tensorflow_examples.lite.model_maker.core import compat
from tensorflow_examples.lite.model_maker.core.task import train_image_classifier_lib
from tensorflow_examples.lite.model_maker.core.task import model_util

from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.image_classifier import DataLoader

from tf_explain.core.occlusion_sensitivity import OcclusionSensitivity

from PIL import Image, ImageFilter
import cv2
from skimage import img_as_ubyte
from skimage.transform import rotate
from skimage.util import random_noise
import random

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

#Get all image filenames
filenames = []

#Replace this with your original unmodified image directory
origin_dir = "/content/drive/MyDrive/Dissertation/Images/"
#Replace this with the directory into wish you want your augmented images to be saved
augmented_dir = "/content/drive/MyDrive/Dissertation/ImagesAugmented/"

for label_dir in os.listdir(origin_dir):
  for filename in os.listdir(os.path.join(origin_dir, label_dir)):
    if filename.endswith("jpg"):
        filenames.append(os.path.join(label_dir, filename))

#Clear augmented images prior to running augmentation
for label_dir in os.listdir(augmented_dir):
  for filename in os.listdir(os.path.join(augmented_dir, label_dir)):
    os.remove(os.path.join(augmented_dir, label_dir, filename))

#IMAGE AUGMENTATION
#Perform minor random rotations on image
for filename in filenames:
  image = cv2.imread(os.path.join(origin_dir, filename))

  #Rotate image 10 degrees either way to keep LFT upright
  angle = random.randint(-10, 10)

  rotated_image = rotate(image, angle, preserve_range=True)

  rotated_filename = filename.split(".")[0] + "_rotated.jpg"
  cv2.imwrite(augmented_dir + rotated_filename, rotated_image)

#Add noise to images
aug_filenames = []

for label_dir in os.listdir(augmented_dir):
  for filename in os.listdir(os.path.join(augmented_dir, label_dir)):
    if filename.endswith("jpg"):
        aug_filenames.append(os.path.join(augmented_dir, label_dir, filename))

for filename in aug_filenames:
  image = Image.open(filename)
  image_array = np.asarray(image)

  #Add noise
  noisy_image_array = random_noise(image_array)
  noisy_image_array = (255 * noisy_image_array).astype(np.uint8)

  noisy_image = Image.fromarray(noisy_image_array)
  noisy_filename = filename.split(".")[0] + "_noisy.jpg"

  noisy_image.save(noisy_filename)

#Adapted from https://stackoverflow.com/questions/32609098/how-to-fast-change-image-brightness-with-python-opencv
#Adjust brightness
aug_filenames = []

for label_dir in os.listdir(augmented_dir):
  for filename in os.listdir(os.path.join(augmented_dir, label_dir)):
    if filename.endswith("jpg"):
        aug_filenames.append(os.path.join(augmented_dir, label_dir, filename))

for filename in aug_filenames:
  image = cv2.imread(filename)
  hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
  h, s, v = cv2.split(hsv)

  value = random.randint(0, 50)
  lim = 255 - value
  v[v > lim] = 255
  v[v <= lim] += value

  final_hsv = cv2.merge((h, s, v))
  image = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
  adjusted_brightness_filename = filename.split(".")[0] + "_adjustedbrightness.jpg"
  cv2.imwrite(adjusted_brightness_filename, image)

#Adjust hue - Adapted from https://stackoverflow.com/questions/7274221/changing-image-hue-with-python-pi
def rgb_to_hsv(rgb):
    # Translated from source of colorsys.rgb_to_hsv
    # r,g,b should be a numpy arrays with values between 0 and 255
    # rgb_to_hsv returns an array of floats between 0.0 and 1.0.
    rgb = rgb.astype('float')
    hsv = np.zeros_like(rgb)
    # in case an RGBA array was passed, just copy the A channel
    hsv[..., 3:] = rgb[..., 3:]
    r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
    maxc = np.max(rgb[..., :3], axis=-1)
    minc = np.min(rgb[..., :3], axis=-1)
    hsv[..., 2] = maxc
    mask = maxc != minc
    hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask]
    rc = np.zeros_like(r)
    gc = np.zeros_like(g)
    bc = np.zeros_like(b)
    rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask]
    gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask]
    bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask]
    hsv[..., 0] = np.select(
        [r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc)
    hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0
    return hsv

def hsv_to_rgb(hsv):
    # Translated from source of colorsys.hsv_to_rgb
    # h,s should be a numpy arrays with values between 0.0 and 1.0
    # v should be a numpy array with values between 0.0 and 255.0
    # hsv_to_rgb returns an array of uints between 0 and 255.
    rgb = np.empty_like(hsv)
    rgb[..., 3:] = hsv[..., 3:]
    h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
    i = (h * 6.0).astype('uint8')
    f = (h * 6.0) - i
    p = v * (1.0 - s)
    q = v * (1.0 - s * f)
    t = v * (1.0 - s * (1.0 - f))
    i = i % 6
    conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5]
    rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v)
    rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t)
    rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p)
    return rgb.astype('uint8')


def shift_hue(arr,hout):
    hsv=rgb_to_hsv(arr)
    hsv[...,0]=hout
    rgb=hsv_to_rgb(hsv)
    return rgb

aug_filenames = []

for label_dir in os.listdir(augmented_dir):
  for filename in os.listdir(os.path.join(augmented_dir, label_dir)):
    if filename.endswith("jpg"):
        aug_filenames.append(os.path.join(augmented_dir, label_dir, filename))

for filename in aug_filenames:
  img = Image.open(filename).convert('RGBA')
  arr = np.array(img)

  if __name__=='__main__':
      hue_shift = (180 - random.randint(-180, 180))/360.0

      new_img = Image.fromarray(shift_hue(arr,hue_shift), 'RGBA')
      new_img = new_img.convert("RGB")


      adjusted_hue_filename = filename.split(".")[0] + "_adjustedhue.jpg"
      new_img.save(adjusted_hue_filename)

#MODEL CREATION
# Adapted from https://github.com/tensorflow/examples/blob/master/tensorflow_examples/lite/model_maker/core/task/image_classifier.py
# Adapted to include model history as output and use new hub_train_model function
class ImageClassifierExtended(image_classifier.ImageClassifier):

  def create(train_data,
           model_spec=model_spec.get('mobilenet_v2'),
           validation_data=None,
           batch_size=None,
           epochs=None,
           train_whole_model=None,
           dropout_rate=0.2,
           learning_rate=0.0001,
           momentum=None,
           use_augmentation=False,
           use_hub_library=True,
           warmup_steps=None,
           model_dir=None,
           do_train=True):
    """Loads data and retrains the model based on data for image classification.
    Args:
      train_data: Training data.
      model_spec: Specification for the model.
      validation_data: Validation data. If None, skips validation process.
      batch_size: Number of samples per training step. If `use_hub_library` is
        False, it represents the base learning rate when train batch size is 256
        and it's linear to the batch size.
      epochs: Number of epochs for training.
      train_whole_model: If true, the Hub module is trained together with the
        classification layer on top. Otherwise, only train the top classification
        layer.
      dropout_rate: The rate for dropout.
      learning_rate: Base learning rate when train batch size is 256. Linear to
        the batch size.
      momentum: a Python float forwarded to the optimizer. Only used when
        `use_hub_library` is True.
      use_augmentation: Use data augmentation for preprocessing.
      use_hub_library: Use `make_image_classifier_lib` from tensorflow hub to
        retrain the model.
      warmup_steps: Number of warmup steps for warmup schedule on learning rate.
        If None, the default warmup_steps is used which is the total training
        steps in two epochs. Only used when `use_hub_library` is False.
      model_dir: The location of the model checkpoint files. Only used when
        `use_hub_library` is False.
      do_train: Whether to run training.
    Data shuffled by default
    Returns:
      An instance of ImageClassifierExtended class.
    """
    if compat.get_tf_behavior() not in model_spec.compat_tf_versions:
      raise ValueError('Incompatible versions. Expect {}, but got {}.'.format(
          model_spec.compat_tf_versions, compat.get_tf_behavior()))

    #Get tensorflow hub model hyperparameters
    hparams = get_hub_lib_hparams(
        batch_size=batch_size,
        train_epochs=epochs,
        do_fine_tuning=train_whole_model,
        dropout_rate=dropout_rate,
        learning_rate=learning_rate,
        momentum=momentum)

    #Create image classifier object to be trained
    ic = ImageClassifierExtended(
        model_spec,
        train_data.index_to_label,
        train_data.num_classes,
        hparams=hparams,
        use_augmentation=use_augmentation)

    tf.compat.v1.logging.info('Retraining the models...')
    history = ic.train(train_data, validation_data)

    return ic, history


  def train(self,
            train_data,
            validation_data=None,
            hparams=None,
            steps_per_epoch=None):
    """Feeds the training data for training.
    Args:
      train_data: Training data.
      validation_data: Validation data. If None, skips validation process.
      hparams: An instance of hub_lib.HParams or
        train_image_classifier_lib.HParams. Anamedtuple of hyperparameters.
      steps_per_epoch: Integer or None. Total number of steps (batches of
        samples) before declaring one epoch finished and starting the next
        epoch. If 'steps_per_epoch' is None, the epoch will run until the input
        dataset is exhausted.
    Returns:
      The tf.keras.callbacks.History object returned by tf.keras.Model.fit*().
    """
    self.create_model()
    hparams = self._get_hparams_or_default(hparams)

    if len(train_data) < hparams.batch_size:
      raise ValueError('The size of the train_data (%d) couldn\'t be smaller '
                       'than batch_size (%d). To solve this problem, set '
                       'the batch_size smaller or increase the size of the '
                       'train_data.' % (len(train_data), hparams.batch_size))

    train_ds = train_data.gen_dataset(
        hparams.batch_size,
        is_training=True,
        shuffle=self.shuffle,
        preprocess=self.preprocess)
    steps_per_epoch = get_steps_per_epoch(steps_per_epoch,
                                                     hparams.batch_size,
                                                     train_data)
    if steps_per_epoch is not None:
      train_ds = train_ds.take(steps_per_epoch)

    validation_ds = None
    if validation_data is not None:
      validation_ds = validation_data.gen_dataset(
          hparams.batch_size, is_training=False, preprocess=self.preprocess)

    # Trains the models.
    if isinstance(hparams, train_image_classifier_lib.HParams):
      train_model = train_model
    else:
      train_model = hub_train_model

    self.history = train_model(
        model=self.model,
        hparams=hparams,
        train_ds=train_ds,
        validation_ds=validation_ds,
        steps_per_epoch=steps_per_epoch)

# https://github.com/tensorflow/examples/blob/ad3ab9b65e67459172077d13acb5f7da83cabd80/tensorflow_examples/lite/model_maker/core/task/train_image_classifier_lib.py
# Adapted to include callbacks e.g. for tensorboard
def hub_train_model(model, hparams, train_ds, validation_ds, steps_per_epoch):
  """Trains model with the given data and hyperparameters.
  If using a DistributionStrategy, call this under its `.scope()`.
  Args:
    model: The tf.keras.Model from _build_model().
    hparams: A namedtuple of hyperparameters. This function expects
      .train_epochs: a Python integer with the number of passes over the
        training dataset;
      .learning_rate: a Python float forwarded to the optimizer;
      .momentum: a Python float forwarded to the optimizer;
      .batch_size: a Python integer, the number of examples returned by each
        call to the generators.
    train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit().
    validation_ds: tf.data.Dataset, validation data to be fed in
      tf.keras.Model.fit().
    steps_per_epoch: Integer or None. Total number of steps (batches of samples)
      before declaring one epoch finished and starting the next epoch. If
      `steps_per_epoch` is None, the epoch will run until the input dataset is
      exhausted.
  Returns:
    The tf.keras.callbacks.History object returned by tf.keras.Model.fit().
  """
  loss = tf.keras.losses.CategoricalCrossentropy(
      label_smoothing=hparams.label_smoothing)
  model.compile(
      optimizer=tf.keras.optimizers.SGD(
          lr=hparams.learning_rate, momentum=hparams.momentum),
      loss=loss,
      metrics=["accuracy"])

  log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
  tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

  return model.fit(
      train_ds,
      epochs=hparams.train_epochs,
      steps_per_epoch=steps_per_epoch,
      validation_data=validation_ds,
      callbacks=[tensorboard_callback])
  
def get_hub_lib_hparams(**kwargs):
  """Gets the hyperparameters for the tensorflow hub's library."""
  hparams = hub_lib.get_default_hparams()
  return train_image_classifier_lib.add_params(hparams, **kwargs)

def get_steps_per_epoch(steps_per_epoch=None, batch_size=None, train_data=None):
  """Gets the estimated training steps per epoch.
  1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly.
  2. Else if we can get the length of training data successfully, returns
     `train_data_length // batch_size`.
  3. Else if it fails to get the length of training data, return None.
  Args:
    steps_per_epoch: int, training steps per epoch.
    batch_size: int, batch size.
    train_data: training data.
  Returns:
    Estimated training steps per epoch.
  """
  if steps_per_epoch is not None:
    # steps_per_epoch is set by users manually.
    return steps_per_epoch
  else:
    # Gets the steps by the length of the training data.
    try:
      return len(train_data) // batch_size
    except TypeError:
      return None

#Replace these with the source directory for your augmented and validation images
image_path = '/content/drive/MyDrive/Dissertation/ImagesAugmented'
val_image_path = '/content/drive/MyDrive/Dissertation/ImagesValidation'

#Create dataset from image directory and print details
data = DataLoader.from_folder(image_path)
validation_data = DataLoader.from_folder(val_image_path)

#Split data into training, test and validation data
train_data, test_data = data.split(0.8)

#Print 25 images to verify initial training data
plt.figure(figsize=(10,10))
for i, (image, label) in enumerate(data.gen_dataset().unbatch().take(25)):
  plt.subplot(5,5,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)
  plt.xlabel(data.index_to_label[label.numpy()])
plt.show()

#Train the model
model, history = ImageClassifierExtended.create(train_data, 
                                train_whole_model=True, 
                                epochs=20,
                                validation_data=validation_data)

#Evaluate model using test data
loss, accuracy = model.evaluate(test_data)

# A helper function that returns 'red'/'black' depending on if its two input
# parameters match or not.
def get_label_color(val1, val2):
  if val1 == val2:
    return 'black'
  else:
    return 'red'

# Then plot 100 test images and their predicted labels.
# If a prediction result is different from the label provided label in test
# dataset, it will be highlighted red.
plt.figure(figsize=(20, 20))
predicts = model.predict_top_k(test_data)
for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(100)):
  ax = plt.subplot(10, 10, i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)

  predict_label = predicts[i][0][0]
  color = get_label_color(predict_label,
                          test_data.index_to_label[label.numpy()])
  ax.xaxis.label.set_color(color)
  plt.xlabel('Predicted: %s' % predict_label)
plt.show()

#Export the model to SavedModel format
model.export(export_dir='.', export_format=[ExportFormat.LABEL, ExportFormat.SAVED_MODEL])

#Load the saved model
saved_model = tf.keras.models.load_model('saved_model')

#Choose one image or an array of images (requires some modification) to perform Occlusion Sensitivity on.
image_path = '/content/drive/MyDrive/Dissertation/VisOcclusionSensitivity/DataCollect_2020-11-24-08-46-14-374.jpg'
# Load a sample image (or multiple ones)
img = tf.keras.preprocessing.image.load_img(image_path, target_size=(224, 224))
img = tf.keras.preprocessing.image.img_to_array(img)
data = ([img], None)

# Start explainer
explainer = OcclusionSensitivity()
grid = explainer.explain(data, saved_model, class_index=0, patch_size=50)


explainer.save(grid, "/content/drive/MyDrive/Dissertation/VisOcclusionSensitivity", "osens.png")

#Convert the model to tflite format
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model') # path to the SavedModel directory
tflite_model = converter.convert()

# Save the model
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

#Attach metadata to TFLite file to enable ML Binding
#Source - https://www.tensorflow.org/lite/convert/metadata
from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb

""" ... """
"""Creates the metadata for an image classifier."""

# Creates model info.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "LFT Recogniser"
model_meta.description = ("Identify the most prominent object in the "
                          "image from a set of 2 categories - Lateral Flow "
                          "Test (LFT) or Not Lateral Flow Test (Not LFT). "
                          "Built upon Mobilenet_V2).")
model_meta.version = "v1"
model_meta.author = "Luke Morris"
model_meta.license = ("Apache License. Version 2.0 "
                      "http://www.apache.org/licenses/LICENSE-2.0.")

# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()

input_meta.name = "image"
input_meta.description = (
    "Input image to be classified. The expected image is {0} x {1}, with "
    "three channels (red, blue, and green) per pixel. Each value in the "
    "tensor is a single byte between 0 and 255.".format(224, 224))
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
    _metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
    _metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [255]
input_stats.min = [0]
input_meta.stats = input_stats

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()
output_meta.name = "probability"
output_meta.description = "Probabilities of the 2 labels respectively."
output_meta.content = _metadata_fb.ContentT()
output_meta.content.content_properties = _metadata_fb.FeaturePropertiesT()
output_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_stats = _metadata_fb.StatsT()
output_stats.max = [1.0]
output_stats.min = [0.0]
output_meta.stats = output_stats
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename("labels.txt")
label_file.description = "Labels for objects that the model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
output_meta.associatedFiles = [label_file]

# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [output_meta]
model_meta.subgraphMetadata = [subgraph]

b = flatbuffers.Builder(0)
b.Finish(
    model_meta.Pack(b),
    _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()

#Populate the tflite model with metadata
populator = _metadata.MetadataPopulator.with_model_file('model.tflite')
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(["labels.txt"])
populator.populate()

# Commented out IPython magic to ensure Python compatibility.
# %tensorboard --logdir logs/fit

#Save the model training results to Tensorboard
!tensorboard dev upload \
  --logdir logs/fit \
  --name "LFT Recognition Model Maker Experiment" \
  --description "Experiment using Model Maker and transfer learning with Mobilenet v2 as a base" \
  --one_shot