# -*- coding: utf-8 -*-
"""
Created on Sat Aug  3 08:55:33 2024

@author: Aus
"""

import os
import cv2
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Model
from keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger, ReduceLROnPlateau
from keras.optimizers import Adam
import segmentation_models as sm
import tensorflow as tf
from tensorflow.keras.metrics import MeanIoU
from sklearn.model_selection import train_test_split, ParameterGrid
import gc
import random
from tensorflow.keras.callbacks import Callback
from keras.preprocessing.image import ImageDataGenerator

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Set TensorFlow configuration
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# Set mixed precision policy
# from tensorflow.keras.mixed_precision import experimental as mixed_precision
# policy = mixed_precision.Policy('mixed_float16')
# mixed_precision.set_policy(policy)

# Define constants
DATA_SAVE_DIR = r"D:\PROTOS\LIVER\numpy"
MODEL_SAVE_DIR = r"D:\PROTOS\LIVER\models_hyper_param"
LOG_FILE_PATH = os.path.join(MODEL_SAVE_DIR, 'training_logs.csv')
EPOCHS = 10
VALIDATION_SPLIT = 0.10  # Increased to 10%
TEST_SPLIT = 0.10  # Kept at 10%
RANDOM_STATE = 42
BACKBONE = 'resnet101'
N_CLASSES = 3
ACTIVATION = 'softmax'

# Create directories if they don't exist
os.makedirs(DATA_SAVE_DIR, exist_ok=True)
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

def load_nifti(file_path: str) -> np.ndarray:
    """
    Load NIfTI file.

    Args:
        file_path (str): Path to the NIfTI file.

    Returns:
        np.ndarray: Loaded NIfTI data.
    """
    img = nib.load(file_path)
    img_data = img.get_fdata()
    return img_data

def preprocess_data(image: np.ndarray, label: np.ndarray, target_height: int, target_width: int) -> tuple:
    """
    Preprocess image and label data.

    Args:
        image (np.ndarray): Image data.
        label (np.ndarray): Label data.
        target_height (int): Target height.
        target_width (int): Target width.

    Returns:
        np.ndarray: Preprocessed image data.
        np.ndarray: Preprocessed label data.
    """
    # Resize the image and label slices
    image = resize_slices(image, target_height, target_width)
    label = resize_slices(label, target_height, target_width)

    # Normalize the image
    image = normalize_intensity(image)

    # Replicate the single channel to create three channels
    image = np.repeat(image[..., np.newaxis], 3, axis=-1)

    return image, label

def resize_slices(data: np.ndarray, target_height: int, target_width: int) -> np.ndarray:
    """
    Resize slices of data.

    Args:
        data (np.ndarray): Data to resize.
        target_height (int): Target height.
        target_width (int): Target width.

    Returns:
        np.ndarray: Resized data.
    """
    resized_data = []
    for slice in data:
        resized_slice = cv2.resize(slice, (target_width, target_height))
        resized_data.append(resized_slice)
    return np.array(resized_data)

def normalize_intensity(data: np.ndarray) -> np.ndarray:
    """
    Normalize intensity of data.

    Args:
        data (np.ndarray): Data to normalize.

    Returns:
        np.ndarray: Normalized data.
    """
    return data / np.max(data)

def load_dataset(data, labels, batch_size):
    """
    Create a TensorFlow dataset for better performance.

    Args:
        data (np.ndarray): Image data.
        labels (np.ndarray): Label data.
        batch_size (int): Batch size.

    Returns:
        tf.data.Dataset: TensorFlow dataset.
    """
    dataset = tf.data.Dataset.from_tensor_slices((data, labels))
    dataset = dataset.shuffle(buffer_size=1024).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

