# -*- coding: utf-8 -*-
"""
Created on Mon Aug  5 16:40:13 2024

@author: Aus
"""
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, precision_recall_curve

# Input validation
def validate_input(y_test, y_pred_argmax):
    if not isinstance(y_test, tf.Tensor) or not isinstance(y_pred_argmax, tf.Tensor):
        raise ValueError("Input tensors must be of type tf.Tensor")
    if y_test.shape != y_pred_argmax.shape:
        raise ValueError("Input tensors must have the same shape")

# Convert to appropriate tensor format if not already
def convert_to_tensor(y_test, y_pred_argmax):
    y_test_tensor = tf.convert_to_tensor(y_test)
    y_pred_argmax_tensor = tf.convert_to_tensor(y_pred_argmax)
    y_test_tensor = tf.squeeze(y_test_tensor)
    validate_input(y_test_tensor, y_pred_argmax_tensor)
    return y_test_tensor, y_pred_argmax_tensor

# Calculate metrics
def calculate_metric(y_true, y_pred, class_index, metric):
    y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    intersection = tf.reduce_sum(y_true_class * y_pred_class)
    union = tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class) - intersection
    
    if metric == 'iou':
        return tf.cond(tf.equal(union, 0), lambda: 0.0, lambda: intersection / union)
    elif metric == 'dice':
        return tf.cond(tf.equal(tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class), 0), lambda: 0.0, lambda: (2. * intersection) / (tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class)))
    elif metric == 'precision':
        return tf.cond(tf.equal(tf.reduce_sum(y_pred_class), 0), lambda: 0.0, lambda: intersection / tf.reduce_sum(y_pred_class))
    elif metric == 'recall':
        return tf.cond(tf.equal(tf.reduce_sum(y_true_class), 0), lambda: 0.0, lambda: intersection / tf.reduce_sum(y_true_class))
    else:
        raise ValueError("Invalid metric. Choose from 'iou', 'dice', 'precision', 'recall'.")

def compute_metrics(y_test, y_pred_argmax, n_classes):
    y_test_tensor, y_pred_argmax_tensor = convert_to_tensor(y_test, y_pred_argmax)
    
    # Pixel Accuracy
    accuracy = tf.keras.metrics.Accuracy()
    accuracy.update_state(y_test_tensor, y_pred_argmax_tensor)
    pixel_accuracy = accuracy.result().numpy()
    
    # Mean IoU and Confusion Matrix
    miou = tf.keras.metrics.MeanIoU(num_classes=n_classes)
    miou.update_state(y_test_tensor, y_pred_argmax_tensor)
    mean_iou = miou.result().numpy()
    confusion_matrix = miou.total_cm.numpy()
    
    # Per-class metrics
    per_class_metrics = []
    for i in range(n_classes):
        iou = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'iou')
        dice = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'dice')
        precision = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'precision')
        recall = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'recall')
        per_class_metrics.append((iou.numpy(), dice.numpy(), precision.numpy(), recall.numpy()))

    # Cohen's Kappa
    n = tf.reduce_sum(confusion_matrix)
    sum_po = tf.linalg.trace(confusion_matrix)
    sum_pe = tf.reduce_sum(tf.reduce_sum(confusion_matrix, axis=0) * tf.reduce_sum(confusion_matrix, axis=1)) / n
    po = sum_po / n
    pe = sum_pe / n
    kappa = (po - pe) / (1 - pe + tf.keras.backend.epsilon())

    return pixel_accuracy, mean_iou, per_class_metrics, kappa.numpy(), confusion_matrix

