# -*- coding: utf-8 -*-
"""
Created on Mon Aug  5 09:42:47 2024

@author: Aus
"""

import tensorflow as tf
from tensorflow.keras.metrics import MeanIoU, Precision, Recall, Accuracy
import numpy as np
from sklearn.metrics import cohen_kappa_score

# Parameters
num_test_images = 10
image_height = 256
image_width = 256
num_classes = 3

# Create random ground truth labels (ground_truth_labels)
ground_truth_labels = np.random.randint(0, num_classes, size=(num_test_images, image_height, image_width, 1))

# Create random predicted labels (predicted_class_indices)
predicted_class_indices = np.random.randint(0, num_classes, size=(num_test_images, image_height, image_width, 1))

# Convert to appropriate tensor format and cast to int32
ground_truth_labels_tensor = tf.cast(tf.convert_to_tensor(ground_truth_labels), tf.int32)
predicted_class_indices_tensor = tf.cast(tf.convert_to_tensor(predicted_class_indices), tf.int32)

# Ensure predicted_class_indices has the same shape as ground_truth_labels
if len(predicted_class_indices_tensor.shape) < len(ground_truth_labels_tensor.shape):
    predicted_class_indices_tensor = tf.expand_dims(predicted_class_indices_tensor, axis=-1)

# Pixel Accuracy
pixel_accuracy_metric = Accuracy()
pixel_accuracy_metric.update_state(ground_truth_labels_tensor, predicted_class_indices_tensor)
pixel_accuracy = pixel_accuracy_metric.result().numpy()
print(f"\nPixel Accuracy: {pixel_accuracy}")

# IoU and Dice Coefficient (F1 Score)
iou_metric = MeanIoU(num_classes=num_classes)
iou_metric.update_state(ground_truth_labels_tensor, predicted_class_indices_tensor)
mean_iou = iou_metric.result().numpy()
print(f"\nMean IoU: {mean_iou}\n")

# Class-wise IoU
iou_values = iou_metric.get_weights()[0]
for class_index in range(num_classes):
    class_iou = iou_values[class_index, class_index] / (np.sum(iou_values[class_index, :]) + np.sum(iou_values[:, class_index]) - iou_values[class_index, class_index])
    print(f"IoU for class {class_index}: {class_iou}")

# Dice Coefficient (F1 Score)
dice_coefficient = 2 * mean_iou / (1 + mean_iou)
print(f"\nDice Coefficient (F1 Score): {dice_coefficient}\n")

# Class-wise Dice Coefficient
def calculate_dice_coefficient(ground_truth_labels, predicted_class_indices, class_index):
    ground_truth_labels_binary = tf.cast(tf.equal(ground_truth_labels, class_index), tf.float32)
    predicted_class_indices_binary = tf.cast(tf.equal(predicted_class_indices, class_index), tf.float32)
    intersection = tf.reduce_sum(ground_truth_labels_binary * predicted_class_indices_binary)
    return (2 * intersection) / (tf.reduce_sum(ground_truth_labels_binary) + tf.reduce_sum(predicted_class_indices_binary) + 1e-7)

for class_index in range(num_classes):
    dice_coefficient = calculate_dice_coefficient(ground_truth_labels_tensor, predicted_class_indices_tensor, class_index)
    print(f"Dice Coefficient for class {class_index}: {dice_coefficient.numpy()}")


# Class-wise Precision and Recall
class_precision_metrics = [Precision() for _ in range(num_classes)]
class_recall_metrics = [Recall() for _ in range(num_classes)]

for class_index in range(num_classes):
    class_mask = tf.equal(ground_truth_labels_tensor, class_index)
    class_predictions = tf.equal(predicted_class_indices_tensor, class_index)
    class_precision_metrics[class_index].update_state(class_mask, class_predictions)
    class_recall_metrics[class_index].update_state(class_mask, class_predictions)

print("\nPrecision values for each class:")
for class_index in range(num_classes):
    precision = class_precision_metrics[class_index].result().numpy()
    print(f"Class {class_index}: {precision}")

print("\nRecall values for each class:")
for class_index in range(num_classes):
    recall = class_recall_metrics[class_index].result().numpy()
    print(f"Class {class_index}: {recall}")
    
