import os
import cv2
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, Conv2D, UpSampling2D, concatenate
from keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from keras.optimizers import Adam
from keras.applications import ResNet50
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Paths to the directories
labels_dir = r"F:/Task03_Liver/labelsTr"
images_dir = r"F:/Task03_Liver/imagesTr"
model_save_dir = r"D:\PROTOS\LIVER\models"
data_save_dir = r"D:\PROTOS\LIVER\numpy"

# Function to load nii files from a directory
def load_nii_files(directory):
    nii_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.nii')]
    return nii_files

# Load the nii files
label_files = load_nii_files(labels_dir)
image_files = load_nii_files(images_dir)

# Function to load a nii file
def load_nii(file_path):
    img = nib.load(file_path)
    img_data = img.get_fdata()
    return img_data

# Function to plot the image and corresponding label
def plot_image_and_label(image, label, slice_index):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    # Plot the image
    axes[0].imshow(image[:, :, slice_index], cmap='gray')
    axes[0].set_title('Image')
    axes[0].axis('off')
    
    # Plot the label
    axes[1].imshow(label[:, :, slice_index], cmap='gray')
    axes[1].set_title('Label')
    axes[1].axis('off')
    
    plt.show()

# Define target shape
target_height, target_width = 256, 256

# Function to resize image slices
def resize_slices(image, target_height, target_width):
    original_height, original_width, num_slices = image.shape
    resized_image = np.zeros((target_height, target_width, num_slices), dtype=image.dtype)
    for z in range(num_slices):
        resized_image[:, :, z] = cv2.resize(image[:, :, z], (target_width, target_height), interpolation=cv2.INTER_LINEAR)
    return resized_image

# Function to normalize intensity values
def normalize_intensity(image):
    image = (image - np.min(image)) / (np.max(image) - np.min(image))  # Rescale to [0, 1]
    image = image * 2 - 1  # Rescale to [-1, 1]
    return image

# Function to preprocess image and label
def preprocess_data(image, label, target_height, target_width):
    # Resize the image and label slices
    image = resize_slices(image, target_height, target_width)
    label = resize_slices(label, target_height, target_width)
    
    # Normalize the image
    image = normalize_intensity(image)
    
    # Replicate the single channel to create three channels
    image = np.repeat(image[..., np.newaxis], 3, axis=-1)
    
    return image, label

# ResNet-based U-Net model definition
def resnet_unet_model(input_size):
    inputs = Input(input_size)
    
    # Use ResNet50 as the encoder
    resnet_encoder = ResNet50(weights='imagenet', include_top=False, input_tensor=inputs)
    
    # Decoder part of the U-Net
    x = resnet_encoder.output
    x = UpSampling2D(size=(2, 2))(x)
    x = concatenate([x, resnet_encoder.get_layer('conv4_block6_2_relu').output], axis=-1)
    x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D(size=(2, 2))(x)
    x = concatenate([x, resnet_encoder.get_layer('conv3_block4_2_relu').output], axis=-1)
    x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D(size=(2, 2))(x)
    x = concatenate([x, resnet_encoder.get_layer('conv2_block3_2_relu').output], axis=-1)
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D(size=(2, 2))(x)
    x = concatenate([x, resnet_encoder.get_layer('conv1_relu').output], axis=-1)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D(size=(2, 2))(x)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(x)
    
    model = Model(inputs=[inputs], outputs=[outputs])
    
    model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

# Custom callback to print epoch number and save status
class CustomLoggingCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f'Epoch {epoch + 1} finished.')
        print(f"Model saved: {os.path.join(model_save_dir, f'model_epoch_{epoch + 1:02d}.h5')}")

# # Prepare training and validation data
# X_data = []
# y_data = []

# for img_file, lbl_file in zip(image_files, label_files):
#     print(f"Processing image file: {img_file}")
#     print(f"Processing label file: {lbl_file}")
    
#     image_data = load_nii(img_file)
#     print(f"Loaded image data: {img_file}")
    
#     label_data = load_nii(lbl_file)
#     print(f"Loaded label data: {lbl_file}")
    
#     image_data, label_data = preprocess_data(image_data, label_data, target_height, target_width)
#     print(f"Preprocessed image and label data: {img_file}, {lbl_file}")
    
#     for z in range(image_data.shape[2]):
#         if np.sum(label_data[:, :, z]) > 0:  # Check if the slice has a corresponding mask
#             X_data.append(image_data[:, :, z])
#             y_data.append(label_data[:, :, z])
#     print(f"Appended slices for: {img_file}, {lbl_file}")

# X_data = np.array(X_data).reshape(-1, target_height, target_width, 3)
# y_data = np.array(y_data).reshape(-1, target_height, target_width, 1)
# print("Converted lists to numpy arrays and reshaped them.")

# # Save preprocessed data as numpy arrays
# np.save(os.path.join(data_save_dir, 'X_data.npy'), X_data)
# np.save(os.path.join(data_save_dir, 'y_data.npy'), y_data)
# print("Saved preprocessed data as numpy arrays.")

# Load the NumPy arrays (for future use)
X_data = np.load(os.path.join(data_save_dir, 'X_data.npy'))
y_data = np.load(os.path.join(data_save_dir, 'y_data.npy'))
print("Loaded X_data and y_data from .npy files.")

# Split data into training and validation sets
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(X_data, y_data, test_size=0.2, random_state=42)
print("Split data into training and validation sets.")

# Create and compile the model
model = resnet_unet_model(input_size=(target_height, target_width, 3))
print("Created and compiled the model.")

# Define callbacks
checkpoint = ModelCheckpoint(os.path.join(model_save_dir, 'model_epoch_{epoch:02d}.h5'), 
                             save_best_only=False, 
                             save_weights_only=False, 
                             save_freq='epoch')
early_stopping = EarlyStopping(monitor='val_loss', patience=20)
custom_logging = CustomLoggingCallback()
print("Defined callbacks.")

# Train the model
history = model.fit(X_train, y_train, 
                    validation_data=(X_val, y_val), 
                    epochs=1000, 
                    batch_size=8, 
                    callbacks=[checkpoint, early_stopping, custom_logging])
print("Started training the model.")