#!/usr/bin/env python3

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""Magnetometer bias estimation node."""

import os
import time
import yaml

import numpy as np

import rospy
from std_msgs.msg import String
from sensor_msgs.msg import MagneticField
from std_srvs.srv import Trigger


class MagBiasObserver:
    def __init__(self):

        self.x_min = 10000
        self.y_min = 10000
        self.z_min = 10000

        self.x_max = -10000
        self.y_max = -10000
        self.z_max = -10000

        self.x_mean = None
        self.y_mean = None
        self.z_mean = None

        self.last_mag = None
        self.msg = MagneticField()

        self.frame_id = "imu"
        self.started = False

        self.measuring_time = rospy.Duration(rospy.get_param('~measuring_time', 30))
        self.finish_measuring_time = rospy.Time(0)
        self.two_d_mode = bool(rospy.get_param('~2d_mode', True))  # only calibrate in a plane
        # If None, the axis will be autodetected; otherwise, specify either "X", "Y" or "Z"
        self.ignore_axis = rospy.get_param('~2d_mode_ignore_axis', None)

        load_from_params = rospy.get_param("~load_from_params", False)
        load_from_params = load_from_params and rospy.has_param("magnetometer_bias_x")
        load_from_params = load_from_params and rospy.has_param("magnetometer_bias_y")
        load_from_params = load_from_params and rospy.has_param("magnetometer_bias_z")

        default_file = os.path.join(os.environ.get("HOME", ""), ".ros", "magnetometer_calib.yaml")
        self.calibration_file = rospy.get_param("~calibration_file_path", default_file)

        load_from_file = rospy.get_param("~load_from_file", True)
        load_from_file = load_from_file and len(self.calibration_file) > 0 and os.path.exists(self.calibration_file)

        self.save_to_file = rospy.get_param("~save_to_file", True)

        self.pub = rospy.Publisher('imu/mag_bias', MagneticField, latch=True, queue_size=1)
        self.speak_pub = rospy.Publisher('speak/warn', String, queue_size=1)

        self.sub = None  # will be initialized in subscribe()

        self.calibrate_server = rospy.Service("calibrate_magnetometer", Trigger, self.calibrate_service_cb)

        if load_from_file:
            with open(self.calibration_file, 'r') as f:
                data = yaml.safe_load(f)
            self.x_mean = float(data.get("magnetometer_bias_x", 0.0))
            self.y_mean = float(data.get("magnetometer_bias_y", 0.0))
            self.z_mean = float(data.get("magnetometer_bias_z", 0.0))
            time.sleep(0.01)  # give the publishers some time to get set up
            rospy.logwarn("Magnetometer calibration loaded from file {}.".format(self.calibration_file))
            self.pub_msg(allow_save=False)
        elif load_from_params:
            self.x_mean = float(rospy.get_param("magnetometer_bias_x"))
            self.y_mean = float(rospy.get_param("magnetometer_bias_y"))
            self.z_mean = float(rospy.get_param("magnetometer_bias_z"))
            time.sleep(0.01)  # give the publishers some time to get set up
            rospy.logwarn("Magnetometer calibration loaded from parameters.")
            self.pub_msg()

    def subscribe(self):
        self.started = False
        self.sub = rospy.Subscriber('imu/mag', MagneticField, self.mag_callback)

    def calibrate_service_cb(self, _):
        if self.sub or rospy.Time.now() < self.finish_measuring_time:
            rospy.logwarn("Calibration in progress, ignoring another request")
            return False, "Calibration already running"

        self.subscribe()
        return True, "Calibration started, rotate the robot for the following %i seconds" % (self.measuring_time.secs,)

    def speak(self, message):
        msg = String()
        msg.data = message
        self.speak_pub.publish(msg)

    def mag_callback(self, msg):
        self.frame_id = msg.header.frame_id
        self.last_mag = msg

        if not self.started and rospy.Time.now().is_zero():
            return

        if not self.started:
            self.started = True
            self.finish_measuring_time = rospy.Time.now() + self.measuring_time

            log = "Started magnetometer calibration, rotate the robot several times in the following %i seconds." % \
                (self.measuring_time.secs,)
            rospy.logwarn(log)
            self.speak(log)

        if rospy.Time.now() < self.finish_measuring_time:
            mag = msg.magnetic_field

            if mag.x > self.x_max:
                self.x_max = mag.x
            elif mag.x < self.x_min:
                self.x_min = mag.x

            if mag.y > self.y_max:
                self.y_max = mag.y
            elif mag.y < self.y_min:
                self.y_min = mag.y

            if mag.z > self.z_max:
                self.z_max = mag.z
            elif mag.z < self.z_min:
                self.z_min = mag.z
        else:
            rospy.logwarn("Magnetometer calibrated")
            self.speak("Magnetometer calibrated")
            self.sub.unregister()  # unsubscribe the mag messages
            self.sub = None

            self.set_means()
            self.pub_msg()

    def set_means(self):
        self.x_mean = (self.x_min + self.x_max)/2
        self.y_mean = (self.y_min + self.y_max)/2
        self.z_mean = (self.z_min + self.z_max)/2

        if self.two_d_mode:
            x_range = self.x_max - self.x_min
            y_range = self.y_max - self.y_min
            z_range = self.z_max - self.z_min

            rospy.loginfo("range %f %f %f" % (x_range, y_range, z_range))

            free_axis = None
            if self.ignore_axis is not None:
                free_axis = self.ignore_axis.upper()
            elif x_range < min(0.75 * y_range, 0.75 * z_range):
                free_axis = "X"
            elif y_range < min(0.75 * x_range, 0.75 * z_range):
                free_axis = "Y"
            elif z_range < min(0.75 * x_range, 0.75 * y_range):
                free_axis = "Z"

            if free_axis == "X":
                self.x_mean = 0
            elif free_axis == "Y":
                self.y_mean = 0
            elif free_axis == "Z":
                self.z_mean = 0

            if free_axis is not None:
                rospy.loginfo("Magnetometer calibration finished in 2D mode with %s axis uncalibrated." % (free_axis,))
            else:
                rospy.logwarn(
                    "Magnetometer calibration in 2D mode requested, but autodetection of the free axis failed. "
                    "Did you rotate the robot? Or did you do move with the robot in all 3 axes?")

        if self.last_mag is not None:
            magnitude = np.linalg.norm([
                self.last_mag.magnetic_field.x - self.x_mean,
                self.last_mag.magnetic_field.y - self.y_mean,
                self.last_mag.magnetic_field.z - self.z_mean])
            logfn = rospy.loginfo if magnitude < 1e-4 else rospy.logwarn
            logfn("Magnitude of the calibrated magnetic field is %f T." % (magnitude,))

    def pub_msg(self, allow_save=True):
        self.msg.header.stamp = rospy.Time.now()
        self.msg.header.frame_id = self.frame_id

        self.msg.magnetic_field.x = self.x_mean
        self.msg.magnetic_field.y = self.y_mean
        self.msg.magnetic_field.z = self.z_mean

        self.pub.publish(self.msg)

        rospy.set_param("magnetometer_bias_x", self.x_mean)
        rospy.set_param("magnetometer_bias_y", self.y_mean)
        rospy.set_param("magnetometer_bias_z", self.z_mean)

        if self.save_to_file and allow_save:
            dir = os.path.dirname(self.calibration_file)
            if not os.path.exists(dir):
                try:
                    os.makedirs(dir)
                except OSError:
                    rospy.logerr("Could not create folder for storing calibration " + dir)

            if os.path.exists(dir):
                data = {
                    "magnetometer_bias_x": self.x_mean,
                    "magnetometer_bias_y": self.y_mean,
                    "magnetometer_bias_z": self.z_mean,
                }
                try:
                    with open(self.calibration_file, 'w') as f:
                        yaml.safe_dump(data, f)
                    rospy.logwarn("Saved magnetometer calibration to file " + self.calibration_file)
                except Exception as e:
                    rospy.logerr("An error occured while saving magnetometer calibration to file " +
                                 self.calibration_file)
            else:
                rospy.logerr("Could not store magnetometer calibration to file {}. Cannot create the file.".format(
                    self.calibration_file))

    def run(self):
        rospy.spin()


def main():
    rospy.init_node('magnetometer_bias_observer')
    node = MagBiasObserver()
    node.run()


if __name__ == '__main__':
    main()
