# -*- coding: utf-8 -*-
"""
Created on Tue Aug  6 15:12:27 2024

@author: Aus
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, precision_recall_curve
from scipy.ndimage import distance_transform_edt

def validate_input(y_true, y_pred):
    if not isinstance(y_true, (np.ndarray, tf.Tensor)) or not isinstance(y_pred, (np.ndarray, tf.Tensor)):
        raise ValueError("Inputs must be numpy arrays or tensorflow tensors")
    if y_true.shape != y_pred.shape:
        raise ValueError("Input arrays must have the same shape")
    if y_true.ndim not in [3, 4]:
        raise ValueError("Inputs must be 3D (batch, height, width) or 4D (batch, height, width, 1)")

def convert_to_tensor(y_true, y_pred):
    y_true = tf.convert_to_tensor(y_true)
    y_pred = tf.convert_to_tensor(y_pred)
    if y_true.shape.ndims == 4 and y_true.shape[-1] == 1:
        y_true = tf.squeeze(y_true, axis=-1)
    if y_pred.shape.ndims == 4 and y_pred.shape[-1] == 1:
        y_pred = tf.squeeze(y_pred, axis=-1)
    validate_input(y_true, y_pred)
    return y_true, y_pred

def calculate_metric(y_true, y_pred, class_index, metric):
    y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    intersection = tf.reduce_sum(y_true_class * y_pred_class)
    union = tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class) - intersection
    
    if metric == 'iou':
        return tf.math.divide_no_nan(intersection, union)
    elif metric == 'dice':
        return tf.math.divide_no_nan(2. * intersection, tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class))
    elif metric == 'precision':
        return tf.math.divide_no_nan(intersection, tf.reduce_sum(y_pred_class))
    elif metric == 'recall':
        return tf.math.divide_no_nan(intersection, tf.reduce_sum(y_true_class))
    else:
        raise ValueError("Invalid metric. Choose from 'iou', 'dice', 'precision', 'recall'.")

def compute_metrics(y_true, y_pred, n_classes):
    y_true, y_pred = convert_to_tensor(y_true, y_pred)
    
    # Pixel Accuracy
    accuracy = tf.keras.metrics.Accuracy()
    accuracy.update_state(y_true, y_pred)
    pixel_accuracy = accuracy.result().numpy()
    
    # Mean IoU and Confusion Matrix
    miou = tf.keras.metrics.MeanIoU(num_classes=n_classes)
    miou.update_state(y_true, y_pred)
    mean_iou = miou.result().numpy()
    confusion_matrix = miou.total_cm.numpy()
    
    # Per-class metrics
    per_class_metrics = []
    for i in range(n_classes):
        iou = calculate_metric(y_true, y_pred, i, 'iou')
        dice = calculate_metric(y_true, y_pred, i, 'dice')
        precision = calculate_metric(y_true, y_pred, i, 'precision')
        recall = calculate_metric(y_true, y_pred, i, 'recall')
        f1_score = 2 * (precision * recall) / (precision + recall + tf.keras.backend.epsilon())
        per_class_metrics.append((iou.numpy(), dice.numpy(), precision.numpy(), recall.numpy(), f1_score.numpy()))

    # Cohen's Kappa
    n = tf.reduce_sum(confusion_matrix)
    sum_po = tf.linalg.trace(confusion_matrix)
    sum_pe = tf.reduce_sum(tf.reduce_sum(confusion_matrix, axis=0) * tf.reduce_sum(confusion_matrix, axis=1)) / n
    po = sum_po / n
    pe = sum_pe / n
    kappa = tf.math.divide_no_nan(po - pe, 1 - pe)

    return pixel_accuracy, mean_iou, per_class_metrics, kappa.numpy(), confusion_matrix

def plot_confusion_matrix(confusion_matrix):
    plt.figure(figsize=(10, 8))
    sns.heatmap(confusion_matrix, annot=True, fmt='g', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.show()

def plot_class_metrics(class_indices, metrics, metric_names):
    n_metrics = len(metric_names)
    fig, axes = plt.subplots(1, n_metrics, figsize=(5*n_metrics, 5))
    for i, (metric, name) in enumerate(zip(metrics, metric_names)):
        axes[i].bar(class_indices, metric)
        axes[i].set_xlabel('Class Index')
        axes[i].set_ylabel(name)
        axes[i].set_title(f'Class-wise {name}')
    plt.tight_layout()
    plt.show()

def plot_roc_curves(y_true, y_pred_prob, n_classes):
    if y_true.ndim == 4 and y_true.shape[-1] == 1:
        y_true = np.squeeze(y_true, axis=-1)
    y_true_onehot = tf.keras.utils.to_categorical(y_true, num_classes=n_classes)
    y_true_onehot_flat = y_true_onehot.reshape(-1, n_classes)
    y_pred_prob_flat = y_pred_prob.reshape(-1, n_classes)
    
    plt.figure(figsize=(10, 8))
    for i in range(n_classes):
        fpr, tpr, _ = roc_curve(y_true_onehot_flat[:, i], y_pred_prob_flat[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'Class {i} (AUC = {roc_auc:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc='lower right')
    plt.show()

def plot_pr_curves(y_true, y_pred_prob, n_classes):
    if y_true.ndim == 4 and y_true.shape[-1] == 1:
        y_true = np.squeeze(y_true, axis=-1)
    y_true_onehot = tf.keras.utils.to_categorical(y_true, num_classes=n_classes)
    y_true_onehot_flat = y_true_onehot.reshape(-1, n_classes)
    y_pred_prob_flat = y_pred_prob.reshape(-1, n_classes)
    
    plt.figure(figsize=(10, 8))
    for i in range(n_classes):
        precision, recall, _ = precision_recall_curve(y_true_onehot_flat[:, i], y_pred_prob_flat[:, i])
        pr_auc = auc(recall, precision)
        plt.plot(recall, precision, label=f'Class {i} (AUC = {pr_auc:.2f})')
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc='lower left')
    plt.show()

# def plot_sample_predictions(X_test, y_true, y_pred, n_samples=5, class_names=['Background', 'Class 1', 'Class 2']):
#     if y_true.ndim == 4 and y_true.shape[-1] == 1:
#         y_true = np.squeeze(y_true, axis=-1)
#     indices = np.random.choice(len(X_test), n_samples, replace=False)
    
#     # Create a custom colormap
#     colors = ['blue', 'green', 'red']  # You can change these colors
#     n_classes = len(class_names)
#     # cmap = ListedColormap(colors[:n_classes])
    
#     for idx in indices:
#         fig, axs = plt.subplots(1, 3, figsize=(20, 6))
        
#         plt.figure(figsize=(15, 5))
        
#         plt.subplot(1, 3, 1)
#         plt.title('Original Image')
#         plt.imshow(image[:,:,0], cmap='gray')
        
#         plt.subplot(1, 3, 2)
#         plt.title('True Segmentation')
#         plt.imshow(image[:,:,0], cmap='gray')
#         plt.imshow(true_mask, alpha=alpha, cmap='jet')
        
#         plt.subplot(1, 3, 3)
#         plt.title('Predicted Segmentation')
#         plt.imshow(image[:,:,0], cmap='gray')
#         plt.imshow(pred_mask, alpha=alpha, cmap='jet')
        
#         plt.tight_layout()
#         plt.show()

def plot_uncertainty(y_pred_prob, threshold=0.5):
    uncertainty = 1 - np.max(y_pred_prob, axis=-1)
    plt.figure(figsize=(10, 5))
    plt.imshow(uncertainty[0], cmap='hot')
    plt.colorbar(label='Uncertainty')
    plt.title('Uncertainty Map')
    plt.axis('off')
    plt.show()

def analyze_error_types(y_true, y_pred, n_classes):
    if y_true.ndim == 4 and y_true.shape[-1] == 1:
        y_true = np.squeeze(y_true, axis=-1)
    error_types = {'under_segmentation': 0, 'over_segmentation': 0, 'misclassification': 0}
    
    for i in range(n_classes):
        true_mask = (y_true == i)
        pred_mask = (y_pred == i)
        
        error_types['under_segmentation'] += np.sum(true_mask & ~pred_mask)
        error_types['over_segmentation'] += np.sum(~true_mask & pred_mask)
        error_types['misclassification'] += np.sum((y_true != y_pred) & ~true_mask & ~pred_mask)
    
    total_errors = sum(error_types.values())
    for error_type, count in error_types.items():
        percentage = (count / total_errors) * 100 if total_errors > 0 else 0
        print(f"{error_type}: {percentage:.2f}%")

    plt.figure(figsize=(10, 5))
    plt.bar(error_types.keys(), [v/total_errors for v in error_types.values()])
    plt.title('Error Type Distribution')
    plt.ylabel('Proportion of Errors')
    plt.show()
def plot_segmentation_overlay(image, true_mask, pred_mask, alpha=0.5):
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.title('Original Image')
    plt.imshow(image[:,:,0], cmap='gray')
    
    plt.subplot(1, 3, 2)
    plt.title('True Segmentation')
    plt.imshow(image[:,:,0], cmap='gray')
    plt.imshow(true_mask, alpha=alpha, cmap='jet')
    
    plt.subplot(1, 3, 3)
    plt.title('Predicted Segmentation')
    plt.imshow(image[:,:,0], cmap='gray')
    plt.imshow(pred_mask, alpha=alpha, cmap='jet')
    
    plt.show()
    
def evaluate_segmentation_model(model, X_test, y_true, n_classes):
    print("Starting model evaluation...")
    
    try:
       # Print shapes and ranges
        print(f"X_test shape: {X_test.shape}, range: [{X_test.min()}, {X_test.max()}]")
        print(f"y_true shape: {y_true.shape}, range: [{y_true.min()}, {y_true.max()}]")
        
        # Predict
        y_pred_prob = model.predict(X_test)
        y_pred = np.argmax(y_pred_prob, axis=-1)
        
        print(f"y_pred_prob shape: {y_pred_prob.shape}, range: [{y_pred_prob.min()}, {y_pred_prob.max()}]")
        print(f"y_pred shape: {y_pred.shape}, range: [{y_pred.min()}, {y_pred.max()}]")
        
        # Only squeeze y_true if it has a singleton last dimension
        if y_true.shape[-1] == 1:
            y_true = np.squeeze(y_true, axis=-1)
        
        # Compute metrics
        pixel_accuracy, mean_iou, per_class_metrics, kappa, confusion_matrix = compute_metrics(y_true, y_pred, n_classes)
        
        print(f"Pixel Accuracy: {pixel_accuracy:.4f}")
        print(f"Mean IoU: {mean_iou:.4f}")
        print(f"Cohen's Kappa: {kappa:.4f}")

        # Print per-class metrics
        print("\nPer-class metrics:")
        for i in range(n_classes):
            iou, dice, precision, recall, f1 = per_class_metrics[i]
            print(f"  Class {i}:")
            print(f"    IoU: {iou:.4f}")
            print(f"    Dice: {dice:.4f}")
            print(f"    Precision: {precision:.4f}")
            print(f"    Recall: {recall:.4f}")
            print(f"    F1 Score: {f1:.4f}")

        # Plotting
        plot_confusion_matrix(confusion_matrix)
        
        class_indices = list(range(n_classes))
        metrics = list(zip(*per_class_metrics))
        metric_names = ['IoU', 'Dice', 'Precision', 'Recall', 'F1 Score']
        plot_class_metrics(class_indices, metrics, metric_names)
        
        plot_roc_curves(y_true, y_pred_prob, n_classes)
        plot_pr_curves(y_true, y_pred_prob, n_classes)
        
        # plot_sample_predictions(X_test1, y_test, y_pred, n_samples=5, class_names=['Background', 'Class 1', 'Class 2'])
        plot_uncertainty(y_pred_prob)
        
        analyze_error_types(y_true, y_pred, n_classes)
        # Generating a few sample predictions with overlay
        n_samples = 10  # Reduced number of samples for brevity
        sample_indices = np.random.choice(range(len(X_test1)), n_samples, replace=False)

        for idx, sample_idx in enumerate(sample_indices):
            test_img = X_test1[sample_idx]
            ground_truth = y_test[sample_idx]
            
            test_img_input = np.expand_dims(test_img, 0)
            prediction = model.predict(test_img_input)
            predicted_img = np.argmax(prediction, axis=3)[0,:,:]
            
            # Plotting the individual sample prediction with overlay
            plot_segmentation_overlay(test_img, ground_truth, predicted_img)
            
            
        print("Model evaluation completed.")
    except Exception as e:
        print(f"An error occurred during evaluation: {e}")

# Example usage:
# model = load_your_model()
# X_test = load_your_test_data()
# y_true = load_your_test_labels()
# n_classes = 3
evaluate_segmentation_model(best_model, X_test1, y_test, n_classes)
# print("X_test shape:", X_test1.shape)
# print("y_test shape:", y_test.shape)

