#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Plotting functions for Sionna PHY"""
from itertools import compress
from typing import Callable, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
from sionna.phy.utils import sim_ber
__all__ = ["plot_ber", "PlotBER"]
[docs]
def plot_ber(
snr_db: Union[np.ndarray, List[np.ndarray]],
ber: Union[np.ndarray, List[np.ndarray]],
legend: Union[str, List[str]] = "",
ylabel: str = "BER",
title: str = "Bit Error Rate",
ebno: bool = True,
is_bler: Optional[Union[bool, List[bool]]] = None,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
save_fig: bool = False,
path: str = "",
) -> Tuple[plt.Figure, plt.Axes]:
"""Plot error-rates.
:param snr_db: Array defining the simulated SNR points.
:param ber: Array defining the BER/BLER per SNR point.
:param legend: Legend entries.
:param ylabel: Label for the y-axis.
:param title: Figure title.
:param ebno: If `True`, the x-label is set to
"EbNo [dB]" instead of "EsNo [dB]".
:param is_bler: If `True`, the corresponding curve is dashed.
:param xlim: x-axis limits.
:param ylim: y-axis limits.
:param save_fig: If `True`, the figure is saved as `.png`.
:param path: Path to save the figure if ``save_fig`` is `True`.
:output fig: `matplotlib.figure.Figure`.
Figure handle.
:output ax: `matplotlib.axes.Axes`.
Axes object.
.. rubric:: Examples
.. code-block:: python
import numpy as np
from sionna.phy.utils import plot_ber
snr = np.array([0, 2, 4, 6, 8, 10])
ber = np.array([0.2, 0.1, 0.05, 0.01, 0.001, 0.0001])
fig, ax = plot_ber(snr, ber, legend="AWGN", title="BER vs SNR")
"""
# legend must be a list or string
if not isinstance(legend, list):
if not isinstance(legend, str):
raise TypeError("legend must be str or list of str.")
legend = [legend]
if not isinstance(title, str):
raise TypeError("title must be str.")
# broadcast snr if ber is list
if isinstance(ber, list):
if not isinstance(snr_db, list):
snr_db = [snr_db] * len(ber)
# check that is_bler is list of same size and contains only bools
if is_bler is None:
if isinstance(ber, list):
is_bler = [False] * len(ber)
else:
is_bler = False
else:
if isinstance(is_bler, list):
if len(is_bler) != len(ber):
raise ValueError("is_bler has invalid size.")
else:
if not isinstance(is_bler, bool):
raise TypeError("is_bler must be bool or list of bool.")
is_bler = [is_bler]
fig, ax = plt.subplots(figsize=(16, 10))
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
if xlim is not None:
plt.xlim(xlim)
if ylim is not None:
plt.ylim(ylim)
plt.title(title, fontsize=25)
if isinstance(ber, list):
for idx, b in enumerate(ber):
line_style = "--" if is_bler[idx] else ""
plt.semilogy(snr_db[idx], b, line_style, linewidth=2)
else:
line_style = "--" if is_bler else ""
plt.semilogy(snr_db, ber, line_style, linewidth=2)
plt.grid(which="both")
if ebno:
plt.xlabel(r"$E_b/N_0$ (dB)", fontsize=25)
else:
plt.xlabel(r"$E_s/N_0$ (dB)", fontsize=25)
plt.ylabel(ylabel, fontsize=25)
plt.legend(legend, fontsize=20)
if save_fig:
plt.savefig(path)
plt.close(fig)
return fig, ax
[docs]
class PlotBER:
"""Provides a plotting object to simulate and store BER/BLER curves.
:param title: Figure title.
:input snr_db: `numpy.ndarray` or `list` of `numpy.ndarray`.
SNR values.
:input ber: `numpy.ndarray` or `list` of `numpy.ndarray`.
BER values corresponding to ``snr_db``.
:input legend: `str` or `list` of `str`.
Legend entries.
:input is_bler: `bool` or `list` of `bool`.
If `True`, ``ber`` will be interpreted as BLER.
:input show_ber: `bool`.
If `True`, BER curves will be plotted.
:input show_bler: `bool`.
If `True`, BLER curves will be plotted.
:input xlim: `None` | (`float`, `float`).
x-axis limits.
:input ylim: `None` | (`float`, `float`).
y-axis limits.
:input save_fig: `bool`.
If `True`, the figure is saved as `.png`.
:input path: `str`.
Path to save the figure if ``save_fig`` is `True`.
.. rubric:: Examples
.. code-block:: python
import numpy as np
from sionna.phy.utils import PlotBER
ber_plot = PlotBER(title="My BER Plot")
snr = np.array([0, 2, 4, 6, 8])
ber = np.array([0.1, 0.05, 0.01, 0.001, 0.0001])
ber_plot.add(snr, ber, legend="Curve 1")
ber_plot() # Display the plot
"""
def __init__(self, title: str = "Bit/Block Error Rate"):
if not isinstance(title, str):
raise TypeError("title must be str.")
self._title = title
# init lists
self._bers: List[np.ndarray] = []
self._snrs: List[np.ndarray] = []
self._legends: List[str] = []
self._is_bler: List[bool] = []
def __call__(
self,
snr_db: Union[np.ndarray, List[np.ndarray], float] = None,
ber: Union[np.ndarray, List[np.ndarray], float] = None,
legend: Union[str, List[str]] = None,
is_bler: Union[bool, List[bool]] = None,
show_ber: bool = True,
show_bler: bool = True,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
save_fig: bool = False,
path: str = "",
) -> None:
"""Plot BER curves."""
# Handle None defaults
if snr_db is None:
snr_db = []
if ber is None:
ber = []
if legend is None:
legend = []
if is_bler is None:
is_bler = []
if not isinstance(path, str):
raise TypeError("path must be str.")
if not isinstance(save_fig, bool):
raise TypeError("save_fig must be bool.")
# broadcast snr if ber is list
if isinstance(ber, list):
if not isinstance(snr_db, list):
snr_db = [snr_db] * len(ber)
if not isinstance(snr_db, list):
snrs = self._snrs + [snr_db]
else:
snrs = self._snrs + snr_db
if not isinstance(ber, list):
bers = self._bers + [ber]
else:
bers = self._bers + ber
if not isinstance(legend, list):
legends = self._legends + [legend]
else:
legends = self._legends + legend
if not isinstance(is_bler, list):
is_bler_list = self._is_bler + [is_bler]
else:
is_bler_list = self._is_bler + is_bler
# deactivate BER/BLER
if len(is_bler_list) > 0:
if show_ber is False:
snrs = list(compress(snrs, is_bler_list))
bers = list(compress(bers, is_bler_list))
legends = list(compress(legends, is_bler_list))
is_bler_list = list(compress(is_bler_list, is_bler_list))
if show_bler is False:
inverted = np.invert(is_bler_list)
snrs = list(compress(snrs, inverted))
bers = list(compress(bers, inverted))
legends = list(compress(legends, inverted))
is_bler_list = list(compress(is_bler_list, inverted))
# set ylabel
ylabel = "BER / BLER"
if len(is_bler_list) > 0:
if np.all(is_bler_list):
ylabel = "BLER"
if not np.any(is_bler_list):
ylabel = "BER"
# plot the results
if len(bers) > 0:
plot_ber(
snr_db=snrs,
ber=bers,
legend=legends,
is_bler=is_bler_list,
title=self._title,
ylabel=ylabel,
xlim=xlim,
ylim=ylim,
save_fig=save_fig,
path=path,
)
@property
def title(self) -> str:
"""Get/set title of the plot."""
return self._title
@title.setter
def title(self, title: str) -> None:
if not isinstance(title, str):
raise TypeError("title must be string.")
self._title = title
@property
def ber(self) -> List[np.ndarray]:
"""Stored BER/BLER values."""
return self._bers
@property
def snr(self) -> List[np.ndarray]:
"""Stored SNR values."""
return self._snrs
@property
def legend(self) -> List[str]:
"""Legend entries."""
return self._legends
@property
def is_bler(self) -> List[bool]:
"""Indicates if a curve shall be interpreted as BLER."""
return self._is_bler
[docs]
def simulate(
self,
mc_fun: Callable,
ebno_dbs: torch.Tensor,
batch_size: int,
max_mc_iter: int,
legend: str = "",
add_ber: bool = True,
add_bler: bool = False,
soft_estimates: bool = False,
num_target_bit_errors: Optional[int] = None,
num_target_block_errors: Optional[int] = None,
target_ber: Optional[float] = None,
target_bler: Optional[float] = None,
early_stop: bool = True,
compile_mode: Optional[str] = None,
add_results: bool = True,
forward_keyboard_interrupt: bool = True,
show_fig: bool = True,
verbose: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Simulate BER/BLER curves for a given model and saves the results.
Internally calls :func:`sionna.phy.utils.sim_ber`.
:param mc_fun: Callable that yields the transmitted bits `b` and the
receiver's estimate `b_hat` for a given ``batch_size`` and
``ebno_db``. If ``soft_estimates`` is `True`, `b_hat` is
interpreted as logit.
:param ebno_dbs: SNR points to be evaluated.
:param batch_size: Batch-size for evaluation.
:param max_mc_iter: Max. number of Monte-Carlo iterations per
SNR point.
:param legend: Name to appear in legend.
:param add_ber: If `True`, BER will be added to plot.
:param add_bler: If `True`, BLER will be added to plot.
:param soft_estimates: If `True`, ``b_hat`` is interpreted as logit
and additional hard-decision is applied internally.
:param num_target_bit_errors: Target number of bit errors per SNR
point until the simulation stops.
:param num_target_block_errors: Target number of block errors per
SNR point until the simulation stops.
:param target_ber: The simulation stops after the first SNR point
which achieves a lower bit error rate as specified by
``target_ber``. This requires ``early_stop`` to be `True`.
:param target_bler: The simulation stops after the first SNR point
which achieves a lower block error rate as specified by
``target_bler``. This requires ``early_stop`` to be `True`.
:param early_stop: If `True`, the simulation stops after the first
error-free SNR point (i.e., no error occurred after
``max_mc_iter`` Monte-Carlo iterations).
:param compile_mode: Compilation mode for ``mc_fun``.
If `None`, ``mc_fun`` is executed as is.
Options: `None`, ``"default"``, ``"reduce-overhead"``,
``"max-autotune"``.
:param add_results: If `True`, the simulation results will be
appended to the internal list of results.
:param forward_keyboard_interrupt: If `False`, `KeyboardInterrupts`
will be caught internally and not forwarded (e.g., will not stop
outer loops). If `True`, the simulation ends and returns the
intermediate simulation results.
:param show_fig: If `True`, a BER figure will be plotted.
:param verbose: If `True`, the current progress will be printed.
:output ber: `torch.float`.
Simulated bit-error rates.
:output bler: `torch.float`.
Simulated block-error rates.
"""
ber, bler = sim_ber(
mc_fun,
ebno_dbs,
batch_size,
soft_estimates=soft_estimates,
max_mc_iter=max_mc_iter,
num_target_bit_errors=num_target_bit_errors,
num_target_block_errors=num_target_block_errors,
target_ber=target_ber,
target_bler=target_bler,
early_stop=early_stop,
compile_mode=compile_mode,
verbose=verbose,
forward_keyboard_interrupt=forward_keyboard_interrupt,
)
# Convert to numpy for storage
ber_np = ber.cpu().numpy()
bler_np = bler.cpu().numpy()
ebno_np = (
ebno_dbs.cpu().numpy() if isinstance(ebno_dbs, torch.Tensor) else ebno_dbs
)
if add_ber:
self._bers.append(ber_np)
self._snrs.append(ebno_np)
self._legends.append(legend)
self._is_bler.append(False)
if add_bler:
self._bers.append(bler_np)
self._snrs.append(ebno_np)
self._legends.append(legend + " (BLER)")
self._is_bler.append(True)
if show_fig:
self()
# remove current curve if add_results=False
if add_results is False:
if add_bler:
self.remove(-1)
if add_ber:
self.remove(-1)
return ber, bler
[docs]
def add(
self,
ebno_db: np.ndarray,
ber: np.ndarray,
is_bler: bool = False,
legend: str = "",
) -> None:
"""Add static reference curves.
:param ebno_db: SNR points.
:param ber: BER corresponding to each SNR point.
:param is_bler: If `True`, ``ber`` is interpreted as BLER.
:param legend: Legend entry.
"""
if len(ebno_db) != len(ber):
raise ValueError("ebno_db and ber must have same number of elements.")
if not isinstance(legend, str):
raise TypeError("legend must be str.")
if not isinstance(is_bler, bool):
raise TypeError("is_bler must be bool.")
self._bers.append(ber)
self._snrs.append(ebno_db)
self._legends.append(legend)
self._is_bler.append(is_bler)
[docs]
def reset(self) -> None:
"""Remove all internal data."""
self._bers = []
self._snrs = []
self._legends = []
self._is_bler = []
[docs]
def remove(self, idx: int = -1) -> None:
"""Remove curve with index ``idx``.
:param idx: Index of the dataset that should be removed. Negative
indexing is possible.
"""
if not isinstance(idx, int):
raise TypeError("idx must be int.")
del self._bers[idx]
del self._snrs[idx]
del self._legends[idx]
del self._is_bler[idx]