Neural Receiver for OFDM SIMO Systems#

In this notebook, you will learn how to train a neural receiver that implements OFDM detection. The considered setup is shown in the figure below. As one can see, the neural receiver substitutes channel estimation, equalization, and demapping. It takes as input the post-DFT (discrete Fourier transform) received samples, which form the received resource grid, and computes log-likelihood ratios (LLRs) on the transmitted coded bits. These LLRs are then fed to the outer decoder to reconstruct the transmitted information bits.

System Model

Two baselines are considered for benchmarking, which are shown in the figure above. Both baselines use linear minimum mean square error (LMMSE) equalization and demapping assuming additive white Gaussian noise (AWGN). They differ by how channel estimation is performed:

  • Pefect CSI: Perfect channel state information (CSI) knowledge is assumed.

  • LS estimation: Uses the transmitted pilots to perform least squares (LS) estimation of the channel with nearest-neighbor interpolation.

All the considered end-to-end systems use an LDPC outer code from the 5G NR specification, QPSK modulation, and a 3GPP CDL channel model simulated in the frequency domain.

Configuration and Imports#

[1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use GPU 1 (0-indexed)

# Import Sionna
try:
    import sionna.phy
except ImportError as e:
    import os
    import sys
    if 'google.colab' in sys.modules:
       # Install Sionna in Google Colab
       print("Installing Sionna and restarting the runtime. Please run the cell again.")
       os.system("pip install sionna")
       os.kill(os.getpid(), 5)
    else:
       raise e

# Set seed for reproducible random number generation
sionna.phy.config.seed = 42
device = sionna.phy.config.device
[2]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from sionna.phy import Block
from sionna.phy.channel.tr38901 import Antenna, AntennaArray, CDL
from sionna.phy.channel import OFDMChannel
from sionna.phy.mimo import StreamManagement
from sionna.phy.ofdm import ResourceGrid, ResourceGridMapper, LSChannelEstimator, \
                            LMMSEEqualizer, RemoveNulledSubcarriers, ResourceGridDemapper
from sionna.phy.utils import ebnodb2no, insert_dims, expand_to_rank, sim_ber
from sionna.phy.fec.ldpc import LDPC5GEncoder, LDPC5GDecoder
from sionna.phy.mapping import Mapper, Demapper, BinarySource

Simulation Parameters#

[3]:
############################################
## Channel configuration
carrier_frequency = 3.5e9 # Hz
delay_spread = 100e-9 # s
cdl_model = "C" # CDL model to use
speed = 10.0 # Speed for evaluation and training [m/s]
# SNR range for evaluation and training [dB]
ebno_db_min = -5.0
ebno_db_max = 10.0

############################################
## OFDM waveform configuration
subcarrier_spacing = 30e3 # Hz
fft_size = 128 # Number of subcarriers forming the resource grid, including the null-subcarrier and the guard bands
num_ofdm_symbols = 14 # Number of OFDM symbols forming the resource grid
dc_null = True # Null the DC subcarrier
num_guard_carriers = [5, 6] # Number of guard carriers on each side
pilot_pattern = "kronecker" # Pilot pattern
pilot_ofdm_symbol_indices = [2, 11] # Index of OFDM symbols carrying pilots
cyclic_prefix_length = 0 # Simulation in frequency domain. This is useless

############################################
## Modulation and coding configuration
num_bits_per_symbol = 2 # QPSK
coderate = 0.5 # Coderate for LDPC code

############################################
## Neural receiver configuration
num_conv_channels = 128 # Number of convolutional channels for the convolutional layers forming the neural receiver

############################################
## Training configuration
num_training_iterations = 30000 # Number of training iterations
training_batch_size = 128 # Training batch size
model_weights_path = "neural_receiver_weights" # Location to save the neural receiver weights once training is done

############################################
## Evaluation configuration
results_filename = "neural_receiver_results" # Location to save the results

The StreamManagement class is used to configure the receiver-transmitter association and the number of streams per transmitter. A SIMO system is considered, with a single transmitter equipped with a single non-polarized antenna. Therefore, there is only a single stream, and the receiver-transmitter association matrix is \([1]\). The receiver is equipped with an antenna array.

[4]:
stream_manager = StreamManagement(np.array([[1]]), # Receiver-transmitter association matrix
                                  1)               # One stream per transmitter

The ResourceGrid class is used to configure the OFDM resource grid. It is initialized with the parameters defined above.

[5]:
resource_grid = ResourceGrid(num_ofdm_symbols = num_ofdm_symbols,
                             fft_size = fft_size,
                             subcarrier_spacing = subcarrier_spacing,
                             num_tx = 1,
                             num_streams_per_tx = 1,
                             cyclic_prefix_length = cyclic_prefix_length,
                             dc_null = dc_null,
                             pilot_pattern = pilot_pattern,
                             pilot_ofdm_symbol_indices = pilot_ofdm_symbol_indices,
                             num_guard_carriers = num_guard_carriers)

Outer coding is performed such that all the databits carried by the resource grid with size fft_sizexnum_ofdm_symbols form a single codeword.

[6]:
# Codeword length. It is calculated from the total number of databits carried by the resource grid, and the number of bits transmitted per resource element
n = int(resource_grid.num_data_symbols*num_bits_per_symbol)
# Number of information bits per codeword
k = int(n*coderate)

The SIMO link is setup by considering an uplink transmission with one user terminal (UT) equipped with a single non-polarized antenna, and a base station (BS) equipped with an antenna array. One can try other configurations for the BS antenna array.

[7]:
ut_antenna = Antenna(polarization="single",
                     polarization_type="V",
                     antenna_pattern="38.901",
                     carrier_frequency=carrier_frequency)

bs_array = AntennaArray(num_rows=1,
                        num_cols=1,
                        polarization="dual",
                        polarization_type="VH",
                        antenna_pattern="38.901",
                        carrier_frequency=carrier_frequency)

Neural Receiver#

The next cell defines the PyTorch modules that implement the neural receiver. As in [1] and [2], a neural receiver using residual convolutional layers is implemented. Convolutional layers are leveraged to efficienly process the 2D resource grid, that is fed as an input to the neural receiver. Residual (skip) connections are used to avoid gradient vanishing [3].

For convenience, a PyTorch module that implements a residual block is first defined. The module that implements the neural receiver is built by stacking such blocks. The following figure shows the architecture of the neural receiver.

Neural RX

[8]:
class ResidualBlock(nn.Module):
    """
    Convolutional residual block with two convolutional layers, ReLU activation,
    layer normalization, and a skip connection.

    The number of convolutional channels of the input must match num_conv_channels
    for the skip connection to work.

    Input shape: [batch_size, num_conv_channels, num_ofdm_symbols, num_subcarriers]
    Output shape: [batch_size, num_conv_channels, num_ofdm_symbols, num_subcarriers]
    """

    def __init__(self, num_conv_channels: int):
        super().__init__()
        # Layer normalization over the last three dimensions (C, H, W)
        self._layer_norm_1 = nn.LayerNorm([num_conv_channels, num_ofdm_symbols, fft_size])
        self._conv_1 = nn.Conv2d(
            in_channels=num_conv_channels,
            out_channels=num_conv_channels,
            kernel_size=3,
            padding=1,  # 'same' padding
        )
        self._layer_norm_2 = nn.LayerNorm([num_conv_channels, num_ofdm_symbols, fft_size])
        self._conv_2 = nn.Conv2d(
            in_channels=num_conv_channels,
            out_channels=num_conv_channels,
            kernel_size=3,
            padding=1,
        )

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        z = self._layer_norm_1(inputs)
        z = F.relu(z)
        z = self._conv_1(z)
        z = self._layer_norm_2(z)
        z = F.relu(z)
        z = self._conv_2(z)
        # Skip connection
        z = z + inputs
        return z


class NeuralReceiver(nn.Module):
    """
    Residual convolutional neural receiver.

    This neural receiver is fed with the post-DFT received samples, forming a
    resource grid of size num_ofdm_symbols x fft_size, and computes LLRs on
    the transmitted coded bits.

    Input
    -----
    y : [batch_size, num_rx_antenna, num_ofdm_symbols, num_subcarriers], complex
        Received post-DFT samples.
    no : [batch_size], float
        Noise variance.

    Output
    ------
    llr : [batch_size, num_ofdm_symbols, num_subcarriers, num_bits_per_symbol], float
        LLRs on the transmitted bits.
    """

    def __init__(self, num_conv_channels: int, num_bits_per_symbol: int):
        super().__init__()
        self._num_bits_per_symbol = num_bits_per_symbol

        # Input convolution: 2*num_rx_antenna + 1 input channels (real, imag, noise)
        # For dual polarization: 2*2 + 1 = 5 input channels
        num_input_channels = 2 * 2 + 1  # 2 antennas, real+imag, plus noise
        self._input_conv = nn.Conv2d(
            in_channels=num_input_channels,
            out_channels=num_conv_channels,
            kernel_size=3,
            padding=1,
        )
        # Residual blocks
        self._res_block_1 = ResidualBlock(num_conv_channels)
        self._res_block_2 = ResidualBlock(num_conv_channels)
        self._res_block_3 = ResidualBlock(num_conv_channels)
        self._res_block_4 = ResidualBlock(num_conv_channels)
        # Output convolution
        self._output_conv = nn.Conv2d(
            in_channels=num_conv_channels,
            out_channels=num_bits_per_symbol,
            kernel_size=3,
            padding=1,
        )

    def forward(self, y: torch.Tensor, no: torch.Tensor) -> torch.Tensor:
        # y: [batch, num_rx_ant, num_ofdm_symbols, num_subcarriers]
        # no: [batch]

        # Feeding the noise power in log10 scale helps with the performance
        no = torch.log10(no)

        # Stack real and imaginary components
        y_real = y.real  # [batch, num_rx_ant, time, freq]
        y_imag = y.imag  # [batch, num_rx_ant, time, freq]

        # Reshape noise to [batch, 1, 1, 1] and broadcast to match y's batch size
        batch_size = y.shape[0]
        no = no.view(-1, 1, 1, 1)
        no = no.expand(batch_size, 1, y.shape[2], y.shape[3])  # [batch, 1, time, freq]

        # Concatenate: [batch, 2*num_rx_ant + 1, time, freq]
        z = torch.cat(
            [
                y_real[:, 0:1],  # real part of antenna 0
                y_real[:, 1:2],  # real part of antenna 1
                y_imag[:, 0:1],  # imag part of antenna 0
                y_imag[:, 1:2],  # imag part of antenna 1
                no,
            ],
            dim=1,
        )

        # Input conv
        z = self._input_conv(z)
        # Residual blocks
        z = self._res_block_1(z)
        z = self._res_block_2(z)
        z = self._res_block_3(z)
        z = self._res_block_4(z)
        # Output conv: [batch, num_bits_per_symbol, time, freq]
        z = self._output_conv(z)

        # Transpose to [batch, time, freq, num_bits_per_symbol]
        z = z.permute(0, 2, 3, 1)

        return z

End-to-end System#

The following cell defines the end-to-end system.

Training is done on the bit-metric decoding (BMD) rate which is computed from the transmitted bits and LLRs:

\begin{equation} R = 1 - \frac{1}{SNMK} \sum_{s = 0}^{S-1} \sum_{n = 0}^{N-1} \sum_{m = 0}^{M-1} \sum_{k = 0}^{K-1} \texttt{BCE} \left( B_{s,n,m,k}, \texttt{LLR}_{s,n,m,k} \right) \end{equation}

where

  • \(S\) is the batch size

  • \(N\) the number of subcarriers

  • \(M\) the number of OFDM symbols

  • \(K\) the number of bits per symbol

  • \(B_{s,n,m,k}\) the \(k^{th}\) coded bit transmitted on the resource element \((n,m)\) and for the \(s^{th}\) batch example

  • \(\texttt{LLR}_{s,n,m,k}\) the LLR (logit) computed by the neural receiver corresponding to the \(k^{th}\) coded bit transmitted on the resource element \((n,m)\) and for the \(s^{th}\) batch example

  • \(\texttt{BCE} \left( \cdot, \cdot \right)\) the binary cross-entropy in log base 2

Because no outer code is required at training, the outer encoder and decoder are not used at training to reduce computational complexity.

The BMD rate is known to be an achievable information rate for BICM systems, which motivates its used as objective function [4].

[9]:
## Transmitter
binary_source = BinarySource()
mapper = Mapper("qam", num_bits_per_symbol)
rg_mapper = ResourceGridMapper(resource_grid)

## Channel
cdl = CDL(cdl_model, delay_spread, carrier_frequency,
          ut_antenna, bs_array, "uplink", min_speed=speed)
channel = OFDMChannel(cdl, resource_grid, normalize_channel=True, return_channel=True)

## Receiver
neural_receiver = NeuralReceiver(num_conv_channels, num_bits_per_symbol).to(device)
rg_demapper = ResourceGridDemapper(resource_grid, stream_manager) # Used to extract data-carrying resource elements

The following cell performs one forward step through the end-to-end system:

[10]:
batch_size = 64
ebno_db = torch.full((batch_size,), 5.0)
no = ebnodb2no(ebno_db, num_bits_per_symbol, coderate)


## Transmitter
# Generate codewords
c = binary_source([batch_size, 1, 1, n])
print("c shape: ", c.shape)
# Map bits to QAM symbols
x = mapper(c)
print("x shape: ", x.shape)
# Map the QAM symbols to a resource grid
x_rg = rg_mapper(x)
print("x_rg shape: ", x_rg.shape)

######################################
## Channel
# A batch of new channel realizations is sampled and applied at every inference
no_ = expand_to_rank(no, x_rg.ndim)
y, _ = channel(x_rg, no_)
print("y shape: ", y.shape)

######################################
## Receiver
# The neural receiver computes LLRs from the frequency domain received symbols and N0
y = y.squeeze(1)
llr = neural_receiver(y, no)
print("llr shape: ", llr.shape)
# Reshape the input to fit what the resource grid demapper is expected
llr = insert_dims(llr, 2, 1)
# Extract data-carrying resource elements. The other LLRs are discarded
llr = rg_demapper(llr)
llr = llr.reshape(batch_size, 1, 1, n)
print("Post RG-demapper LLRs: ", llr.shape)
c shape:  torch.Size([64, 1, 1, 2784])
x shape:  torch.Size([64, 1, 1, 1392])
x_rg shape:  torch.Size([64, 1, 1, 14, 128])
y shape:  torch.Size([64, 1, 2, 14, 128])
llr shape:  torch.Size([64, 14, 128, 2])
Post RG-demapper LLRs:  torch.Size([64, 1, 1, 2784])

The BMD rate is computed from the LLRs and transmitted bits as follows:

[11]:
bce = torch.nn.functional.binary_cross_entropy_with_logits(llr, c.float(), reduction='none')
bce = bce.mean()
rate = 1.0 - bce / math.log(2.)
print(f"Rate: {rate:.2E} bit")
Rate: -2.92E-02 bit

The rate is very poor (negative values means 0 bit) as the neural receiver is not trained.

End-to-end System as a Sionna Block#

The following Sionna block implements the three considered end-to-end systems (perfect CSI baseline, LS estimation baseline, and neural receiver).

When instantiating the end-to-end model, the parameter system is used to specify the system to setup, and the parameter training is used to specified if the system is instantiated to be trained or to be evaluated. The training parameter is only relevant when the neural receiver is used.

At each call of this model:

  • A batch of codewords is randomly sampled, modulated, and mapped to resource grids to form the channel inputs

  • A batch of channel realizations is randomly sampled and applied to the channel inputs

  • The receiver is executed on the post-DFT received samples to compute LLRs on the coded bits. Which receiver is executed (baseline with perfect CSI knowledge, baseline with LS estimation, or neural receiver) depends on the specified system parameter.

  • If not training, the outer decoder is applied to reconstruct the information bits

  • If training, the BMD rate is estimated over the batch from the LLRs and the transmitted bits

[12]:
class E2ESystem(Block):
    r"""
    Sionna Block that implements the end-to-end system

    As the three considered end-to-end systems (perfect CSI baseline, LS estimation baseline, and neural receiver) share most of
    the link components (transmitter, channel model, outer code...), they are implemented using the same end-to-end model.

    When instantiating the Sionna block, the parameter ``system`` is used to specify the system to setup,
    and the parameter ``training`` is used to specified if the system is instantiated to be trained or to be evaluated.
    The ``training`` parameter is only relevant when the neural

    At each call of this model:
    * A batch of codewords is randomly sampled, modulated, and mapped to resource grids to form the channel inputs
    * A batch of channel realizations is randomly sampled and applied to the channel inputs
    * The receiver is executed on the post-DFT received samples to compute LLRs on the coded bits.
      Which receiver is executed (baseline with perfect CSI knowledge, baseline with LS estimation, or neural receiver) depends
      on the specified ``system`` parameter.
    * If not training, the outer decoder is applied to reconstruct the information bits
    * If training, the BMD rate is estimated over the batch from the LLRs and the transmitted bits

    Parameters
    -----------
    system : str
        Specify the receiver to use. Should be one of 'baseline-perfect-csi', 'baseline-ls-estimation' or 'neural-receiver'

    training : bool
        Set to `True` if the system is instantiated to be trained. Set to `False` otherwise. Defaults to `False`.
        If the system is instantiated to be trained, the outer encoder and decoder are not instantiated as they are not required for training.
        This significantly reduces the computational complexity of training.
        If training, the bit-metric decoding (BMD) rate is computed from the transmitted bits and the LLRs. The BMD rate is known to be
        an achievable information rate for BICM systems, and therefore training of the neural receiver aims at maximizing this rate.

    Input
    ------
    batch_size : int
        Batch size

    ebno_db : scalar or [batch_size], torch.Tensor
        Eb/N0 in dB.
        At training, a different Eb/N0 should be sampled for each batch example.

    Output
    -------
    If ``training`` is set to `True`, then the output is a single scalar, which is an estimation of the BMD rate computed over the batch. It
    should be used as objective for training.
    If ``training`` is set to `False`, the transmitted information bits and their reconstruction on the receiver side are returned to
    compute the block/bit error rate.
    """

    def __init__(self, system, training=False):
        super().__init__()
        self._system = system
        self._training = training

        ######################################
        ## Transmitter
        self._binary_source = BinarySource()
        # To reduce the computational complexity of training, the outer code is not used when training,
        # as it is not required
        if not training:
            self._encoder = LDPC5GEncoder(k, n)
        self._mapper = Mapper("qam", num_bits_per_symbol)
        self._rg_mapper = ResourceGridMapper(resource_grid)

        ######################################
        ## Channel
        # A 3GPP CDL channel model is used
        cdl = CDL(cdl_model, delay_spread, carrier_frequency,
                  ut_antenna, bs_array, "uplink", min_speed=speed)
        self._channel = OFDMChannel(cdl, resource_grid, normalize_channel=True, return_channel=True)

        ######################################
        ## Receiver
        # Three options for the receiver depending on the value of `system`
        if "baseline" in system:
            if system == 'baseline-perfect-csi': # Perfect CSI
                self._removed_null_subc = RemoveNulledSubcarriers(resource_grid)
            elif system == 'baseline-ls-estimation': # LS estimation
                self._ls_est = LSChannelEstimator(resource_grid, interpolation_type="nn")
            # Components required by both baselines
            self._lmmse_equ = LMMSEEqualizer(resource_grid, stream_manager, )
            self._demapper = Demapper("app", "qam", num_bits_per_symbol)
        elif system == "neural-receiver": # Neural receiver
            self._neural_receiver = NeuralReceiver(num_conv_channels, num_bits_per_symbol).to(device)
            self._rg_demapper = ResourceGridDemapper(resource_grid, stream_manager) # Used to extract data-carrying resource elements
        # To reduce the computational complexity of training, the outer code is not used when training,
        # as it is not required
        if not training:
            self._decoder = LDPC5GDecoder(self._encoder, hard_out=True)

    def forward(self, batch_size, ebno_db):

        # Ensure ebno_db has shape [batch_size]
        ebno_db = ebno_db.expand(batch_size)

        ######################################
        ## Transmitter
        no = ebnodb2no(ebno_db, num_bits_per_symbol, coderate)
        # Outer coding is only performed if not training
        if self._training:
            c = self._binary_source([batch_size, 1, 1, n])
        else:
            b = self._binary_source([batch_size, 1, 1, k])
            c = self._encoder(b)
        # Modulation
        x = self._mapper(c)
        x_rg = self._rg_mapper(x)

        ######################################
        ## Channel
        # A batch of new channel realizations is sampled and applied at every inference
        no_ = expand_to_rank(no, x_rg.ndim)
        y, h = self._channel(x_rg, no_)

        ######################################
        ## Receiver
        # Three options for the receiver depending on the value of ``system``
        if "baseline" in self._system:
            if self._system == 'baseline-perfect-csi':
                h_hat = self._removed_null_subc(h) # Extract non-null subcarriers
                err_var = 0.0 # No channel estimation error when perfect CSI knowledge is assumed
            elif self._system == 'baseline-ls-estimation':
                h_hat, err_var = self._ls_est(y, no) # LS channel estimation with nearest-neighbor
            x_hat, no_eff = self._lmmse_equ(y, h_hat, err_var, no) # LMMSE equalization
            no_eff_ = expand_to_rank(no_eff, x_hat.ndim)
            llr = self._demapper(x_hat, no_eff_) # Demapping
        elif self._system == "neural-receiver":
            # The neural receiver computes LLRs from the frequency domain received symbols and N0
            y = y.squeeze(1)
            llr = self._neural_receiver(y, no)
            llr = insert_dims(llr, 2, 1) # Reshape the input to fit what the resource grid demapper is expected
            llr = self._rg_demapper(llr) # Extract data-carrying resource elements. The other LLrs are discarded
            llr = llr.reshape(batch_size, 1, 1, n) # Reshape the LLRs to fit what the outer decoder is expected

        # Outer coding is not needed if the information rate is returned
        if self._training:
            # Compute and return BMD rate (in bit), which is known to be an achievable
            # information rate for BICM systems.
            # Training aims at maximizing the BMD rate
            bce = torch.nn.functional.binary_cross_entropy_with_logits(llr, c.float(), reduction='none')
            bce = bce.mean()
            rate = 1.0 - bce / math.log(2.)
            return rate
        else:
            # Outer decoding
            b_hat = self._decoder(llr)
            return b, b_hat # Ground truth and reconstructed information bits returned for BER/BLER computation

Evaluation of the Baselines#

We evaluate the BERs achieved by the baselines in the next cell.

Note: Evaluation of the two systems can take a while. Therefore, we provide pre-computed results at the end of this notebook.

[13]:
# Range of SNRs over which the systems are evaluated
ebno_dbs = np.arange(ebno_db_min, # Min SNR for evaluation
                     ebno_db_max, # Max SNR for evaluation
                     0.5) # Step
[14]:
# Dictionary storing the evaluation results
BLER = {}

model = E2ESystem('baseline-perfect-csi')
_,bler = sim_ber(model, ebno_dbs, batch_size=128, num_target_block_errors=100, max_mc_iter=1000, target_bler=1e-4, compile_mode="reduce-overhead")
BLER['baseline-perfect-csi'] = bler.cpu().numpy()

model = E2ESystem('baseline-ls-estimation')
_,bler = sim_ber(model, ebno_dbs, batch_size=128, num_target_block_errors=100, max_mc_iter=1000, target_bler=1e-4, compile_mode="reduce-overhead")
BLER['baseline-ls-estimation'] = bler.cpu().numpy()
/home/faycal/work/sionna-torch/sionna/venv/lib/python3.12/site-packages/torch/_inductor/lowering.py:2156: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
  warnings.warn(
/home/faycal/work/sionna-torch/sionna/venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:321: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
EbNo [dB] |        BER |       BLER |  bit errors |    num bits | block errors |  num blocks | runtime [s] |    status
---------------------------------------------------------------------------------------------------------------------------------------
     -5.0 | 2.5235e-01 | 1.0000e+00 |       44962 |      178176 |          128 |         128 |        47.4 |reached target block errors
     -4.5 | 2.3640e-01 | 1.0000e+00 |       42120 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -4.0 | 2.1780e-01 | 1.0000e+00 |       38806 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -3.5 | 1.9618e-01 | 1.0000e+00 |       34954 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -3.0 | 1.6602e-01 | 1.0000e+00 |       29581 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -2.5 | 1.0070e-01 | 9.8438e-01 |       17943 |      178176 |          126 |         128 |         0.0 |reached target block errors
     -2.0 | 1.8841e-02 | 5.1953e-01 |        6714 |      356352 |          133 |         256 |         0.0 |reached target block errors
     -1.5 | 1.2466e-03 | 3.1875e-02 |        5553 |     4454400 |          102 |        3200 |         0.2 |reached target block errors
     -1.0 | 1.5830e-04 | 2.2258e-03 |        9900 |    62539776 |          100 |       44928 |         2.6 |reached target block errors
     -0.5 | 4.5174e-05 | 5.7812e-04 |        8049 |   178176000 |           74 |      128000 |         7.6 |reached max iterations
      0.0 | 1.8532e-05 | 1.7187e-04 |        3302 |   178176000 |           22 |      128000 |         7.6 |reached max iterations
      0.5 | 4.9221e-06 | 8.5937e-05 |         877 |   178176000 |           11 |      128000 |         7.5 |reached max iterations

Simulation stopped as target BLER is reached @ EbNo = 0.5 dB.

EbNo [dB] |        BER |       BLER |  bit errors |    num bits | block errors |  num blocks | runtime [s] |    status
---------------------------------------------------------------------------------------------------------------------------------------
     -5.0 | 3.8921e-01 | 1.0000e+00 |       69347 |      178176 |          128 |         128 |        20.5 |reached target block errors
     -4.5 | 3.8127e-01 | 1.0000e+00 |       67933 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -4.0 | 3.6546e-01 | 1.0000e+00 |       65116 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -3.5 | 3.5705e-01 | 1.0000e+00 |       63618 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -3.0 | 3.3959e-01 | 1.0000e+00 |       60506 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -2.5 | 3.2678e-01 | 1.0000e+00 |       58225 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -2.0 | 3.1074e-01 | 1.0000e+00 |       55367 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -1.5 | 2.9492e-01 | 1.0000e+00 |       52548 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -1.0 | 2.7710e-01 | 1.0000e+00 |       49372 |      178176 |          128 |         128 |         0.0 |reached target block errors
     -0.5 | 2.5576e-01 | 1.0000e+00 |       45571 |      178176 |          128 |         128 |         0.0 |reached target block errors
      0.0 | 2.3130e-01 | 1.0000e+00 |       41212 |      178176 |          128 |         128 |         0.0 |reached target block errors
      0.5 | 2.0719e-01 | 1.0000e+00 |       36917 |      178176 |          128 |         128 |         0.0 |reached target block errors
      1.0 | 1.6566e-01 | 1.0000e+00 |       29517 |      178176 |          128 |         128 |         0.0 |reached target block errors
      1.5 | 7.8720e-02 | 8.8281e-01 |       14026 |      178176 |          113 |         128 |         0.0 |reached target block errors
      2.0 | 1.0404e-02 | 2.3047e-01 |        7415 |      712704 |          118 |         512 |         0.0 |reached target block errors
      2.5 | 7.0989e-04 | 1.0959e-02 |        9107 |    12828672 |          101 |        9216 |         0.6 |reached target block errors
      3.0 | 1.7045e-04 | 1.6413e-03 |       14456 |    84811776 |          100 |       60928 |         3.6 |reached target block errors
      3.5 | 9.9217e-05 | 7.4219e-04 |       17678 |   178176000 |           95 |      128000 |         7.7 |reached max iterations
      4.0 | 4.2228e-05 | 2.9687e-04 |        7524 |   178176000 |           38 |      128000 |         7.7 |reached max iterations
      4.5 | 2.0272e-05 | 1.7187e-04 |        3612 |   178176000 |           22 |      128000 |         7.7 |reached max iterations
      5.0 | 9.8835e-06 | 7.8125e-05 |        1761 |   178176000 |           10 |      128000 |         7.7 |reached max iterations

Simulation stopped as target BLER is reached @ EbNo = 5.0 dB.

Training the Neural Receiver#

In the next cell, one forward pass is performed within a gradient tape, which enables the computation of gradient and therefore the optimization of the neural network through stochastic gradient descent (SGD).

Note: For an introduction to the implementation of differentiable communication systems and their optimization through SGD and backpropagation with Sionna, please refer to the Part 2 of the Sionna tutorial for Beginners.

[15]:
# The end-to-end system equipped with the neural receiver is instantiated for training.
# When called, it therefore returns the estimated BMD rate
model = E2ESystem('neural-receiver', training=True)

# Sampling a batch of SNRs
ebno_db = torch.empty(1, device=device).uniform_(ebno_db_min, ebno_db_max)
# Forward pass
rate = model(training_batch_size, ebno_db)
# Optimizers minimize loss functions, so we define loss as the negative BMD rate
loss = -rate

Next, one can perform one step of stochastic gradient descent (SGD). The Adam optimizer is used

[16]:
optimizer = torch.optim.Adam(model.parameters())

# Computing and applying gradients
optimizer.zero_grad()
loss.backward()
optimizer.step()

Training consists in looping over SGD steps. The next cell implements a training loop.

At each iteration:

  • A batch of SNRs \(E_b/N_0\) is sampled

  • A forward pass through the end-to-end system is performed within a gradient tape

  • The gradients are computed using the gradient tape, and applied using the Adam optimizer

  • The achieved BMD rate is periodically shown

After training, the weights of the models are saved in a file

Note: Training can take a while. Therefore, we have made pre-trained weights available. Do not execute the next cell if you don’t want to train the model from scratch.

[17]:
training = True # Change to True to train your own model
if training:
    model = torch.compile(E2ESystem('neural-receiver', training=True), mode="reduce-overhead")

    optimizer = torch.optim.Adam(model.parameters())

    for i in range(num_training_iterations):
        # Sampling a batch of SNRs
        ebno_db = torch.empty(1, device=device).uniform_(ebno_db_min, ebno_db_max)
        # Forward pass
        rate = model(training_batch_size, ebno_db)
        # Optimizers minimize loss functions, so we define loss as the negative BMD rate
        loss = -rate
        # Computing and applying gradients
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Periodically printing the progress
        if i % 100 == 0:
            print('Iteration {}/{}  Rate: {:.4f} bit'.format(i, num_training_iterations, rate.item()), end='\r')

    # Save the weights in a file (access _orig_mod since model is wrapped by torch.compile)
    torch.save(model._orig_mod._neural_receiver.state_dict(), model_weights_path)
Iteration 29900/30000  Rate: 0.6487 bit

Evaluation of the Neural Receiver#

The next cell evaluates the neural receiver.

Note: Evaluation of the system can take a while and requires having the trained weights of the neural receiver.

[18]:
model = E2ESystem('neural-receiver')

# Run one inference to build the layers and load the weights
model(1, torch.tensor(10.0))
model._neural_receiver.load_state_dict(torch.load(model_weights_path))
model.eval()

# Evaluations
_, bler = sim_ber(model, ebno_dbs, batch_size=128, num_target_block_errors=100, max_mc_iter=1000, target_bler=1e-4)
BLER['neural-receiver'] = bler.cpu().numpy()

EbNo [dB] |        BER |       BLER |  bit errors |    num bits | block errors |  num blocks | runtime [s] |    status
---------------------------------------------------------------------------------------------------------------------------------------
     -5.0 | 2.5959e-01 | 1.0000e+00 |       46252 |      178176 |          128 |         128 |         0.1 |reached target block errors
     -4.5 | 2.4163e-01 | 1.0000e+00 |       43052 |      178176 |          128 |         128 |         0.1 |reached target block errors
     -4.0 | 2.2403e-01 | 1.0000e+00 |       39917 |      178176 |          128 |         128 |         0.1 |reached target block errors
     -3.5 | 2.0239e-01 | 1.0000e+00 |       36061 |      178176 |          128 |         128 |         0.1 |reached target block errors
     -3.0 | 1.7484e-01 | 1.0000e+00 |       31153 |      178176 |          128 |         128 |         0.1 |reached target block errors
     -2.5 | 1.2488e-01 | 1.0000e+00 |       22250 |      178176 |          128 |         128 |         0.1 |reached target block errors
     -2.0 | 4.0786e-02 | 7.4219e-01 |       14534 |      356352 |          190 |         256 |         0.1 |reached target block errors
     -1.5 | 2.5356e-03 | 9.2882e-02 |        4066 |     1603584 |          107 |        1152 |         0.6 |reached target block errors
     -1.0 | 6.6897e-04 | 8.4005e-03 |       11085 |    16570368 |          100 |       11904 |         6.0 |reached target block errors
     -0.5 | 3.2312e-04 | 3.0048e-03 |       14969 |    46325760 |          100 |       33280 |        16.8 |reached target block errors
      0.0 | 1.4118e-04 | 1.1748e-03 |       16728 |   118487040 |          100 |       85120 |        43.0 |reached target block errors
      0.5 | 8.5893e-05 | 6.4844e-04 |       15304 |   178176000 |           83 |      128000 |        64.6 |reached max iterations
      1.0 | 3.7205e-05 | 3.2031e-04 |        6629 |   178176000 |           41 |      128000 |        64.6 |reached max iterations
      1.5 | 2.9448e-05 | 2.2656e-04 |        5247 |   178176000 |           29 |      128000 |        64.6 |reached max iterations
      2.0 | 1.9453e-05 | 1.4062e-04 |        3466 |   178176000 |           18 |      128000 |        64.6 |reached max iterations
      2.5 | 1.1792e-05 | 6.2500e-05 |        2101 |   178176000 |            8 |      128000 |        64.6 |reached max iterations

Simulation stopped as target BLER is reached @ EbNo = 2.5 dB.

Pre-computed Results#

Finally, we plot the BLERs

[19]:
plt.figure(figsize=(10,6))
# Baseline - Perfect CSI
plt.semilogy(ebno_dbs, BLER['baseline-perfect-csi'], 'o-', c=f'C0', label=f'Baseline - Perfect CSI')
# Baseline - LS Estimation
plt.semilogy(ebno_dbs, BLER['baseline-ls-estimation'], 'x--', c=f'C1', label=f'Baseline - LS Estimation')
# Neural receiver
plt.semilogy(ebno_dbs, BLER['neural-receiver'], 's-.', c=f'C2', label=f'Neural receiver')
#
plt.xlabel(r"$E_b/N_0$ (dB)")
plt.ylabel("BLER")
plt.grid(which="both")
plt.ylim((1e-4, 1.0))
plt.legend()
plt.tight_layout()
../../../../build/doctrees/nbsphinx/phy_tutorials_notebooks_Neural_Receiver_49_0.png

References#

[1] M. Honkala, D. Korpi and J. M. J. Huttunen, “DeepRx: Fully Convolutional Deep Learning Receiver,” in IEEE Transactions on Wireless Communications, vol. 20, no. 6, pp. 3925-3940, June 2021, doi: 10.1109/TWC.2021.3054520.

[2] F. Ait Aoudia and J. Hoydis, “End-to-end Learning for OFDM: From Neural Receivers to Pilotless Communication,” in IEEE Transactions on Wireless Communications, doi: 10.1109/TWC.2021.3101364.

[3] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, “Deep Residual Learning for Image Recognition”, Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 770-778

[4] G. Böcherer, “Achievable Rates for Probabilistic Shaping”, arXiv:1707.01134, 2017.