# -*- coding: utf-8 -*-
"""
Created on Sat Aug  3 11:54:26 2024

@author: Aus
"""


import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
import segmentation_models as smn
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
from tensorflow.keras.utils import Sequence
import datetime
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def print_with_timestamp(message):
    current_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print(f"[{current_time}] {message}")
sm.set_framework('tf.keras')

# Check for GPU availability
# gpus = tf.config.experimental.list_physical_devices('GPU')
# print("Physical GPUs:", gpus)
# if gpus:
#     try:
#         # Set GPU memory growth to avoid memory allocation issues
#         # for gpu in gpus:
#         #     tf.config.experimental.set_memory_growth(gpu, True)
        
#         # Use all available GPUs for training
#         # Create a MirroredStrategy.
#         strategy = tf.distribute.MirroredStrategy()
#         # strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1", "/gpu:2"])  # Update device names as needed
#         print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

# Define the directory where your batch files are saved
batch_directory = r"D:\PROTOS\LIVER\numpy"
num_batches = 1 # Assuming you have 6 batches
# Load entire dataset
all_images = []
all_masks = []
for batch_idx in range(num_batches):
    print_with_timestamp(f"Batch {batch_idx} ...")
    image_batch_filename = f'images.npy'
    mask_batch_filename = f'masks.npy'
    image_batch_path = os.path.join(batch_directory, image_batch_filename)
    mask_batch_path = os.path.join(batch_directory, mask_batch_filename)
    image_batch_data = np.load(image_batch_path)
    mask_batch_data = np.load(mask_batch_path)
    all_images.append(image_batch_data)
    all_masks.append(mask_batch_data)
print_with_timestamp("All Data Loaded...")
ground_truth_array = np.concatenate(all_images, axis=0)
mask_array = np.concatenate(all_masks, axis=0)

# Delete unnecessary variables
del all_images, all_masks
print_with_timestamp("All Data Concatenated...")
import random

# Perform sanity check on a random subset of 100 images
sanity_check_indices = random.sample(range(len(ground_truth_array)), 10)

for idx in sanity_check_indices:
    plt.figure(figsize=(12, 8))
    plt.subplot(231)
    plt.title(f'Ground truth Image {idx}')
    plt.imshow(ground_truth_array[idx,:,:], cmap='gray')
    plt.subplot(232)
    plt.title(f'Mask Image {idx}')
    plt.imshow(mask_array[idx,:,:], cmap='jet')
    plt.show()
    
# Split data into train, validation, and test sets
x_train_val, x_test, y_train_val, y_test = train_test_split(ground_truth_array, mask_array, test_size=0.05, random_state=42)

# Delete unnecessary variables
del ground_truth_array, mask_array

x_train, x_val, y_train, y_val = train_test_split(x_train_val, y_train_val, test_size=0.05, random_state=42)

# Delete unnecessary variables
del x_train_val, y_train_val
print_with_timestamp("Data splitted...")
n_classes = 3
train_masks_cat = to_categorical(y_train, num_classes=n_classes)
y_train_cat = train_masks_cat.reshape((y_train.shape[0], y_train.shape[1], y_train.shape[2], n_classes))

val_masks_cat = to_categorical(y_val, num_classes=n_classes)
y_val_cat = val_masks_cat.reshape((y_val.shape[0], y_val.shape[1], y_val.shape[2], n_classes))

test_masks_cat = to_categorical(y_test, num_classes=n_classes)
y_test_cat = test_masks_cat.reshape((y_test.shape[0], y_test.shape[1], y_test.shape[2], n_classes))

# Delete unnecessary variables
del y_train, y_val, train_masks_cat, val_masks_cat, test_masks_cat
print_with_timestamp("Categorical done...")
activation = 'softmax'
LR = 0.0001
optim = tf.keras.optimizers.Adam(LR)

dice_loss = sm.losses.DiceLoss(class_weights=np.array([0.1, 0.4, 0.5]))
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

