# -*- coding: utf-8 -*-
"""
Created on Sat Aug  3 08:55:33 2024

@author: Aus
"""

# -*- coding: utf-8 -*-
"""
Created on Sat Aug  3 08:55:33 2024

@author: Aus
"""

import os
import cv2
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Model
from keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger, ReduceLROnPlateau
from keras.optimizers import Adam
import segmentation_models as sm
import tensorflow as tf
from tensorflow.keras.metrics import MeanIoU
from sklearn.model_selection import train_test_split, ParameterGrid
import gc
import random
from tensorflow.keras.callbacks import Callback

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Set TensorFlow configuration
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# Define constants
DATA_SAVE_DIR = r"D:\PROTOS\LIVER\numpy"
MODEL_SAVE_DIR = r"D:\PROTOS\LIVER\models_hyper_param"
LOG_FILE_PATH = os.path.join(MODEL_SAVE_DIR, 'training_logs.csv')
EPOCHS = 1000
VALIDATION_SPLIT = 0.10  # Increased to 10%
TEST_SPLIT = 0.10  # Kept at 10%
RANDOM_STATE = 42
BACKBONE = 'resnet101'
N_CLASSES = 3
ACTIVATION = 'softmax'


# Create directories if they don't exist
os.makedirs(DATA_SAVE_DIR, exist_ok=True)
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)


def load_nifti(file_path: str) -> np.ndarray:
    """
    Load NIfTI file.

    Args:
        file_path (str): Path to the NIfTI file.

    Returns:
        np.ndarray: Loaded NIfTI data.
    """
    img = nib.load(file_path)
    img_data = img.get_fdata()
    return img_data

def preprocess_data(image: np.ndarray, label: np.ndarray, target_height: int, target_width: int) -> tuple:
    """
    Preprocess image and label data.

    Args:
        image (np.ndarray): Image data.
        label (np.ndarray): Label data.
        target_height (int): Target height.
        target_width (int): Target width.

    Returns:
        np.ndarray: Preprocessed image data.
        np.ndarray: Preprocessed label data.
    """
    # Resize the image and label slices
    image = resize_slices(image, target_height, target_width)
    label = resize_slices(label, target_height, target_width)

    # Normalize the image
    image = normalize_intensity(image)

    # Replicate the single channel to create three channels
    image = np.repeat(image[..., np.newaxis], 3, axis=-1)

    return image, label

def resize_slices(data: np.ndarray, target_height: int, target_width: int) -> np.ndarray:
    """
    Resize slices of data.

    Args:
        data (np.ndarray): Data to resize.
        target_height (int): Target height.
        target_width (int): Target width.

    Returns:
        np.ndarray: Resized data.
    """
    resized_data = []
    for slice in data:
        resized_slice = cv2.resize(slice, (target_width, target_height))
        resized_data.append(resized_slice)
    return np.array(resized_data)

def normalize_intensity(data: np.ndarray) -> np.ndarray:
    """
    Normalize intensity of data.

    Args:
        data (np.ndarray): Data to normalize.

    Returns:
        np.ndarray: Normalized data.
    """
    return data / np.max(data)

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, x_set: np.ndarray, y_set: np.ndarray, batch_size: int):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self) -> int:
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx: int) -> tuple:
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        return batch_x, batch_y

def train_model(model: Model, train_gen: DataGenerator, val_gen: DataGenerator, epochs: int, callbacks: list) -> tf.keras.callbacks.History:
    """
    Train model.

    Args:
        model (Model): Model to train.
        train_gen (DataGenerator): Training data generator.
        val_gen (DataGenerator): Validation data generator.
        epochs (int): Number of epochs.
        callbacks (list): List of callbacks.

    Returns:
        tf.keras.callbacks.History: Training history.
    """
    try:
        history = model.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks, steps_per_epoch=len(train_gen), validation_steps=len(val_gen))
    except tf.errors.ResourceExhaustedError:
        print("ResourceExhaustedError occurred. Trying with reduced batch size...")
        for bs in [BATCH_SIZE, BATCH_SIZE//2, BATCH_SIZE//4]:
            try:
                train_gen = DataGenerator(train_gen.x, train_gen.y, bs)
                val_gen = DataGenerator(val_gen.x, val_gen.y, bs)
                history = model.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks, steps_per_epoch=len(train_gen), validation_steps=len(val_gen))
                break
            except tf.errors.ResourceExhaustedError:
                continue
        else:
            print("Unable to train even with reduced batch size. Exiting.")
            history = None
    return history

def evaluate_model(model: Model, x_test: np.ndarray, y_test: np.ndarray) -> tuple:
    """
    Evaluate model.

    Args:
        model (Model): Model to evaluate.
        x_test (np.ndarray): Test input data.
        y_test (np.ndarray): Test label data.

    Returns:
        float: Test loss.
        float: Test accuracy.
    """
    metrics = model.evaluate(x_test, tf.keras.utils.to_categorical(y_test, num_classes=N_CLASSES))
    test_loss = metrics[0]
    test_accuracy = metrics[1]
    return test_loss, test_accuracy

