diffusion

package
v0.9.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 20, 2024 License: Apache-2.0 Imports: 45 Imported by: 0

README

Denoising Diffusion Implicit Models for the Oxford Flowers 102 Dataset

See description and details in the accompanying Jupyter Notebook file on the directory above.

This library can be used from a Jupyter Notebook or can be used from the command line with the binary located in the train subdirectory.

It includes the model code, the training code, generation functions and plotting functions that work on a Jupyter Notebook.

Documentation

Overview

Package diffusion contains an example diffusion model, trained on Oxford Flowers 102 dataset.

See the accompanying jupyter notebook for some results, and how to call it.

The subdirectory `train/` has the command line binary that can be executed for training.

Based on the Keras tutorial in https://keras.io/examples/generative/ddim/, and recreated for GoMLX, with many small modifications.

Flags are defined on the files that use them, so they are spread over the code.

Index

Constants

View Source
const (
	NoiseSamplesFile       = "noise_samples.tensor"
	FlowerIdsSamplesFile   = "flower_ids_samples.tensor"
	GeneratedSamplesPrefix = "generated_samples_"
)

Variables

View Source
var (
	// DataDir is the directory where data is saved. This includes training data, saved models (checkpoints)
	// and intermediary data (like normalization constants).
	DataDir string // Directory where resources are stored

	// ImageSize used everywhere. Images are square, the same size is used for height and width.
	ImageSize int

	// BatchSize used for training.
	BatchSize int

	// EvalBatchSize used for evaluation. The size is only affected by the limitation of the accelerator memory.
	EvalBatchSize int

	// PartitionSeed used for the dataset splitting into train/validation.
	PartitionSeed = int64(42) // Some arbitrary number.

	// ValidationFraction where the rest is used for training. There is no test set.
	ValidationFraction = 0.2 // 20% of data.
)
View Source
var (
	DType shapes.DType
)
View Source
var (

	// NormalizationInfoFile where NormalizationValues results are saved (and loaded from).
	NormalizationInfoFile = "normalization_data.bin"
)

Functions

func ActivationLayer

func ActivationLayer(x *Node) *Node

ActivationLayer can be configured.

func AssertNoError

func AssertNoError(err error)

AssertNoError will panic if the given error is not nil.

func CompareModelPlots

func CompareModelPlots(modelNames ...string)

CompareModelPlots display several model metrics on the same plots.

func CreateInMemoryDatasets

func CreateInMemoryDatasets() (trainDS, validationDS *data.InMemoryDataset)

CreateInMemoryDatasets returns a train and a validation InMemoryDataset.

func Denoise

func Denoise(ctx *context.Context, noisyImages, signalRatios, noiseRatios, flowerIds *Node) (
	predictedImages, predictedNoises *Node)

Denoise tries to separate the noise from the image. It is given the signal and noise ratios.

func DenoiseStepGraph

func DenoiseStepGraph(ctx *context.Context, noisyImages, diffusionTime, nextDiffusionTime, flowerIds *Node) (
	predictedImages, nextNoisyImages *Node)

DenoiseStepGraph executes one step of separating the noise and images from noisy images.

func DenormalizeImages

func DenormalizeImages(images *Node) *Node

DenormalizeImages revert images back to the 0 - 255 range. But it keeps it as float, it doesn't convert it back to bytes (== `shapes.S8` or `uint8`)

func DiffusionSchedule

func DiffusionSchedule(times *Node, clipStart bool) (signalRatios, noiseRatios *Node)

DiffusionSchedule calculates a ratio of noise and image that needs to be mixed, given the diffusion time `~ [0.0, 1.0]`. Diffusion time 0 means minimum diffusion -- the signal ratio will be set to -max_signal_ratio, default to 0.95 -- and diffusion time 1.0 means almost all noise -- the signal ratio will be set to -min_signal_ratio, default to 0.02. The returned ratio has the sum of their square total 1.

Typically, the shape of `time` and the returned ratios will be `[batch_size, 1, 1, 1]`.

If `clipStart` is set to false, the signal ratio is not clipped, and it can go all the way to 1.0.

func DisplayImagesAcrossDiffusionSteps added in v0.7.0

func DisplayImagesAcrossDiffusionSteps(numImages int, numDiffusionSteps int, displayEveryNSteps int)

DisplayImagesAcrossDiffusionSteps using reverse diffusion. If displayEveryNSteps is not 0, it will display intermediary results every n results -- it also displays the initial noise and final image.

Plotting results only work if in a Jupyter (with GoNB kernel) notebook.

func DisplayTrainingPlots

func DisplayTrainingPlots()

DisplayTrainingPlots simply display the training plots of a model, without any training.

func DownBlock