metrics = ['accuracy', sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

BACKBONE1 = 'resnet101'
preprocess_input1 = sm.get_preprocessing(BACKBONE1)

X_train1 = preprocess_input1(x_train)
X_val1 = preprocess_input1(x_val)
X_test1 = preprocess_input1(x_test)

# Delete unnecessary variables
del x_train, x_val, x_test
print_with_timestamp("Preprocessed...")

# Define DataGenerator class
class DataGenerator(Sequence):
    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        return batch_x, batch_y
    

# with strategy.scope():
model = sm.Unet(BACKBONE1, encoder_weights=None, input_shape=(None, None, X_train1.shape[-1]), classes=n_classes, activation=activation)
model.compile(optim, total_loss, metrics=metrics)

# checkpoint = ModelCheckpoint('D:/PROTOS/AIRA/Results/model_checkpoint.h5', monitor='loss', verbose=1, save_best_only=True, mode='min')
# early_stop = EarlyStopping(monitor='loss', patience=100, verbose=1)
# log_csv = CSVLogger('D:/PROTOS/AIRA/Results/training_logs.csv', separator=',', append=False)
# callbacks_list = [checkpoint, early_stop, log_csv]
# Define callbacks for saving models and logging
checkpoint = ModelCheckpoint('D:/PROTOS/LIVER/models/model_checkpoint_{epoch:02d}_{val_loss:.2f}.h5', monitor='loss', verbose=1, save_best_only=False, mode='min')
early_stop = EarlyStopping(monitor='loss', patience=100, verbose=1)
log_csv = CSVLogger('D:/PROTOS/LIVER/models/training_logs.csv', separator=',', append=False)
callbacks_list = [checkpoint, early_stop, log_csv]

# print(model.summary())

print_with_timestamp("Going to Train...")

try:
    train_gen = DataGenerator(X_train1, y_train_cat, 16)
    history = model.fit(train_gen, epochs=10, validation_data=(X_val1, y_val_cat), callbacks=callbacks_list)
except tf.errors.ResourceExhaustedError:
    print("ResourceExhaustedError occurred. Trying with reduced batch size...")
    for bs in [8, 4, 2, 1]:
        try:
            train_gen = DataGenerator(X_train1, y_train_cat, bs)
            history = model.fit(train_gen, epochs=1000, validation_data=(X_val1, y_val_cat), callbacks=callbacks_list)
            break
        except tf.errors.ResourceExhaustedError:
            continue
    else:
        print("Unable to train even with reduced batch size. Exiting.")
        

############################################################################################
#TESTING

# Ensure that segmentation models use keras
sm.set_framework('tf.keras')

# Load the saved model
model_path = 'D:\PROTOS\LIVER\models\model_checkpoint_994_0.68.h5'  # Update with the correct path to your saved model
best_model = tf.keras.models.load_model(model_path, compile=False)

import numpy as np
from tensorflow.keras.metrics import MeanIoU

# Predict the output of the model
y_pred = model.predict(X_test1)
y_pred_argmax = np.argmax(y_pred, axis=3)

#################################################### IOU 
# Create an instance of the MeanIoU metric
iou_metric = MeanIoU(num_classes=n_classes)

# Update the state of the metric
iou_metric.update_state(y_test, y_pred_argmax)

# Print the mean IoU
print("Mean IoU =", iou_metric.result().numpy())

# Get the number of classes
num_classes = iou_metric.num_classes

# Get the weights of the metric
iou_weights = iou_metric.get_weights()

# Convert the weights to a numpy array
iou_values = np.array(iou_weights)

# Reshape the weights into a 2D array
iou_values = iou_values.reshape(num_classes, num_classes)

# Calculate the IoU for each class
for i in range(num_classes):
    class_iou = iou_values[i, i] / (iou_values[i, :].sum())
    print(f"IoU for class {i} is: {class_iou}")

''' M-994
Mean IoU = 0.96547204
IoU for class 0 is: 0.9972414970397949
IoU for class 1 is: 0.9886064529418945
IoU for class 2 is: 0.9583278894424438
'''
# import numpy as np
# from sklearn.metrics import confusion_matrix

# # Remove the extra dimension from y_test and flatten both arrays
# y_test_flat = y_test.squeeze().reshape(-1)
# y_pred_flat = y_pred_argmax.reshape(-1)

# # Convert to integer type if they're not already
# y_test_flat = y_test_flat.astype(int)
# y_pred_flat = y_pred_flat.astype(int)

# # Compute confusion matrix
# cm = confusion_matrix(y_test_flat, y_pred_flat)

# # Pixel Accuracy
# pixel_accuracy = np.sum(np.diag(cm)) / np.sum(cm)
# print("Pixel Accuracy =", pixel_accuracy)

# # Class-wise metrics
# n_classes = cm.shape[0]
# for i in range(n_classes):
#     tp = cm[i, i]
#     fp = np.sum(cm[:, i]) - tp
#     fn = np.sum(cm[i, :]) - tp
#     tn = np.sum(cm) - (tp + fp + fn)
    
#     precision = tp / (tp + fp) if (tp + fp) > 0 else 0
#     recall = tp / (tp + fn) if (tp + fn) > 0 else 0
#     f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
#     print(f"Class {i}:")
#     print(f"  Precision: {precision}")
#     print(f"  Recall: {recall}")
#     print(f"  F1-Score: {f1}")

# # Mean IoU (using previously calculated value)
# print("Mean IoU =", iou_metric.result().numpy())

# # Dice Coefficient (mean F1 score)
# dice = np.mean([2 * cm[i, i] / (np.sum(cm[i, :]) + np.sum(cm[:, i])) for i in range(n_classes)])
# print("Dice Coefficient (Mean F1) =", dice)

# # Cohen's Kappa
# n = np.sum(cm)
# sum_po = np.sum(np.diag(cm))
# sum_pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / n
# kappa = (sum_po - sum_pe) / (n - sum_pe)
# print("Cohen's Kappa =", kappa)



###########################################################  tf - more accurate
import tensorflow as tf
import numpy as np

# Assuming y_test and y_pred_argmax are your ground truth and predicted labels respectively
n_classes = 3  # Adjust this if you have a different number of classes

# Convert to appropriate tensor format if not already
y_test = tf.convert_to_tensor(y_test)
y_pred_argmax = tf.convert_to_tensor(y_pred_argmax)

# Remove the extra dimension from y_test if needed
y_test = tf.squeeze(y_test)

# Pixel Accuracy
accuracy = tf.keras.metrics.Accuracy()
accuracy.update_state(y_test, y_pred_argmax)
pixel_accuracy = accuracy.result().numpy()

# IoU
iou_metric = tf.keras.metrics.MeanIoU(num_classes=n_classes)
iou_metric.update_state(y_test, y_pred_argmax)
mean_iou = iou_metric.result().numpy()

# Get confusion matrix
iou_values = iou_metric.get_weights()[0]

# Function to calculate metrics for each class
def calculate_metrics(confusion_matrix, class_id):
    true_positives = confusion_matrix[class_id, class_id]
    false_positives = np.sum(confusion_matrix[:, class_id]) - true_positives
    false_negatives = np.sum(confusion_matrix[class_id, :]) - true_positives
    
    iou = true_positives / (true_positives + false_positives + false_negatives + 1e-10)
    precision = true_positives / (true_positives + false_positives + 1e-10)
    recall = true_positives / (true_positives + false_negatives + 1e-10)
    
    return iou, precision, recall

# Calculate metrics for each class
class_metrics = [calculate_metrics(iou_values, i) for i in range(n_classes)]

# Dice Coefficient (F1 Score)
def dice_coefficient(y_true, y_pred, class_index):
    y_true_f = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_f = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2 * intersection) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + 1e-7)

