import shutil
import matplotlib.pyplot as plt
import scipy
import numpy as np
import os
import typing

PLOT_MANIFEST = {
    "PD": {
        "default": { # "default" gjelder for alle
            "fields": ["p", "e"], # Målte verdier vi ønsker å plotte, dersom vi ønsker subplots lag en liste av lister
            "title": { # Titelen kan enten være en streng, eller variabler autofunnet fra init filen
                "eig_1": "$\\lambda_1 = §$", # "§" erstatets med verdien
                "eig_2": "$\\lambda_2 = §$",
            },
            "ylabel": ["pitch [rad]", "elevation rate [rad/s]"] # Aksetitler for y-aksen, én per subplot
        },
    },
    "LQR": {
        "default": {
            "fields": [
                ["rp", "p"],
                ["red", "ed"],
            ],
            "title": {
                "Q": "Q = §",
            },
            "ylabel": ["pitch [rad]", "elevation rate [rad/s]"]
        },
    },
    "LQRI": {
        "default": {
            "fields": [
                ["rp", "p"],
                ["red", "ed"],
            ],
            "title": {
                "Q": "$Q = §$"
            },
            "ylabel": ["pitch [rad]", "elevation rate [rad/s]"]
        },
        "2024-09-26 09-41-51": {
            "title": "Eksempel på oveskrevet plot",
            "fields": [
                ["rp", "p"],
                ["pd"],
                ["red", "ed", "e"],
            ],
            "figsize": (6.4, 8),
            "ylabel": ["pitch [rad]", "pitch [rad/s]", "elevation [rad] / elevation rate [rad/s]"]
        }
    },
    "LQR": {
        "default": {
            "fields": [
                ["rp", "p"],
                ["red", "ed"],
            ],
            "title": {
                "Q": "Q = §",
            },
            "ylabel": ["pitch [rad]", "elevation rate [rad/s]"]
        },
    },
    "Kalman": {
        "default": {
            "fields": [
                ["p_hat", "p", "rp"],
                ["ed_hat", "ed", "red"],
            ],
            "title": {
                "a": lambda value: f"$Q_d = 10^{{{np.log10(float(value)):.0f}}} \\cdot I$",
            },
            "start_time": 1,
            "end_time": 11,
            "ylabel": ["pitch [rad]", "elevation rate [rad/s]"],
            "figsize": (6, 6),
            "color_offset": 1,
        }
        # "default": {
        #     "fields": [
        #         ["IMU_p", "p_hat_kalman", "p"],
        #         # ["IMU_euler_pd", "pd_hat_kalman", "pd"],
        #         # ["IMU_e", "e_hat_kalman", "e"],
        #         ["IMU_euler_ed", "ed_hat_kalman", "ed"],    
        #         # ["td_hat_kalman"],
        #     ],
        #     "title": {
        #         "a": lambda value: f"$Q_d = 10^{{{np.log10(float(value)):.0f}}} \\cdot I$",
        #     },
        #     "ylabel": [
        #         "pitch [rad]", 
        #         # "pitch rate [rad/s]", 
        #         # "elevation [rad]", 
        #         "elevation rate [rad/s]", 
        #         # "travel rate [rad/s]"
        #     ],
        #     "start_time": 1,
        #     "end_time": 16,
        #     "figsize": (5, 6),
        # },
        # "default": {
        #     "fields": [
        #         ["rp", "p_hat_kalman", "p"],
        #         ["ed", "ed_hat_kalman", "ed"],
        #     ],
        #     "title": {
        #         "a": lambda value: f"$Q_d = 10^{{{np.log10(float(value)):.0f}}} \\cdot I$",
        #     },
        #     "ylabel": ["pitch [rad]", "elevation rate [rad/s]"],
        #     "start_time": 1,
        #     "end_time": 16,
        #     "figsize": (8, 12),
        # },
        # "default": {
        #     "fields": [
        #         "P_hat_diag0",
        #         "P_hat_diag1",
        #         "P_hat_diag2",
        #         "P_hat_diag3",
        #         "P_hat_diag4",
        #     ],
        #     "start_time": 10,
        #     "end_time": 25,
        #     "ylabel": "Variance [$rad^2$]/[$(rad/s)^2$]",
        #     "title": {
        #         "a": lambda value: f"Values on the Diagonal of $P$ for a Flight with Data Disruption",
        #     },
        # }
    },
}

AUTOSORT_FIELDS_BY_VARIANCE = False
CLEAR_OLD_PLOTS = True
FILE_TYPE = "svg"
LEGEND_POS = None

# Colmaps: https://matplotlib.org/stable/users/explain/colors/colormaps.html#grayscale-conversion
# COLMAP = "viridis"
COLMAP = "rainbow"

# Files
DATA_DIR = "./data"
FIGURE_DIR = "./figures"

# Delta time
DT = 0.002

