# -*- coding: utf-8 -*-
"""
Created on Mon Aug  5 15:23:57 2024

@author: Aus
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix as sk_confusion_matrix
from sklearn.metrics import roc_curve, auc

# Assuming y_test and y_pred_argmax are your ground truth and predicted labels respectively
n_classes = 3  # Adjust this if you have a different number of classes

# Create temporary arrays for testing
y_test_ = np.random.randint(0, n_classes, size=(100, 100))  # 100x100 array with random class labels
y_pred_argmax_ = np.random.randint(0, n_classes, size=(100, 100))  # 100x100 array with random predicted class labels

# Input validation
def validate_input(y_test, y_pred_argmax):
    if not isinstance(y_test, tf.Tensor) or not isinstance(y_pred_argmax, tf.Tensor):
        raise ValueError("Input tensors must be of type tf.Tensor")
    if y_test.shape != y_pred_argmax.shape:
        raise ValueError("Input tensors must have the same shape")
    # if len(y_test.shape) != 1:
    #     raise ValueError("Input tensors must be 1D")

# Convert to appropriate tensor format if not already
y_test_tensor = tf.convert_to_tensor(y_test_)
y_pred_argmax_tensor = tf.convert_to_tensor(y_pred_argmax_)

# Remove the extra dimension from y_test if needed
y_test_tensor = tf.squeeze(y_test_tensor)

# Validate input
validate_input(y_test_tensor, y_pred_argmax_tensor)

# Pixel Accuracy
accuracy = tf.keras.metrics.Accuracy()
accuracy.update_state(y_test_tensor, y_pred_argmax_tensor)
pixel_accuracy = accuracy.result().numpy()

# Mean IoU and Confusion Matrix
miou = tf.keras.metrics.MeanIoU(num_classes=n_classes)
miou.update_state(y_test_tensor, y_pred_argmax_tensor)
mean_iou = miou.result().numpy()
confusion_matrix = miou.total_cm

# Function to calculate metrics
def calculate_metric(y_true, y_pred, class_index, metric):
    y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    intersection = tf.reduce_sum(y_true_class * y_pred_class)
    union = tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class) - intersection
    
    if metric == 'iou':
        # Handle edge case where union is zero
        return tf.cond(tf.equal(union, 0), lambda: 0.0, lambda: intersection / union)
    elif metric == 'dice':
        # Handle edge case where sum of true and predicted is zero
        return tf.cond(tf.equal(tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class), 0), lambda: 0.0, lambda: (2. * intersection) / (tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class)))
    elif metric == 'precision':
        # Handle edge case where predicted is zero
        return tf.cond(tf.equal(tf.reduce_sum(y_pred_class), 0), lambda: 0.0, lambda: intersection / tf.reduce_sum(y_pred_class))
    elif metric == 'recall':
        # Handle edge case where true is zero
        return tf.cond(tf.equal(tf.reduce_sum(y_true_class), 0), lambda: 0.0, lambda: intersection / tf.reduce_sum(y_true_class))
    else:
        raise ValueError("Invalid metric. Choose from 'iou', 'dice', 'precision', 'recall'.")

# Calculate per-class metrics
per_class_metrics = []
for i in range(n_classes):
    iou = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'iou')
    dice = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'dice')
    precision = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'precision')
    recall = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'recall')
    per_class_metrics.append((iou.numpy(), dice.numpy(), precision.numpy(), recall.numpy()))

# Cohen's Kappa
n = tf.cast(tf.reduce_sum(confusion_matrix), tf.float32)
sum_po = tf.cast(tf.linalg.trace(confusion_matrix), tf.float32)
sum_pe = tf.reduce_sum(tf.cast(tf.reduce_sum(confusion_matrix, axis=0) * tf.reduce_sum(confusion_matrix, axis=1), tf.float32)) / n
po = sum_po / n
pe = sum_pe / n
kappa = (po - pe) / (1 - pe + tf.keras.backend.epsilon())

# Print results
print(f"Pixel Accuracy: {pixel_accuracy}")
print(f"Mean IoU: {mean_iou}")

print("IoU:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[0]}")

print("Precision:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[2]}")

print("Recall:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[3]}")

print("F1 Score (Dice Coefficient):")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[1]}")

print(f"Mean Dice Coefficient (F1 Score): {np.mean([m[1] for m in per_class_metrics])}")
print(f"Cohen's Kappa: {kappa.numpy()}")

##### Visualizations
import seaborn as sns

# Plot confusion matrix
conf_matrix = confusion_matrix.numpy()
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='g', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()

# Extract IoU and Dice Coefficient for each class
class_indices = list(range(n_classes))
iou_values = [m[0] for m in per_class_metrics]
dice_values = [m[1] for m in per_class_metrics]

# Plot IoU
plt.figure(figsize=(10, 5))
plt.bar(class_indices, iou_values, color='skyblue')
plt.xlabel('Class Index')
plt.ylabel('IoU')
plt.title('Class-wise IoU')
plt.show()

# Plot Dice Coefficient
plt.figure(figsize=(10, 5))
plt.bar(class_indices, dice_values, color='lightgreen')
plt.xlabel('Class Index')
plt.ylabel('Dice Coefficient')
plt.title('Class-wise Dice Coefficient')
plt.show()

# Function to plot individual sample predictions
def plot_individual_sample_prediction(image, true_mask, pred_mask, index):
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.title('Testing Image')
    plt.imshow(image[:,:,0], cmap='gray')
    
    plt.subplot(1, 3, 2)
    plt.title('Testing Label')
    plt.imshow(true_mask[:,:,0], cmap='jet')
    
    plt.subplot(1, 3, 3)
    plt.title('Prediction on test image')
    plt.imshow(pred_mask, cmap='jet')
    
    plt.suptitle(f'Sample {index}')
    plt.show()

# Generating a few sample predictions
n_samples = 50  # Number of samples to display
sample_indices = random.sample(range(len(X_test1)), n_samples)

for idx, sample_idx in enumerate(sample_indices):
    test_img = X_test1[sample_idx]
    ground_truth = y_test[sample_idx]
    
    test_img_input = np.expand_dims(test_img, 0)
    prediction = model.predict(test_img_input)
    predicted_img = np.argmax(prediction, axis=3)[0,:,:]
    
    # Plotting the individual sample prediction
    plot_individual_sample_prediction(test_img, ground_truth, predicted_img, idx + 1)


###
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import tensorflow as tf

# Assuming y_test and model1 are available
# Convert y_test to one-hot encoding
n_classes = 3  # Number of classes
y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=n_classes)

# Get the predicted probabilities from the model
y_pred_probs = model.predict(X_test1)  # Predicted probabilities

# Flatten the arrays for ROC calculation
y_test_onehot_flat = y_test_onehot.reshape(-1, n_classes)
y_pred_probs_flat = y_pred_probs.reshape(-1, n_classes)

# Calculate ROC curve and AUC for each class
fpr = {}
tpr = {}
roc_auc = {}

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test_onehot_flat[:, i], y_pred_probs_flat[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
    print(f"Class {i} - FPR: {fpr[i]}")
    print(f"Class {i} - TPR: {tpr[i]}")
    print(f"Class {i} - AUC: {roc_auc[i]:.2f}\n")

# Plotting the ROC curves
plt.figure(figsize=(10, 8))
for i in range(n_classes):
    plt.plot(fpr[i], tpr[i], label=f'Class {i} (area = {roc_auc[i]:.2f})')
    
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.show()