Documentation ¶
Overview ¶
Package diffusion contains an example diffusion model, trained on Oxford Flowers 102 dataset.
See the accompanying jupyter notebook for some results, and how to call it.
The subdirectory `train/` has the command line binary that can be executed for training.
Based on the Keras tutorial in https://keras.io/examples/generative/ddim/, and recreated for GoMLX, with many small modifications.
Flags are defined on the files that use them, so they are spread over the code.
Index ¶
- Constants
- Variables
- func ActivationLayer(x *Node) *Node
- func AssertNoError(err error)
- func CompareModelPlots(modelNames ...string)
- func CreateInMemoryDatasets() (trainDS, validationDS *data.InMemoryDataset)
- func Denoise(ctx *context.Context, noisyImages, signalRatios, noiseRatios, flowerIds *Node) (predictedImages, predictedNoises *Node)
- func DenoiseStepGraph(ctx *context.Context, ...) (predictedImages, nextNoisyImages *Node)
- func DenormalizeImages(images *Node) *Node
- func DiffusionSchedule(times *Node, clipStart bool) (signalRatios, noiseRatios *Node)
- func DisplayImagesAcrossDiffusionSteps(numImages int, numDiffusionSteps int, displayEveryNSteps int)
- func DisplayTrainingPlots()
- func DownBlock(ctx *context.Context, x *Node, skips []*Node, numBlocks, outputChannels int) (*Node, []*Node)
- func DropdownFlowerTypes(cacheKey string, ctx *context.Context, numImages, numDiffusionSteps int, ...) *common.Latch
- func FlowerTypeEmbedding(ctx *context.Context, flowerIds, x *Node) *Node
- func GenerateFlowerIds(numImages int) tensor.Tensor
- func GenerateImagesOfAllFlowerTypes(numDiffusionSteps int) (predictedImages tensor.Tensor)
- func GenerateImagesOfFlowerType(numImages int, flowerType int32, numDiffusionSteps int) (predictedImages tensor.Tensor)
- func GenerateNoise(numImages int) tensor.Tensor
- func GetManager() *Manager
- func ImagesToHtml(images []image.Image) string
- func Init()
- func LoadCheckpointToContext(ctx *context.Context) (checkpoint *checkpoints.Handler, noise, flowerIds tensor.Tensor)
- func MustNoError[T any](value T, err error) T
- func NormalizationValues() (mean, stddev tensor.Tensor)
- func NormalizeLayer(ctx *context.Context, x *Node) *Node
- func PlotImages(images []image.Image)
- func PlotImagesTensor(imagesT tensor.Tensor)
- func PlotModelEvolution(imagesPerSample int, animate bool)
- func PreprocessImages(images *Node, normalize bool) *Node
- func ResidualBlock(ctx *context.Context, x *Node, outputChannels int) *Node
- func SinusoidalEmbedding(x *Node) *Node
- func SliderDiffusionSteps(cacheKey string, ctx *context.Context, numImages int, numDiffusionSteps int, ...) *common.Latch
- func TrainModel()
- func TrainingModelGraph(ctx *context.Context, _ any, inputs []*Node) []*Node
- func TrainingMonitor(checkpoint *checkpoints.Handler, loop *train.Loop, metrics []tensor.Tensor, ...) error
- func TransformerBlock(ctx *context.Context, x *Node) *Node
- func UNetModelGraph(ctx *context.Context, noisyImages, noiseVariances, flowerIds *Node) *Node
- func UpBlock(ctx *context.Context, x *Node, skips []*Node, numBlocks, outputChannels int) (*Node, []*Node)
- type ImagesGenerator
- type KidGenerator
Constants ¶
const ( NoiseSamplesFile = "noise_samples.tensor" FlowerIdsSamplesFile = "flower_ids_samples.tensor" GeneratedSamplesPrefix = "generated_samples_" )
Variables ¶
var ( // DataDir is the directory where data is saved. This includes training data, saved models (checkpoints) // and intermediary data (like normalization constants). DataDir string // Directory where resources are stored // ImageSize used everywhere. Images are square, the same size is used for height and width. ImageSize int // BatchSize used for training. BatchSize int // EvalBatchSize used for evaluation. The size is only affected by the limitation of the accelerator memory. EvalBatchSize int // PartitionSeed used for the dataset splitting into train/validation. PartitionSeed = int64(42) // Some arbitrary number. // ValidationFraction where the rest is used for training. There is no test set. ValidationFraction = 0.2 // 20% of data. )
var (
DType shapes.DType
)
var (
// NormalizationInfoFile where NormalizationValues results are saved (and loaded from).
NormalizationInfoFile = "normalization_data.bin"
)
Functions ¶
func AssertNoError ¶
func AssertNoError(err error)
AssertNoError will panic if the given error is not nil.
func CompareModelPlots ¶
func CompareModelPlots(modelNames ...string)
CompareModelPlots display several model metrics on the same plots.
func CreateInMemoryDatasets ¶
func CreateInMemoryDatasets() (trainDS, validationDS *data.InMemoryDataset)
CreateInMemoryDatasets returns a train and a validation InMemoryDataset.
func Denoise ¶
func Denoise(ctx *context.Context, noisyImages, signalRatios, noiseRatios, flowerIds *Node) ( predictedImages, predictedNoises *Node)
Denoise tries to separate the noise from the image. It is given the signal and noise ratios.
func DenoiseStepGraph ¶
func DenoiseStepGraph(ctx *context.Context, noisyImages, diffusionTime, nextDiffusionTime, flowerIds *Node) ( predictedImages, nextNoisyImages *Node)
DenoiseStepGraph executes one step of separating the noise and images from noisy images.
func DenormalizeImages ¶
func DenormalizeImages(images *Node) *Node
DenormalizeImages revert images back to the 0 - 255 range. But it keeps it as float, it doesn't convert it back to bytes (== `shapes.S8` or `uint8`)
func DiffusionSchedule ¶
func DiffusionSchedule(times *Node, clipStart bool) (signalRatios, noiseRatios *Node)
DiffusionSchedule calculates a ratio of noise and image that needs to be mixed, given the diffusion time `~ [0.0, 1.0]`. Diffusion time 0 means minimum diffusion -- the signal ratio will be set to -max_signal_ratio, default to 0.95 -- and diffusion time 1.0 means almost all noise -- the signal ratio will be set to -min_signal_ratio, default to 0.02. The returned ratio has the sum of their square total 1.
Typically, the shape of `time` and the returned ratios will be `[batch_size, 1, 1, 1]`.
If `clipStart` is set to false, the signal ratio is not clipped, and it can go all the way to 1.0.
func DisplayImagesAcrossDiffusionSteps ¶ added in v0.7.0
func DisplayImagesAcrossDiffusionSteps(numImages int, numDiffusionSteps int, displayEveryNSteps int)
DisplayImagesAcrossDiffusionSteps using reverse diffusion. If displayEveryNSteps is not 0, it will display intermediary results every n results -- it also displays the initial noise and final image.
Plotting results only work if in a Jupyter (with GoNB kernel) notebook.
func DisplayTrainingPlots ¶
func DisplayTrainingPlots()
DisplayTrainingPlots simply display the training plots of a model, without any training.
func DownBlock ¶
func DownBlock(ctx *context.Context, x *Node, skips []*Node, numBlocks, outputChannels int) (*Node, []*Node)
DownBlock applies `numBlocks` residual blocks followed by an average pooling of size 2, halfing the spatial size. It pushes the values between each residual blocks to the `skips` stack, to build the skip connections later.
It returns the transformed `x` and `skips` with newly stacked skip connections.
func DropdownFlowerTypes ¶ added in v0.7.0
func DropdownFlowerTypes(cacheKey string, ctx *context.Context, numImages, numDiffusionSteps int, htmlId string) *common.Latch
DropdownFlowerTypes creates a drop-down that shows images at different diffusion steps.
If `cacheKey` empty, cache is by-passed. Otherwise, try to load images from cache first if available, or save generated images in cache for future use.
func FlowerTypeEmbedding ¶ added in v0.4.1
FlowerTypeEmbedding will if configured (`--flower_type_dim` flag), concatenate a flower embedding to `x`. If `--flower_type_dim==0` it returns `x` unchanged.
func GenerateFlowerIds ¶ added in v0.4.1
GenerateFlowerIds generates random flower ids: this is the type of flowers, one of the 102.
func GenerateImagesOfAllFlowerTypes ¶ added in v0.4.1
GenerateImagesOfAllFlowerTypes takes one random noise, and generate the flower for each of the 102 types.
func GenerateImagesOfFlowerType ¶ added in v0.4.1
func GenerateImagesOfFlowerType(numImages int, flowerType int32, numDiffusionSteps int) (predictedImages tensor.Tensor)
GenerateImagesOfFlowerType is similar to DisplayImagesAcrossDiffusionSteps, but it limits itself to generating images of only one flower type.
func GenerateNoise ¶
GenerateNoise generates random noise that can be used to generate images.
func GetManager ¶
func GetManager() *Manager
func ImagesToHtml ¶ added in v0.7.0
ImagesToHtml converts slice of images to a list of images side-by-side in HTML format, that can be easily displayed.
func Init ¶
func Init()
Init will parse and move flag values into global variables. It's idempotent, and can be called multiple times.
func LoadCheckpointToContext ¶
func LoadCheckpointToContext(ctx *context.Context) (checkpoint *checkpoints.Handler, noise, flowerIds tensor.Tensor)
LoadCheckpointToContext and attaches to it, so that it gets saved.
It also loads the noise (+flowerIds) samples for this model. The idea is that at each evaluation checkpoint we generate the images for these fixed noise samples, and one can observe the model quality evolving.
For new models -- whose directory didn't previously exist, it does 2 things:
- It creates the noise + flowerIds samples used to monitor the model quality evolving.
- It creates the file `args.txt` with a copy of the arguments used to create the model. Later, if the same model is used, it checks that the arguments match (with some exceptions), and warns about mismatches.
func MustNoError ¶
MustNoError takes a value and an error. It returns the value is err is nil, otherwise it panics.
func NormalizationValues ¶
NormalizationValues for the flowers dataset -- only look at the training data.
func NormalizeLayer ¶
NormalizeLayer behaves according to the `--norm` flag. It works with `x` with rank 4 and rank 3.
func PlotImages ¶
PlotImages all in one row. The image size in the HTML is set to the value given.
This only works in a Jupyter (GoNB kernel) notebook.
func PlotImagesTensor ¶
PlotImagesTensor plots images in tensor format, all in one row. It assumes image's MaxValue of 255.
This only works in a Jupyter (GoNB kernel) notebook.
func PlotModelEvolution ¶
PlotModelEvolution plots the saved sampled generated images of a model in the current configured checkpoint.
It outputs at most imagesPerSample per checkpoint sampled.
func PreprocessImages ¶
func PreprocessImages(images *Node, normalize bool) *Node
PreprocessImages converts the image to the model `DType` and optionally normalizes it according to `NormalizationValues()` calculated on the training dataset.
func ResidualBlock ¶
ResidualBlock on the input with `outputChannels` (axis 3) in the output.
The parameter `x` must be of rank 4, shaped `[batchSize, height, width, channels]`.
func SinusoidalEmbedding ¶
func SinusoidalEmbedding(x *Node) *Node
SinusoidalEmbedding provides embeddings of `x` for different frequencies. This is applied to the variance of the noise, and facilitates the NN model to easily map different ranges of the signal/noise ratio.
func SliderDiffusionSteps ¶ added in v0.7.0
func SliderDiffusionSteps(cacheKey string, ctx *context.Context, numImages int, numDiffusionSteps int, htmlId string) *common.Latch
SliderDiffusionSteps creates and animates a slider that shows images at different diffusion steps. It handles the slider on a separate goroutine. Trigger the returned latch to stop it.
If `cacheKey` empty, cache is by-passed. Otherwise, try to load images from cache first if available, or save generated images in cache for future use.
func TrainModel ¶
func TrainModel()
func TrainingModelGraph ¶
TrainingModelGraph builds the model for training and evaluation.
func TrainingMonitor ¶
func TrainingMonitor(checkpoint *checkpoints.Handler, loop *train.Loop, metrics []tensor.Tensor, plotter stdplots.Plotter, evalDatasets []train.Dataset, generator *ImagesGenerator, kid *KidGenerator) error
TrainingMonitor is periodically called during training, and is used to report metrics and generate sample images at the current training step.
func TransformerBlock ¶
TransformerBlock takes embed shaped `[batchDim, spatialDim, embedDim]`, where the spatial dimension is the combined dimensions of the image.
func UNetModelGraph ¶
UNetModelGraph builds the U-Net model.
Parameters:
- noisyImages: image shaped `[batch_size, size, size, channels=3]`.
- noiseVariance: One value per example in the batch, shaped `[batch_size, 1, 1, 1]`.
- numChannelsList (static hyperparameter): number of channels to use in the model. For each value `numBlocks` are applied and then the image is pooled and reduced by a factor of 2 -- later to be up-sampled again. So at most `log2(size)` values.
- numBlocks (static hyperparameter): number of blocks to use per numChannelsList element.
func UpBlock ¶
func UpBlock(ctx *context.Context, x *Node, skips []*Node, numBlocks, outputChannels int) (*Node, []*Node)
UpBlock is the counter-part to DownBlock. It performs up-scaling convolutions and connects skip-connections popped from `skips`.
It returns `x` and `skips` after popping the consumed skip connections.
Types ¶
type ImagesGenerator ¶
type ImagesGenerator struct {
// contains filtered or unexported fields
}
ImagesGenerator given noise and the flowerIds. Use it with NewImagesGenerator.
func NewImagesGenerator ¶
func NewImagesGenerator(ctx *context.Context, noise, flowerIds tensor.Tensor, numDiffusionSteps int) *ImagesGenerator
NewImagesGenerator generates flowers given initial `noise` and `flowerIds`, in `numDiffusionSteps`. Typically, 20 diffusion steps will suffice.
func (*ImagesGenerator) Generate ¶
func (g *ImagesGenerator) Generate() (batchedImages tensor.Tensor)
Generate images from the original noise.
It can be called multiple times if the context changed, if the model was further trained. Otherwise, it will always return the same images.
func (*ImagesGenerator) GenerateEveryN ¶ added in v0.7.0
func (g *ImagesGenerator) GenerateEveryN(n int) (predictedImages []tensor.Tensor, diffusionSteps []int, diffusionTimes []float64)
GenerateEveryN images from the original noise. While iteratively undoing diffusion, it will keep every `n` intermediary images. It will always return the last image generated.
It can be called multiple times if the context changed, if the model was further trained. Otherwise, it will always return the same images.
It returns a slice of batches of images, one batch per intermediary diffusion step, a slice with the step used for each batch, and another slice with the "diffusionTime" of the intermediary images (it will be 1.0 for the last)
type KidGenerator ¶
type KidGenerator struct {
// contains filtered or unexported fields
}
KidGenerator generates the [Kernel Inception Distance (KID)](https://arxiv.org/abs/1801.01401) metric.
func NewKidGenerator ¶
func NewKidGenerator(ctx *context.Context, evalDS train.Dataset, numDiffusionStep int) *KidGenerator
NewKidGenerator allows to generate the Kid metric. The ctx passed is the context for the diffusion model. It uses a different context for the InceptionV3 KID metric, so that it's weights are not included in the generator model.
func (*KidGenerator) Eval ¶
func (kg *KidGenerator) Eval() (metric tensor.Tensor)
func (*KidGenerator) EvalStepGraph ¶
func (kg *KidGenerator) EvalStepGraph(ctx *context.Context, allImages []*Node) (metric *Node)