# Cohen's Kappa
def cohen_kappa(confusion_matrix):
    n = np.sum(confusion_matrix)
    sum_po = np.sum(np.diag(confusion_matrix))
    sum_pe = np.sum(np.sum(confusion_matrix, axis=0) * np.sum(confusion_matrix, axis=1)) / n
    po = sum_po / n
    pe = sum_pe / n
    kappa = (po - pe) / (1 - pe)
    return kappa

# Calculate Cohen's Kappa
kappa = cohen_kappa(iou_values)

# Print results
print(f"Pixel Accuracy: {pixel_accuracy}")

print(f"\nMean IoU: {mean_iou}")

print("\nIoU:")
for i, (iou, _, _) in enumerate(class_metrics):
    print(f"  Class {i}: {iou}")

print("\nPrecision:")
for i, (_, precision, _) in enumerate(class_metrics):
    print(f"  Class {i}: {precision}")

print("\nRecall:")
for i, (_, _, recall) in enumerate(class_metrics):
    print(f"  Class {i}: {recall}")

print("\nDice Coefficient (F1 Score):")
dice_scores = []
for i in range(n_classes):
    dice = dice_coefficient(y_test, y_pred_argmax, i)
    dice_scores.append(dice.numpy())
    print(f"  Class {i}: {dice.numpy()}")

print(f"\nMean Dice Coefficient (F1 Score): {np.mean(dice_scores)}")

print(f"\nCohen's Kappa: {kappa}")

################################################################################# claude

import tensorflow as tf
import numpy as np

# Assuming y_test and y_pred_argmax are your ground truth and predicted labels respectively
n_classes = 3  # Adjust this if you have a different number of classes

# Convert to appropriate tensor format if not already
y_test = tf.convert_to_tensor(y_test)
y_pred_argmax = tf.convert_to_tensor(y_pred_argmax)

# Remove the extra dimension from y_test if needed
y_test = tf.squeeze(y_test)

# Pixel Accuracy
accuracy = tf.keras.metrics.Accuracy()
accuracy.update_state(y_test, y_pred_argmax)
pixel_accuracy = accuracy.result().numpy()

# IoU and Dice Coefficient (F1 Score)
iou_metric = tf.keras.metrics.MeanIoU(num_classes=n_classes)
iou_metric.update_state(y_test, y_pred_argmax)
mean_iou = iou_metric.result().numpy()

# Get confusion matrix
iou_values = iou_metric.get_weights()[0]

# Function to calculate metrics for each class
def calculate_metrics(confusion_matrix, class_id):
    true_positives = confusion_matrix[class_id, class_id]
    false_positives = np.sum(confusion_matrix[:, class_id]) - true_positives
    false_negatives = np.sum(confusion_matrix[class_id, :]) - true_positives
    
    iou = true_positives / (true_positives + false_positives + false_negatives + 1e-10)
    precision = true_positives / (true_positives + false_positives + 1e-10)
    recall = true_positives / (true_positives + false_negatives + 1e-10)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
    
    return iou, precision, recall, f1

# Calculate metrics for each class
class_metrics = [calculate_metrics(iou_values, i) for i in range(n_classes)]

# Cohen's Kappa
def cohen_kappa(confusion_matrix):
    n = np.sum(confusion_matrix)
    sum_po = np.sum(np.diag(confusion_matrix))
    sum_pe = np.sum(np.sum(confusion_matrix, axis=0) * np.sum(confusion_matrix, axis=1)) / n
    po = sum_po / n
    pe = sum_pe / n
    kappa = (po - pe) / (1 - pe)
    return kappa

