# -*- coding: utf-8 -*-
"""
Created on Mon Aug  5 17:22:37 2024

@author: Aus
"""

import os
import json
import nibabel as nib
import numpy as np

def read_nii_file(path):
    """
    Read a NIfTI file and return the image object and numpy array.
    """
    img = nib.load(path)
    img_data = img.get_fdata()
    return img, img_data

def get_orientation(img):
    """
    Get the orientation of the NIfTI image.
    """
    return list(nib.aff2axcodes(img.affine))

def print_nii_info(img, img_data, label=False):
    """
    Print detailed information about the NIfTI image or label.
    """
    data_type = "Label" if label else "Image"
    header = img.header

    # Commenting out prints for general information
    # print(f"{data_type} File path: {img.get_filename()}")
    # print(f"{data_type} Shape: {img_data.shape}")
    # print(f"{data_type} Affine: \n{img.affine}")
    # print(f"{data_type} Dimension (dim): {header['dim']}")
    # print(f"{data_type} Pixel Dimensions (pixdim): {header['pixdim']}")

    orientation = get_orientation(img)
    # print(f"{data_type} Orientation: {orientation}")

    if label:
        unique_values = np.unique(img_data)
        min_value = np.min(img_data)
        max_value = np.max(img_data)
        num_slices = img_data.shape[2]  # Assuming the third dimension is the number of slices
        slices_with_masks = np.sum(np.any(img_data > 0, axis=(0, 1)))

        # Commenting out prints for label specific information
        # print(f"{data_type} Unique values: {unique_values}")
        # print(f"{data_type} Min value: {min_value}")
        # print(f"{data_type} Max value: {max_value}")
        # print(f"{data_type} Number of slices: {num_slices}")
        # print(f"{data_type} Number of slices with masks: {slices_with_masks}")

    # print("\n")

def save_orientation_to_json(orientation, json_path):
    """
    Save the orientation to a JSON file.
    """
    with open(json_path, 'w') as json_file:
        json.dump({'orientation': orientation}, json_file)

def load_orientation_from_json(json_path):
    """
    Load the orientation from a JSON file.
    """
    with open(json_path, 'r') as json_file:
        data = json.load(json_file)
        return data['orientation']

def process_data(images_dir, labels_dir, json_path):
    """
    Process all NIfTI files in the images and labels directories.
    """
    image_files = sorted([f for f in os.listdir(images_dir) if f.endswith(".nii") or f.endswith(".nii.gz")])
    label_files = sorted([f for f in os.listdir(labels_dir) if f.endswith(".nii") or f.endswith(".nii.gz")])

    first_image = True
    saved_orientation = None
    all_same_orientation = True
    total_slices_with_masks = 0
    total_files = len(image_files)

    for idx, (image_file, label_file) in enumerate(zip(image_files, label_files)):
        image_path = os.path.join(images_dir, image_file)
        label_path = os.path.join(labels_dir, label_file)

        print(f"Processing {image_file} and {label_file}")

        img, img_data = read_nii_file(image_path)
        lbl, lbl_data = read_nii_file(label_path)

        image_orientation = get_orientation(img)
        label_orientation = get_orientation(lbl)

        # Save the orientation of the first image to a JSON file
        if first_image:
            save_orientation_to_json(image_orientation, json_path)
            saved_orientation = image_orientation
            first_image = False
            print("Orientation saved to JSON")
        else:
            saved_orientation = load_orientation_from_json(json_path)

        print_nii_info(img, img_data)
        print_nii_info(lbl, lbl_data, label=True)

        # Check and compare orientations
        if image_orientation != saved_orientation:
            print(f"Orientation error: Image file {image_file} has a different orientation: {image_orientation}")
            all_same_orientation = False
        if label_orientation != saved_orientation:
            print(f"Orientation error: Label file {label_file} has a different orientation: {label_orientation}")
            all_same_orientation = False
        if image_orientation == saved_orientation and label_orientation == saved_orientation:
            print("Matching orientation")

        # Count the total number of slices with masks
        slices_with_masks = np.sum(np.any(lbl_data > 0, axis=(0, 1)))
        total_slices_with_masks += slices_with_masks

        remaining_files = total_files - (idx + 1)
        print(f"{remaining_files} files remaining")
        print("-" * 30)

    if all_same_orientation:
        print("All images and labels have the same orientation.")
    
    print(f"Total number of slices with masks: {total_slices_with_masks}")


# Paths to the directories and JSON file
labels_dir = r"F:/Task03_Liver/labelsTr"
images_dir = r"F:/Task03_Liver/imagesTr"
json_path = r"F:/Task03_Liver/orientation.json"

print("Processing images and labels...")
process_data(images_dir, labels_dir, json_path)




###############################################################################################
#testing
#load a nifti file and change it's orientation
import nibabel as nib
import numpy as np

def reorient_nifti_file(input_path, output_path, target_orientation):
    """
    Reorient a NIfTI file to the target orientation and save the result.

    Parameters:
    input_path (str): Path to the input NIfTI file.
    output_path (str): Path to save the reoriented NIfTI file.
    target_orientation (tuple): Target orientation as a tuple, e.g., ('R', 'S', 'A').
    """
    # Load the NIfTI file
    img = nib.load(input_path)
    
    # Get the current orientation
    current_orientation = nib.aff2axcodes(img.affine)
    print(f"Current orientation: {current_orientation}")
    
    # Define the transform from current to target orientation
    transform = nib.orientations.ornt_transform(
        nib.orientations.axcodes2ornt(current_orientation),
        nib.orientations.axcodes2ornt(target_orientation)
    )
    
    # Apply the transformation
    img_data = img.get_fdata()
    new_data = nib.orientations.apply_orientation(img_data, transform)
    
    # Create a new NIfTI image with the new orientation
    new_affine = img.affine @ nib.orientations.inv_ornt_aff(transform, img.shape)
    new_img = nib.Nifti1Image(new_data, new_affine, img.header)
    
    # Save the reoriented image
    nib.save(new_img, output_path)
    print(f"Reoriented image saved to {output_path}")
    
    # Print the orientation of the reoriented file
    reoriented_orientation = nib.aff2axcodes(new_img.affine)
    print(f"Reoriented orientation: {reoriented_orientation}")

# Example usage
input_path = 'E:/LIVER_TUMOR_DETECTION/IMAGES/liver_126.nii'
output_path = 'E:/LIVER_TUMOR_DETECTION/IMAGES/liver_126_reoriented.nii'
target_orientation = ('R', 'S', 'A')  # Desired orientation

reorient_nifti_file(input_path, output_path, target_orientation)
