Source code for sionna.phy.signal.downsampling
#
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Block implementing downsampling"""
from typing import Optional
import torch
from sionna.phy import Block
from sionna.phy.config import Precision
__all__ = ["Downsampling"]
[docs]
class Downsampling(Block):
"""Downsamples a tensor along a specified axis by retaining one out of
``samples_per_symbol`` elements.
:param samples_per_symbol: Downsampling factor. If ``samples_per_symbol``
is equal to `n`, then the downsampled axis will be `n`-times shorter.
:param offset: Index of the first element to be retained. Defaults to 0.
:param num_symbols: Total number of symbols to be retained after
downsampling. If `None`, all available symbols are retained.
Defaults to `None`.
:param axis: Dimension to be downsampled. Must not be the first dimension.
Defaults to -1.
:param precision: Precision used for internal calculations and outputs.
If set to `None`, :attr:`~sionna.phy.config.Config.precision` is used.
:param device: Device for computation. If `None`,
:attr:`~sionna.phy.config.Config.device` is used.
:input x: [..., n, ...], `torch.float` or `torch.complex`.
Tensor to be downsampled. `n` is the size of the `axis` dimension.
:output y: [..., k, ...], `torch.float` or `torch.complex`.
Downsampled tensor, where ``k``
is min((``n``-``offset``)//``samples_per_symbol``, ``num_symbols``).
.. rubric:: Examples
.. code-block:: python
import torch
from sionna.phy.signal import Downsampling
downsampler = Downsampling(samples_per_symbol=4, offset=2)
x = torch.randn(32, 400)
y = downsampler(x)
print(y.shape)
# torch.Size([32, 100])
"""
def __init__(
self,
samples_per_symbol: int,
offset: int = 0,
num_symbols: Optional[int] = None,
axis: int = -1,
precision: Optional[Precision] = None,
device: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(precision=precision, device=device, **kwargs)
self._samples_per_symbol = samples_per_symbol
self._offset = offset
self._num_symbols = num_symbols
self._axis = axis
@property
def samples_per_symbol(self) -> int:
"""Downsampling factor"""
return self._samples_per_symbol
@property
def offset(self) -> int:
"""Index of the first element to be retained"""
return self._offset
@property
def num_symbols(self) -> Optional[int]:
"""Total number of symbols to be retained after downsampling"""
return self._num_symbols
@property
def axis(self) -> int:
"""Dimension to be downsampled"""
return self._axis
def call(self, x: torch.Tensor) -> torch.Tensor:
# Put selected axis last
x = torch.swapaxes(x, self._axis, -1)
# Downsample
x = x[..., self._offset :: self._samples_per_symbol]
if self._num_symbols is not None:
x = x[..., : self._num_symbols]
# Put last axis to original position
x = torch.swapaxes(x, -1, self._axis)
return x