import os
import cv2
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Model
from keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger, ReduceLROnPlateau
from keras.optimizers import Adam
from keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import segmentation_models as sm
import tensorflow as tf
from tensorflow.keras.metrics import MeanIoU
import gc
import random
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Set TensorFlow configuration
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

# Define constants
DATA_SAVE_DIR = r"D:\PROTOS\LIVER\numpy"
MODEL_SAVE_DIR = r"D:\PROTOS\LIVER\models"
LOG_FILE_PATH = os.path.join(MODEL_SAVE_DIR, 'training_logs.csv')
BATCH_SIZE = 16
EPOCHS = 1000
VALIDATION_SPLIT = 0.1  # Increased for more robust validation
TEST_SPLIT = 0.1
RANDOM_STATE = 42
LEARNING_RATE = 0.0001
BACKBONE = 'resnet101'
N_CLASSES = 3
ACTIVATION = 'softmax'

def load_nifti(file_path: str) -> np.ndarray:
    """
    Load NIfTI file.

    Args:
        file_path (str): Path to the NIfTI file.

    Returns:
        np.ndarray: Loaded NIfTI data.
    """
    img = nib.load(file_path)
    img_data = img.get_fdata()
    return img_data

def preprocess_data(image: np.ndarray, label: np.ndarray, target_height: int, target_width: int) -> tuple:
    """
    Preprocess image and label data.

    Args:
        image (np.ndarray): Image data.
        label (np.ndarray): Label data.
        target_height (int): Target height.
        target_width (int): Target width.

    Returns:
        np.ndarray: Preprocessed image data.
        np.ndarray: Preprocessed label data.
    """
    # Resize the image and label slices
    image = resize_slices(image, target_height, target_width)
    label = resize_slices(label, target_height, target_width)

    # Normalize the image
    image = normalize_intensity(image)

    # Replicate the single channel to create three channels
    image = np.repeat(image[..., np.newaxis], 3, axis=-1)

    return image, label

def resize_slices(data: np.ndarray, target_height: int, target_width: int) -> np.ndarray:
    """
    Resize slices of data.

    Args:
        data (np.ndarray): Data to resize.
        target_height (int): Target height.
        target_width (int): Target width.

    Returns:
        np.ndarray: Resized data.
    """
    resized_data = [cv2.resize(slice, (target_width, target_height)) for slice in data]
    return np.array(resized_data)

def normalize_intensity(data: np.ndarray) -> np.ndarray:
    """
    Normalize intensity of data.

    Args:
        data (np.ndarray): Data to normalize.

    Returns:
        np.ndarray: Normalized data.
    """
    return data / np.max(data)

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, x_set: np.ndarray, y_set: np.ndarray, batch_size: int):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self) -> int:
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx: int) -> tuple:
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        return batch_x, batch_y

def train_model(model: Model, train_gen: DataGenerator, val_gen: DataGenerator, epochs: int, callbacks: list) -> tf.keras.callbacks.History:
    """
    Train model.

    Args:
        model (Model): Model to train.
        train_gen (DataGenerator): Training data generator.
        val_gen (DataGenerator): Validation data generator.
        epochs (int): Number of epochs.
        callbacks (list): List of callbacks.

    Returns:
        tf.keras.callbacks.History: Training history.
    """
    try:
        history = model.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks, steps_per_epoch=len(train_gen), validation_steps=len(val_gen))
    except tf.errors.ResourceExhaustedError:
        print("ResourceExhaustedError occurred. Trying with reduced batch size...")
        for bs in [8, 2, 1]:
            try:
                train_gen = DataGenerator(x_train, y_train, bs)
                history = model.fit(train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks, steps_per_epoch=len(train_gen), validation_steps=len(val_gen))
                break
            except tf.errors.ResourceExhaustedError:
                continue
        else:
            print("Unable to train even with reduced batch size. Exiting.")
            history = None
    return history

def evaluate_model(model: Model, x_test: np.ndarray, y_test: np.ndarray) -> tuple:
    """
    Evaluate model.

    Args:
        model (Model): Model to evaluate.
        x_test (np.ndarray): Test input data.
        y_test (np.ndarray): Test label data.

    Returns:
        float: Test loss.
        float: Test accuracy.
    """
    metrics = model.evaluate(x_test, to_categorical(y_test, num_classes=N_CLASSES))
    test_loss = metrics[0]
    test_accuracy = metrics[1]
    return test_loss, test_accuracy

def calculate_iou(model: Model, x_test: np.ndarray, y_test: np.ndarray) -> np.ndarray:
    """
    Calculate IoU for each class.

    Args:
        model (Model): Model to evaluate.
        x_test (np.ndarray): Test input data.
        y_test (np.ndarray): Test label data.

    Returns:
        np.ndarray: IoU for each class.
    """
    y_pred = np.argmax(model.predict(x_test), axis=3)
    y_true = y_test[:, :, :, 0]
    iou_scores = []
    for i in range(N_CLASSES):
        iou_keras = MeanIoU(num_classes=N_CLASSES)
        iou_keras.update_state(y_true == i, y_pred == i)
        iou_scores.append(iou_keras.result().numpy())
    return np.array(iou_scores)

