Source code for sionna.utils.plotting

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Plotting functions for the Sionna library."""

import numpy as np
import matplotlib.pyplot as plt
from sionna.utils import sim_ber
from itertools import compress # to "filter" list

[docs] def plot_ber(snr_db, ber, legend="", ylabel="BER", title="Bit Error Rate", ebno=True, is_bler=None, xlim=None, ylim=None, save_fig=False, path=""): """Plot error-rates. Input ----- snr_db: ndarray Array of floats defining the simulated SNR points. Can be also a list of multiple arrays. ber: ndarray Array of floats defining the BER/BLER per SNR point. Can be also a list of multiple arrays. legend: str Defaults to "". Defining the legend entries. Can be either a string or a list of strings. ylabel: str Defaults to "BER". Defining the y-label. title: str Defaults to "Bit Error Rate". Defining the title of the figure. ebno: bool Defaults to True. If True, the x-label is set to "EbNo [dB]" instead of "EsNo [dB]". is_bler: bool Defaults to False. If True, the corresponding curve is dashed. xlim: tuple of floats Defaults to None. A tuple of two floats defining x-axis limits. ylim: tuple of floats Defaults to None. A tuple of two floats defining y-axis limits. save_fig: bool Defaults to False. If True, the figure is saved as `.png`. path: str Defaults to "". Defining the path to save the figure (iff ``save_fig`` is True). Output ------ (fig, ax) : Tuple: fig : matplotlib.figure.Figure A matplotlib figure handle. ax : matplotlib.axes.Axes A matplotlib axes object. """ # legend must be a list or string if not isinstance(legend, list): assert isinstance(legend, str) legend = [legend] assert isinstance(title, str), "title must be str." # broadcast snr if ber is list if isinstance(ber, list): if not isinstance(snr_db, list): snr_db = [snr_db]*len(ber) # check that is_bler is list of same size and contains only bools if is_bler is None: if isinstance(ber, list): is_bler = [False] * len(ber) # init is_bler as list with False else: is_bler = False else: if isinstance(is_bler, list): assert (len(is_bler) == len(ber)), "is_bler has invalid size." else: assert isinstance(is_bler, bool), \ "is_bler must be bool or list of bool." is_bler = [is_bler] # change to list # tile snr_db if not list, but ber is list fig, ax = plt.subplots(figsize=(16,10)) plt.xticks(fontsize=18) plt.yticks(fontsize=18) if xlim is not None: plt.xlim(xlim) if ylim is not None: plt.ylim(ylim) plt.title(title, fontsize=25) # return figure handle if isinstance(ber, list): for idx, b in enumerate(ber): if is_bler[idx]: line_style = "--" else: line_style = "" plt.semilogy(snr_db[idx], b, line_style, linewidth=2) else: if is_bler: line_style = "--" else: line_style = "" plt.semilogy(snr_db, ber, line_style, linewidth=2) plt.grid(which="both") if ebno: plt.xlabel(r"$E_b/N_0$ (dB)", fontsize=25) else: plt.xlabel(r"$E_s/N_0$ (dB)", fontsize=25) plt.ylabel(ylabel, fontsize=25) plt.legend(legend, fontsize=20) if save_fig: plt.savefig(path) plt.close(fig) else: #plt.close(fig) pass return fig, ax
###### Plotting classes #######
[docs] class PlotBER(): """Provides a plotting object to simulate and store BER/BLER curves. Parameters ---------- title: str A string defining the title of the figure. Defaults to `"Bit/Block Error Rate"`. Input ----- snr_db: float Python array (or list of Python arrays) of additional SNR values to be plotted. ber: float Python array (or list of Python arrays) of additional BERs corresponding to ``snr_db``. legend: str String (or list of strings) of legends entries. is_bler: bool A boolean (or list of booleans) defaults to False. If True, ``ber`` will be interpreted as BLER. show_ber: bool A boolean defaults to True. If True, BER curves will be plotted. show_bler: bool A boolean defaults to True. If True, BLER curves will be plotted. xlim: tuple of floats Defaults to None. A tuple of two floats defining x-axis limits. ylim: tuple of floats Defaults to None. A tuple of two floats defining y-axis limits. save_fig: bool A boolean defaults to False. If True, the figure is saved as file. path: str A string defining where to save the figure (if ``save_fig`` is True). """ def __init__(self, title="Bit/Block Error Rate"): assert isinstance(title, str), "title must be str." self._title = title # init lists self._bers = [] self._snrs = [] self._legends = [] self._is_bler = [] # pylint: disable=W0102 def __call__(self, snr_db=[], ber=[], legend=[], is_bler=[], show_ber=True, show_bler=True, xlim=None, ylim=None, save_fig=False, path=""): """Plot BER curves. """ assert isinstance(path, str), "path must be str." assert isinstance(save_fig, bool), "save_fig must be bool." # broadcast snr if ber is list if isinstance(ber, list): if not isinstance(snr_db, list): snr_db = [snr_db]*len(ber) if not isinstance(snr_db, list): snrs = self._snrs + [snr_db] else: snrs = self._snrs + snr_db if not isinstance(ber, list): bers = self._bers + [ber] else: bers = self._bers + ber if not isinstance(legend, list): legends = self._legends + [legend] else: legends = self._legends + legend if not isinstance(is_bler, list): is_bler = self._is_bler + [is_bler] else: is_bler = self._is_bler + is_bler # deactivate BER/BLER if len(is_bler)>0: # ignore if object is empty if show_ber is False: snrs = list(compress(snrs, is_bler)) bers = list(compress(bers, is_bler)) legends = list(compress(legends, is_bler)) is_bler = list(compress(is_bler, is_bler)) if show_bler is False: snrs = list(compress(snrs, np.invert(is_bler))) bers = list(compress(bers, np.invert(is_bler))) legends = list(compress(legends, np.invert(is_bler))) is_bler = list(compress(is_bler, np.invert(is_bler))) # set ylabel ylabel = "BER / BLER" if np.all(is_bler): # only BLERs to plot ylabel = "BLER" if not np.any(is_bler): # only BERs to plot ylabel = "BER" # and plot the results plot_ber(snr_db=snrs, ber=bers, legend=legends, is_bler=is_bler, title=self._title, ylabel=ylabel, xlim=xlim, ylim=ylim, save_fig=save_fig, path=path) ####public methods @property def title(self): """Title of the plot.""" return self._title @title.setter def title(self, title): """Set title of the plot.""" assert isinstance(title, str), "title must be string" self._title = title @property def ber(self): """List containing all stored BER curves.""" return self._bers @property def snr(self): """List containing all stored SNR curves.""" return self._snrs @property def legend(self): """List containing all stored legend entries curves.""" return self._legends @property def is_bler(self): """List of booleans indicating if ber shall be interpreted as BLER.""" return self._is_bler
[docs] def simulate(self, mc_fun, ebno_dbs, batch_size, max_mc_iter, legend="", add_ber=True, add_bler=False, soft_estimates=False, num_target_bit_errors=None, num_target_block_errors=None, target_ber=None, target_bler=None, early_stop=True, graph_mode=None, distribute=None, add_results=True, forward_keyboard_interrupt=True, show_fig=True, verbose=True): # pylint: disable=line-too-long r"""Simulate BER/BLER curves for given Keras model and saves the results. Internally calls :class:`sionna.utils.sim_ber`. Input ----- mc_fun: Callable that yields the transmitted bits `b` and the receiver's estimate `b_hat` for a given ``batch_size`` and ``ebno_db``. If ``soft_estimates`` is True, b_hat is interpreted as logit. ebno_dbs: ndarray of floats SNR points to be evaluated. batch_size: tf.int32 Batch-size for evaluation. max_mc_iter: int Max. number of Monte-Carlo iterations per SNR point. legend: str Name to appear in legend. add_ber: bool Defaults to True. Indicate if BER should be added to plot. add_bler: bool Defaults to False. Indicate if BLER should be added to plot. soft_estimates: bool A boolean, defaults to False. If True, ``b_hat`` is interpreted as logit and additional hard-decision is applied internally. num_target_bit_errors: int Target number of bit errors per SNR point until the simulation stops. num_target_block_errors: int Target number of block errors per SNR point until the simulation stops. target_ber: tf.float32 Defaults to `None`. The simulation stops after the first SNR point which achieves a lower bit error rate as specified by ``target_ber``. This requires ``early_stop`` to be `True`. target_bler: tf.float32 Defaults to `None`. The simulation stops after the first SNR point which achieves a lower block error rate as specified by ``target_bler``. This requires ``early_stop`` to be `True`. early_stop: bool A boolean defaults to True. If True, the simulation stops after the first error-free SNR point (i.e., no error occurred after ``max_mc_iter`` Monte-Carlo iterations). graph_mode: One of ["graph", "xla"], str A string describing the execution mode of ``mc_fun``. Defaults to `None`. In this case, ``mc_fun`` is executed as is. distribute: `None` (default) | "all" | list of indices | `tf.distribute.strategy` Distributes simulation on multiple parallel devices. If `None`, multi-device simulations are deactivated. If "all", the workload will be automatically distributed across all available GPUs via the `tf.distribute.MirroredStrategy`. If an explicit list of indices is provided, only the GPUs with the given indices will be used. Alternatively, a custom `tf.distribute.strategy` can be provided. Note that the same `batch_size` will be used for all GPUs in parallel, but the number of Monte-Carlo iterations ``max_mc_iter`` will be scaled by the number of devices such that the same number of total samples is simulated. However, all stopping conditions are still in-place which can cause slight differences in the total number of simulated samples. add_results: bool Defaults to True. If True, the simulation results will be appended to the internal list of results. show_fig: bool Defaults to True. If True, a BER figure will be plotted. verbose: bool A boolean defaults to True. If True, the current progress will be printed. forward_keyboard_interrupt: bool A boolean defaults to True. If False, `KeyboardInterrupts` will be catched internally and not forwarded (e.g., will not stop outer loops). If False, the simulation ends and returns the intermediate simulation results. Output ------ (ber, bler): Tuple: ber: float The simulated bit-error rate. bler: float The simulated block-error rate. """ ber, bler = sim_ber( mc_fun, ebno_dbs, batch_size, soft_estimates=soft_estimates, max_mc_iter=max_mc_iter, num_target_bit_errors=num_target_bit_errors, num_target_block_errors=num_target_block_errors, target_ber=target_ber, target_bler=target_bler, early_stop=early_stop, graph_mode=graph_mode, distribute=distribute, verbose=verbose, forward_keyboard_interrupt=forward_keyboard_interrupt) if add_ber: self._bers += [ber] self._snrs += [ebno_dbs] self._legends += [legend] self._is_bler += [False] if add_bler: self._bers += [bler] self._snrs += [ebno_dbs] self._legends += [legend + " (BLER)"] self._is_bler += [True] if show_fig: self() # remove current curve if add_results=False if add_results is False: if add_bler: self.remove(-1) if add_ber: self.remove(-1) return ber, bler
[docs] def add(self, ebno_db, ber, is_bler=False, legend=""): """Add static reference curves. Input ----- ebno_db: float Python array or list of floats defining the SNR points. ber: float Python array or list of floats defining the BER corresponding to each SNR point. is_bler: bool A boolean defaults to False. If True, ``ber`` is interpreted as BLER. legend: str A string defining the text of the legend entry. """ assert (len(ebno_db)==len(ber)), \ "ebno_db and ber must have same number of elements." assert isinstance(legend, str), "legend must be str." assert isinstance(is_bler, bool), "is_bler must be bool." # concatenate curves self._bers += [ber] self._snrs += [ebno_db] self._legends += [legend] self._is_bler += [is_bler]
[docs] def reset(self): """Remove all internal data.""" self._bers = [] self._snrs = [] self._legends = [] self._is_bler = []
[docs] def remove(self, idx=-1): """Remove curve with index ``idx``. Input ------ idx: int An integer defining the index of the dataset that should be removed. Negative indexing is possible. """ assert isinstance(idx, int), "id must be int." del self._bers[idx] del self._snrs[idx] del self._legends[idx] del self._is_bler[idx]