DEFAULT_FIELDS = ["p", "e"]

# List of fields
DATA_FIELDS = {
    # Elevation
    "e":                "Encoder Elevation",
    "e_hat":            "Estimated Elevation", # Usually equal to either luenberger or kalman
    "e_hat_luenberger": "Luenberger Elevation",
    "e_hat_kalman":     "Kalman Elevation",
    "IMU_e":            "IMU Elevation",
    
    # Elevation rate
    "ed":                "Encoder Elevation Rate",
    "ed_hat":            "Estimated Elevation Rate",
    "ed_hat_luenberger": "Luenberger Elevation Rate",
    "ed_hat_kalman":     "Kalman Elevation Rate",
    "IMU_euler_ed":      "IMU Elevation Rate",

    # Pitch
    "p":                "Encoder Pitch",
    "p_hat":            "Estimated Pitch",
    "p_hat_luenberger": "Luenberger Pitch",
    "p_hat_kalman":     "Kalman Pitch",
    "IMU_p":            "IMU Pitch",

    # Pitch rate
    "pd":                "Encoder Pitch Rate",
    "pd_hat":            "Estimated Pitch Rate",
    "pd_hat_luenberger": "Luenberger Pitch Rate",
    "pd_hat_kalman":     "Kalman Pitch Rate",
    "IMU_euler_pd":      "IMU Pitch Rate",

    # Travel rate
    "td_hat":            "Estimated Travel Rate",
    "IMU_euler_td":      "IMU Travel Rate",
    "td_hat_luenberger": "Luenberger Travel Rate",
    "td_hat_kalman":     "Kalman Travel Rate",

    # References
    "red": "Reference Elevation Rate",
    "re":  "Reference Elevation",
    "rp":  "Reference Pitch",

    # Motor voltages
    "vd":  "Motor Voltages Difference",
    "vs":  "Motor Voltages Sum",
    
    # Kalman covariance matrix
    "P_hat_diag": "Diagonal of $P$",
    "P_hat_diag0": "$P_{1,1}$ (Overlapped by $P_{3,3}$)",
    "P_hat_diag1": "$P_{2,2}$ (Overlapped by $P_{4,4}$)",
    "P_hat_diag2": "$P_{3,3}$",
    "P_hat_diag3": "$P_{4,4}$",
    "P_hat_diag4": "$P_{5,5}$",
}

# Utility function to check if there exists any .mat file in the specified directory that is not included in the DATA_FIELDS list
# Unused as of now
# def check_forgotten_fields():
#     import os
#     print()
#     # dir = "/Users/theodor/Documents/NTNU/5-h2024/linsys/lab4/3_luenberger_observer_tests/2024-10-17 13-17-47"
#     dir = "./lab4/4_kalman_filter_tests/2024-11-07 10-15-45"
#     files = os.listdir(dir)
#     any_forgotten = False
#     for file in files:
#         if file.endswith(".mat"):
#             # Get filename
#             filename = file.split(".")[0]
#             # print(filename)

#             # If not in any field, print filename
#             if not filename in DATA_FIELDS:
#                 any_forgotten = True
#                 print("Field forgotten:", filename)

#     if not any_forgotten:
#         print("All fields are included")
#     print()

# Subjects
subject_types = typing.Literal[
    "PD",
    "LQR",
    "LQRI",
    "Luenberger",
    "Kalman",
]
subject_dict = {
    "PD":         "1_pitch_regulator_tests_2",
    "LQR":        "2_lqr_regulator_tests",
    "LQRI":       "2_lqri_regulator_tests",
    "Luenberger": "3_luenberger_observer_tests",
    "Kalman":     "4_kalman_filter_tests",
}

# # Assign a colour to each field
# colours = plt.get_cmap('viridis', len(DATA_FIELDS))
# i = 0
# for key, val in DATA_FIELDS.items():
#     DATA_FIELDS[key] = (val, colours(i)) # type: ignore
#     i += 1