# Calculate Cohen's Kappa
kappa = cohen_kappa(iou_values)

# Print results in the requested format
print(f"Pixel Accuracy: {pixel_accuracy}")

print(f"\nMean IoU: {mean_iou}")

print("\nIoU:")
for i, (iou, _, _, _) in enumerate(class_metrics):
    print(f"  Class {i}: {iou}")

print("\nPrecision:")
for i, (_, precision, _, _) in enumerate(class_metrics):
    print(f"  Class {i}: {precision}")

print("\nRecall:")
for i, (_, _, recall, _) in enumerate(class_metrics):
    print(f"  Class {i}: {recall}")

print("\nF1 Score (Dice Coefficient):")
for i, (_, _, _, f1) in enumerate(class_metrics):
    print(f"  Class {i}: {f1}")

print(f"\nMean Dice Coefficient (F1 Score): {np.mean([m[3] for m in class_metrics])}")

print(f"\nCohen's Kappa: {kappa}")
#####################################################
import tensorflow as tf
from tensorflow.keras.metrics import MeanIoU, Precision, Recall, Accuracy
import numpy as np
from sklearn.metrics import cohen_kappa_score

# Parameters
num_test_images = 10
image_height = 256
image_width = 256
num_classes = 3

# Create random ground truth labels (ground_truth_labels)
ground_truth_labels = y_test

# Create random predicted labels (predicted_class_indices)
predicted_class_indices = y_pred_argmax

# Convert to appropriate tensor format and cast to int32
ground_truth_labels_tensor = tf.cast(tf.convert_to_tensor(ground_truth_labels), tf.int32)
predicted_class_indices_tensor = tf.cast(tf.convert_to_tensor(predicted_class_indices), tf.int32)

# Ensure predicted_class_indices has the same shape as ground_truth_labels
if len(predicted_class_indices_tensor.shape) < len(ground_truth_labels_tensor.shape):
    predicted_class_indices_tensor = tf.expand_dims(predicted_class_indices_tensor, axis=-1)

# Pixel Accuracy
pixel_accuracy_metric = Accuracy()
pixel_accuracy_metric.update_state(ground_truth_labels_tensor, predicted_class_indices_tensor)
pixel_accuracy = pixel_accuracy_metric.result().numpy()
print(f"\nPixel Accuracy: {pixel_accuracy}")

# IoU and Dice Coefficient (F1 Score)
iou_metric = MeanIoU(num_classes=num_classes)
iou_metric.update_state(ground_truth_labels_tensor, predicted_class_indices_tensor)
mean_iou = iou_metric.result().numpy()
print(f"\nMean IoU: {mean_iou}\n")

# Class-wise IoU
iou_values = iou_metric.get_weights()[0]
for class_index in range(num_classes):
    class_iou = iou_values[class_index, class_index] / (np.sum(iou_values[class_index, :]) + np.sum(iou_values[:, class_index]) - iou_values[class_index, class_index])
    print(f"IoU for class {class_index}: {class_iou}")

# Dice Coefficient (F1 Score)
dice_coefficient = 2 * mean_iou / (1 + mean_iou)
print(f"\nDice Coefficient (F1 Score): {dice_coefficient}\n")

# Class-wise Dice Coefficient
def calculate_dice_coefficient(ground_truth_labels, predicted_class_indices, class_index):
    ground_truth_labels_binary = tf.cast(tf.equal(ground_truth_labels, class_index), tf.float32)
    predicted_class_indices_binary = tf.cast(tf.equal(predicted_class_indices, class_index), tf.float32)
    intersection = tf.reduce_sum(ground_truth_labels_binary * predicted_class_indices_binary)
    return (2 * intersection) / (tf.reduce_sum(ground_truth_labels_binary) + tf.reduce_sum(predicted_class_indices_binary) + 1e-7)

for class_index in range(num_classes):
    dice_coefficient = calculate_dice_coefficient(ground_truth_labels_tensor, predicted_class_indices_tensor, class_index)
    print(f"Dice Coefficient for class {class_index}: {dice_coefficient.numpy()}")


# Class-wise Precision and Recall
class_precision_metrics = [Precision() for _ in range(num_classes)]
class_recall_metrics = [Recall() for _ in range(num_classes)]

for class_index in range(num_classes):
    class_mask = tf.equal(ground_truth_labels_tensor, class_index)
    class_predictions = tf.equal(predicted_class_indices_tensor, class_index)
    class_precision_metrics[class_index].update_state(class_mask, class_predictions)
    class_recall_metrics[class_index].update_state(class_mask, class_predictions)

print("\nPrecision values for each class:")
for class_index in range(num_classes):
    precision = class_precision_metrics[class_index].result().numpy()
    print(f"Class {class_index}: {precision}")

print("\nRecall values for each class:")
for class_index in range(num_classes):
    recall = class_recall_metrics[class_index].result().numpy()
    print(f"Class {class_index}: {recall}")
    
