import matplotlib.pyplot as plt
import scipy
import numpy as np
import os
import typing

# List of fields
data_fields = {
    # Elevation
    "e":                "Encoder Elevation",
    "e_hat":            "Estimated Elevation", # Usually equal to either luenberger or kalman
    "e_hat_luenberger": "Luenberger Estimated Elevation",
    "e_hat_kalman":     "Kalman Estimated Elevation",
    "IMU_e":            "IMU Elevation",
    
    # Elevation rate
    "ed":                "Encoder Elevation rate",
    "ed_hat":            "Estimated Elevation rate",
    "ed_hat_luenberger": "Luenberger Estimated Elevation rate",
    "ed_hat_kalman":     "Kalman Estimated Elevation rate",
    "IMU_euler_ed":      "IMU Euler Elevation rate",

    # Pitch
    "p":                "Encoder Pitch",
    "p_hat":            "Estimated Pitch",
    "p_hat_luenberger": "Luenberger Estimated Pitch",
    "p_hat_kalman":     "Kalman Estimated Pitch",
    "IMU_p":            "IMU Pitch",

    # Pitch rate
    "pd":                "Encoder Pitch rate",
    "pd_hat":            "Estimated Pitch rate",
    "pd_hat_luenberger": "Luenberger Estimated Pitch rate",
    "pd_hat_kalman":     "Kalman Estimated Pitch rate",
    "IMU_euler_pd":      "IMU Euler Pitch rate",

    # Travel rate
    "td_hat":            "Estimated Travel rate",
    "IMU_euler_td":      "IMU Euler Travel rate",
    "td_hat_luenberger": "Luenberger Estimated Travel rate",
    "td_hat_kalman":     "Kalman Estimated Travel rate",

    # References
    "red": "Reference elevation rate",
    "rp":  "Reference pitch",

    # Motor voltages
    "vd":  "Difference of motor voltages",
    "vs":  "Sum of motor voltages",
}

# Colmaps: https://matplotlib.org/stable/users/explain/colors/colormaps.html#grayscale-conversion
COLMAP = "viridis"
# COLMAP = "tab20c"

# Files
FOLDER = "/Users/theodor/Documents/NTNU/5-h2024/linsys/lab4"

# Utility function to check if there exists any .mat file in the specified directory that is not included in the data_fields list
# Unused as of now
def check_forgotten_fields():
    import os
    print()
    # dir = "/Users/theodor/Documents/NTNU/5-h2024/linsys/lab4/3_luenberger_observer_tests/2024-10-17 13-17-47"
    dir = "/Users/theodor/Documents/NTNU/5-h2024/linsys/lab4/4_kalman_filter_tests/2024-11-07 10-15-45"
    files = os.listdir(dir)
    any_forgotten = False
    for file in files:
        if file.endswith(".mat"):
            # Get filename
            filename = file.split(".")[0]
            # print(filename)

            # If not in any field, print filename
            if not filename in data_fields:
                any_forgotten = True
                print("Field forgotten:", filename)

    if not any_forgotten:
        print("All fields are included")
    print()

# Subjects
subject_types = typing.Literal[
    "LQR",
    "LQRI",
    "Luenberger",
    "Kalman",
]
subject_dict = {
    "LQR":        "2_lqr_regulator_tests",
    "LQRI":       "2_lqri_regulator_tests",
    "Luenberger": "3_luenberger_observer_tests",
    "Kalman":     "4_kalman_filter_tests",
}

# # Assign a colour to each field
# colours = plt.get_cmap('viridis', len(data_fields))
# i = 0
# for key, val in data_fields.items():
#     data_fields[key] = (val, colours(i)) # type: ignore
#     i += 1

def plot_data(fields, subject: subject_types, test):
    global data_fields

    DT = 0.002

    # Get all .mat files from the specified directory
    subject_dir = subject_dict[subject]
    dir = f"{FOLDER}/{subject_dir}/{test}"
    files = os.listdir(dir)

    # Load data into dictionary
    data = {}
    for field in fields:
        filename = f"{field}.mat"
        if filename in files:
            scipy.io.loadmat(f"{dir}/{filename}", data)
        else:
            # print(f"Field '{field}' not found in subject: {test}")
            print(f"Skipping test: {test} since field '{field}' not found")
            return

    # Remove all fields that are not in the files
    fields = [field for field in fields if f"{field}.mat" in files]

    # Filter the data_fields dict to only include the fields that are in the files
    fields = [(key, val) for key, val in data_fields.items() if key in fields]
    # fields = {key: val for key, val in data_fields.items() if key in fields}

    # Generate colours
    N = len(fields)
    colours = plt.get_cmap(COLMAP, N)
    for i in range(len(fields)):
        key  = fields[i][0]
        name = fields[i][1]
        col  = colours(i)
        fields[i] = (key, (name, col)) # type: ignore
    
    # Read readme file
    readme = ""
    if ("README" in files):
        with open(f"{dir}/README", "r") as file:
            readme = file.read()
            # print("\nREADME:\n\r=======\n\r")
            # print(readme)
            # print()
    
    # Sort fields by variance
    fields = sorted(fields, key=lambda pair: np.var(data[pair[0]].flatten()), reverse=True)
    print(fields)

    # Plot data
    plt.figure()
    time = None
    for field in fields:
        field_key = field[0]
        field_name = field[1][0]
        field_colour = field[1][1]
        if field_key in data: # Should always be true
            y = data[field_key].flatten()
            if time is None:
                length = len(y)
                time = np.arange(0, length*DT, DT)
            
            plt.plot(time, y, label=field_name, color=field_colour)

    # Set plot title as value of Q matrix
    a = None
    with open(f"{dir}/init_heli_3_10.m", "r", encoding="latin-1") as file:
        for line in file:
            if "a =" in line:
                a = float(line.split("=")[1].strip().rstrip(';'))
                plt.title(f"Q = {a}")
                break
    if a is None:
        plt.title("Q = ?")

    # Set plot title as README content if available
    # import re
    # readme_content = readme.strip().replace('\n', '.')
    # readme_content = re.sub('\.\.+', ". ", readme_content) # Replace multiple dots with one
    # plt.title(readme_content)

    # Sort legends. We change the plotting order to preserve visibility, so we also want the legends to be the same across all plots
    # https://stackoverflow.com/a/46160465
    handles, labels = plt.gca().get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
    plt.legend(handles, labels, loc='upper left')

    plt.grid(alpha=0.3)
    # plt.show() # Commented since we want to show multiple plot windows, not just one. Call plt.show() after all plots are made

tests = [
    "2024-10-24 08-59-33",
    "2024-10-24 09-16-40",
    "2024-10-24 12-36-08",
    "2024-10-24 12-41-15",
    "2024-10-24 12-45-25",
    "2024-10-24 12-45-43",
    "2024-11-07 09-39-37",
    "2024-11-07 09-43-34",
    "2024-11-07 09-49-43",
    "2024-11-07 09-53-26",
    "2024-11-07 10-04-11",
    "2024-11-07 10-15-45"
]
for test in tests:
    plot_data(
        [
            # "e", "e_hat_luenberger", "e_hat_kalman", "IMU_e",
            # "ed", "ed_hat_luenberger", "ed_hat_kalman", "IMU_euler_ed",
            # "p", "p_hat_luenberger", "p_hat_kalman", "IMU_p",
            "pd", "pd_hat_luenberger", "pd_hat_kalman", "IMU_euler_pd",
            # "td_hat", "IMU_euler_td",
            "red",
            # "rp",
            # "vd", "vs",
        ],
        "Kalman",
        test
    )

plt.show()