def plot_data(
    subject: subject_types, 
    test: str, 
    fields: list[str] | None = None, 
    start_time: float = 0, 
    end_time: float | None = None, 
    title: str | None = None,
    xlabel: str = "time [s]",
    ylabel: str = "DENNE MÅ ENDRES",
    figsize: tuple[float, float] = (6.4, 4.8),
    color_offset: int = 0,
):
    if fields is None:
        fields = DEFAULT_FIELDS

    if len(fields) > 0 and not isinstance(fields[0], list):
        fields = [fields]

    if title is None:
        title = test

    if fields is None:
        fields = DEFAULT_FIELDS

    # Get all .mat files from the specified directory
    subject_dir = subject_dict[subject]
    dir = f"{DATA_DIR}/{subject_dir}/{test}"
    files = os.listdir(dir)

    # Load data into dictionary
    data = {}
    
    for file in files:
        if file.endswith(".mat"):
            scipy.io.loadmat(f"{dir}/{file}", data)

    data_extension = {}
    for field_key, field in data.items():
        if not np.shape(field) or np.shape(field)[-1] == 1:
            continue

        for i in range(np.shape(field)[-1]):
            data_extension[f"{field_key}{i}"] = field[..., i]

    data = {**data, **data_extension}

    subplot_count = len(fields)

    plt.subplots(subplot_count, 1, figsize=figsize, sharex=True)

    for i, fields in enumerate(fields):
        plt.subplot(subplot_count, 1, i+1)

        for field in fields:
            if field not in data.keys():
                print(f"Skipping test: {test} since field '{field}' not found")
                return

        # Remove all fields that are not in the files
        # fields = [field for field in fields if f"{field}.mat" in files]
        # Filter the DATA_FIELDS dict to only include the fields that are in the files
        fields = [(field, DATA_FIELDS[field]) for field in fields]
        # fields = {key: val for key, val in DATA_FIELDS.items() if key in fields}

        # Generate colours
        colours = plt.get_cmap(COLMAP, len(fields)+1+color_offset)
        for j in range(len(fields)):
            key  = fields[j][0]
            name = fields[j][1]
            col  = colours(j + color_offset)
            fields[j] = (key, (name, col)) # type: ignore
        
        # Read readme file
        # readme = ""
        # if ("README" in files):
        #     with open(f"{dir}/README", "r") as file:
        #         readme = file.read()
        #         # print("\nREADME:\n\r=======\n\r")
        #         # print(readme)
        #         # print()
        
        # Sort fields by variance
        if AUTOSORT_FIELDS_BY_VARIANCE:
            fields = sorted(fields, key=lambda pair: np.var(data[pair[0]].flatten()), reverse=True)
        # print(fields)

        start_index = int(np.round(start_time / DT))
        end_index = int(np.round(end_time / DT)) if end_time is not None else -1

        for field in fields:
            # print(field)
            field_key = field[0]
            field_name = field[1][0]
            field_colour = field[1][1]
            # if field_key in data: # Should always be true
            y = data[field_key]
            length = len(y)
            time = np.linspace(0, length*DT, length)
    
            plt.plot(time[start_index:end_index], y[start_index:end_index], label=field_name, color=field_colour)

        plt.grid(alpha=0.3)

        # Sort legends. We change the plotting order to preserve visibility, so we also want the legends to be the same across all plots
        # https://stackoverflow.com/a/46160465
        handles, labels = plt.gca().get_legend_handles_labels()
        labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
        plt.legend(handles, labels, loc=LEGEND_POS) # bbox_to_anchor=(0, -0.1),
        
        if isinstance(xlabel, list):
            plt.xlabel(xlabel[i])
        plt.ylabel(ylabel[i] if isinstance(ylabel, list) else ylabel)

    if isinstance(xlabel, str):
        plt.xlabel(xlabel)

    if isinstance(title, dict):
        # Set plot title as value of Q matrix
        titles = []

        with open(f"{dir}/init_heli_3_10.m", "r", encoding="latin-1") as file:
            for line in file:
                for symbol, display_str in title.items():
                    search_str = f"{symbol} ="

                    if search_str in line:
                        value = line.rsplit(search_str, 1)[-1]
                        value = value.strip().strip(";")
                        if isinstance(display_str, str):
                            disp = display_str.replace("§", value)
                        else:
                            disp = display_str(value)
                        titles.append(disp)
        
        title = ", ".join(titles)

    plt.suptitle(title, math_fontfamily='dejavuserif')

    plt.tight_layout()

    # Set plot title as README content if available
    # import re
    # readme_content = readme.strip().replace('\n', '.')
    # readme_content = re.sub('\.\.+', ". ", readme_content) # Replace multiple dots with one
    # plt.title(readme_content)

    folder_path = os.path.join(FIGURE_DIR, subject_dir)
    image_path = os.path.join(folder_path, f"{test}_{title.replace("\\", "").replace("$", "")}.{FILE_TYPE}")
    os.makedirs(folder_path, exist_ok=True)

    plt.savefig(image_path)

    plt.close()

if CLEAR_OLD_PLOTS:
    shutil.rmtree(FIGURE_DIR, ignore_errors=True)

for subject, subject_path in subject_dict.items():
    try:
        tests = os.listdir(os.path.join(DATA_DIR, subject_path))
    except FileNotFoundError:
        print(f"Skipping '{subject}' as data could not be found.")
        continue

    default_manifest = PLOT_MANIFEST.get(subject, {}).get("default", {})
    
    for test in tests:
        manifests = PLOT_MANIFEST.get(subject, {}).get(test, {})

        if not isinstance(manifests, list):
            manifests = [manifests]
        
        for manifest in manifests:
            manifest = {**default_manifest, **manifest} # Set unspecified values to default values

            plot_data(
                subject,
                test,
                **manifest
            )