Open In Colab

Training#

This demonstrates our training process for a Generative Adversarial Network (GAN) to generate stellar core collapse signals. The GAN is trained on a dataset of 1764 stellar core collapse signals, each with 256 timestamps.

%load_ext autoreload
%autoreload 2
%matplotlib inline
! pip install starccato -q

Load training data#

from starccato.training.training_data import TrainingData

training_data = TrainingData()
training_data.summary()
Signal Dataset mean: -0.516 +/- 39.724
Signal Dataset scaling factor (to match noise in generator): 5
Signal Dataset shape: (256, 1684)
from starccato.plotting import overplot_signals

signals = training_data.standardize(training_data.raw_signals)[:, 130:-50]
fig = overplot_signals(signals, color="k", alpha=0.01, linewidth=0.2)
fig.axes[0].set_axis_off()
fig.axes[0].grid(False)
_ = fig.suptitle("Standardised Stellar Core Collapse Signals [Training Data]")
from starccato.plotting import plot_stacked_signals

fig = plot_stacked_signals(signals, norm="linear", cmap="inferno_r")
fig, axes = training_data.plot_waveforms(standardised=True)

Train GAN and discriminator models#

For details on the model, see the model architecture.

For details on the training, see the training code.

from starccato.training import train

result = train(num_epochs=8)

Plots#

Signals#

Signals

Training Loss plot#

Losses

Gradients#

Generator Gradients

Discriminator Gradients

Generator Gradients

Discriminator Gradients