# Cohen's Kappa using scikit-learn
# Flatten the arrays to 1D
ground_truth_labels_flat = tf.reshape(ground_truth_labels_tensor, [-1]).numpy()
predicted_class_indices_flat = tf.reshape(predicted_class_indices_tensor, [-1]).numpy()
cohen_kappa_score_value = cohen_kappa_score(ground_truth_labels_flat, predicted_class_indices_flat)
print(f"\nCohen's Kappa: {cohen_kappa_score_value}")
################### tf
import tensorflow as tf
import numpy as np

# Assuming y_test and y_pred_argmax are your ground truth and predicted labels respectively
n_classes = 3  # Adjust this if you have a different number of classes

# Convert to appropriate tensor format if not already
y_test = tf.convert_to_tensor(y_test)
y_pred_argmax = tf.convert_to_tensor(y_pred_argmax)

# Remove the extra dimension from y_test if needed
y_test = tf.squeeze(y_test)

# Pixel Accuracy
accuracy = tf.keras.metrics.Accuracy()
accuracy.update_state(y_test, y_pred_argmax)
pixel_accuracy = accuracy.result().numpy()

# Mean IoU and Confusion Matrix
miou = tf.keras.metrics.MeanIoU(num_classes=n_classes)
miou.update_state(y_test, y_pred_argmax)
mean_iou = miou.result().numpy()
confusion_matrix = miou.total_cm

# Function to calculate IoU
def calculate_iou(y_true, y_pred, class_index):
    y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    intersection = tf.reduce_sum(y_true_class * y_pred_class)
    union = tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class) - intersection
    return intersection / (union + tf.keras.backend.epsilon())

# Function to calculate Dice Coefficient
def calculate_dice_coefficient(y_true, y_pred, class_index):
    y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    intersection = tf.reduce_sum(y_true_class * y_pred_class)
    return (2. * intersection) / (tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class) + tf.keras.backend.epsilon())

# Function to calculate Precision
def calculate_precision(y_true, y_pred, class_index):
    y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    true_positives = tf.reduce_sum(y_true_class * y_pred_class)
    predicted_positives = tf.reduce_sum(y_pred_class)
    return true_positives / (predicted_positives + tf.keras.backend.epsilon())

# Function to calculate Recall
def calculate_recall(y_true, y_pred, class_index):
    y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    true_positives = tf.reduce_sum(y_true_class * y_pred_class)
    actual_positives = tf.reduce_sum(y_true_class)
    return true_positives / (actual_positives + tf.keras.backend.epsilon())

# Calculate per-class metrics
per_class_metrics = []
for i in range(n_classes):
    iou = calculate_iou(y_test, y_pred_argmax, i)
    dice = calculate_dice_coefficient(y_test, y_pred_argmax, i)
    precision = calculate_precision(y_test, y_pred_argmax, i)
    recall = calculate_recall(y_test, y_pred_argmax, i)
    per_class_metrics.append((iou.numpy(), dice.numpy(), precision.numpy(), recall.numpy()))

# Cohen's Kappa
n = tf.cast(tf.reduce_sum(confusion_matrix), tf.float32)
sum_po = tf.cast(tf.linalg.trace(confusion_matrix), tf.float32)
sum_pe = tf.reduce_sum(tf.cast(tf.reduce_sum(confusion_matrix, axis=0) * tf.reduce_sum(confusion_matrix, axis=1), tf.float32)) / n
po = sum_po / n
pe = sum_pe / n
kappa = (po - pe) / (1 - pe + tf.keras.backend.epsilon())

# Print results
print(f"Pixel Accuracy: {pixel_accuracy}")
print(f"Mean IoU: {mean_iou}")

print("IoU:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[0]}")

print("Precision:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[2]}")

print("Recall:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[3]}")

print("F1 Score (Dice Coefficient):")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[1]}")

print(f"Mean Dice Coefficient (F1 Score): {np.mean([m[1] for m in per_class_metrics])}")
print(f"Cohen's Kappa: {kappa.numpy()}")

### optimise
import tensorflow as tf
import numpy as np

# Assuming y_test and y_pred_argmax are your ground truth and predicted labels respectively
n_classes = 3  # Adjust this if you have a different number of classes

# Convert to appropriate tensor format if not already
y_test = tf.convert_to_tensor(y_test)
y_pred_argmax = tf.convert_to_tensor(y_pred_argmax)

# Remove the extra dimension from y_test if needed
y_test = tf.squeeze(y_test)

# Pixel Accuracy
accuracy = tf.keras.metrics.Accuracy()
accuracy.update_state(y_test, y_pred_argmax)
pixel_accuracy = accuracy.result().numpy()

# Mean IoU and Confusion Matrix
miou = tf.keras.metrics.MeanIoU(num_classes=n_classes)
miou.update_state(y_test, y_pred_argmax)
mean_iou = miou.result().numpy()
confusion_matrix = miou.total_cm