def calculate_iou(model: Model, x_test: np.ndarray, y_test: np.ndarray) -> np.ndarray:
    """
    Calculate IoU for each class.

    Args:
        model (Model): Model to evaluate.
        x_test (np.ndarray): Test input data.
        y_test (np.ndarray): Test label data.

    Returns:
        np.ndarray: IoU for each class.
    """
    y_pred = np.argmax(model.predict(x_test), axis=3)
    y_true = y_test[:, :, :, 0]
    iou_scores = []
    for i in range(N_CLASSES):
        iou_keras = MeanIoU(num_classes=N_CLASSES)
        iou_keras.update_state(y_true == i, y_pred == i)
        iou_scores.append(iou_keras.result().numpy())
    return np.array(iou_scores)

def plot_results(model: Model, x_test: np.ndarray, y_test: np.ndarray) -> None:
    """
    Plot results.

    Args:
        model (Model): Model to evaluate.
        x_test (np.ndarray): Test input data.
        y_test (np.ndarray): Test label data.
    """
    test_img_number = random.randint(0, len(x_test) - 1)
    test_img = x_test[test_img_number]
    ground_truth = y_test[test_img_number]
    test_img_input = np.expand_dims(test_img, 0)
    prediction = model.predict(test_img_input)
    predicted_img = np.argmax(prediction, axis=3)[0, :, :]

    plt.figure(figsize=(12, 8))
    plt.subplot(231)
    plt.title('Testing Image')
    plt.imshow(test_img[:, :, 0], cmap='gray')
    plt.subplot(232)
    plt.title('Testing Label')
    plt.imshow(ground_truth[:, :, 0], cmap='jet')
    plt.subplot(233)
    plt.title('Prediction on test image')
    plt.imshow(predicted_img, cmap='jet')
    plt.show()

def sanity_check(images: np.ndarray, masks: np.ndarray, num_samples: int = 5) -> None:
    """
    Perform a sanity check by visualizing some image and mask pairs.

    Args:
        images (np.ndarray): Image data.
        masks (np.ndarray): Mask data.
        num_samples (int): Number of samples to visualize.
    """
    indices = random.sample(range(len(images)), num_samples)
    for i in indices:
        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.imshow(images[i, :, :, 0], cmap='gray')
        plt.title('Image')
        plt.subplot(1, 2, 2)
        plt.imshow(masks[i, :, :], cmap='jet')
        plt.title('Mask')
        plt.show()
        
def plot_training_history(history: tf.keras.callbacks.History) -> None:
    """
    Plot training and validation loss and accuracy.

    Args:
        history (tf.keras.callbacks.History): Training history object returned by model.fit.
    """
    # Extract loss and accuracy from history
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    epochs = range(1, len(loss) + 1)

    # Plot training and validation loss
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss, 'y', label='Training loss')
    plt.plot(epochs, val_loss, 'r', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Plot training and validation accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, acc, 'y', label='Training accuracy')
    plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()


class CustomModelCheckpoint(Callback):
    def __init__(self, filepath, monitor='val_loss', save_best_only=False, mode='min', verbose=0):
        super(CustomModelCheckpoint, self).__init__()
        self.filepath = "{epoch:02d}_{val_loss:.4f}.h5"  # Change this to a format string
        self.monitor = monitor
        self.save_best_only = save_best_only
        self.mode = mode
        self.verbose = verbose
        self.best = float('inf') if mode == 'min' else -float('inf')

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get(self.monitor)
        if current is None:
            return

        if self.save_best_only:
            if self.mode == 'min':
                if current < self.best:
                    self.best = current
                    self._save_model(epoch, current)
            elif self.mode == 'max':
                if current > self.best:
                    self.best = current
                    self._save_model(epoch, current)
        else:
            self._save_model(epoch, current)

    def _save_model(self, epoch, current):
        filename = os.path.join(MODEL_SAVE_DIR, self.filepath.format(epoch=epoch, val_loss=current))
        if self.verbose > 0:
            print(f"\nSaving model to {filename}")
        self.model.save(filename)
        
