

import os
import cv2
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, Conv2D, UpSampling2D, concatenate
from keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from keras.optimizers import Adam
from keras.applications import ResNet50
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Paths to the directories
labels_dir = r"F:/Task03_Liver/labelsTr"
images_dir = r"F:/Task03_Liver/imagesTr"
model_save_dir = r"F:/Task03_Liver/models"

# Function to load nii files from a directory
def load_nii_files(directory):
    nii_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.nii')]
    return nii_files

# Load the nii files
label_files = load_nii_files(labels_dir)
image_files = load_nii_files(images_dir)

# Function to load a nii file
def load_nii(file_path):
    img = nib.load(file_path)
    img_data = img.get_fdata()
    return img_data

# Function to plot the image and corresponding label
def plot_image_and_label(image, label, slice_index):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    # Plot the image
    axes[0].imshow(image[:, :, slice_index], cmap='gray')
    axes[0].set_title('Image')
    axes[0].axis('off')
    
    # Plot the label
    axes[1].imshow(label[:, :, slice_index], cmap='gray')
    axes[1].set_title('Label')
    axes[1].axis('off')
    
    plt.show()

# Define target shape
target_height, target_width = 128, 128

# Function to resize image slices
def resize_slices(image, target_height, target_width):
    original_height, original_width, num_slices = image.shape
    resized_image = np.zeros((target_height, target_width, num_slices), dtype=image.dtype)
    for z in range(num_slices):
        resized_image[:, :, z] = cv2.resize(image[:, :, z], (target_width, target_height), interpolation=cv2.INTER_LINEAR)
    return resized_image

# Function to normalize intensity values
def normalize_intensity(image):
    image = (image - np.min(image)) / (np.max(image) - np.min(image))  # Rescale to [0, 1]
    image = image * 2 - 1  # Rescale to [-1, 1]
    return image

# Function to preprocess image and label
def preprocess_data(image, label, target_height, target_width):
    # Resize the image and label slices
    image = resize_slices(image, target_height, target_width)
    label = resize_slices(label, target_height, target_width)
    
    # Normalize the image
    image = normalize_intensity(image)
    
    # Replicate the single channel to create three channels
    image = np.repeat(image[..., np.newaxis], 3, axis=-1)
    
    return image, label

# ResNet-based U-Net model definition
def resnet_unet_model(input_size):
    inputs = Input(input_size)
    
    # Use ResNet50 as the encoder
    resnet_encoder = ResNet50(weights='imagenet', include_top=False, input_tensor=inputs)
    
    # Decoder part of the U-Net
    x = resnet_encoder.output
    x = UpSampling2D(size=(2, 2))(x)
    x = concatenate([x, resnet_encoder.get_layer('conv4_block6_2_relu').output], axis=-1)
    x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D(size=(2, 2))(x)
    x = concatenate([x, resnet_encoder.get_layer('conv3_block4_2_relu').output], axis=-1)
    x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D(size=(2, 2))(x)
    x = concatenate([x, resnet_encoder.get_layer('conv2_block3_2_relu').output], axis=-1)
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D(size=(2, 2))(x)
    x = concatenate([x, resnet_encoder.get_layer('conv1_relu').output], axis=-1)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D(size=(2, 2))(x)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(x)
    
    model = Model(inputs=[inputs], outputs=[outputs])
    
    model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

# Custom callback to print epoch number and save status
class CustomLoggingCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f'Epoch {epoch + 1} finished.')
        print(f"Model saved: {os.path.join(model_save_dir, f'model_epoch_{epoch + 1:02d}.h5')}")

# Example usage
if len(image_files) > 0 and len(label_files) > 0:
    # Prepare training and validation data
    X_data = []
    y_data = []
    
    for img_file, lbl_file in zip(image_files, label_files):
        print(f"Processing image file: {img_file}")
        print(f"Processing label file: {lbl_file}")
        
        image_data = load_nii(img_file)
        print(f"Loaded image data: {img_file}")
        
        label_data = load_nii(lbl_file)
        print(f"Loaded label data: {lbl_file}")
        
        image_data, label_data = preprocess_data(image_data, label_data, target_height, target_width)
        print(f"Preprocessed image and label data: {img_file}, {lbl_file}")
        
        for z in range(image_data.shape[2]):
            X_data.append(image_data[:, :, z])
            y_data.append(label_data[:, :, z])
        print(f"Appended slices for: {img_file}, {lbl_file}")
    
    X_data = np.array(X_data).reshape(-1, target_height, target_width, 3)
    y_data = np.array(y_data).reshape(-1, target_height, target_width, 1)
    print("Converted lists to numpy arrays and reshaped them.")
    
    # Split data into training and validation sets
    from sklearn.model_selection import train_test_split
    X_train, X_val, y_train, y_val = train_test_split(X_data, y_data, test_size=0.2, random_state=42)
    print("Split data into training and validation sets.")
    
    # Create and compile the model
    model = resnet_unet_model(input_size=(target_height, target_width, 3))
    print("Created and compiled the model.")
    
    # Define callbacks
    checkpoint = ModelCheckpoint(os.path.join(model_save_dir, 'model_epoch_{epoch:02d}.h5'), 
                                 save_best_only=False, 
                                 save_weights_only=False, 
                                 save_freq='epoch')
    early_stopping = EarlyStopping(monitor='val_loss', patience=5)
    custom_logging = CustomLoggingCallback()
    print("Defined callbacks.")
    
    # Train the model
    history = model.fit(X_train, y_train, 
                        validation_data=(X_val, y_val), 
                        epochs=50, 
                        batch_size=1, 
                        callbacks=[checkpoint, early_stopping, custom_logging])
    print("Started training the model.")
else:
    print("No nii files found in the specified directories.")