func DownBlock(ctx *context.Context, x *Node, skips []*Node, numBlocks, outputChannels int) (*Node, []*Node)

DownBlock applies `numBlocks` residual blocks followed by an average pooling of size 2, halfing the spatial size. It pushes the values between each residual blocks to the `skips` stack, to build the skip connections later.

It returns the transformed `x` and `skips` with newly stacked skip connections.

func DropdownFlowerTypes(cacheKey string, ctx *context.Context, numImages, numDiffusionSteps int, htmlId string) *common.Latch

DropdownFlowerTypes creates a drop-down that shows images at different diffusion steps.

If `cacheKey` empty, cache is by-passed. Otherwise, try to load images from cache first if available, or save generated images in cache for future use.

func FlowerTypeEmbedding added in v0.4.1

func FlowerTypeEmbedding(ctx *context.Context, flowerIds, x *Node) *Node

FlowerTypeEmbedding will if configured (`--flower_type_dim` flag), concatenate a flower embedding to `x`. If `--flower_type_dim==0` it returns `x` unchanged.

func GenerateFlowerIds added in v0.4.1

func GenerateFlowerIds(numImages int) tensor.Tensor

GenerateFlowerIds generates random flower ids: this is the type of flowers, one of the 102.

func GenerateImagesOfAllFlowerTypes added in v0.4.1

func GenerateImagesOfAllFlowerTypes(numDiffusionSteps int) (predictedImages tensor.Tensor)

GenerateImagesOfAllFlowerTypes takes one random noise, and generate the flower for each of the 102 types.

func GenerateImagesOfFlowerType added in v0.4.1

func GenerateImagesOfFlowerType(numImages int, flowerType int32, numDiffusionSteps int) (predictedImages tensor.Tensor)

GenerateImagesOfFlowerType is similar to DisplayImagesAcrossDiffusionSteps, but it limits itself to generating images of only one flower type.

func GenerateNoise

func GenerateNoise(numImages int) tensor.Tensor

GenerateNoise generates random noise that can be used to generate images.

func GetManager

func GetManager() *Manager

func ImagesToHtml added in v0.7.0

func ImagesToHtml(images []image.Image) string

ImagesToHtml converts slice of images to a list of images side-by-side in HTML format, that can be easily displayed.

func Init

func Init()

Init will parse and move flag values into global variables. It's idempotent, and can be called multiple times.

func LoadCheckpointToContext

func LoadCheckpointToContext(ctx *context.Context) (checkpoint *checkpoints.Handler, noise, flowerIds tensor.Tensor)

LoadCheckpointToContext and attaches to it, so that it gets saved.

It also loads the noise (+flowerIds) samples for this model. The idea is that at each evaluation checkpoint we generate the images for these fixed noise samples, and one can observe the model quality evolving.

For new models -- whose directory didn't previously exist, it does 2 things:

  • It creates the noise + flowerIds samples used to monitor the model quality evolving.
  • It creates the file `args.txt` with a copy of the arguments used to create the model. Later, if the same model is used, it checks that the arguments match (with some exceptions), and warns about mismatches.

func MustNoError

func MustNoError[T any](value T, err error) T

MustNoError takes a value and an error. It returns the value is err is nil, otherwise it panics.

func NormalizationValues

func NormalizationValues() (mean, stddev tensor.Tensor)

NormalizationValues for the flowers dataset -- only look at the training data.

func NormalizeLayer

func NormalizeLayer(ctx *context.Context, x *Node) *Node

NormalizeLayer behaves according to the `--norm` flag. It works with `x` with rank 4 and rank 3.

func PlotImages

func PlotImages(images []image.Image)

PlotImages all in one row. The image size in the HTML is set to the value given.

This only works in a Jupyter (GoNB kernel) notebook.

func PlotImagesTensor

func PlotImagesTensor(imagesT tensor.Tensor)

PlotImagesTensor plots images in tensor format, all in one row. It assumes image's MaxValue of 255.

This only works in a Jupyter (GoNB kernel) notebook.

func PlotModelEvolution

func PlotModelEvolution(imagesPerSample int, animate bool)

PlotModelEvolution plots the saved sampled generated images of a model in the current configured checkpoint.

It outputs at most imagesPerSample per checkpoint sampled.

func PreprocessImages

func PreprocessImages(images *Node, normalize bool) *Node

PreprocessImages converts the image to the model `DType` and optionally normalizes it according to `NormalizationValues()` calculated on the training dataset.

func ResidualBlock

func ResidualBlock(ctx *context.Context, x *Node, outputChannels int) *Node

ResidualBlock on the input with `outputChannels` (axis 3) in the output.

The parameter `x` must be of rank 4, shaped `[batchSize, height, width, channels]`.

func SinusoidalEmbedding