def train_model(model: Model, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray, epochs: int, batch_size: int, callbacks: list) -> tf.keras.callbacks.History:
    """
    Train model with specified batch size.

    Args:
        model (Model): Model to train.
        x_train (np.ndarray): Training input data.
        y_train (np.ndarray): Training label data.
        x_val (np.ndarray): Validation input data.
        y_val (np.ndarray): Validation label data.
        epochs (int): Number of epochs.
        batch_size (int): Batch size.
        callbacks (list): List of callbacks.

    Returns:
        tf.keras.callbacks.History: Training history.
    """
    try:
        history = model.fit(x_train, tf.keras.utils.to_categorical(y_train, num_classes=N_CLASSES),
                            validation_data=(x_val, tf.keras.utils.to_categorical(y_val, num_classes=N_CLASSES)),
                            epochs=epochs, batch_size=batch_size, callbacks=callbacks)
    except tf.errors.ResourceExhaustedError:
        print("ResourceExhaustedError occurred. Trying with reduced batch size...")
        for bs in [batch_size, batch_size // 2, batch_size // 4, batch_size // 8, 1]:
            try:
                print(f"Trying batch size: {bs}")
                history = model.fit(x_train, tf.keras.utils.to_categorical(y_train, num_classes=N_CLASSES),
                                    validation_data=(x_val, tf.keras.utils.to_categorical(y_val, num_classes=N_CLASSES)),
                                    epochs=epochs, batch_size=bs, callbacks=callbacks)
                break
            except tf.errors.ResourceExhaustedError:
                continue
        else:
            print("Unable to train even with reduced batch size. Exiting.")
            history = None
    return history

def evaluate_model(model: Model, x_test: np.ndarray, y_test: np.ndarray) -> tuple:
    """
    Evaluate model.

    Args:
        model (Model): Model to evaluate.
        x_test (np.ndarray): Test input data.
        y_test (np.ndarray): Test label data.

    Returns:
        float: Test loss.
        float: Test accuracy.
    """
    metrics = model.evaluate(x_test, tf.keras.utils.to_categorical(y_test, num_classes=N_CLASSES))
    test_loss = metrics[0]
    test_accuracy = metrics[1]
    return test_loss, test_accuracy

def calculate_iou(model: Model, x_test: np.ndarray, y_test: np.ndarray) -> np.ndarray:
    """
    Calculate IoU for each class.

    Args:
        model (Model): Model to evaluate.
        x_test (np.ndarray): Test input data.
        y_test (np.ndarray): Test label data.

    Returns:
        np.ndarray: IoU for each class.
    """
    y_pred = np.argmax(model.predict(x_test), axis=3)
    y_true = y_test[:, :, :, 0]
    iou_scores = []
    for i in range(N_CLASSES):
        iou_keras = MeanIoU(num_classes=N_CLASSES)
        iou_keras.update_state(y_true == i, y_pred == i)
        iou_scores.append(iou_keras.result().numpy())
    return np.array(iou_scores)

def calculate_dice_coefficient(y_true, y_pred, smooth=1):
    """
    Calculate the Dice Coefficient for each class.

    Args:
        y_true (np.ndarray): True labels.
        y_pred (np.ndarray): Predicted labels.
        smooth (int): Smoothing factor to avoid division by zero.

    Returns:
        np.ndarray: Dice Coefficient for each class.
    """
    dice_scores = []
    for i in range(N_CLASSES):
        y_true_f = (y_true == i).astype(np.float32).flatten()
        y_pred_f = (y_pred == i).astype(np.float32).flatten()
        intersection = np.sum(y_true_f * y_pred_f)
        dice = (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)
        dice_scores.append(dice)
    return np.array(dice_scores)

def plot_results(model: Model, x_test: np.ndarray, y_test: np.ndarray) -> None:
    """
    Plot results.

    Args:
        model (Model): Model to evaluate.
        x_test (np.ndarray): Test input data.
        y_test (np.ndarray): Test label data.
    """
    test_img_number = random.randint(0, len(x_test) - 1)
    test_img = x_test[test_img_number]
    ground_truth = y_test[test_img_number]
    test_img_input = np.expand_dims(test_img, 0)
    prediction = model.predict(test_img_input)
    predicted_img = np.argmax(prediction, axis=3)[0, :, :]

    plt.figure(figsize=(12, 8))
    plt.subplot(231)
    plt.title('Testing Image')
    plt.imshow(test_img[:, :, 0], cmap='gray')
    plt.subplot(232)
    plt.title('Testing Label')
    plt.imshow(ground_truth[:, :, 0], cmap='jet')
    plt.subplot(233)
    plt.title('Prediction on test image')
    plt.imshow(predicted_img, cmap='jet')
    plt.show()

def sanity_check(images: np.ndarray, masks: np.ndarray, num_samples: int = 15) -> None:
    """
    Perform a sanity check by visualizing random image-mask pairs.

    Args:
        images (np.ndarray): Image data.
        masks (np.ndarray): Mask data.
        num_samples (int): Number of samples to visualize.
    """
    for _ in range(num_samples):
        idx = random.randint(0, len(images) - 1)
        image = images[idx]
        mask = masks[idx]

        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(image[:, :, 0], cmap='gray')
        plt.title('Image')
        plt.subplot(1, 2, 2)
        plt.imshow(mask[:, :, 0], cmap='jet')
        plt.title('Mask')
        plt.show()

def plot_training_history(history: tf.keras.callbacks.History) -> None:
    """
    Plot training history.

    Args:
        history (tf.keras.callbacks.History): Training history.
    """
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    accuracy = history.history['accuracy']
    val_accuracy = history.history['val_accuracy']

    # Plot loss
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss')

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(accuracy, label='Training Accuracy')
    plt.plot(val_accuracy, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy')

    plt.show()

class CustomModelCheckpoint(Callback):
    """
    Custom callback for saving model checkpoints based on validation loss.

    Args:
        filepath (str): File path to save the model.
        monitor (str): Metric to monitor.
        verbose (int): Verbosity mode.
        save_best_only (bool): Whether to save only the best model.
        mode (str): Mode for monitoring ('min' or 'max').
    """
    def __init__(self, filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min'):
        super(CustomModelCheckpoint, self).__init__()
        self.filepath = filepath
        self.monitor = monitor
        self.verbose = verbose
        self.save_best_only = save_best_only
        self.mode = mode
        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get(self.monitor)
        if current is None:
            return
        if self.monitor_op(current, self.best):
            if self.verbose > 0:
                print(f'\nEpoch {epoch + 1}: {self.monitor} improved from {self.best} to {current}, saving model to {self.filepath}')
            self.best = current
            self.model.save(self.filepath)

# Define hyperparameters grid
param_grid = {
    'batch_size': [16],
    'learning_rate': [1e-3, 1e-4],
    'epochs': [5, 10]# 'epochs': [50, 100]
}

# Function to perform hyperparameter tuning
def hyperparameter_tuning(x_train, y_train, x_val, y_val, param_grid):
    best_score = -np.inf
    best_params = None
    best_model = None

    for params in ParameterGrid(param_grid):
        print(f"Testing parameters: {params}")
        
        # Data augmentation
        datagen = ImageDataGenerator(
            rotation_range=20,
            width_shift_range=0.1,
            height_shift_range=0.1,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest'
        )
        train_gen = datagen.flow(x_train, tf.keras.utils.to_categorical(y_train, num_classes=N_CLASSES), batch_size=params['batch_size'])
        val_gen = datagen.flow(x_val, tf.keras.utils.to_categorical(y_val, num_classes=N_CLASSES), batch_size=params['batch_size'])

        # Convert datasets to tf.data.Dataset format
        train_dataset = load_dataset(x_train, tf.keras.utils.to_categorical(y_train, num_classes=N_CLASSES), params['batch_size'])
        val_dataset = load_dataset(x_val, tf.keras.utils.to_categorical(y_val, num_classes=N_CLASSES), params['batch_size'])

        # Initialize model
        model = sm.Unet(BACKBONE, classes=N_CLASSES, activation=ACTIVATION)
        model.compile(optimizer=Adam(lr=params['learning_rate']), loss=sm.losses.categorical_focal_dice_loss, metrics=['accuracy', sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)])

        # Define callbacks
        checkpoint_filepath = os.path.join(MODEL_SAVE_DIR, 'best_model.h5')
        callbacks = [
            CustomModelCheckpoint(filepath=checkpoint_filepath, monitor='val_loss', save_best_only=True),
            EarlyStopping(monitor='val_loss', patience=50, verbose=1, restore_best_weights=True),
            CSVLogger(LOG_FILE_PATH, append=True),
            ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=20, verbose=1)
        ]

        # Train model
        history = train_model(model, train_dataset, val_dataset, params['epochs'], callbacks)

        # Load best model
        model.load_weights(checkpoint_filepath)

        # Evaluate model
        test_loss, test_accuracy = evaluate_model(model, x_val, y_val)
        print(f"Validation Loss: {test_loss}, Validation Accuracy: {test_accuracy}")

        # Calculate IoU
        iou_scores = calculate_iou(model, x_val, y_val)
        mean_iou = np.mean(iou_scores)
        print(f"Mean IoU: {mean_iou}")

        if mean_iou > best_score:
            best_score = mean_iou
            best_params = params
            best_model = model

    print(f"Best parameters: {best_params}")
    print(f"Best mean IoU: {best_score}")
    return best_model, best_params

def main():
    # Load data
    x_data = np.load(os.path.join(DATA_SAVE_DIR, 'images.npy'))
    y_data = np.load(os.path.join(DATA_SAVE_DIR, 'masks.npy'))
    
    # Sanity check
    sanity_check(x_data, y_data)

    # Train-validation-test split
    x_temp, x_test, y_temp, y_test = train_test_split(x_data, y_data, test_size=TEST_SPLIT, random_state=RANDOM_STATE)
    x_train, x_val, y_train, y_val = train_test_split(x_temp, y_temp, test_size=VALIDATION_SPLIT, random_state=RANDOM_STATE)
    del x_data, y_data, x_temp, y_temp
    # Hyperparameter tuning
    best_model, best_params = hyperparameter_tuning(x_train, y_train, x_val, y_val, param_grid)

    # Evaluate best model on test set
    test_loss, test_accuracy = evaluate_model(best_model, x_test, y_test)
    print(f'Test Loss: {test_loss}, Test Accuracy: {test_accuracy}')

    # Calculate IoU
    iou_scores = calculate_iou(best_model, x_test, y_test)
    print(f'IoU Scores: {iou_scores}')

    # Calculate Dice Coefficient
    y_pred = np.argmax(best_model.predict(x_test), axis=3)
    y_true = y_test[:, :, :, 0]
    dice_scores = calculate_dice_coefficient(y_true, y_pred)
    print(f'Dice Scores: {dice_scores}')

    # Plot results
    plot_results(best_model, x_test, y_test)

    # Clean up
    gc.collect()

if __name__ == '__main__':
    main()