def plot_results(model: Model, x_test: np.ndarray, y_test: np.ndarray) -> None:
    """
    Plot results.

    Args:
        model (Model): Model to evaluate.
        x_test (np.ndarray): Test input data.
        y_test (np.ndarray): Test label data.
    """
    test_img_number = random.randint(0, len(x_test))
    test_img = x_test[test_img_number]
    ground_truth = y_test[test_img_number]
    test_img_input = np.expand_dims(test_img, 0)
    prediction = model.predict(test_img_input)
    predicted_img = np.argmax(prediction, axis=3)[0, :, :]

    plt.figure(figsize=(12, 8))
    plt.subplot(231)
    plt.title('Testing Image')
    plt.imshow(test_img[:, :, 0], cmap='gray')
    plt.subplot(232)
    plt.title('Testing Label')
    plt.imshow(ground_truth[:, :, 0], cmap='jet')
    plt.subplot(233)
    plt.title('Prediction on test image')
    plt.imshow(predicted_img, cmap='jet')
    plt.show()

# Load data
X_data = np.load(os.path.join(DATA_SAVE_DIR, 'X_data.npy'))
y_data = np.load(os.path.join(DATA_SAVE_DIR, 'y_data.npy'))

# Split data
x_train_val, x_test, y_train_val, y_test = train_test_split(X_data, y_data, test_size=TEST_SPLIT, random_state=RANDOM_STATE)
x_train, x_val, y_train, y_val = train_test_split(x_train_val, y_train_val, test_size=VALIDATION_SPLIT, random_state=RANDOM_STATE)

# Delete data variables to save memory
del X_data, y_data, x_train_val, y_train_val
gc.collect()

# Create data generators
train_gen = DataGenerator(x_train, y_train, BATCH_SIZE)
val_gen = DataGenerator(x_val, y_val, BATCH_SIZE)

# Define model
model = sm.Unet(BACKBONE, encoder_weights=None, input_shape=(None, None, x_train.shape[-1]), classes=N_CLASSES, activation=ACTIVATION)

# Compile model
model.compile(Adam(LEARNING_RATE), loss=sm.losses.DiceLoss(class_weights=np.array([0.1, 0.4, 0.5])) + sm.losses.CategoricalFocalLoss(), metrics=['accuracy', sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)])

# Define callbacks
callbacks_list = [
    ModelCheckpoint(os.path.join(MODEL_SAVE_DIR, 'model_checkpoint_{epoch:02d}_{val_loss:.2f}.h5'), monitor='val_loss', verbose=1, save_best_only=True, mode='min', save_weights_only=False),
    EarlyStopping(monitor='val_loss', patience=100, verbose=1),
    CSVLogger(LOG_FILE_PATH, separator=',', append=False),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=50, verbose=1, min_lr=1e-6)
]

# Train model
history = train_model(model, train_gen, val_gen, EPOCHS, callbacks_list)

# Evaluate model
if history is not None:
    test_loss, test_accuracy = evaluate_model(model, x_test, y_test)
    print(f'Test Loss: {test_loss:.4f}')
    print(f'Test Accuracy: {test_accuracy:.4f}')

    # Calculate IoU
    iou = calculate_iou(model, x_test, y_test)
    # Print IoU scores
    for i, iou in enumerate(iou):
        print(f'IoU for class {i}: {iou:.4f}')

    # Plot results
    plot_results(model, x_test, y_test)

# Clear session
tf.keras.backend.clear_session()

# from vertebrae code
import random
test_img_number = random.randint(0, len(x_test))
test_img = x_test[test_img_number]
ground_truth=y_test[test_img_number]
#test_img_norm=test_img[:,:,0][:,:,None]
#test_img_norm.reshape(256,256,3)
test_img_input=np.expand_dims(test_img, 0)
print(test_img_input.shape)
prediction = (model.predict(test_img_input))
predicted_img=np.argmax(prediction, axis=3)[0,:,:]

plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.title('Testing Image')
plt.imshow(test_img[:,:,0], cmap='gray')
plt.subplot(232)
plt.title('Testing Label')
plt.imshow(ground_truth[:,:,0], cmap='jet')
plt.subplot(233)
plt.title('Prediction on test image')
plt.imshow(predicted_img, cmap='jet')
plt.show()

# train performance test
import random
test_img_number = random.randint(0, len(x_train))
test_img = x_train[test_img_number]
ground_truth=y_train[test_img_number]
#test_img_norm=test_img[:,:,0][:,:,None]
#test_img_norm.reshape(256,256,3)
test_img_input=np.expand_dims(test_img, 0)
print(test_img_input.shape)
prediction = (model.predict(test_img_input))
predicted_img=np.argmax(prediction, axis=3)[0,:,:]

plt.figure(figsize=(12, 8))
plt.subplot(231)
plt.title('Testing Image')
plt.imshow(test_img[:,:,0], cmap='gray')
plt.subplot(232)
plt.title('Testing Label')
plt.imshow(ground_truth[:,:,0], cmap='jet')
plt.subplot(233)
plt.title('Prediction on test image')
plt.imshow(predicted_img, cmap='jet')
plt.show()