#!/usr/bin/env python3

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""
Visualization node for sky views.

This node visualizes a sky view received on either of two topics:
- `sky_view` (type `gnss_info_msgs/SkyView`)
- `fix` (type `gps_common/GPSFix`)

The SkyView topic should also be accompanied by topic `satellites` (`gnss_info_msgs/SatellitesList`) to decode
which satellite belongs to what constellation.

The sky plot is published as `sensor_msgs/Image` on topic `sky_plot`.

For parameters of the node, see the bottom of this script.
"""

import matplotlib.pyplot as plt
import numpy as np
import rospy

from cv_bridge import CvBridge
from datetime import datetime
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg

from geographic_msgs.msg import GeoPoint
from gnss_info_msgs.msg import SatelliteInfo, SatelliteSkyPosition, SatellitesList, SkyView
from gps_common.msg import GPSFix
from sensor_msgs.msg import Image

satlist = None
cv = CvBridge()

cmap = plt.get_cmap("tab10")
colors = {
    'GPS': cmap(0),
    'BEIDOU': cmap(1),
    'GLONASS': cmap(2),
    'GALILEO': cmap(3),
}


def on_satellites(_satlist):
    """
    :param SatellitesList _satlist:
    """
    global satlist
    satlist = dict([(s.satcat_id, s) for s in _satlist.satellites])


def on_sky_view(msg):
    """
    :param SkyView msg:
    """

    ax.cla()
    fig.patch.set_facecolor(plot_color)
    ax.set_facecolor(plot_color)

    satellites = {}
    if satlist is None:
        satellites["All Constellations"] = msg.satellites
    else:
        for sat in msg.satellites:
            id = sat.satcat_id
            if id not in satlist:
                constellation = "Unknown"
            else:
                constellation = satlist[id].constellation

            if len(only_constellations) > 0 and constellation not in only_constellations:
                continue

            if constellation not in satellites:
                satellites[constellation] = []
            satellites[constellation].append(sat)

    sky_plot(satellites, msg.header, msg.reference_position, msg.elevation_mask_deg)


def on_gps_fix(msg):
    """
    :param GPSFix msg:
    """

    ax.cla()
    fig.patch.set_facecolor(plot_color)
    ax.set_facecolor(plot_color)

    satellites = {
        "Used": [],
        "Unused": [],
    }
    global satlist
    satlist = dict()
    used_i = 0
    for i in range(msg.status.satellites_visible):
        prn = msg.status.satellite_visible_prn[i]
        is_used = used_i < len(msg.status.satellite_used_prn) and msg.status.satellite_used_prn[used_i] == prn
        if is_used:
            used_i += 1
        satlist_idx = 1000 * i + prn
        sat = SatelliteSkyPosition()
        sat.satcat_id = satlist_idx
        sat.azimuth_deg = msg.status.satellite_visible_azimuth[i]
        sat.elevation_deg = msg.status.satellite_visible_z[i]
        satellites["Used" if is_used else "Unused"].append(sat)
        satlist[satlist_idx] = SatelliteInfo()
        satlist[satlist_idx].stamp = msg.header.stamp
        satlist[satlist_idx].prn = prn
        satlist[satlist_idx].satcat_id = satlist_idx
        satlist[satlist_idx].active = True
        satlist[satlist_idx].name = str(prn)

    ref_pos = GeoPoint()
    ref_pos.latitude = msg.latitude
    ref_pos.longitude = msg.longitude
    ref_pos.altitude = msg.altitude
    sky_plot(satellites, msg.header, ref_pos, 0.0)


def sky_plot(satellites, header, reference_position, elevation_mask_deg=0.0):
    for constellation in satellites:
        # Azimuth MUST be given to plot() in radians.
        azimuths = np.array([s.azimuth_deg * np.pi / 180.0 for s in satellites[constellation]])
        elevations = np.array([90.0 - s.elevation_deg for s in satellites[constellation]])

        az_data = azimuths[elevations <= 90.0]
        elev_data = elevations[elevations <= 90.0]

        if az_data.shape[0] == 0:
            continue

        label = constellation + " (%i)" % (len(satellites[constellation],))
        ax.scatter(az_data, elev_data, label=label, marker='o')

        if show_labels:
            for (i, s) in enumerate(satellites[constellation]):
                ax.annotate(satlist[s.satcat_id].prn, (az_data[i], elev_data[i]), fontsize='small', color='gray')

    if show_elevation_mask and elevation_mask_deg > 0:
        ax.fill_between(np.linspace(0, 2*np.pi, 100), 90, 90 - elevation_mask_deg, color='gray', zorder=0, alpha=0.5)

    ax.set_theta_zero_location('N')
    ax.set_theta_direction(-1)
    ax.set_rlim(0, 90)
    ax.grid(True, which='major')

    elev_labels = ('90', '75', '60', '45', '30', '15', '')

    azimuth_labels = []
    for i in range(0, 8):
        label_angle = i * 45
        if i == 0:
            azimuth_labels.append('N')
        elif i == 2:
            azimuth_labels.append('E')
        elif i == 4:
            azimuth_labels.append('S')
        elif i == 6:
            azimuth_labels.append('W')
        else:
            azimuth_labels.append(str(label_angle))
    azimuth_labels.append('')

    ax.set_rgrids(range(1, 106, 15), elev_labels, angle=0, color='gray')
    ax.set_thetagrids(range(0, 360, 45), azimuth_labels, color='gray')

    if show_legend:
        ax.legend(loc=(0.9, -0.05), framealpha=0.3, frameon=False)

    if show_title:
        stamp = datetime.utcfromtimestamp(header.stamp.secs).strftime('%Y-%m-%d %H:%M:%S')
        lat = reference_position.latitude
        lon = reference_position.longitude
        title = "%s UTC, %0.2f° %s, %0.2f° %s" % (
            stamp, abs(lat), "N" if lat >= 0 else "S", abs(lon), "E" if lon >= 0 else "W")
        ax.set_title(title, y=1.0, color='gray')

    canvas.draw()
    img = cv.cv2_to_imgmsg(np.asarray(canvas.buffer_rgba()), encoding='rgba8', header=header)
    pub.publish(img)


if __name__ == '__main__':
    rospy.init_node("sky_plot")

    pub = rospy.Publisher("sky_plot", Image, queue_size=10, latch=True)

    plot_color = rospy.get_param("~plot_color", "white")
    show_legend = bool(rospy.get_param("~show_legend", True))
    show_labels = bool(rospy.get_param("~show_labels", True))
    show_title = bool(rospy.get_param("~show_title", True))
    show_elevation_mask = bool(rospy.get_param("~show_elevation_mask", True))
    only_constellations = []
    if rospy.has_param("~only_constellations"):
        only_constellations = rospy.get_param("~only_constellations")
        if isinstance(only_constellations, str):
            only_constellations = only_constellations.split(",")
        if len(only_constellations) == 1 and only_constellations[0] == "":
            only_constellations = []

    width = rospy.get_param("~width", 840)
    height = rospy.get_param("~height", 720)
    dpi = rospy.get_param("~dpi", 120)

    fig = Figure(figsize=(width / dpi, height / dpi), dpi=dpi)
    canvas = FigureCanvasAgg(fig)
    ax = fig.add_subplot(projection='polar')

    sub = rospy.Subscriber("satellites", SatellitesList, on_satellites, queue_size=1)
    sub2 = rospy.Subscriber("sky_view", SkyView, on_sky_view, queue_size=1)
    sub3 = rospy.Subscriber("fix", GPSFix, on_gps_fix, queue_size=1)

    rospy.spin()
