# -*- coding: utf-8 -*-
"""
Created on Tue Aug  6 09:21:59 2024

@author: Aus
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, precision_recall_curve
from scipy.ndimage import distance_transform_edt

# Input validation
def validate_input(y_test, y_pred_argmax):
    if not isinstance(y_test, (np.ndarray, tf.Tensor)) or not isinstance(y_pred_argmax, (np.ndarray, tf.Tensor)):
        raise ValueError("Inputs must be numpy arrays or tensorflow tensors")
    if y_test.shape != y_pred_argmax.shape:
        raise ValueError("Input arrays must have the same shape")

# Convert to appropriate tensor format if not already
def convert_to_tensor(y_test, y_pred_argmax):
    y_test_tensor = tf.convert_to_tensor(y_test)
    y_pred_argmax_tensor = tf.convert_to_tensor(y_pred_argmax)
    y_test_tensor = tf.squeeze(y_test_tensor)
    validate_input(y_test_tensor, y_pred_argmax_tensor)
    return y_test_tensor, y_pred_argmax_tensor

# Calculate metrics
def calculate_metric(y_true, y_pred, class_index, metric):
    y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    intersection = tf.reduce_sum(y_true_class * y_pred_class)
    union = tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class) - intersection
    
    if metric == 'iou':
        return tf.math.divide_no_nan(intersection, union)
    elif metric == 'dice':
        return tf.math.divide_no_nan(2. * intersection, tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class))
    elif metric == 'precision':
        return tf.math.divide_no_nan(intersection, tf.reduce_sum(y_pred_class))
    elif metric == 'recall':
        return tf.math.divide_no_nan(intersection, tf.reduce_sum(y_true_class))
    else:
        raise ValueError("Invalid metric. Choose from 'iou', 'dice', 'precision', 'recall'.")

def compute_metrics(y_test, y_pred_argmax, n_classes):
    y_test_tensor, y_pred_argmax_tensor = convert_to_tensor(y_test, y_pred_argmax)
    
    # Pixel Accuracy
    accuracy = tf.keras.metrics.Accuracy()
    accuracy.update_state(y_test_tensor, y_pred_argmax_tensor)
    pixel_accuracy = accuracy.result().numpy()
    
    # Mean IoU and Confusion Matrix
    miou = tf.keras.metrics.MeanIoU(num_classes=n_classes)
    miou.update_state(y_test_tensor, y_pred_argmax_tensor)
    mean_iou = miou.result().numpy()
    confusion_matrix = miou.total_cm.numpy()
    
    # Per-class metrics
    per_class_metrics = []
    for i in range(n_classes):
        iou = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'iou')
        dice = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'dice')
        precision = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'precision')
        recall = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'recall')
        per_class_metrics.append((iou.numpy(), dice.numpy(), precision.numpy(), recall.numpy()))

    # Cohen's Kappa
    n = tf.reduce_sum(confusion_matrix)
    sum_po = tf.linalg.trace(confusion_matrix)
    sum_pe = tf.reduce_sum(tf.reduce_sum(confusion_matrix, axis=0) * tf.reduce_sum(confusion_matrix, axis=1)) / n
    po = sum_po / n
    pe = sum_pe / n
    kappa = tf.math.divide_no_nan(po - pe, 1 - pe)

    return pixel_accuracy, mean_iou, per_class_metrics, kappa.numpy(), confusion_matrix

# Visualization functions
def plot_confusion_matrix(confusion_matrix):
    plt.figure(figsize=(10, 8))
    sns.heatmap(confusion_matrix, annot=True, fmt='g', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.show()

def plot_bar_chart(class_indices, values, metric_name):
    plt.figure(figsize=(10, 5))
    plt.bar(class_indices, values, color='skyblue' if metric_name == 'IoU' else 'lightgreen')
    plt.xlabel('Class Index')
    plt.ylabel(metric_name)
    plt.title(f'Class-wise {metric_name}')
    plt.show()

def plot_roc_curves(y_test, model, X_test1, n_classes):
    try:
        y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=n_classes)
        y_pred_probs = model.predict(X_test1)
        
        y_test_onehot_flat = y_test_onehot.reshape(-1, n_classes)
        y_pred_probs_flat = y_pred_probs.reshape(-1, n_classes)
        
        fpr = {}
        tpr = {}
        roc_auc = {}

        plt.figure(figsize=(10, 8))
        for i in range(n_classes):
            fpr[i], tpr[i], _ = roc_curve(y_test_onehot_flat[:, i], y_pred_probs_flat[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])
            print(f"Class {i} - AUC: {roc_auc[i]:.2f}")
            plt.plot(fpr[i], tpr[i], label=f'Class {i} (area = {roc_auc[i]:.2f})')
        
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic')
        plt.legend(loc='lower right')
        plt.show()
    except Exception as e:
        print(f"Error in plotting ROC curves: {str(e)}")

def plot_precision_recall_curves(y_test, model, X_test1, n_classes):
    try:
        y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=n_classes)
        y_pred_probs = model.predict(X_test1)
        
        y_test_onehot_flat = y_test_onehot.reshape(-1, n_classes)
        y_pred_probs_flat = y_pred_probs.reshape(-1, n_classes)
        
        precision = {}
        recall = {}
        pr_auc = {}

        plt.figure(figsize=(10, 8))
        for i in range(n_classes):
            precision[i], recall[i], _ = precision_recall_curve(y_test_onehot_flat[:, i], y_pred_probs_flat[:, i])
            pr_auc[i] = auc(recall[i], precision[i])
            print(f"Class {i} - Precision-Recall AUC: {pr_auc[i]:.2f}")
            plt.plot(recall[i], precision[i], label=f'Class {i} (area = {pr_auc[i]:.2f})')
        
        plt.plot([0, 1], [1, 0], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall Curve')
        plt.legend(loc='lower left')
        plt.show()
    except Exception as e:
        print(f"Error in plotting Precision-Recall curves: {str(e)}")

def plot_individual_sample_prediction(image, true_mask, pred_mask, index):
    try:
        plt.figure(figsize=(15, 5))
        
        plt.subplot(1, 3, 1)
        plt.title('Testing Image')
        plt.imshow(image[:,:,0], cmap='gray')
        
        plt.subplot(1, 3, 2)
        plt.title('Testing Label')
        plt.imshow(true_mask[:,:,0], cmap='jet')
        
        plt.subplot(1, 3, 3)
        plt.title('Prediction on test image')
        plt.imshow(pred_mask, cmap='jet')
        
        plt.suptitle(f'Sample {index}')
        plt.show()
    except Exception as e:
        print(f"Error in plotting individual sample prediction: {str(e)}")

def plot_learning_curves(history):
    if history is None:
        print("Warning: Training history not available. Learning curves cannot be plotted.")
        return

    try:
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        loss = history.history['loss']
        val_loss = history.history['val_loss']
        epochs = range(1, len(loss) + 1)
        plt.plot(epochs, loss, 'y', label='Training loss')
        plt.plot(epochs, val_loss, 'r', label='Validation loss')
        plt.title('Training and validation loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 2, 2)
        acc = history.history['accuracy']
        val_acc = history.history['val_accuracy']

        plt.plot(epochs, acc, 'y', label='Training acc')
        plt.plot(epochs, val_acc, 'r', label='Validation acc')
        plt.title('Training and validation accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()

        plt.tight_layout()
        plt.show()
        print("Learning Curves plotted")
    except Exception as e:
        print(f"Error in plotting learning curves: {str(e)}")

def plot_dice_coefficient_distribution(y_test, y_pred, n_classes):
    try:
        dice_scores = []
        for i in range(n_classes):
            y_true = (y_test == i).astype(int).squeeze()
            y_pred_class = (y_pred == i).astype(int)
            dice = np.sum(2 * (y_true * y_pred_class)) / (np.sum(y_true) + np.sum(y_pred_class) + 1e-7)
            dice_scores.append(dice)
        
        plt.figure(figsize=(10, 6))
        sns.violinplot(data=dice_scores)
        plt.title('Dice Coefficient Distribution')
        plt.xlabel('Class')
        plt.ylabel('Dice Coefficient')
        plt.show()
    except Exception as e:
        print(f"Error in plotting Dice coefficient distribution: {str(e)}")

def plot_misclassification_examples(X_test, y_test, y_pred, n_classes, n_examples=5):
    try:
        misclassified = np.where(y_test != y_pred)
        n_rows = n_classes
        n_cols = n_examples
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*3, n_rows*3))
        
        for i in range(n_classes):
            class_misclassified = np.where(y_test[misclassified] == i)[0]
            for j in range(n_examples):
                if j < len(class_misclassified):
                    idx = misclassified[0][class_misclassified[j]]
                    axes[i, j].imshow(X_test[idx])
                    axes[i, j].set_title(f'True: {y_test[idx]}, Pred: {y_pred[idx]}')
                axes[i, j].axis('off')
        
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"Error in plotting misclassification examples: {str(e)}")

def plot_uncertainty_visualization(X_test, y_pred_prob, threshold=0.5):
    try:
        uncertainty = 1 - np.max(y_pred_prob, axis=-1)
        
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(X_test[0])
        plt.title('Original Image')
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(X_test[0])
        plt.imshow(uncertainty[0], cmap='hot', alpha=0.5)
        plt.title('Uncertainty Map')
        plt.colorbar()
        plt.axis('off')
        
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"Error in plotting uncertainty visualization: {str(e)}")


# New function for class distribution visualization
def plot_class_distribution(y_test):
    # Convert to integer type, rounding float values if necessary
    y_test_int = np.round(y_test).astype(int)
    
    class_counts = np.bincount(y_test_int.flatten())
    plt.figure(figsize=(10, 5))
    plt.bar(range(len(class_counts)), class_counts)
    plt.title('Class Distribution in Test Set')
    plt.xlabel('Class')
    plt.ylabel('Pixel Count')
    plt.xticks(range(len(class_counts)))
    plt.show()

# New function for segmentation overlay
def plot_segmentation_overlay(image, true_mask, pred_mask, alpha=0.5):
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.title('Original Image')
    plt.imshow(image[:,:,0], cmap='gray')
    
    plt.subplot(1, 3, 2)
    plt.title('True Segmentation')
    plt.imshow(image[:,:,0], cmap='gray')
    plt.imshow(true_mask, alpha=alpha, cmap='jet')
    
    plt.subplot(1, 3, 3)
    plt.title('Predicted Segmentation')
    plt.imshow(image[:,:,0], cmap='gray')
    plt.imshow(pred_mask, alpha=alpha, cmap='jet')
    
    plt.show()


####################
def plot_pixel_accuracy_map(y_test, y_pred):
    y_test = y_test.squeeze()  # Remove the last dimension if it's 1
    accuracy_map = (y_test == y_pred).astype(float)
    plt.figure(figsize=(10, 5))
    plt.imshow(accuracy_map[0], cmap='RdYlGn')  # Show the first image in the batch
    plt.colorbar(label='Accuracy')
    plt.title('Per-pixel Accuracy Map (First image in batch)')
    plt.show()

def plot_boundary_accuracy(y_test, y_pred, distance_threshold=3):
    y_test = y_test.squeeze()  # Remove the last dimension if it's 1
    boundaries = np.zeros_like(y_test)
    for i in range(int(y_test.min()), int(y_test.max()) + 1):
        boundary = distance_transform_edt(y_test != i) <= distance_threshold
        boundaries = np.logical_or(boundaries, boundary)
    
    boundary_accuracy = np.mean((y_test == y_pred)[boundaries])
    non_boundary_accuracy = np.mean((y_test == y_pred)[~boundaries])
    
    plt.figure(figsize=(10, 5))
    plt.bar(['Boundary', 'Non-Boundary'], [boundary_accuracy, non_boundary_accuracy])
    plt.title('Accuracy: Boundary vs Non-Boundary')
    plt.ylabel('Accuracy')
    plt.show()

def plot_iou_distribution(y_test, y_pred, n_classes):
    y_test = y_test.squeeze()  # Remove the last dimension if it's 1
    ious = []
    for i in range(n_classes):
        y_true = (y_test == i).astype(int)
        y_pred_class = (y_pred == i).astype(int)
        intersection = np.logical_and(y_true, y_pred_class)
        union = np.logical_or(y_true, y_pred_class)
        iou = np.sum(intersection, axis=(1,2)) / np.sum(union, axis=(1,2))
        ious.append(iou)
    
    plt.figure(figsize=(10, 6))
    sns.violinplot(data=ious)
    plt.title('IoU Distribution')
    plt.xlabel('Class')
    plt.ylabel('IoU')
    plt.show()

def analyze_error_types(y_test, y_pred, n_classes):
    y_test = y_test.squeeze()  # Remove the last dimension if it's 1
    error_types = {'under_segmentation': 0, 'over_segmentation': 0, 'misclassification': 0}
    
    for i in range(n_classes):
        true_mask = (y_test == i)
        pred_mask = (y_pred == i)
        
        # Under-segmentation
        error_types['under_segmentation'] += np.sum(true_mask & ~pred_mask)
        
        # Over-segmentation
        error_types['over_segmentation'] += np.sum(~true_mask & pred_mask)
        
        # Misclassification (excluding under and over segmentation)
        misclassified = np.sum((y_test != y_pred) & ~true_mask & ~pred_mask)
        error_types['misclassification'] += misclassified
    
    total_errors = sum(error_types.values())
    for error_type, count in error_types.items():
        percentage = (count / total_errors) * 100
        print(f"{error_type}: {percentage:.2f}%")

    plt.figure(figsize=(10, 5))
    plt.bar(error_types.keys(), [v/total_errors for v in error_types.values()])
    plt.title('Error Type Distribution')
    plt.ylabel('Proportion of Errors')
    plt.show()
# Modified main evaluation function
def evaluate_model(y_test, y_pred_argmax, model, X_test1, n_classes, history=None):
    print("Starting model evaluation...")
    
    try:
        if history:
            print("Plotting learning curves...")
            plot_learning_curves(history)
        else:
            print("Warning: Training history not available. Learning curves cannot be plotted.")
        
        pixel_accuracy, mean_iou, per_class_metrics, kappa, confusion_matrix = compute_metrics(y_test, y_pred_argmax, n_classes)
        
        print(f"Pixel Accuracy: {pixel_accuracy}")
        print(f"Mean IoU: {mean_iou}")

        class_indices = list(range(n_classes))
        iou_values = [m[0] for m in per_class_metrics]
        dice_values = [m[1] for m in per_class_metrics]
        precision_values = [m[2] for m in per_class_metrics]
        recall_values = [m[3] for m in per_class_metrics]

        print("\nClass-wise metrics:")
        for i in range(n_classes):
            print(f"  Class {i}:")
            print(f"    IoU: {iou_values[i]}")
            print(f"    Precision: {precision_values[i]}")
            print(f"    Recall: {recall_values[i]}")
            print(f"    F1 Score (Dice Coefficient): {dice_values[i]}")

        print(f"\nMean Dice Coefficient (F1 Score): {np.mean(dice_values)}")
        print(f"Cohen's Kappa: {kappa}")

        print("\nPlotting confusion matrix...")
        plot_confusion_matrix(confusion_matrix)
        
        print("Plotting IoU bar chart...")
        plot_bar_chart(class_indices, iou_values, 'IoU')
        
        print("Plotting Dice Coefficient bar chart...")
        plot_bar_chart(class_indices, dice_values, 'Dice Coefficient')
        
        print("Plotting ROC curves...")
        plot_roc_curves(y_test, model, X_test1, n_classes)
        
        print("Plotting Precision-Recall curves...")
        plot_precision_recall_curves(y_test, model, X_test1, n_classes)
        
        print("Plotting Dice Coefficient distribution...")
        plot_dice_coefficient_distribution(y_test, y_pred_argmax, n_classes)
        
        print("Plotting misclassification examples...")
        plot_misclassification_examples(X_test1, y_test, y_pred_argmax, n_classes)
        
        print("Plotting uncertainty visualization...")
        y_pred_prob = model.predict(X_test1)
        plot_uncertainty_visualization(X_test1, y_pred_prob)
        
        # New visualizations and analyses
        print("Plotting class distribution...")
        plot_class_distribution(y_test)
        
        print("Plotting per-pixel accuracy map...")
        plot_pixel_accuracy_map(y_test, y_pred_argmax)
        
        print("Analyzing boundary accuracy...")
        plot_boundary_accuracy(y_test, y_pred_argmax)
        
        print("Plotting IoU distribution...")
        plot_iou_distribution(y_test, y_pred_argmax, n_classes)
        
        print("Analyzing error types...")
        analyze_error_types(y_test, y_pred_argmax, n_classes)
        
        # Generating a few sample predictions with overlay
        n_samples = 5  # Reduced number of samples for brevity
        sample_indices = np.random.choice(range(len(X_test1)), n_samples, replace=False)

        for idx, sample_idx in enumerate(sample_indices):
            test_img = X_test1[sample_idx]
            ground_truth = y_test[sample_idx]
            
            test_img_input = np.expand_dims(test_img, 0)
            prediction = model.predict(test_img_input)
            predicted_img = np.argmax(prediction, axis=3)[0,:,:]
            
            # Plotting the individual sample prediction with overlay
            plot_segmentation_overlay(test_img, ground_truth, predicted_img)
            
        print("Model evaluation completed.")
    except Exception as e:
        print(f"An error occurred: {e}")

# Example usage
# # Ensure that segmentation models use keras
# sm.set_framework('tf.keras')

# # Load the saved model
# model_path = 'D:\PROTOS\LIVER\models\model_checkpoint_994_0.68.h5'  # Update with the correct path to your saved model
# best_model = tf.keras.models.load_model(model_path, compile=False)

# # Predict the output of the model
# y_pred = model.predict(X_test1)
# y_pred_argmax = np.argmax(y_pred, axis=3)
evaluate_model(y_test, y_pred_argmax, best_model, X_test1, n_classes=3, history=history if history else None)

# Test - Data
# y_test_ = np.random.randint(0, 3, (959, 256, 256))
# y_pred_argmax_ = np.random.randint(0, 3, (959, 256, 256))
# n_classes_ = 3
# best_model_ = None  # Replace with your model
# X_test1_ = np.random.rand(959, 256, 256, 1)

# evaluate_model(y_test_, y_pred_argmax_, best_model_, X_test1_, n_classes=n_classes_, history=history if history else None)