func SinusoidalEmbedding(x *Node) *Node

SinusoidalEmbedding provides embeddings of `x` for different frequencies. This is applied to the variance of the noise, and facilitates the NN model to easily map different ranges of the signal/noise ratio.

func SliderDiffusionSteps added in v0.7.0

func SliderDiffusionSteps(cacheKey string, ctx *context.Context, numImages int, numDiffusionSteps int, htmlId string) *common.Latch

SliderDiffusionSteps creates and animates a slider that shows images at different diffusion steps. It handles the slider on a separate goroutine. Trigger the returned latch to stop it.

If `cacheKey` empty, cache is by-passed. Otherwise, try to load images from cache first if available, or save generated images in cache for future use.

func TrainModel

func TrainModel()

func TrainingModelGraph

func TrainingModelGraph(ctx *context.Context, _ any, inputs []*Node) []*Node

TrainingModelGraph builds the model for training and evaluation.

func TrainingMonitor

func TrainingMonitor(checkpoint *checkpoints.Handler, loop *train.Loop, metrics []tensor.Tensor,
	plotter stdplots.Plotter, evalDatasets []train.Dataset, generator *ImagesGenerator, kid *KidGenerator) error

TrainingMonitor is periodically called during training, and is used to report metrics and generate sample images at the current training step.

func TransformerBlock

func TransformerBlock(ctx *context.Context, x *Node) *Node

TransformerBlock takes embed shaped `[batchDim, spatialDim, embedDim]`, where the spatial dimension is the combined dimensions of the image.

func UNetModelGraph

func UNetModelGraph(ctx *context.Context, noisyImages, noiseVariances, flowerIds *Node) *Node

UNetModelGraph builds the U-Net model.

Parameters:

  • noisyImages: image shaped `[batch_size, size, size, channels=3]`.
  • noiseVariance: One value per example in the batch, shaped `[batch_size, 1, 1, 1]`.
  • numChannelsList (static hyperparameter): number of channels to use in the model. For each value `numBlocks` are applied and then the image is pooled and reduced by a factor of 2 -- later to be up-sampled again. So at most `log2(size)` values.
  • numBlocks (static hyperparameter): number of blocks to use per numChannelsList element.

func UpBlock

func UpBlock(ctx *context.Context, x *Node, skips []*Node, numBlocks, outputChannels int) (*Node, []*Node)

UpBlock is the counter-part to DownBlock. It performs up-scaling convolutions and connects skip-connections popped from `skips`.

It returns `x` and `skips` after popping the consumed skip connections.

Types

type ImagesGenerator

type ImagesGenerator struct {
	// contains filtered or unexported fields
}

ImagesGenerator given noise and the flowerIds. Use it with NewImagesGenerator.

func NewImagesGenerator

func NewImagesGenerator(ctx *context.Context, noise, flowerIds tensor.Tensor, numDiffusionSteps int) *ImagesGenerator

NewImagesGenerator generates flowers given initial `noise` and `flowerIds`, in `numDiffusionSteps`. Typically, 20 diffusion steps will suffice.

func (*ImagesGenerator) Generate

func (g *ImagesGenerator) Generate() (batchedImages tensor.Tensor)

Generate images from the original noise.

It can be called multiple times if the context changed, if the model was further trained. Otherwise, it will always return the same images.

func (*ImagesGenerator) GenerateEveryN added in v0.7.0

func (g *ImagesGenerator) GenerateEveryN(n int) (predictedImages []tensor.Tensor,
	diffusionSteps []int, diffusionTimes []float64)

GenerateEveryN images from the original noise. While iteratively undoing diffusion, it will keep every `n` intermediary images. It will always return the last image generated.

It can be called multiple times if the context changed, if the model was further trained. Otherwise, it will always return the same images.

It returns a slice of batches of images, one batch per intermediary diffusion step, a slice with the step used for each batch, and another slice with the "diffusionTime" of the intermediary images (it will be 1.0 for the last)

type KidGenerator

type KidGenerator struct {
	// contains filtered or unexported fields
}

KidGenerator generates the [Kernel Inception Distance (KID)](https://arxiv.org/abs/1801.01401) metric.

func NewKidGenerator

func NewKidGenerator(ctx *context.Context, evalDS train.Dataset, numDiffusionStep int) *KidGenerator

NewKidGenerator allows to generate the Kid metric. The ctx passed is the context for the diffusion model. It uses a different context for the InceptionV3 KID metric, so that it's weights are not included in the generator model.

func (*KidGenerator) Eval

func (kg *KidGenerator) Eval() (metric tensor.Tensor)

func (*KidGenerator) EvalStepGraph

func (kg *KidGenerator) EvalStepGraph(ctx *context.Context, allImages []*Node) (metric *Node)

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL