#WORKING CODE 1EPOCH TAKES 1 HOUR
import os
import random
import cv2
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
from keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from keras.optimizers import Adam

# Paths to the directories
labels_dir = r"E:\Task03_Liver\labelsTr"
images_dir = r"E:\Task03_Liver\imagesTr"
model_save_dir = r"E:\Task03_Liver\models"

# Function to load nii files from a directory
def load_nii_files(directory):
    nii_files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.nii')]
    return nii_files

# Load the nii files
label_files = load_nii_files(labels_dir)
image_files = load_nii_files(images_dir)

# Function to load a nii file
def load_nii(file_path):
    img = nib.load(file_path)
    img_data = img.get_fdata()
    return img_data

# Function to plot the image and corresponding label
def plot_image_and_label(image, label, slice_index):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    # Plot the image
    axes[0].imshow(image[:, :, slice_index], cmap='gray')
    axes[0].set_title('Image')
    axes[0].axis('off')
    
    # Plot the label
    axes[1].imshow(label[:, :, slice_index], cmap='gray')
    axes[1].set_title('Label')
    axes[1].axis('off')
    
    plt.show()

# Define target shape
target_height, target_width = 128, 128

# Function to resize image slices
def resize_slices(image, target_height, target_width):
    original_height, original_width, num_slices = image.shape
    resized_image = np.zeros((target_height, target_width, num_slices), dtype=image.dtype)
    for z in range(num_slices):
        resized_image[:, :, z] = cv2.resize(image[:, :, z], (target_width, target_height), interpolation=cv2.INTER_LINEAR)
    return resized_image

# Function to normalize intensity values
def normalize_intensity(image):
    image = (image - np.min(image)) / (np.max(image) - np.min(image))  # Rescale to [0, 1]
    image = image * 2 - 1  # Rescale to [-1, 1]
    return image

# Function to preprocess image and label
def preprocess_data(image, label, target_height, target_width):
    # Resize the image and label slices
    image = resize_slices(image, target_height, target_width)
    label = resize_slices(label, target_height, target_width)
    
    # Normalize the image
    image = normalize_intensity(image)
    
    # Divide by 3071
    image = image / 3071.0
    
    return image, label

# U-Net model definition
def unet_model(input_size):
    inputs = Input(input_size)
    
    # Down-sampling path
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = Conv2D(512, 3, activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    # Bottleneck
    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(pool4)
    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(conv5)
    
    # Up-sampling path
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=-1)
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(up6)
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(conv6)
    
    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=-1)
    conv7 = Conv2D(256, 3, activation='relu', padding='same')(up7)
    conv7 = Conv2D(256, 3, activation='relu', padding='same')(conv7)
    
    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=-1)
    conv8 = Conv2D(128, 3, activation='relu', padding='same')(up8)
    conv8 = Conv2D(128, 3, activation='relu', padding='same')(conv8)
    
    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=-1)
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(up9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(conv9)
    
    conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)
    
    model = Model(inputs=[inputs], outputs=[conv10])
    
    model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=['accuracy'])
    
    return model

# Custom callback to print epoch number and save status
class CustomLoggingCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f'Epoch {epoch + 1} finished.')
        print(f"Model saved: {os.path.join(model_save_dir, f'model_epoch_{epoch + 1:02d}.h5')}")

# Example usage
if len(image_files) > 0 and len(label_files) > 0:
    # Prepare training data
    X_train = []
    y_train = []
    for img_file, lbl_file in zip(image_files, label_files):
        image_data = load_nii(img_file)
        label_data = load_nii(lbl_file)
        image_data, label_data = preprocess_data(image_data, label_data, target_height, target_width)
        
        for z in range(image_data.shape[2]):
            X_train.append(image_data[:, :, z])
            y_train.append(label_data[:, :, z])
    
    X_train = np.array(X_train).reshape(-1, target_height, target_width, 1)
    y_train = np.array(y_train).reshape(-1, target_height, target_width, 1)
    
    # Create and compile the model
    model = unet_model(input_size=(target_height, target_width, 1))
    
    # Define callbacks
    checkpoint = ModelCheckpoint(os.path.join(model_save_dir, 'model_epoch_{epoch:02d}.h5'), 
                                 save_best_only=False, 
                                 save_weights_only=False, 
                                 save_freq='epoch')
    early_stopping = EarlyStopping(monitor='val_loss', patience=5)
    custom_logging = CustomLoggingCallback()
    
    # Train the model
    history = model.fit(X_train, y_train, 
                        validation_split=0.2, 
                        epochs=50, 
                        batch_size=8, 
                        callbacks=[checkpoint, early_stopping, custom_logging])
else:
    print("No nii files found in the specified directories.")