model

package
v0.0.0-...-9f00eee Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Apr 30, 2021 License: Apache-2.0 Imports: 11 Imported by: 0

Documentation

Index

Constants

View Source
const (
	// Constants to save and retrieve the gradients
	WeightSuffix = ".weight"
	BiasSuffix   = ".bias"
)

Variables

This section is empty.

Functions

This section is empty.

Types

type Layer

type Layer struct {
	Name    string
	Dtype   string
	Weights *tensor.Dense
}

Layer keeps the Weights of a certain layer of the Neural Network the weights can be either the weights or bias indistinctly

type Model

type Model struct {
	Name string

	// StateDict holds the layer names
	// and the layers of the model. Each
	// layer has a bias and a weight
	StateDict map[string]*Layer
	// contains filtered or unexported fields
}

Holds the Layers of the model

func NewModel

func NewModel(
	logger *zap.Logger,
	jobId string,
	task api.TrainRequest,
	layerNames []string,
	pool *redis.Pool) *Model

Creates a new model with the specified layers

func (*Model) Build

func (m *Model) Build() error

Build gets all the initialized layers from the database Build should be called once just after the network is initialized by a worker

func (*Model) Clear

func (m *Model) Clear()

Clear wipes the statedict of the model

func (*Model) Save

func (m *Model) Save() error

Save saves the new updated weights and bias in the database so it can be retrieved by the following functions

func (*Model) Summary

func (m *Model) Summary()

Summary runs through the layers of a model and prints its info

func (*Model) Update

func (m *Model) Update(funcId int)

Update fetches the layers saved by a function and adds them to the statedict

type ParallelSGD

type ParallelSGD struct {
	// contains filtered or unexported fields
}

ParallelSGD simply averages the weights of the models trained independently. In this way, we get the freedom of using any optimizer in the functions.

Simply fetch all the model weights and average them

func MakeParallelSGD

func MakeParallelSGD(logger *zap.Logger) ParallelSGD

func (ParallelSGD) Average

func (psgd ParallelSGD) Average(m *Model, num int) error

Average averages the layers by the number of finished functions

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL