Part 4: Toward Learned Receivers#
This tutorial will guide you through Sionna, from its basic principles to the implementation of a point-to-point link with a 5G NR compliant code and a 3GPP channel model. You will also learn how to write custom trainable layers by implementing a state of the art neural receiver, and how to train and evaluate end-to-end communication systems.
The tutorial is structured in four notebooks:
Part I: Getting started with Sionna
Part II: Differentiable Communication Systems
Part III: Advanced Link-level Simulations
Part IV: Toward Learned Receivers
The official documentation provides key material on how to use Sionna and how its components are implemented.
Imports#
[ ]:
# Import Sionna
try:
import sionna.phy
except ImportError as e:
import os
import sys
if 'google.colab' in sys.modules:
# Install Sionna in Google Colab
print("Installing Sionna and restarting the runtime. Please run the cell again.")
os.system("pip install sionna")
os.kill(os.getpid(), 5)
else:
raise e
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# For plotting
%matplotlib inline
import matplotlib.pyplot as plt
# Sionna imports
from sionna.phy import Block
from sionna.phy.channel.tr38901 import Antenna, AntennaArray, CDL
from sionna.phy.channel import OFDMChannel
from sionna.phy.mimo import StreamManagement
from sionna.phy.ofdm import (ResourceGrid, ResourceGridMapper, LSChannelEstimator,
LMMSEEqualizer, ResourceGridDemapper)
from sionna.phy.utils import ebnodb2no, insert_dims, expand_to_rank, PlotBER
from sionna.phy.fec.ldpc import LDPC5GEncoder, LDPC5GDecoder
from sionna.phy.mapping import Mapper, Demapper, BinarySource
# Set seed for reproducible random number generation
sionna.phy.config.seed = 42
device = sionna.phy.config.device
Simulation Parameters#
[2]:
# Bit per channel use
NUM_BITS_PER_SYMBOL = 2 # QPSK
# Minimum value of Eb/N0 [dB] for simulations
EBN0_DB_MIN = -3.0
# Maximum value of Eb/N0 [dB] for simulations
EBN0_DB_MAX = 5.0
# How many examples are processed by Sionna in parallel
BATCH_SIZE = 128
# Coding rate
CODERATE = 0.5
# Define the number of UT and BS antennas
NUM_UT = 1
NUM_BS = 1
NUM_UT_ANT = 1
NUM_BS_ANT = 2
# The number of transmitted streams is equal to the number of UT antennas
# in both uplink and downlink
NUM_STREAMS_PER_TX = NUM_UT_ANT
# Create an RX-TX association matrix.
# RX_TX_ASSOCIATION[i,j]=1 means that receiver i gets at least one stream
# from transmitter j. Depending on the transmission direction (uplink or downlink),
# the role of UT and BS can change.
# For example, considering a system with 2 RX and 4 TX, the RX-TX
# association matrix could be
# [ [1 , 1, 0, 0],
# [0 , 0, 1, 1] ]
# which indicates that the RX 0 receives from TX 0 and 1, and RX 1 receives from
# TX 2 and 3.
#
# In this notebook, as we have only a single transmitter and receiver,
# the RX-TX association matrix is simply:
RX_TX_ASSOCIATION = np.array([[1]])
# Instantiate a StreamManagement object
# This determines which data streams are determined for which receiver.
# In this simple setup, this is fairly easy. However, it can get more involved
# for simulations with many transmitters and receivers.
STREAM_MANAGEMENT = StreamManagement(RX_TX_ASSOCIATION, NUM_STREAMS_PER_TX)
# Resource grid configuration
NUM_OFDM_SYMBOLS = 14
FFT_SIZE = 76
RESOURCE_GRID = ResourceGrid(num_ofdm_symbols=NUM_OFDM_SYMBOLS,
fft_size=FFT_SIZE,
subcarrier_spacing=30e3,
num_tx=NUM_UT,
num_streams_per_tx=NUM_STREAMS_PER_TX,
cyclic_prefix_length=6,
pilot_pattern="kronecker",
pilot_ofdm_symbol_indices=[2, 11])
# Carrier frequency in Hz.
CARRIER_FREQUENCY = 2.6e9
# Antenna setting
UT_ARRAY = Antenna(polarization="single",
polarization_type="V",
antenna_pattern="38.901",
carrier_frequency=CARRIER_FREQUENCY)
BS_ARRAY = AntennaArray(num_rows=1,
num_cols=int(NUM_BS_ANT / 2),
polarization="dual",
polarization_type="cross",
antenna_pattern="38.901",
carrier_frequency=CARRIER_FREQUENCY)
# Nominal delay spread in [s]. Please see the CDL documentation
# about how to choose this value.
DELAY_SPREAD = 100e-9
# The `direction` determines if the UT or BS is transmitting.
# In the `uplink`, the UT is transmitting.
DIRECTION = "uplink"
# Suitable values are ["A", "B", "C", "D", "E"]
CDL_MODEL = "C"
# UT speed [m/s]. BSs are always assumed to be fixed.
# The direction of travel will chosen randomly within the x-y plane.
SPEED = 10.0
# Configure a channel impulse reponse (CIR) generator for the CDL model.
CDL_CHANNEL = CDL(CDL_MODEL,
DELAY_SPREAD,
CARRIER_FREQUENCY,
UT_ARRAY,
BS_ARRAY,
DIRECTION,
min_speed=SPEED)
# Number of coded bits and information bits
N = int(RESOURCE_GRID.num_data_symbols * NUM_BITS_PER_SYMBOL) # Number of coded bits
K = int(N * CODERATE) # Number of information bits
Implemention of an Advanced Neural Receiver#
We will implement a state-of-the-art neural receiver that operates over the entire resource grid of received symbols.
The neural receiver computes LLRs on the coded bits from the received resource grid of frequency-domain baseband symbols.
As shown in the following figure, the neural receiver substitutes to the channel estimator, equalizer, and demapper.
As in [1] and [2], a neural receiver using residual convolutional layers is implemented.
Convolutional layers are leveraged to efficienly process the 2D resource grid that is fed as an input to the neural receiver.
Residual (skip) connections are used to avoid gradient vanishing [3].
For convenience, a PyTorch module that implements a residual block is first defined. The neural receiver module is built by stacking such blocks. The following figure shows the architecture of the neural receiver.
[3]:
class ResidualBlock(nn.Module):
"""
Convolutional residual block with two convolutional layers, ReLU activation,
layer normalization, and a skip connection.
The number of convolutional channels of the input must match num_conv_channels
for the skip connection to work.
Input shape: [batch_size, num_conv_channels, num_ofdm_symbols, num_subcarriers]
Output shape: [batch_size, num_conv_channels, num_ofdm_symbols, num_subcarriers]
"""
def __init__(self, num_conv_channels: int = 128):
super().__init__()
# Layer normalization over the last three dimensions (C, H, W)
self._layer_norm_1 = nn.LayerNorm([num_conv_channels, NUM_OFDM_SYMBOLS, FFT_SIZE])
self._conv_1 = nn.Conv2d(
in_channels=num_conv_channels,
out_channels=num_conv_channels,
kernel_size=3,
padding=1, # 'same' padding
)
self._layer_norm_2 = nn.LayerNorm([num_conv_channels, NUM_OFDM_SYMBOLS, FFT_SIZE])
self._conv_2 = nn.Conv2d(
in_channels=num_conv_channels,
out_channels=num_conv_channels,
kernel_size=3,
padding=1,
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
z = self._layer_norm_1(inputs)
z = F.relu(z)
z = self._conv_1(z)
z = self._layer_norm_2(z)
z = F.relu(z)
z = self._conv_2(z)
# Skip connection
z = z + inputs
return z
class NeuralReceiver(nn.Module):
"""
Residual convolutional neural receiver.
This neural receiver is fed with the post-DFT received samples, forming a
resource grid of size num_ofdm_symbols x fft_size, and computes LLRs on
the transmitted coded bits.
Input
-----
y : [batch_size, num_rx_antenna, num_ofdm_symbols, num_subcarriers], complex
Received post-DFT samples.
no : [batch_size], float
Noise variance.
Output
------
llr : [batch_size, num_ofdm_symbols, num_subcarriers, num_bits_per_symbol], float
LLRs on the transmitted bits.
"""
def __init__(self, num_conv_channels: int = 128):
super().__init__()
self._num_bits_per_symbol = NUM_BITS_PER_SYMBOL
# Input convolution: 2*num_rx_antenna + 1 input channels (real, imag, noise)
num_input_channels = 2 * NUM_BS_ANT + 1
self._input_conv = nn.Conv2d(
in_channels=num_input_channels,
out_channels=num_conv_channels,
kernel_size=3,
padding=1,
)
# Residual blocks
self._res_block_1 = ResidualBlock(num_conv_channels)
self._res_block_2 = ResidualBlock(num_conv_channels)
self._res_block_3 = ResidualBlock(num_conv_channels)
self._res_block_4 = ResidualBlock(num_conv_channels)
# Output convolution
self._output_conv = nn.Conv2d(
in_channels=num_conv_channels,
out_channels=NUM_BITS_PER_SYMBOL,
kernel_size=3,
padding=1,
)
def forward(self, y: torch.Tensor, no: torch.Tensor) -> torch.Tensor:
# y: [batch, num_rx_ant, num_ofdm_symbols, num_subcarriers]
# no: [batch]
# Feeding the noise power in log10 scale helps with the performance
no = torch.log10(no)
# Stack real and imaginary components
y_real = y.real # [batch, num_rx_ant, time, freq]
y_imag = y.imag # [batch, num_rx_ant, time, freq]
# Reshape noise to [batch, 1, 1, 1] and broadcast
batch_size = y.shape[0]
no = no.view(-1, 1, 1, 1)
no = no.expand(batch_size, 1, y.shape[2], y.shape[3]) # [batch, 1, time, freq]
# Concatenate: [batch, 2*num_rx_ant + 1, time, freq]
# PyTorch Conv2d expects channels-first format
z = torch.cat([y_real, y_imag, no], dim=1)
# Input conv
z = self._input_conv(z)
# Residual blocks
z = self._res_block_1(z)
z = self._res_block_2(z)
z = self._res_block_3(z)
z = self._res_block_4(z)
# Output conv: [batch, num_bits_per_symbol, time, freq]
z = self._output_conv(z)
# Transpose to [batch, time, freq, num_bits_per_symbol]
z = z.permute(0, 2, 3, 1)
return z
The task of the receiver is to jointly solve, for each resource element, NUM_BITS_PER_SYMBOL binary classification problems in order to reconstruct the transmitted bits. Therefore, a natural choice for the loss function is the binary cross-entropy (BCE) applied to each bit and to each received symbol.
Remark: The LLRs computed by the demapper are logits on the transmitted bits, and can therefore be used as-is to compute the BCE without any additional processing. Remark 2: The BCE is closely related to an achieveable information rate for bit-interleaved coded modulation systems [4,5]
The next cell defines an end-to-end communication system using the neural receiver layer.
At initialization, the paramater training indicates if the system is instantiated to be trained (True) or evaluated (False).
If the system is instantiated to be trained, the outer encoder and decoder are not used as they are not required for training. Moreover, the estimated BCE is returned. This significantly reduces the computational complexity of training.
If the system is instantiated to be evaluated, the outer encoder and decoder are used, and the transmited information and corresponding LLRs are returned.
[4]:
class OFDMSystemNeuralReceiver(Block): # Inherits from Sionna Block
"""
End-to-end OFDM system with neural receiver.
Inherits from Sionna Block for lazy building and device/precision management.
"""
def __init__(self, training: bool):
super().__init__() # Must call the parent class initializer
self._training = training
self._k = K
self._n = N
# Transmitter components
self._binary_source = BinarySource()
if not training:
self._encoder = LDPC5GEncoder(K, N)
self._mapper = Mapper("qam", NUM_BITS_PER_SYMBOL)
self._rg_mapper = ResourceGridMapper(RESOURCE_GRID)
# Channel
self._channel = OFDMChannel(CDL_CHANNEL, RESOURCE_GRID, add_awgn=True,
normalize_channel=True, return_channel=False)
# Neural receiver
self._neural_receiver = NeuralReceiver().to(device)
self._rg_demapper = ResourceGridDemapper(RESOURCE_GRID, STREAM_MANAGEMENT)
# Decoder (only for evaluation)
if not training:
self._decoder = LDPC5GDecoder(self._encoder, hard_out=True)
def call(self, batch_size: int, ebno_db: torch.Tensor):
"""
Forward pass through the end-to-end system.
Parameters
----------
batch_size : int
Number of samples in the batch
ebno_db : torch.Tensor
Eb/N0 in dB, shape [batch_size] or scalar
Returns
-------
If training: loss (scalar)
If not training: (bits, bits_hat) tuple
"""
no = ebnodb2no(ebno_db, num_bits_per_symbol=NUM_BITS_PER_SYMBOL,
coderate=CODERATE, resource_grid=RESOURCE_GRID)
# Ensure no has shape [batch_size]
if no.dim() == 0:
no = no.expand(batch_size)
# Transmitter
if self._training:
codewords = self._binary_source([batch_size, NUM_UT, NUM_UT_ANT, self._n])
else:
bits = self._binary_source([batch_size, NUM_UT, NUM_UT_ANT, self._k])
codewords = self._encoder(bits)
x = self._mapper(codewords)
x_rg = self._rg_mapper(x)
# Channel
no_ = expand_to_rank(no, x_rg.ndim)
y = self._channel(x_rg, no_)
# Receiver
y = y.squeeze(1) # Remove num_rx dimension (assuming single receiver)
llr = self._neural_receiver(y, no)
llr = insert_dims(llr, 2, 1) # Add dimensions for rg_demapper
llr = self._rg_demapper(llr)
llr = llr.reshape(batch_size, NUM_UT, NUM_UT_ANT, self._n)
if self._training:
# Compute BCE loss
loss = F.binary_cross_entropy_with_logits(llr, codewords.float())
return loss
else:
bits_hat = self._decoder(llr)
return bits, bits_hat
Training the Neural Receiver#
The next cell implements a training loop of NUM_TRAINING_ITERATIONS iterations.
At each iteration:
A batch of SNRs \(E_b/N_0\) is sampled
A forward pass through the end-to-end system is performed within a gradient tape
The gradients are computed using the gradient tape, and applied using the Adam optimizer
A progress bar is periodically updated to follow the progress of training
After training, the weights of the models are saved in a file using pickle.
Executing the next cell will take quite a while. If you do not want to train your own neural receiver, you can download the weights here and use them later on.
[5]:
train = True # Change to True to train your own model
if train:
# Number of iterations used for training
NUM_TRAINING_ITERATIONS = 30000
# Instantiating the end-to-end model for training
model = OFDMSystemNeuralReceiver(training=True)
# Use torch.compile for faster execution
try:
compiled_model = torch.compile(model, mode="reduce-overhead")
except Exception:
print("torch.compile not available, using eager mode")
compiled_model = model
# Adam optimizer (SGD variant)
optimizer = torch.optim.Adam(model.parameters())
# Training loop
for i in range(NUM_TRAINING_ITERATIONS):
# Sample a batch of SNRs
ebno_db = torch.empty(BATCH_SIZE, device=device).uniform_(EBN0_DB_MIN, EBN0_DB_MAX)
# Forward pass
loss = compiled_model(BATCH_SIZE, ebno_db)
# Computing and applying gradients
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print progress
if i % 100 == 0:
print(f"{i}/{NUM_TRAINING_ITERATIONS} Loss: {loss.item():.2E}", end="\r")
# Save the weights in a file
torch.save(model._neural_receiver.state_dict(), 'weights-ofdm-neuralrx.pt')
29900/30000 Loss: 2.19E-01
Benchmarking the Neural Receiver#
We evaluate the trained model and benchmark it against the previously introduced baselines.
We first define and evaluate the baselines.
[6]:
class OFDMSystem(Block): # Inherits from Sionna Block
"""
End-to-end OFDM system with conventional receiver (LS estimation + LMMSE equalization).
Inherits from Sionna Block for lazy building and device/precision management.
"""
def __init__(self, perfect_csi: bool):
super().__init__() # Must call the parent class initializer
self._perfect_csi = perfect_csi
self._k = K
# Transmitter components
self._binary_source = BinarySource()
self._encoder = LDPC5GEncoder(K, N)
self._mapper = Mapper("qam", NUM_BITS_PER_SYMBOL)
self._rg_mapper = ResourceGridMapper(RESOURCE_GRID)
# Channel (return_channel=True to get channel frequency response)
self._channel = OFDMChannel(CDL_CHANNEL, RESOURCE_GRID, add_awgn=True,
normalize_channel=True, return_channel=True)
# Receiver components
self._ls_est = LSChannelEstimator(RESOURCE_GRID, interpolation_type="nn")
self._lmmse_equ = LMMSEEqualizer(RESOURCE_GRID, STREAM_MANAGEMENT)
self._demapper = Demapper("app", "qam", NUM_BITS_PER_SYMBOL)
self._decoder = LDPC5GDecoder(self._encoder, hard_out=True)
def call(self, batch_size: int, ebno_db: torch.Tensor):
"""
Forward pass through the end-to-end system.
Parameters
----------
batch_size : int
Number of samples in the batch
ebno_db : torch.Tensor
Eb/N0 in dB, shape [batch_size] or scalar
Returns
-------
(bits, bits_hat) tuple
"""
no = ebnodb2no(ebno_db, num_bits_per_symbol=NUM_BITS_PER_SYMBOL,
coderate=CODERATE, resource_grid=RESOURCE_GRID)
# Transmitter
bits = self._binary_source([batch_size, NUM_UT, RESOURCE_GRID.num_streams_per_tx, self._k])
codewords = self._encoder(bits)
x = self._mapper(codewords)
x_rg = self._rg_mapper(x)
# Channel
no_ = expand_to_rank(no, x_rg.ndim)
y, h_freq = self._channel(x_rg, no_)
# Receiver
if self._perfect_csi:
h_hat, err_var = h_freq, 0.0
else:
h_hat, err_var = self._ls_est(y, no)
x_hat, no_eff = self._lmmse_equ(y, h_hat, err_var, no)
no_eff_ = expand_to_rank(no_eff, x_hat.ndim)
llr = self._demapper(x_hat, no_eff_)
bits_hat = self._decoder(llr)
return bits, bits_hat
[7]:
ber_plots = PlotBER("Advanced neural receiver")
baseline_ls = OFDMSystem(False)
ber_plots.simulate(baseline_ls,
ebno_dbs=np.linspace(EBN0_DB_MIN, EBN0_DB_MAX, 20),
batch_size=BATCH_SIZE,
num_target_block_errors=100, # simulate until 100 block errors occurred
legend="Baseline: LS Estimation",
soft_estimates=True,
max_mc_iter=100, # run 100 Monte-Carlo simulations (each with batch_size samples)
show_fig=False)
baseline_pcsi = OFDMSystem(True)
ber_plots.simulate(baseline_pcsi,
ebno_dbs=np.linspace(EBN0_DB_MIN, EBN0_DB_MAX, 20),
batch_size=BATCH_SIZE,
num_target_block_errors=100, # simulate until 100 block errors occurred
legend="Baseline: Perfect CSI",
soft_estimates=True,
max_mc_iter=100, # run 100 Monte-Carlo simulations (each with batch_size samples)
show_fig=False)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
-3.0 | 3.7178e-01 | 1.0000e+00 | 43400 | 116736 | 128 | 128 | 0.1 |reached target block errors
-2.579 | 3.6117e-01 | 1.0000e+00 | 42161 | 116736 | 128 | 128 | 0.0 |reached target block errors
-2.158 | 3.4515e-01 | 1.0000e+00 | 40292 | 116736 | 128 | 128 | 0.0 |reached target block errors
-1.737 | 3.3232e-01 | 1.0000e+00 | 38794 | 116736 | 128 | 128 | 0.0 |reached target block errors
-1.316 | 3.1903e-01 | 1.0000e+00 | 37242 | 116736 | 128 | 128 | 0.0 |reached target block errors
-0.895 | 3.0887e-01 | 1.0000e+00 | 36056 | 116736 | 128 | 128 | 0.0 |reached target block errors
-0.474 | 2.9605e-01 | 1.0000e+00 | 34560 | 116736 | 128 | 128 | 0.0 |reached target block errors
-0.053 | 2.7512e-01 | 1.0000e+00 | 32116 | 116736 | 128 | 128 | 0.0 |reached target block errors
0.368 | 2.6070e-01 | 1.0000e+00 | 30433 | 116736 | 128 | 128 | 0.0 |reached target block errors
0.789 | 2.4271e-01 | 1.0000e+00 | 28333 | 116736 | 128 | 128 | 0.0 |reached target block errors
1.211 | 2.2145e-01 | 1.0000e+00 | 25851 | 116736 | 128 | 128 | 0.0 |reached target block errors
1.632 | 1.9261e-01 | 1.0000e+00 | 22484 | 116736 | 128 | 128 | 0.0 |reached target block errors
2.053 | 1.5621e-01 | 9.9219e-01 | 18235 | 116736 | 127 | 128 | 0.0 |reached target block errors
2.474 | 9.8778e-02 | 8.9062e-01 | 11531 | 116736 | 114 | 128 | 0.0 |reached target block errors
2.895 | 2.6594e-02 | 4.0234e-01 | 6209 | 233472 | 103 | 256 | 0.1 |reached target block errors
3.316 | 1.7423e-03 | 4.3837e-02 | 3661 | 2101248 | 101 | 2304 | 0.6 |reached target block errors
3.737 | 1.2498e-04 | 2.1875e-03 | 1459 | 11673600 | 28 | 12800 | 3.1 |reached max iterations
4.158 | 1.5077e-05 | 7.8125e-05 | 176 | 11673600 | 1 | 12800 | 3.2 |reached max iterations
4.579 | 0.0000e+00 | 0.0000e+00 | 0 | 11673600 | 0 | 12800 | 3.2 |reached max iterations
Simulation stopped as no error occurred @ EbNo = 4.6 dB.
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
-3.0 | 2.1718e-01 | 1.0000e+00 | 25353 | 116736 | 128 | 128 | 0.0 |reached target block errors
-2.579 | 1.9822e-01 | 1.0000e+00 | 23139 | 116736 | 128 | 128 | 0.0 |reached target block errors
-2.158 | 1.7697e-01 | 1.0000e+00 | 20659 | 116736 | 128 | 128 | 0.0 |reached target block errors
-1.737 | 1.3037e-01 | 9.9219e-01 | 15219 | 116736 | 127 | 128 | 0.0 |reached target block errors
-1.316 | 7.4656e-02 | 8.8281e-01 | 8715 | 116736 | 113 | 128 | 0.0 |reached target block errors
-0.895 | 1.4948e-02 | 3.4635e-01 | 5235 | 350208 | 133 | 384 | 0.1 |reached target block errors
-0.474 | 5.6842e-04 | 2.6210e-02 | 2057 | 3618816 | 104 | 3968 | 1.0 |reached target block errors
-0.053 | 3.1181e-05 | 5.4688e-04 | 364 | 11673600 | 7 | 12800 | 3.1 |reached max iterations
0.368 | 2.1673e-05 | 1.5625e-04 | 253 | 11673600 | 2 | 12800 | 3.1 |reached max iterations
0.789 | 0.0000e+00 | 0.0000e+00 | 0 | 11673600 | 0 | 12800 | 3.1 |reached max iterations
Simulation stopped as no error occurred @ EbNo = 0.8 dB.
[7]:
(tensor([2.1718e-01, 1.9822e-01, 1.7697e-01, 1.3037e-01, 7.4656e-02, 1.4948e-02,
5.6842e-04, 3.1181e-05, 2.1673e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00], device='cuda:0'),
tensor([1.0000e+00, 1.0000e+00, 1.0000e+00, 9.9219e-01, 8.8281e-01, 3.4635e-01,
2.6210e-02, 5.4688e-04, 1.5625e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00], device='cuda:0'))
We then instantiate and evaluate the end-to-end system equipped with the neural receiver.
[8]:
# Instantiating the end-to-end model for evaluation
model_neuralrx = OFDMSystemNeuralReceiver(training=False)
# Run one inference to build the model
model_neuralrx(1, torch.tensor(10.0))
# Load the trained weights
model_neuralrx._neural_receiver.load_state_dict(
torch.load('weights-ofdm-neuralrx.pt', weights_only=True)
)
[8]:
<All keys matched successfully>
[9]:
# Computing and plotting BER
ber_plots.simulate(model_neuralrx,
ebno_dbs=np.linspace(EBN0_DB_MIN, EBN0_DB_MAX, 20),
batch_size=BATCH_SIZE,
num_target_block_errors=100,
legend="Neural Receiver",
soft_estimates=True,
max_mc_iter=100,
show_fig=True)
EbNo [dB] | BER | BLER | bit errors | num bits | block errors | num blocks | runtime [s] | status
---------------------------------------------------------------------------------------------------------------------------------------
-3.0 | 2.2578e-01 | 1.0000e+00 | 26357 | 116736 | 128 | 128 | 0.0 |reached target block errors
-2.579 | 2.0454e-01 | 1.0000e+00 | 23877 | 116736 | 128 | 128 | 0.0 |reached target block errors
-2.158 | 1.7964e-01 | 1.0000e+00 | 20971 | 116736 | 128 | 128 | 0.0 |reached target block errors
-1.737 | 1.4847e-01 | 1.0000e+00 | 17332 | 116736 | 128 | 128 | 0.0 |reached target block errors
-1.316 | 9.6765e-02 | 9.4531e-01 | 11296 | 116736 | 121 | 128 | 0.0 |reached target block errors
-0.895 | 3.1104e-02 | 5.3906e-01 | 7262 | 233472 | 138 | 256 | 0.1 |reached target block errors
-0.474 | 2.2294e-03 | 6.5755e-02 | 3123 | 1400832 | 101 | 1536 | 0.6 |reached target block errors
-0.053 | 2.4962e-04 | 4.6094e-03 | 2914 | 11673600 | 59 | 12800 | 4.6 |reached max iterations
0.368 | 1.0948e-04 | 1.0156e-03 | 1278 | 11673600 | 13 | 12800 | 4.6 |reached max iterations
0.789 | 5.1912e-05 | 3.9063e-04 | 606 | 11673600 | 5 | 12800 | 4.6 |reached max iterations
1.211 | 3.5379e-05 | 2.3437e-04 | 413 | 11673600 | 3 | 12800 | 4.7 |reached max iterations
1.632 | 1.7133e-05 | 1.5625e-04 | 200 | 11673600 | 2 | 12800 | 4.6 |reached max iterations
2.053 | 4.1118e-06 | 7.8125e-05 | 48 | 11673600 | 1 | 12800 | 4.7 |reached max iterations
2.474 | 1.0879e-05 | 7.8125e-05 | 127 | 11673600 | 1 | 12800 | 4.6 |reached max iterations
2.895 | 4.1461e-05 | 2.3437e-04 | 484 | 11673600 | 3 | 12800 | 4.6 |reached max iterations
3.316 | 0.0000e+00 | 0.0000e+00 | 0 | 11673600 | 0 | 12800 | 4.6 |reached max iterations
Simulation stopped as no error occurred @ EbNo = 3.3 dB.
[9]:
(tensor([2.2578e-01, 2.0454e-01, 1.7964e-01, 1.4847e-01, 9.6765e-02, 3.1104e-02,
2.2294e-03, 2.4962e-04, 1.0948e-04, 5.1912e-05, 3.5379e-05, 1.7133e-05,
4.1118e-06, 1.0879e-05, 4.1461e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00], device='cuda:0'),
tensor([1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 9.4531e-01, 5.3906e-01,
6.5755e-02, 4.6094e-03, 1.0156e-03, 3.9063e-04, 2.3437e-04, 1.5625e-04,
7.8125e-05, 7.8125e-05, 2.3437e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00], device='cuda:0'))
Conclusion#
We hope you are excited about Sionna - there is much more to be discovered:
PyTorch profiling and debugging tools available
Scaling to multi-GPU simulation is simple
See the available tutorials for more examples
And if something is still missing - the project is open-source: you can modify, add, and extend any component at any time.