# Function to calculate metrics
def calculate_metric(y_true, y_pred, class_index, metric):
    y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    intersection = tf.reduce_sum(y_true_class * y_pred_class)
    union = tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class) - intersection
    
    if metric == 'iou':
        return intersection / (union + tf.keras.backend.epsilon())
    elif metric == 'dice':
        return (2. * intersection) / (tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class) + tf.keras.backend.epsilon())
    elif metric == 'precision':
        return intersection / (tf.reduce_sum(y_pred_class) + tf.keras.backend.epsilon())
    elif metric == 'recall':
        return intersection / (tf.reduce_sum(y_true_class) + tf.keras.backend.epsilon())
    else:
        raise ValueError("Invalid metric. Choose from 'iou', 'dice', 'precision', 'recall'.")

# Calculate per-class metrics
per_class_metrics = []
for i in range(n_classes):
    iou = calculate_metric(y_test, y_pred_argmax, i, 'iou')
    dice = calculate_metric(y_test, y_pred_argmax, i, 'dice')
    precision = calculate_metric(y_test, y_pred_argmax, i, 'precision')
    recall = calculate_metric(y_test, y_pred_argmax, i, 'recall')
    per_class_metrics.append((iou.numpy(), dice.numpy(), precision.numpy(), recall.numpy()))

# Cohen's Kappa
n = tf.cast(tf.reduce_sum(confusion_matrix), tf.float32)
sum_po = tf.cast(tf.linalg.trace(confusion_matrix), tf.float32)
sum_pe = tf.reduce_sum(tf.cast(tf.reduce_sum(confusion_matrix, axis=0) * tf.reduce_sum(confusion_matrix, axis=1), tf.float32)) / n
po = sum_po / n
pe = sum_pe / n
kappa = (po - pe) / (1 - pe + tf.keras.backend.epsilon())

# Print results
print(f"Pixel Accuracy: {pixel_accuracy}")
print(f"Mean IoU: {mean_iou}")

print("IoU:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[0]}")

print("Precision:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[2]}")

print("Recall:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[3]}")

print("F1 Score (Dice Coefficient):")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[1]}")

print(f"Mean Dice Coefficient (F1 Score): {np.mean([m[1] for m in per_class_metrics])}")
print(f"Cohen's Kappa: {kappa.numpy()}")

### further optimization
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix as sk_confusion_matrix
from sklearn.metrics import roc_curve, auc

# Assuming y_test and y_pred_argmax are your ground truth and predicted labels respectively
n_classes = 3  # Adjust this if you have a different number of classes

# Input validation
def validate_input(y_test, y_pred_argmax):
    if not isinstance(y_test, tf.Tensor) or not isinstance(y_pred_argmax, tf.Tensor):
        raise ValueError("Input tensors must be of type tf.Tensor")
    if y_test.shape != y_pred_argmax.shape:
        raise ValueError("Input tensors must have the same shape")
    # if len(y_test.shape) != 1:
    #     raise ValueError("Input tensors must be 1D")

# Convert to appropriate tensor format if not already
y_test_tensor = tf.convert_to_tensor(y_test)
y_pred_argmax_tensor = tf.convert_to_tensor(y_pred_argmax)

# Remove the extra dimension from y_test if needed
y_test_tensor = tf.squeeze(y_test_tensor)

# Validate input
validate_input(y_test_tensor, y_pred_argmax_tensor)

# Pixel Accuracy
accuracy = tf.keras.metrics.Accuracy()
accuracy.update_state(y_test_tensor, y_pred_argmax_tensor)
pixel_accuracy = accuracy.result().numpy()

# Mean IoU and Confusion Matrix
miou = tf.keras.metrics.MeanIoU(num_classes=n_classes)
miou.update_state(y_test_tensor, y_pred_argmax_tensor)
mean_iou = miou.result().numpy()
confusion_matrix = miou.total_cm

# Function to calculate metrics
def calculate_metric(y_true, y_pred, class_index, metric):
    y_true_class = tf.cast(tf.equal(y_true, class_index), tf.float32)
    y_pred_class = tf.cast(tf.equal(y_pred, class_index), tf.float32)
    intersection = tf.reduce_sum(y_true_class * y_pred_class)
    union = tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class) - intersection
    
    if metric == 'iou':
        # Handle edge case where union is zero
        return tf.cond(tf.equal(union, 0), lambda: 0.0, lambda: intersection / union)
    elif metric == 'dice':
        # Handle edge case where sum of true and predicted is zero
        return tf.cond(tf.equal(tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class), 0), lambda: 0.0, lambda: (2. * intersection) / (tf.reduce_sum(y_true_class) + tf.reduce_sum(y_pred_class)))
    elif metric == 'precision':
        # Handle edge case where predicted is zero
        return tf.cond(tf.equal(tf.reduce_sum(y_pred_class), 0), lambda: 0.0, lambda: intersection / tf.reduce_sum(y_pred_class))
    elif metric == 'recall':
        # Handle edge case where true is zero
        return tf.cond(tf.equal(tf.reduce_sum(y_true_class), 0), lambda: 0.0, lambda: intersection / tf.reduce_sum(y_true_class))
    else:
        raise ValueError("Invalid metric. Choose from 'iou', 'dice', 'precision', 'recall'.")