# Visualization functions
def plot_confusion_matrix(confusion_matrix):
    plt.figure(figsize=(10, 8))
    sns.heatmap(confusion_matrix, annot=True, fmt='g', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.show()

def plot_bar_chart(class_indices, values, metric_name):
    plt.figure(figsize=(10, 5))
    plt.bar(class_indices, values, color='skyblue' if metric_name == 'IoU' else 'lightgreen')
    plt.xlabel('Class Index')
    plt.ylabel(metric_name)
    plt.title(f'Class-wise {metric_name}')
    plt.show()

def plot_roc_curves(y_test, model, X_test1, n_classes):
    y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=n_classes)
    y_pred_probs = model.predict(X_test1)
    
    y_test_onehot_flat = y_test_onehot.reshape(-1, n_classes)
    y_pred_probs_flat = y_pred_probs.reshape(-1, n_classes)
    
    fpr = {}
    tpr = {}
    roc_auc = {}

    plt.figure(figsize=(10, 8))
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_onehot_flat[:, i], y_pred_probs_flat[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        print(f"Class {i} - AUC: {roc_auc[i]:.2f}")
        plt.plot(fpr[i], tpr[i], label=f'Class {i} (area = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc='lower right')
    plt.show()

def plot_precision_recall_curves(y_test, model, X_test1, n_classes):
    y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=n_classes)
    y_pred_probs = model.predict(X_test1)
    
    y_test_onehot_flat = y_test_onehot.reshape(-1, n_classes)
    y_pred_probs_flat = y_pred_probs.reshape(-1, n_classes)
    
    precision = {}
    recall = {}
    pr_auc = {}

    plt.figure(figsize=(10, 8))
    for i in range(n_classes):
        precision[i], recall[i], _ = precision_recall_curve(y_test_onehot_flat[:, i], y_pred_probs_flat[:, i])
        pr_auc[i] = auc(recall[i], precision[i])
        print(f"Class {i} - Precision-Recall AUC: {pr_auc[i]:.2f}")
        plt.plot(recall[i], precision[i], label=f'Class {i} (area = {pr_auc[i]:.2f})')
    
    plt.plot([0, 1], [1, 0], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc='lower left')
    plt.show()

def plot_individual_sample_prediction(image, true_mask, pred_mask, index):
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.title('Testing Image')
    plt.imshow(image[:,:,0], cmap='gray')
    
    plt.subplot(1, 3, 2)
    plt.title('Testing Label')
    plt.imshow(true_mask[:,:,0], cmap='jet')
    
    plt.subplot(1, 3, 3)
    plt.title('Prediction on test image')
    plt.imshow(pred_mask, cmap='jet')
    
    plt.suptitle(f'Sample {index}')
    plt.show()

# Main function to run the evaluation
def evaluate_model(y_test, y_pred_argmax, model, X_test1, n_classes=3):
    pixel_accuracy, mean_iou, per_class_metrics, kappa, confusion_matrix = compute_metrics(y_test, y_pred_argmax, n_classes)
    
    print(f"Pixel Accuracy: {pixel_accuracy}")
    print(f"Mean IoU: {mean_iou}")

    class_indices = list(range(n_classes))
    iou_values = [m[0] for m in per_class_metrics]
    dice_values = [m[1] for m in per_class_metrics]
    precision_values = [m[2] for m in per_class_metrics]
    recall_values = [m[3] for m in per_class_metrics]

    print("IoU:")
    for i, iou in enumerate(iou_values):
        print(f"  Class {i}: {iou}")

    print("Precision:")
    for i, precision in enumerate(precision_values):
        print(f"  Class {i}: {precision}")

    print("Recall:")
    for i, recall in enumerate(recall_values):
        print(f"  Class {i}: {recall}")

    print("F1 Score (Dice Coefficient):")
    for i, dice in enumerate(dice_values):
        print(f"  Class {i}: {dice}")

    print(f"Mean Dice Coefficient (F1 Score): {np.mean(dice_values)}")
    print(f"Cohen's Kappa: {kappa}")

    plot_confusion_matrix(confusion_matrix)
    plot_bar_chart(class_indices, iou_values, 'IoU')
    plot_bar_chart(class_indices, dice_values, 'Dice Coefficient')
    plot_roc_curves(y_test, model, X_test1, n_classes)
    plot_precision_recall_curves(y_test, model, X_test1, n_classes)

# Example usage
evaluate_model(y_test, y_pred_argmax, model, X_test1, n_classes=3)

''' w/o precision-recall curve
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, confusion_matrix as sk_confusion_matrix

# Input validation
def validate_input(y_test, y_pred_argmax):
    if not isinstance(y_test, tf.Tensor) or not isinstance(y_pred_argmax, tf.Tensor):
        raise ValueError("Input tensors must be of type tf.Tensor")
    if y_test.shape != y_pred_argmax.shape:
        raise ValueError("Input tensors must have the same shape")

# Convert to appropriate tensor format if not already
def convert_to_tensor(y_test, y_pred_argmax):
    y_test_tensor = tf.convert_to_tensor(y_test)
    y_pred_argmax_tensor = tf.convert_to_tensor(y_pred_argmax)
    y_test_tensor = tf.squeeze(y_test_tensor)
    validate_input(y_test_tensor, y_pred_argmax_tensor)
    return y_test_tensor, y_pred_argmax_tensor

# Calculate metrics
def calculate_metric(y_true, y_pred, class_index, metric):
    y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    intersection = tf.reduce_sum(y_true_class * y_pred_class)
    union = tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class) - intersection
    
    if metric == 'iou':
        return tf.cond(tf.equal(union, 0), lambda: 0.0, lambda: intersection / union)
    elif metric == 'dice':
        return tf.cond(tf.equal(tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class), 0), lambda: 0.0, lambda: (2. * intersection) / (tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class)))
    elif metric == 'precision':
        return tf.cond(tf.equal(tf.reduce_sum(y_pred_class), 0), lambda: 0.0, lambda: intersection / tf.reduce_sum(y_pred_class))
    elif metric == 'recall':
        return tf.cond(tf.equal(tf.reduce_sum(y_true_class), 0), lambda: 0.0, lambda: intersection / tf.reduce_sum(y_true_class))
    else:
        raise ValueError("Invalid metric. Choose from 'iou', 'dice', 'precision', 'recall'.")

def compute_metrics(y_test, y_pred_argmax, n_classes):
    y_test_tensor, y_pred_argmax_tensor = convert_to_tensor(y_test, y_pred_argmax)
    
    # Pixel Accuracy
    accuracy = tf.keras.metrics.Accuracy()
    accuracy.update_state(y_test_tensor, y_pred_argmax_tensor)
    pixel_accuracy = accuracy.result().numpy()
    
    # Mean IoU and Confusion Matrix
    miou = tf.keras.metrics.MeanIoU(num_classes=n_classes)
    miou.update_state(y_test_tensor, y_pred_argmax_tensor)
    mean_iou = miou.result().numpy()
    confusion_matrix = miou.total_cm.numpy()
    
    # Per-class metrics
    per_class_metrics = []
    for i in range(n_classes):
        iou = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'iou')
        dice = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'dice')
        precision = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'precision')
        recall = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'recall')
        per_class_metrics.append((iou.numpy(), dice.numpy(), precision.numpy(), recall.numpy()))

    # Cohen's Kappa
    n = tf.reduce_sum(confusion_matrix)
    sum_po = tf.linalg.trace(confusion_matrix)
    sum_pe = tf.reduce_sum(tf.reduce_sum(confusion_matrix, axis=0) * tf.reduce_sum(confusion_matrix, axis=1)) / n
    po = sum_po / n
    pe = sum_pe / n
    kappa = (po - pe) / (1 - pe + tf.keras.backend.epsilon())

    return pixel_accuracy, mean_iou, per_class_metrics, kappa.numpy(), confusion_matrix

# Visualization functions
def plot_confusion_matrix(confusion_matrix):
    plt.figure(figsize=(10, 8))
    sns.heatmap(confusion_matrix, annot=True, fmt='g', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.show()

def plot_bar_chart(class_indices, values, metric_name):
    plt.figure(figsize=(10, 5))
    plt.bar(class_indices, values, color='skyblue' if metric_name == 'IoU' else 'lightgreen')
    plt.xlabel('Class Index')
    plt.ylabel(metric_name)
    plt.title(f'Class-wise {metric_name}')
    plt.show()

def plot_roc_curves(y_test, model, X_test1, n_classes):
    y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=n_classes)
    y_pred_probs = model.predict(X_test1)
    
    y_test_onehot_flat = y_test_onehot.reshape(-1, n_classes)
    y_pred_probs_flat = y_pred_probs.reshape(-1, n_classes)
    
    fpr = {}
    tpr = {}
    roc_auc = {}

    plt.figure(figsize=(10, 8))
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_onehot_flat[:, i], y_pred_probs_flat[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        print(f"Class {i} - AUC: {roc_auc[i]:.2f}")
        plt.plot(fpr[i], tpr[i], label=f'Class {i} (area = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc='lower right')
    plt.show()

def plot_individual_sample_prediction(image, true_mask, pred_mask, index):
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.title('Testing Image')
    plt.imshow(image[:,:,0], cmap='gray')
    
    plt.subplot(1, 3, 2)
    plt.title('Testing Label')
    plt.imshow(true_mask[:,:,0], cmap='jet')
    
    plt.subplot(1, 3, 3)
    plt.title('Prediction on test image')
    plt.imshow(pred_mask, cmap='jet')
    
    plt.suptitle(f'Sample {index}')
    plt.show()

# Main function to run the evaluation
def evaluate_model(y_test, y_pred_argmax, model, X_test1, n_classes=3):
    pixel_accuracy, mean_iou, per_class_metrics, kappa, confusion_matrix = compute_metrics(y_test, y_pred_argmax, n_classes)
    
    print(f"Pixel Accuracy: {pixel_accuracy}")
    print(f"Mean IoU: {mean_iou}")

    class_indices = list(range(n_classes))
    iou_values = [m[0] for m in per_class_metrics]
    dice_values = [m[1] for m in per_class_metrics]
    precision_values = [m[2] for m in per_class_metrics]
    recall_values = [m[3] for m in per_class_metrics]

    print("IoU:")
    for i, iou in enumerate(iou_values):
        print(f"  Class {i}: {iou}")

    print("Precision:")
    for i, precision in enumerate(precision_values):
        print(f"  Class {i}: {precision}")

    print("Recall:")
    for i, recall in enumerate(recall_values):
        print(f"  Class {i}: {recall}")

    print("F1 Score (Dice Coefficient):")
    for i, dice in enumerate(dice_values):
        print(f"  Class {i}: {dice}")

    print(f"Mean Dice Coefficient (F1 Score): {np.mean(dice_values)}")
    print(f"Cohen's Kappa: {kappa}")

    plot_confusion_matrix(confusion_matrix)
    plot_bar_chart(class_indices, iou_values, 'IoU')
    plot_bar_chart(class_indices, dice_values, 'Dice Coefficient')
    plot_roc_curves(y_test, model, X_test1, n_classes)

    # Generating a few sample predictions
    n_samples = 50  # Number of samples to display
    sample_indices = np.random.choice(range(len(X_test1)), n_samples, replace=False)

    for idx, sample_idx in enumerate(sample_indices):
        test_img = X_test1[sample_idx]
        ground_truth = y_test[sample_idx]
        
        test_img_input = np.expand_dims(test_img, 0)
        prediction = model.predict(test_img_input)
        predicted_img = np.argmax(prediction, axis=3)[0,:,:]
        
        # Plotting the individual sample prediction
        plot_individual_sample_prediction(test_img, ground_truth, predicted_img, idx + 1)



# # after training is complete, make sure to run this:
# # Predict the output of the model
# y_pred = model.predict(X_test1)
# y_pred_argmax = np.argmax(y_pred, axis=3)

# Example usage
evaluate_model(y_test, y_pred_argmax, model, X_test1, n_classes=3)
'''


############################################################################################# OLD ########################################
# import tensorflow as tf
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns
# from sklearn.metrics import roc_curve, auc, confusion_matrix as sk_confusion_matrix

# # Input validation
# def validate_input(y_test, y_pred_argmax):
#     if not isinstance(y_test, tf.Tensor) or not isinstance(y_pred_argmax, tf.Tensor):
#         raise ValueError("Input tensors must be of type tf.Tensor")
#     if y_test.shape != y_pred_argmax.shape:
#         raise ValueError("Input tensors must have the same shape")

# # Convert to appropriate tensor format if not already
# def convert_to_tensor(y_test, y_pred_argmax):
#     y_test_tensor = tf.convert_to_tensor(y_test)
#     y_pred_argmax_tensor = tf.convert_to_tensor(y_pred_argmax)
#     y_test_tensor = tf.squeeze(y_test_tensor)
#     validate_input(y_test_tensor, y_pred_argmax_tensor)
#     return y_test_tensor, y_pred_argmax_tensor

# # Calculate metrics
# def calculate_metric(y_true, y_pred, class_index, metric):
#     y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
#     y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
#     intersection = tf.reduce_sum(y_true_class * y_pred_class)
#     union = tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class) - intersection
    
#     if metric == 'iou':
#         return tf.cond(tf.equal(union, 0), lambda: 0.0, lambda: intersection / union)
#     elif metric == 'dice':
#         return tf.cond(tf.equal(tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class), 0), lambda: 0.0, lambda: (2. * intersection) / (tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class)))
#     elif metric == 'precision':
#         return tf.cond(tf.equal(tf.reduce_sum(y_pred_class), 0), lambda: 0.0, lambda: intersection / tf.reduce_sum(y_pred_class))
#     elif metric == 'recall':
#         return tf.cond(tf.equal(tf.reduce_sum(y_true_class), 0), lambda: 0.0, lambda: intersection / tf.reduce_sum(y_true_class))
#     else:
#         raise ValueError("Invalid metric. Choose from 'iou', 'dice', 'precision', 'recall'.")

# def compute_metrics(y_test, y_pred_argmax, n_classes):
#     y_test_tensor, y_pred_argmax_tensor = convert_to_tensor(y_test, y_pred_argmax)
    
#     # Pixel Accuracy
#     accuracy = tf.keras.metrics.Accuracy()
#     accuracy.update_state(y_test_tensor, y_pred_argmax_tensor)
#     pixel_accuracy = accuracy.result().numpy()
    
#     # Mean IoU and Confusion Matrix
#     miou = tf.keras.metrics.MeanIoU(num_classes=n_classes)
#     miou.update_state(y_test_tensor, y_pred_argmax_tensor)
#     mean_iou = miou.result().numpy()
#     confusion_matrix = miou.total_cm.numpy()
    
#     # Per-class metrics
#     per_class_metrics = []
#     for i in range(n_classes):
#         iou = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'iou')
#         dice = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'dice')
#         precision = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'precision')
#         recall = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'recall')
#         per_class_metrics.append((iou.numpy(), dice.numpy(), precision.numpy(), recall.numpy()))

#     # Cohen's Kappa
#     n = tf.reduce_sum(confusion_matrix)
#     sum_po = tf.linalg.trace(confusion_matrix)
#     sum_pe = tf.reduce_sum(tf.reduce_sum(confusion_matrix, axis=0) * tf.reduce_sum(confusion_matrix, axis=1)) / n
#     po = sum_po / n
#     pe = sum_pe / n
#     kappa = (po - pe) / (1 - pe + tf.keras.backend.epsilon())

#     return pixel_accuracy, mean_iou, per_class_metrics, kappa.numpy(), confusion_matrix

# # Visualization functions
# def plot_confusion_matrix(confusion_matrix):
#     plt.figure(figsize=(10, 8))
#     sns.heatmap(confusion_matrix, annot=True, fmt='g', cmap='Blues')
#     plt.title('Confusion Matrix')
#     plt.xlabel('Predicted Labels')
#     plt.ylabel('True Labels')
#     plt.show()

# def plot_bar_chart(class_indices, values, metric_name):
#     plt.figure(figsize=(10, 5))
#     plt.bar(class_indices, values, color='skyblue' if metric_name == 'IoU' else 'lightgreen')
#     plt.xlabel('Class Index')
#     plt.ylabel(metric_name)
#     plt.title(f'Class-wise {metric_name}')
#     plt.show()

# def plot_roc_curves(y_test, model, X_test1, n_classes):
#     y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=n_classes)
#     y_pred_probs = model.predict(X_test1)
    
#     y_test_onehot_flat = y_test_onehot.reshape(-1, n_classes)
#     y_pred_probs_flat = y_pred_probs.reshape(-1, n_classes)
    
#     fpr = {}
#     tpr = {}
#     roc_auc = {}

#     plt.figure(figsize=(10, 8))
#     for i in range(n_classes):
#         fpr[i], tpr[i], _ = roc_curve(y_test_onehot_flat[:, i], y_pred_probs_flat[:, i])
#         roc_auc[i] = auc(fpr[i], tpr[i])
#         plt.plot(fpr[i], tpr[i], label=f'Class {i} (area = {roc_auc[i]:.2f})')
    
#     plt.plot([0, 1], [0, 1], 'k--')
#     plt.xlim([0.0, 1.0])
#     plt.ylim([0.0, 1.05])
#     plt.xlabel('False Positive Rate')
#     plt.ylabel('True Positive Rate')
#     plt.title('Receiver Operating Characteristic')
#     plt.legend(loc='lower right')
#     plt.show()

# def plot_individual_sample_prediction(image, true_mask, pred_mask, index):
#     plt.figure(figsize=(15, 5))
    
#     plt.subplot(1, 3, 1)
#     plt.title('Testing Image')
#     plt.imshow(image[:,:,0], cmap='gray')
    
#     plt.subplot(1, 3, 2)
#     plt.title('Testing Label')
#     plt.imshow(true_mask[:,:,0], cmap='jet')
    
#     plt.subplot(1, 3, 3)
#     plt.title('Prediction on test image')
#     plt.imshow(pred_mask, cmap='jet')
    
#     plt.suptitle(f'Sample {index}')
#     plt.show()

# # Main function to run the evaluation
# def evaluate_model(y_test, y_pred_argmax, model, X_test1, n_classes=3):
#     pixel_accuracy, mean_iou, per_class_metrics, kappa, confusion_matrix = compute_metrics(y_test, y_pred_argmax, n_classes)
    
#     print(f"Pixel Accuracy: {pixel_accuracy}")
#     print(f"Mean IoU: {mean_iou}")

#     class_indices = list(range(n_classes))
#     iou_values = [m[0] for m in per_class_metrics]
#     dice_values = [m[1] for m in per_class_metrics]
#     precision_values = [m[2] for m in per_class_metrics]
#     recall_values = [m[3] for m in per_class_metrics]

#     print("IoU:")
#     for i, iou in enumerate(iou_values):
#         print(f"  Class {i}: {iou}")

#     print("Precision:")
#     for i, precision in enumerate(precision_values):
#         print(f"  Class {i}: {precision}")

#     print("Recall:")
#     for i, recall in enumerate(recall_values):
#         print(f"  Class {i}: {recall}")

#     print("F1 Score (Dice Coefficient):")
#     for i, dice in enumerate(dice_values):
#         print(f"  Class {i}: {dice}")

#     print(f"Mean Dice Coefficient (F1 Score): {np.mean(dice_values)}")
#     print(f"Cohen's Kappa: {kappa}")

#     plot_confusion_matrix(confusion_matrix)
#     plot_bar_chart(class_indices, iou_values, 'IoU')
#     plot_bar_chart(class_indices, dice_values, 'Dice Coefficient')
#     plot_roc_curves(y_test, model, X_test1, n_classes)

#     # Generating a few sample predictions
#     n_samples = 5  # Number of samples to display
#     sample_indices = np.random.choice(range(len(X_test1)), n_samples, replace=False)

#     for idx, sample_idx in enumerate(sample_indices):
#         test_img = X_test1[sample_idx]
#         ground_truth = y_test[sample_idx]
        
#         test_img_input = np.expand_dims(test_img, 0)
#         prediction = model.predict(test_img_input)
#         predicted_img = np.argmax(prediction, axis=3)[0,:,:]
        
#         # Plotting the individual sample prediction
#         plot_individual_sample_prediction(test_img, ground_truth, predicted_img, idx + 1)

# # Example usage
# evaluate_model(y_test, y_pred_argmax, model, X_test1, n_classes=3)
