import shutil
import matplotlib.pyplot as plt
import scipy
import numpy as np
import os
import typing

PLOT_MANIFEST = {
    "PD": {
        "default": { # "default" gjelder for alle
            "fields": ["p", "e"], # Målte verdier vi ønsker å plotte, dersom vi ønsker subplots lag en liste av lister
            "title": { # Titelen kan enten være en streng, eller variabler autofunnet fra init filen
                "eig_1": "$\\lambda_1 = §$",
                "eig_2": "$\\lambda_2 = §$",
            },
            "ylabel": ["pitch [rad]", "elevation rate [rad/s]"] # Aksetitler for y-aksen, én per subplot
        },
    },
    "LQR": {
        "default": {
            "fields": [
                ["rp", "p"],
                ["red", "ed"],
            ],
            "title": {
                "Q": "Q = §",
            },
            "ylabel": ["pitch [rad]", "elevation rate [rad/s]"]
        },
    },
    "LQRI": {
        "default": {
            "fields": [
                ["rp", "p"],
                ["red", "ed"],
            ],
            "title": {
                "Q": "Q = §",
            },
            "ylabel": ["pitch [rad]", "elevation rate [rad/s]"]
        },
    }
}

CLEAR_OLD_PLOTS = True

# Colmaps: https://matplotlib.org/stable/users/explain/colors/colormaps.html#grayscale-conversion
# COLMAP = "viridis"
COLMAP = "rainbow"

# Files
DATA_DIR = "./data"
FIGURE_DIR = "./figures"

# Delta time
DT = 0.002

DEFAULT_FIELDS = ["p", "e"]

# List of fields
DATA_FIELDS = {
    # Elevation
    "e":                "Encoder Elevation",
    "e_hat":            "Estimated Elevation", # Usually equal to either luenberger or kalman
    "e_hat_luenberger": "Luenberger Estimated Elevation",
    "e_hat_kalman":     "Kalman Estimated Elevation",
    "IMU_e":            "IMU Elevation",
    
    # Elevation rate
    "ed":                "Encoder Elevation Rate",
    "ed_hat":            "Estimated Elevation Rate",
    "ed_hat_luenberger": "Luenberger Estimated Elevation Rate",
    "ed_hat_kalman":     "Kalman Estimated Elevation Rate",
    "IMU_euler_ed":      "IMU Euler Elevation Rate",

    # Pitch
    "p":                "Encoder Pitch",
    "p_hat":            "Estimated Pitch",
    "p_hat_luenberger": "Luenberger Estimated Pitch",
    "p_hat_kalman":     "Kalman Estimated Pitch",
    "IMU_p":            "IMU Pitch",

    # Pitch rate
    "pd":                "Encoder Pitch Rate",
    "pd_hat":            "Estimated Pitch Rate",
    "pd_hat_luenberger": "Luenberger Estimated Pitch Rate",
    "pd_hat_kalman":     "Kalman Estimated Pitch Rate",
    "IMU_euler_pd":      "IMU Euler Pitch Rate",

    # Travel rate
    "td_hat":            "Estimated Travel Rate",
    "IMU_euler_td":      "IMU Euler Travel Rate",
    "td_hat_luenberger": "Luenberger Estimated Travel Rate",
    "td_hat_kalman":     "Kalman Estimated Travel Rate",

    # References
    "red": "Reference Elevation Rate",
    "re":  "Reference Elevation",
    "rp":  "Reference Pitch",

    # Motor voltages
    "vd":  "Motor Voltages Difference",
    "vs":  "Motor Voltages Sum",
}

# Utility function to check if there exists any .mat file in the specified directory that is not included in the DATA_FIELDS list
# Unused as of now
# def check_forgotten_fields():
#     import os
#     print()
#     # dir = "/Users/theodor/Documents/NTNU/5-h2024/linsys/lab4/3_luenberger_observer_tests/2024-10-17 13-17-47"
#     dir = "./lab4/4_kalman_filter_tests/2024-11-07 10-15-45"
#     files = os.listdir(dir)
#     any_forgotten = False
#     for file in files:
#         if file.endswith(".mat"):
#             # Get filename
#             filename = file.split(".")[0]
#             # print(filename)

#             # If not in any field, print filename
#             if not filename in DATA_FIELDS:
#                 any_forgotten = True
#                 print("Field forgotten:", filename)

#     if not any_forgotten:
#         print("All fields are included")
#     print()

# Subjects
subject_types = typing.Literal[
    "PD",
    "LQR",
    "LQRI",
    "Luenberger",
    "Kalman",
]
subject_dict = {
    "PD":         "1_pitch_regulator_tests_2",
    "LQR":        "2_lqr_regulator_tests",
    "LQRI":       "2_lqri_regulator_tests",
    "Luenberger": "3_luenberger_observer_tests",
    "Kalman":     "4_kalman_filter_tests",
}

# # Assign a colour to each field
# colours = plt.get_cmap('viridis', len(DATA_FIELDS))
# i = 0
# for key, val in DATA_FIELDS.items():
#     DATA_FIELDS[key] = (val, colours(i)) # type: ignore
#     i += 1