# Calculate per-class metrics
per_class_metrics = []
for i in range(n_classes):
    iou = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'iou')
    dice = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'dice')
    precision = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'precision')
    recall = calculate_metric(y_test_tensor, y_pred_argmax_tensor, i, 'recall')
    per_class_metrics.append((iou.numpy(), dice.numpy(), precision.numpy(), recall.numpy()))

# Cohen's Kappa
n = tf.cast(tf.reduce_sum(confusion_matrix), tf.float32)
sum_po = tf.cast(tf.linalg.trace(confusion_matrix), tf.float32)
sum_pe = tf.reduce_sum(tf.cast(tf.reduce_sum(confusion_matrix, axis=0) * tf.reduce_sum(confusion_matrix, axis=1), tf.float32)) / n
po = sum_po / n
pe = sum_pe / n
kappa = (po - pe) / (1 - pe + tf.keras.backend.epsilon())

# Print results
print(f"Pixel Accuracy: {pixel_accuracy}")
print(f"Mean IoU: {mean_iou}")

print("IoU:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[0]}")

print("Precision:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[2]}")

print("Recall:")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[3]}")

print("F1 Score (Dice Coefficient):")
for i, metrics in enumerate(per_class_metrics):
    print(f"  Class {i}: {metrics[1]}")

print(f"Mean Dice Coefficient (F1 Score): {np.mean([m[1] for m in per_class_metrics])}")
print(f"Cohen's Kappa: {kappa.numpy()}")

# Visualization
# import matplotlib.pyplot as plt
# import seaborn as sns
# from sklearn.metrics import precision_recall_curve, average_precision_score

# # Assuming y_test and y_pred_argmax are your ground truth and predicted labels respectively
# # Also assuming you have access to y_pred_prob, which are the probabilities for each class

# # 1. Confusion Matrix Heatmap
# plt.figure(figsize=(10, 8))
# sns.heatmap(confusion_matrix.numpy(), annot=True, fmt='g', cmap='Blues')
# plt.title('Confusion Matrix')
# plt.ylabel('True label')
# plt.xlabel('Predicted label')
# plt.show()

# # 2. Per-class IoU Bar Chart
# class_ious = [metrics[0] for metrics in per_class_metrics]
# plt.figure(figsize=(10, 6))
# plt.bar(range(n_classes), class_ious)
# plt.title('Per-class IoU')
# plt.xlabel('Class')
# plt.ylabel('IoU')
# plt.xticks(range(n_classes))
# plt.show()

# # 3. Sample Prediction Visualization
# def visualize_prediction(image, true_mask, pred_mask, num_classes):
#     fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
    
#     ax1.imshow(image)
#     ax1.set_title('Original Image')
#     ax1.axis('off')
    
#     ax2.imshow(true_mask, cmap='viridis', vmin=0, vmax=num_classes-1)
#     ax2.set_title('True Mask')
#     ax2.axis('off')
    
#     ax3.imshow(pred_mask, cmap='viridis', vmin=0, vmax=num_classes-1)
#     ax3.set_title('Predicted Mask')
#     ax3.axis('off')
    
#     plt.show()

# # Assuming you have a sample image, true mask, and predicted mask
# # visualize_prediction(sample_image, sample_true_mask, sample_pred_mask, n_classes)

# # 4. Precision-Recall Curve (adapted for multi-class)
# plt.figure(figsize=(10, 8))

# for i in range(n_classes):
#     y_true = (y_test == i).astype(int)
#     y_score = y_pred_prob[:, i]
    
#     precision, recall, _ = precision_recall_curve(y_true, y_score)
#     average_precision = average_precision_score(y_true, y_score)
    
#     plt.plot(recall, precision, lw=2, label=f'Class {i} (AP = {average_precision:.2f})')

# plt.xlabel('Recall')
# plt.ylabel('Precision')
# plt.title('Precision-Recall Curve for Each Class')
# plt.legend(loc='best')
# plt.show()

# # 5. Learning Curves (if you have access to training history)
# # Assuming you have training and validation losses stored
# plt.figure(figsize=(10, 6))
# plt.plot(train_losses, label='Training Loss')
# plt.plot(val_losses, label='Validation Loss')
# plt.title('Learning Curves')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.legend()
# plt.show()


import seaborn as sns

# Plot confusion matrix
conf_matrix = confusion_matrix.numpy()
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt='g', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()

# Extract IoU and Dice Coefficient for each class
class_indices = list(range(n_classes))
iou_values = [m[0] for m in per_class_metrics]
dice_values = [m[1] for m in per_class_metrics]

# Plot IoU
plt.figure(figsize=(10, 5))
plt.bar(class_indices, iou_values, color='skyblue')
plt.xlabel('Class Index')
plt.ylabel('IoU')
plt.title('Class-wise IoU')
plt.show()

# Plot Dice Coefficient
plt.figure(figsize=(10, 5))
plt.bar(class_indices, dice_values, color='lightgreen')
plt.xlabel('Class Index')
plt.ylabel('Dice Coefficient')
plt.title('Class-wise Dice Coefficient')
plt.show()