def main():
    # Load data
    images = np.load(os.path.join(DATA_SAVE_DIR, 'images.npy'))
    masks = np.load(os.path.join(DATA_SAVE_DIR, 'masks.npy'))

    # Perform a sanity check
    sanity_check(images, masks)

    # Split data
    x_train, x_temp, y_train, y_temp = train_test_split(images, masks, test_size=(VALIDATION_SPLIT + TEST_SPLIT), random_state=RANDOM_STATE)
    x_val, x_test, y_val, y_test = train_test_split(x_temp, y_temp, test_size=(TEST_SPLIT / (VALIDATION_SPLIT + TEST_SPLIT)), random_state=RANDOM_STATE)

    # Save data to clear memory
    np.save(os.path.join(DATA_SAVE_DIR, 'x_train.npy'), x_train)
    np.save(os.path.join(DATA_SAVE_DIR, 'y_train.npy'), y_train)
    np.save(os.path.join(DATA_SAVE_DIR, 'x_val.npy'), x_val)
    np.save(os.path.join(DATA_SAVE_DIR, 'y_val.npy'), y_val)
    np.save(os.path.join(DATA_SAVE_DIR, 'x_test.npy'), x_test)
    np.save(os.path.join(DATA_SAVE_DIR, 'y_test.npy'), y_test)

    # Delete data from RAM
    del images, masks, x_temp, y_temp
    gc.collect()
    # Load data from disk
    x_train = np.load(os.path.join(DATA_SAVE_DIR, 'x_train.npy'))
    y_train = np.load(os.path.join(DATA_SAVE_DIR, 'y_train.npy'))
    x_val = np.load(os.path.join(DATA_SAVE_DIR, 'x_val.npy'))
    y_val = np.load(os.path.join(DATA_SAVE_DIR, 'y_val.npy'))
    x_test = np.load(os.path.join(DATA_SAVE_DIR, 'x_test.npy'))
    y_test = np.load(os.path.join(DATA_SAVE_DIR, 'y_test.npy'))

    # Define training parameters
    BATCH_SIZE = 16  # Adjust based on your GPU memory

    # Prepare data generators
    train_gen = DataGenerator(x_train, y_train, BATCH_SIZE)
    val_gen = DataGenerator(x_val, y_val, BATCH_SIZE)

    # Define model
    model = sm.Unet(BACKBONE, classes=N_CLASSES, activation=ACTIVATION, input_shape=(x_train.shape[1], x_train.shape[2], x_train.shape[3]))

    # Define optimizer
    optimizer = Adam(learning_rate=1e-4)

    # Define loss and metrics
    dice_loss = sm.losses.DiceLoss(class_weights=np.array([0.1, 0.4, 0.5]))
    focal_loss = sm.losses.CategoricalFocalLoss()
    total_loss = dice_loss + (1 * focal_loss)
    metrics = ['accuracy', sm.metrics.IOUScore(threshold=0.5)]

    # Compile model
    model.compile(optimizer=optimizer, loss=total_loss, metrics=metrics)

    # Define callbacks
    checkpoint = CustomModelCheckpoint(
        os.path.join(MODEL_SAVE_DIR, 'model_epoch_{epoch:02d}_val_loss_{val_loss:.4f}.h5'),
        save_best_only=True,
        monitor='val_loss',
        mode='min',
        verbose=1
    )
    early_stopping = EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=True, mode='min')
    csv_logger = CSVLogger(LOG_FILE_PATH)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=20, min_lr=1e-6, mode='min')

    # Hyperparameter tuning grid
    param_grid = {'learning_rate': [1e-4, 1e-5], 'batch_size': [16,8]}
    best_val_loss = float('inf')
    best_params = None
    best_history = None

    # another part for training
    for params in ParameterGrid(param_grid):
        print(f"Training with params: {params}")
        BATCH_SIZE = params['batch_size']
        lr = params['learning_rate']
    
        # Re-initialize model with new learning rate
        optimizer.learning_rate = lr
        model.compile(optimizer=optimizer, loss=total_loss, metrics=metrics)
    
        # Create new data generators with updated batch size
        train_gen = DataGenerator(x_train, y_train, BATCH_SIZE)
        val_gen = DataGenerator(x_val, y_val, BATCH_SIZE)
    
        # Train model
        checkpoint = CustomModelCheckpoint(filepath="model_{epoch:02d}_{val_loss:.4f}.h5", monitor='val_loss', save_best_only=False, mode='min', verbose=1)
        history = train_model(model, train_gen, val_gen, EPOCHS,[checkpoint, early_stopping, csv_logger, reduce_lr])
        
        # Plot training history
        if history is not None:
            plot_training_history(history)
        
        # Evaluate model
        if history:
            val_loss = min(history.history['val_loss'])
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_params = params
                best_history = history
    
    print(f"Best params: {best_params}, Best validation loss: {best_val_loss:.4f}")

    # Find the file with the lowest validation loss
    best_model_file = min([f for f in os.listdir(MODEL_SAVE_DIR) if f.startswith('model_')], key=lambda f: float(f.split('_')[-1].split('.')[0]))
    
    # Load the best model
    model.load_weights(os.path.join(MODEL_SAVE_DIR, best_model_file))

    # Evaluate model on test set
    test_loss, test_accuracy = evaluate_model(model, x_test, y_test)
    print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')

    # Calculate IoU
    iou_per_class = calculate_iou(model, x_test, y_test)
    for i, iou in enumerate(iou_per_class):
        print(f'IoU for class {i}: {iou:.4f}')

    # Plot results
    plot_results(model, x_test, y_test)
    
if __name__ == '__main__':
    main()

