SymbolLogits2Moments#
- class sionna.phy.mapping.SymbolLogits2Moments(constellation_type: str | None = None, num_bits_per_symbol: int | None = None, constellation: sionna.phy.mapping.Constellation | None = None, precision: Literal['single', 'double'] | None = None, device: str | None = None, **kwargs: Any)[source]#
Bases:
sionna.phy.block.BlockComputes the mean and variance of a constellation from logits (unnormalized log-probabilities) on the constellation points.
More precisely, given a constellation \(\mathcal{C} = \left[ c_0,\dots,c_{N-1} \right]\) of size \(N\), this layer computes the mean and variance according to
\[\begin{split}\begin{aligned} \mu &= \sum_{n = 0}^{N-1} c_n \Pr \left(c_n \lvert \mathbf{\ell} \right)\\ \nu &= \sum_{n = 0}^{N-1} \left( c_n - \mu \right)^2 \Pr \left(c_n \lvert \mathbf{\ell} \right) \end{aligned}\end{split}\]where \(\mathbf{\ell} = \left[ \ell_0, \dots, \ell_{N-1} \right]\) are the logits, and
\[\Pr \left(c_n \lvert \mathbf{\ell} \right) = \frac{\exp \left( \ell_n \right)}{\sum_{i=0}^{N-1} \exp \left( \ell_i \right) }.\]- Parameters:
constellation_type (str | None) – Type of constellation. One of “qam”, “pam”, or “custom”. For “custom”, an instance of
Constellationmust be provided.num_bits_per_symbol (int | None) – The number of bits per constellation symbol, e.g., 4 for QAM16. Only required for
constellation_typein [“qam”, “pam”].constellation (sionna.phy.mapping.Constellation | None) – If no constellation is provided,
constellation_typeandnum_bits_per_symbolmust be provided. Defaults to None.precision (Literal['single', 'double'] | None) – Precision used for internal calculations and outputs. If set to None,
precisionis used.device (str | None) – Device for tensor operations. If None,
deviceis used.kwargs (Any)
- Inputs:
logits – […, n, num_points], torch.float. Logits on constellation points.
- Outputs:
mean – […, n], torch.complex. Mean of the constellation.
var – […, n], torch.float. Variance of the constellation.
Examples
import torch from sionna.phy.mapping import SymbolLogits2Moments converter = SymbolLogits2Moments("qam", 4) logits = torch.randn(10, 25, 16) # 10 batches, 25 symbols, 16 constellation points mean, var = converter(logits) print(mean.shape, var.shape) # torch.Size([10, 25]) torch.Size([10, 25])