# Function to plot individual sample predictions
def plot_individual_sample_prediction(image, true_mask, pred_mask, index):
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.title('Testing Image')
    plt.imshow(image[:,:,0], cmap='gray')
    
    plt.subplot(1, 3, 2)
    plt.title('Testing Label')
    plt.imshow(true_mask[:,:,0], cmap='jet')
    
    plt.subplot(1, 3, 3)
    plt.title('Prediction on test image')
    plt.imshow(pred_mask, cmap='jet')
    
    plt.suptitle(f'Sample {index}')
    plt.show()

# Generating a few sample predictions
n_samples = 50  # Number of samples to display
sample_indices = random.sample(range(len(X_test1)), n_samples)

for idx, sample_idx in enumerate(sample_indices):
    test_img = X_test1[sample_idx]
    ground_truth = y_test[sample_idx]
    
    test_img_input = np.expand_dims(test_img, 0)
    prediction = model.predict(test_img_input)
    predicted_img = np.argmax(prediction, axis=3)[0,:,:]
    
    # Plotting the individual sample prediction
    plot_individual_sample_prediction(test_img, ground_truth, predicted_img, idx + 1)


###
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import tensorflow as tf

# Assuming y_test and model1 are available
# Convert y_test to one-hot encoding
n_classes = 3  # Number of classes
y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=n_classes)

# Get the predicted probabilities from the model
y_pred_probs = model.predict(X_test1)  # Predicted probabilities

# Flatten the arrays for ROC calculation
y_test_onehot_flat = y_test_onehot.reshape(-1, n_classes)
y_pred_probs_flat = y_pred_probs.reshape(-1, n_classes)

# Calculate ROC curve and AUC for each class
fpr = {}
tpr = {}
roc_auc = {}

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test_onehot_flat[:, i], y_pred_probs_flat[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
    print(f"Class {i} - FPR: {fpr[i]}")
    print(f"Class {i} - TPR: {tpr[i]}")
    print(f"Class {i} - AUC: {roc_auc[i]:.2f}\n")

# Plotting the ROC curves
plt.figure(figsize=(10, 8))
for i in range(n_classes):
    plt.plot(fpr[i], tpr[i], label=f'Class {i} (area = {roc_auc[i]:.2f})')
    
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.show()

############################################
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import segmentation_models as sm

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
import segmentation_models as sm
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
from tensorflow.keras.utils import Sequence
import datetime

# Ensure that segmentation models use keras
sm.set_framework('tf.keras')

# Load the saved model
model_path = 'E:/LIVER_TUMOR_DETECTION/models/MODELS/model_checkpoint_10_0.69.h5'  # Update with the correct path to your saved model
model = tf.keras.models.load_model(model_path, compile=False)

# Compile the model with the same settings used during training
n_classes = 3
activation = 'softmax'
LR = 0.0001
optim = tf.keras.optimizers.Adam(LR)

dice_loss = sm.losses.DiceLoss(class_weights=np.array([0.1, 0.4, 0.5]))
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1 * focal_loss)

metrics = ['accuracy', sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

model.compile(optimizer=optim, loss=total_loss, metrics=metrics)

# Preprocess the test data (assuming X_test1 and y_test_cat have been defined as in your training code)
BACKBONE = 'resnet101'
preprocess_input = sm.get_preprocessing(BACKBONE)

# Load the test data again if not already loaded
batch_directory = 'E:/LIVER_TUMOR_DETECTION/NUMPY/'
image_batch_filename = 'images.npy'
mask_batch_filename = 'masks.npy'
image_batch_path = os.path.join(batch_directory, image_batch_filename)
mask_batch_path = os.path.join(batch_directory, mask_batch_filename)
X_test1 = np.load(image_batch_path)
y_test = np.load(mask_batch_path)

# Preprocess the test data
X_test1 = preprocess_input(X_test1)

# Convert y_test to categorical
test_masks_cat = to_categorical(y_test, num_classes=n_classes)
y_test_cat = test_masks_cat.reshape((y_test.shape[0], y_test.shape[1], y_test.shape[2], n_classes))

# Function to plot results
def plot_results(model, x_test, y_test):
    import random
    test_img_number = random.randint(0, len(x_test) - 1)
    test_img = x_test[test_img_number]
    ground_truth = y_test[test_img_number]
    test_img_input = np.expand_dims(test_img, 0)
    prediction = model.predict(test_img_input)
    predicted_img = np.argmax(prediction, axis=3)[0, :, :]

    plt.figure(figsize=(12, 8))
    plt.subplot(131)
    plt.title('Testing Image')
    plt.imshow(test_img[:, :, 0], cmap='gray')
    plt.subplot(132)
    plt.title('Testing Label')
    plt.imshow(np.argmax(ground_truth, axis=-1), cmap='jet')
    plt.subplot(133)
    plt.title('Prediction on test image')
    plt.imshow(predicted_img, cmap='jet')
    plt.show()

# Call the plot_results function
plot_results(model, X_test1, y_test_cat)