def plot_data(
    subject: subject_types, 
    test: str, 
    fields: list[str] | None = None, 
    start_time: float = 0, 
    end_time: float | None = None, 
    title: str | None = None,
    xlabel: str = "time[s]",
    ylabel: str = "DENNE MÅ ENDRES",
    figsize: tuple[float, float] = (6.4, 4.8)
):
    if fields is None:
        fields = DEFAULT_FIELDS

    if len(fields) > 0 and not isinstance(fields[0], list):
        fields = [fields]

    if title is None:
        title = test

    if fields is None:
        fields = DEFAULT_FIELDS

    # Get all .mat files from the specified directory
    subject_dir = subject_dict[subject]
    dir = f"{DATA_DIR}/{subject_dir}/{test}"
    files = os.listdir(dir)

    subplot_count = len(fields)

    plt.subplots(subplot_count, 1, figsize=figsize)

    for i, fields in enumerate(fields):
        plt.subplot(subplot_count, 1, i+1)

        # Load data into dictionary
        data = {}
        for field in fields:
            filename = f"{field}.mat"
            if filename in files:
                scipy.io.loadmat(f"{dir}/{filename}", data)
            else:
                print(f"Skipping test: {test} since field '{field}' not found")
                return

        # Remove all fields that are not in the files
        fields = [field for field in fields if f"{field}.mat" in files]
        # Filter the DATA_FIELDS dict to only include the fields that are in the files
        fields = [(field, DATA_FIELDS[field]) for field in fields]
        # fields = {key: val for key, val in DATA_FIELDS.items() if key in fields}

        # Generate colours
        colours = plt.get_cmap(COLMAP)
        for j in range(len(fields)):
            key  = fields[j][0]
            name = fields[j][1]
            col  = colours(int(list(DATA_FIELDS.keys()).index(key)/(len(list(DATA_FIELDS.keys())))*255))
            fields[j] = (key, (name, col)) # type: ignore
        
        # Read readme file
        # readme = ""
        # if ("README" in files):
        #     with open(f"{dir}/README", "r") as file:
        #         readme = file.read()
        #         # print("\nREADME:\n\r=======\n\r")
        #         # print(readme)
        #         # print()
        
        # Sort fields by variance
        # fields = sorted(fields, key=lambda pair: np.var(data[pair[0]].flatten()), reverse=True)
        # print(fields)

        start_index = int(np.round(start_time / DT))
        end_index = int(np.round(end_time / DT)) if end_time is not None else -1

        for field in fields:
            # print(field)
            field_key = field[0]
            field_name = field[1][0]
            field_colour = field[1][1]
            if field_key in data: # Should always be true
                y = data[field_key].flatten()
                length = len(y)
                time = np.linspace(0, length*DT, length)
                
                plt.plot(time[start_index:end_index], y[start_index:end_index], label=field_name, color=field_colour)

        plt.grid(alpha=0.3)
        
        # Sort legends. We change the plotting order to preserve visibility, so we also want the legends to be the same across all plots
        # https://stackoverflow.com/a/46160465
        handles, labels = plt.gca().get_legend_handles_labels()
        labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
        plt.legend(handles, labels, loc="upper left") # bbox_to_anchor=(0, -0.1),
        
        plt.xlabel(xlabel[i] if isinstance(xlabel, list) else xlabel)
        plt.ylabel(ylabel[i] if isinstance(ylabel, list) else ylabel)

    if isinstance(title, dict):
        # Set plot title as value of Q matrix
        titles = []

        with open(f"{dir}/init_heli_3_10.m", "r", encoding="latin-1") as file:
            for line in file:
                for symbol, display_str in title.items():
                    search_str = f"{symbol} ="

                    if search_str in line:
                        value = line.rsplit(search_str, 1)[-1]
                        value = value.strip().strip(";")
                        titles.append(display_str.replace("§", value))
        
        title = ", ".join(titles)

    plt.suptitle(title)

    # Set plot title as README content if available
    # import re
    # readme_content = readme.strip().replace('\n', '.')
    # readme_content = re.sub('\.\.+', ". ", readme_content) # Replace multiple dots with one
    # plt.title(readme_content)

    folder_path = os.path.join(FIGURE_DIR, subject_dir)
    image_path = os.path.join(folder_path, f"{title.replace("\\", "").replace("$", "")}.svg")
    os.makedirs(folder_path, exist_ok=True)

    plt.savefig(image_path)

    plt.close()

if CLEAR_OLD_PLOTS:
    shutil.rmtree(FIGURE_DIR, ignore_errors=True)

for subject, subject_path in subject_dict.items():
    tests = os.listdir(os.path.join(DATA_DIR, subject_path))

    for test in tests:
        default_manifest = PLOT_MANIFEST.get(subject, {}).get("default", {})
        manifests = PLOT_MANIFEST.get(subject, {}).get(test, default_manifest)

        if not isinstance(manifests, list):
            manifests = [manifests]
        
        for manifest in manifests:
            plot_data(
                subject,
                test,
                **manifest
            )