# Cohen's Kappa using scikit-learn
# Flatten the arrays to 1D
ground_truth_labels_flat = tf.reshape(ground_truth_labels_tensor, [-1]).numpy()
predicted_class_indices_flat = tf.reshape(predicted_class_indices_tensor, [-1]).numpy()
cohen_kappa_score_value = cohen_kappa_score(ground_truth_labels_flat, predicted_class_indices_flat)
print(f"\nCohen's Kappa: {cohen_kappa_score_value}")



#######################################################################################################################
# import tensorflow as tf
# from tensorflow.keras.metrics import MeanIoU, Precision, Recall, Accuracy
# import numpy as np
# from sklearn.metrics import cohen_kappa_score

# # Parameters
# num_images = 10
# image_height = 256
# image_width = 256
# n_classes = 3

# # Create random ground truth labels (y_test)
# y_test = np.random.randint(0, n_classes, size=(num_images, image_height, image_width, 1))

# # Create random predicted labels (y_pred_argmax)
# y_pred_argmax = np.random.randint(0, n_classes, size=(num_images, image_height, image_width, 1))

# # Convert to appropriate tensor format and cast to int32
# y_test = tf.cast(tf.convert_to_tensor(y_test), tf.int32)
# y_pred_argmax = tf.cast(tf.convert_to_tensor(y_pred_argmax), tf.int32)

# # Ensure y_pred_argmax has the same shape as y_test
# if len(y_pred_argmax.shape) < len(y_test.shape):
#     y_pred_argmax = tf.expand_dims(y_pred_argmax, axis=-1)

# # Pixel Accuracy
# accuracy = Accuracy()
# accuracy.update_state(y_test, y_pred_argmax)
# pixel_accuracy = accuracy.result().numpy()
# print(f"Pixel Accuracy: {pixel_accuracy}")

# # IoU and Dice Coefficient (F1 Score)
# iou_metric = MeanIoU(num_classes=n_classes)
# iou_metric.update_state(y_test, y_pred_argmax)
# mean_iou = iou_metric.result().numpy()
# print(f"Mean IoU: {mean_iou}")

# # Class-wise IoU
# iou_values = iou_metric.get_weights()[0]
# for i in range(n_classes):
#     class_iou = iou_values[i, i] / (np.sum(iou_values[i, :]) + np.sum(iou_values[:, i]) - iou_values[i, i])
#     print(f"IoU for class {i}: {class_iou}")

# # Dice Coefficient (F1 Score)
# dice = 2 * mean_iou / (1 + mean_iou)
# print(f"Dice Coefficient (F1 Score): {dice}")

# # Class-wise Dice Coefficient
# def dice_coefficient(y_true, y_pred, class_index):
#     y_true_f = tf.cast(tf.equal(y_true, class_index), tf.float32)
#     y_pred_f = tf.cast(tf.equal(y_pred, class_index), tf.float32)
#     intersection = tf.reduce_sum(y_true_f * y_pred_f)
#     return (2 * intersection) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + 1e-7)

# for i in range(n_classes):
#     dice = dice_coefficient(y_test, y_pred_argmax, i)
#     print(f"Dice Coefficient for class {i}: {dice.numpy()}")

# # Class-wise Precision and Recall
# for i in range(n_classes):
#     class_precision = Precision()
#     class_recall = Recall()
#     class_mask = tf.equal(y_test, i)
#     class_predictions = tf.equal(y_pred_argmax, i)
#     class_precision.update_state(class_mask, class_predictions)
#     class_recall.update_state(class_mask, class_predictions)
#     print(f"Class {i}:")
#     print(f"  Precision: {class_precision.result().numpy()}")
#     print(f"  Recall: {class_recall.result().numpy()}")

# # Cohen's Kappa using scikit-learn
# # Flatten the arrays to 1D
# y_test_flat = tf.reshape(y_test, [-1]).numpy()
# y_pred_argmax_flat = tf.reshape(y_pred_argmax, [-1]).numpy()
# cohen_kappa = cohen_kappa_score(y_test_flat, y_pred_argmax_flat)
# print(f"Cohen's Kappa: {cohen_kappa}")