import matplotlib.pyplot as plt
import scipy
import numpy as np
import os
import typing

PLOT_MANIFEST = {
    "LQR": {
        "2024-09-12 13-28-26": {
            "title": "A cool plot",
            "end_time": 10,
        }
    }
}

# Colmaps: https://matplotlib.org/stable/users/explain/colors/colormaps.html#grayscale-conversion
COLMAP = "viridis"
# COLMAP = "tab20c"

# Files
FOLDER = "./data"

# Delta time
DT = 0.002

# List of fields
DATA_FIELDS = {
    # Elevation
    "e":                "Encoder Elevation",
    "e_hat":            "Estimated Elevation", # Usually equal to either luenberger or kalman
    "e_hat_luenberger": "Luenberger Estimated Elevation",
    "e_hat_kalman":     "Kalman Estimated Elevation",
    "IMU_e":            "IMU Elevation",
    
    # Elevation rate
    "ed":                "Encoder Elevation Rate",
    "ed_hat":            "Estimated Elevation Rate",
    "ed_hat_luenberger": "Luenberger Estimated Elevation Rate",
    "ed_hat_kalman":     "Kalman Estimated Elevation Rate",
    "IMU_euler_ed":      "IMU Euler Elevation Rate",

    # Pitch
    "p":                "Encoder Pitch",
    "p_hat":            "Estimated Pitch",
    "p_hat_luenberger": "Luenberger Estimated Pitch",
    "p_hat_kalman":     "Kalman Estimated Pitch",
    "IMU_p":            "IMU Pitch",

    # Pitch rate
    "pd":                "Encoder Pitch Rate",
    "pd_hat":            "Estimated Pitch Rate",
    "pd_hat_luenberger": "Luenberger Estimated Pitch Rate",
    "pd_hat_kalman":     "Kalman Estimated Pitch Rate",
    "IMU_euler_pd":      "IMU Euler Pitch Rate",

    # Travel rate
    "td_hat":            "Estimated Travel Rate",
    "IMU_euler_td":      "IMU Euler Travel Rate",
    "td_hat_luenberger": "Luenberger Estimated Travel Rate",
    "td_hat_kalman":     "Kalman Estimated Travel Rate",

    # References
    "red": "Reference Elevation Rate",
    "rp":  "Reference Pitch",

    # Motor voltages
    "vd":  "Motor Voltages Difference",
    "vs":  "Motor Voltages Sum",
}

# Utility function to check if there exists any .mat file in the specified directory that is not included in the DATA_FIELDS list
# Unused as of now
# def check_forgotten_fields():
#     import os
#     print()
#     # dir = "/Users/theodor/Documents/NTNU/5-h2024/linsys/lab4/3_luenberger_observer_tests/2024-10-17 13-17-47"
#     dir = "./lab4/4_kalman_filter_tests/2024-11-07 10-15-45"
#     files = os.listdir(dir)
#     any_forgotten = False
#     for file in files:
#         if file.endswith(".mat"):
#             # Get filename
#             filename = file.split(".")[0]
#             # print(filename)

#             # If not in any field, print filename
#             if not filename in DATA_FIELDS:
#                 any_forgotten = True
#                 print("Field forgotten:", filename)

#     if not any_forgotten:
#         print("All fields are included")
#     print()

# Subjects
subject_types = typing.Literal[
    "PD",
    "LQR",
    "LQRI",
    "Luenberger",
    "Kalman",
]
subject_dict = {
    "PD":         "1_pitch_regulator_tests_2",
    "LQR":        "2_lqr_regulator_tests",
    "LQRI":       "2_lqri_regulator_tests",
    "Luenberger": "3_luenberger_observer_tests",
    "Kalman":     "4_kalman_filter_tests",
}

# # Assign a colour to each field
# colours = plt.get_cmap('viridis', len(DATA_FIELDS))
# i = 0
# for key, val in DATA_FIELDS.items():
#     DATA_FIELDS[key] = (val, colours(i)) # type: ignore
#     i += 1

