Source code for sionna.phy.utils.metrics

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Functions to compute frequently used metrics in Sionna PHY"""

import torch
from sionna.phy.config import dtypes, Precision

__all__ = [
    "compute_ber",
    "compute_ser",
    "compute_bler",
    "count_errors",
    "count_block_errors",
]


[docs] def compute_ber( b: torch.Tensor, b_hat: torch.Tensor, precision: Precision = "double" ) -> torch.Tensor: """Computes the bit error rate (BER) between two binary tensors. :param b: A tensor of arbitrary shape filled with ones and zeros. :param b_hat: A tensor like ``b``. :param precision: Precision used for internal calculations and outputs. Defaults to ``"double"``. :output ber: `torch.float`. BER. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import compute_ber b = torch.tensor([0, 1, 0, 1]) b_hat = torch.tensor([0, 1, 1, 0]) print(compute_ber(b, b_hat).item()) # 0.5 """ b_hat = b_hat.to(b.dtype) rdtype = dtypes[precision]["torch"]["dtype"] ber = torch.ne(b, b_hat) ber = ber.to(rdtype) return torch.mean(ber)
[docs] def compute_ser( s: torch.Tensor, s_hat: torch.Tensor, precision: Precision = "double" ) -> torch.Tensor: """Computes the symbol error rate (SER) between two integer tensors. :param s: A tensor of arbitrary shape filled with integers. :param s_hat: A tensor like ``s``. :param precision: Precision used for internal calculations and outputs. Defaults to ``"double"``. :output ser: `torch.float`. SER. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import compute_ser s = torch.tensor([0, 1, 2, 3]) s_hat = torch.tensor([0, 1, 3, 2]) print(compute_ser(s, s_hat).item()) # 0.5 """ return compute_ber(s, s_hat, precision)
[docs] def compute_bler( b: torch.Tensor, b_hat: torch.Tensor, precision: Precision = "double" ) -> torch.Tensor: """Computes the block error rate (BLER) between two binary tensors. A block error happens if at least one element of ``b`` and ``b_hat`` differ in one block. The BLER is evaluated over the last dimension of the input, i. e., all elements of the last dimension are considered to define a block. This is also sometimes referred to as `word error rate` or `frame error rate`. :param b: A tensor of arbitrary shape filled with ones and zeros. :param b_hat: A tensor like ``b``. :param precision: Precision used for internal calculations and outputs. Defaults to ``"double"``. :output bler: `torch.float`. BLER. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import compute_bler b = torch.tensor([[0, 1], [1, 0]]) b_hat = torch.tensor([[0, 1], [1, 1]]) # The first block is correct, the second block is incorrect print(compute_bler(b, b_hat).item()) # 0.5 """ b_hat = b_hat.to(b.dtype) rdtype = dtypes[precision]["torch"]["dtype"] bler = torch.any(torch.ne(b, b_hat), dim=-1) bler = bler.to(rdtype) return torch.mean(bler)
[docs] def count_errors(b: torch.Tensor, b_hat: torch.Tensor) -> torch.Tensor: """Counts the number of bit errors between two binary tensors. :param b: A tensor of arbitrary shape filled with ones and zeros. :param b_hat: A tensor like ``b``. :output num_errors: `torch.int64`. Number of bit errors. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import count_errors b = torch.tensor([0, 1, 0, 1]) b_hat = torch.tensor([0, 1, 1, 0]) print(count_errors(b, b_hat).item()) # 2 """ b_hat = b_hat.to(b.dtype) errors = torch.ne(b, b_hat) errors = errors.to(torch.int64) return torch.sum(errors)
[docs] def count_block_errors(b: torch.Tensor, b_hat: torch.Tensor) -> torch.Tensor: """Counts the number of block errors between two binary tensors. A block error happens if at least one element of ``b`` and ``b_hat`` differ in one block. The BLER is evaluated over the last dimension of the input, i. e., all elements of the last dimension are considered to define a block. This is also sometimes referred to as `word error rate` or `frame error rate`. :param b: A tensor of arbitrary shape filled with ones and zeros. :param b_hat: A tensor like ``b``. :output num_errors: `torch.int64`. Number of block errors. .. rubric:: Examples .. code-block:: python import torch from sionna.phy.utils import count_block_errors b = torch.tensor([[0, 1], [1, 0]]) b_hat = torch.tensor([[0, 1], [1, 1]]) print(count_block_errors(b, b_hat).item()) # 1 """ b_hat = b_hat.to(b.dtype) errors = torch.any(torch.ne(b, b_hat), dim=-1) errors = errors.to(torch.int64) return torch.sum(errors)