def plot_data(fields: list[str], subject: subject_types, test: str, start_time: float = 0, end_time: float | None = None, title: str | None = None):
    if title is None:
        title = test

    # Get all .mat files from the specified directory
    subject_dir = subject_dict[subject]
    dir = f"{FOLDER}/{subject_dir}/{test}"
    files = os.listdir(dir)

    # Load data into dictionary
    data = {}
    for field in fields:
        filename = f"{field}.mat"
        if filename in files:
            scipy.io.loadmat(f"{dir}/{filename}", data)
        else:
            print(f"Skipping test: {test} since field '{field}' not found")
            return

    # Remove all fields that are not in the files
    fields = [field for field in fields if f"{field}.mat" in files]
    # Filter the DATA_FIELDS dict to only include the fields that are in the files
    fields = [(field, DATA_FIELDS[field]) for field in fields]
    # fields = {key: val for key, val in DATA_FIELDS.items() if key in fields}

    # Generate colours
    N = len(fields)
    colours = plt.get_cmap(COLMAP, N)
    for i in range(len(fields)):
        key  = fields[i][0]
        name = fields[i][1]
        col  = colours(i)
        fields[i] = (key, (name, col)) # type: ignore
    
    # Read readme file
    # readme = ""
    # if ("README" in files):
    #     with open(f"{dir}/README", "r") as file:
    #         readme = file.read()
    #         # print("\nREADME:\n\r=======\n\r")
    #         # print(readme)
    #         # print()
    
    # Sort fields by variance
    # fields = sorted(fields, key=lambda pair: np.var(data[pair[0]].flatten()), reverse=True)
    # print(fields)

    start_index = int(np.round(start_time / DT))
    end_index = int(np.round(end_time / DT)) if end_time is not None else -1

    # Plot data
    plt.figure()
    for field in fields:
        # print(field)
        field_key = field[0]
        field_name = field[1][0]
        field_colour = field[1][1]
        if field_key in data: # Should always be true
            y = data[field_key].flatten()
            length = len(y)
            time = np.linspace(0, length*DT, length)
            
            plt.plot(time[start_index:end_index], y[start_index:end_index], label=field_name, color=field_colour)

    # Set plot title as value of Q matrix
    # a = None
    # with open(f"{dir}/init_heli_3_10.m", "r", encoding="latin-1") as file:
    #     for line in file:
    #         if "Q =" in line:
    #             a = line.rsplit("=")[-1].strip().rstrip(';')
    #             title = f"r = {a}"
    #             break

    # if a is None:
    #     title = "?"

    plt.title(title)

    # Set plot title as README content if available
    # import re
    # readme_content = readme.strip().replace('\n', '.')
    # readme_content = re.sub('\.\.+', ". ", readme_content) # Replace multiple dots with one
    # plt.title(readme_content)

    # Sort legends. We change the plotting order to preserve visibility, so we also want the legends to be the same across all plots
    # https://stackoverflow.com/a/46160465
    handles, labels = plt.gca().get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
    plt.legend(handles, labels, loc='upper left')

    plt.grid(alpha=0.3)

    folder_path = os.path.join("./figures/", subject_dir)
    image_path = os.path.join(folder_path, f"{title}.svg")
    os.makedirs(folder_path, exist_ok=True)

    plt.savefig(image_path)

    plt.close()

for subject, subject_path in subject_dict.items():
    tests = os.listdir(os.path.join(FOLDER, subject_path))

    for test in tests:
        manifest = PLOT_MANIFEST.get(subject, {}).get(test, {})

        plot_data(
            [
                # "e", "e_hat_luenberger", "e_hat_kalman", "IMU_e",
                # "ed", "ed_hat_luenberger", "ed_hat_kalman", "IMU_euler_ed",
                # "p", "p_hat_luenberger", "p_hat_kalman", "IMU_p",
                # "pd", "pd_hat_luenberger", "pd_hat_kalman", "IMU_euler_pd",
                # "td_hat", "IMU_euler_td",
                "p",
                "rp",
                # "rp",
                # "vd", "vs",
            ],
            subject,
            test,
            **manifest
        )