tftrain

package
v0.0.0-...-e9349c8 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 18, 2018 License: Apache-2.0 Imports: 10 Imported by: 0

Documentation

Overview

Package tftrain runs training graphs

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func BatchTrain

func BatchTrain(
	graph *tf.Graph,
	numBatches int,
	inputName, targetName, trainOpName, initOpName, lossOpName string,
	outputOpNames []string,
	lossHandlerFunction func(int, float32),
	getTrainingData func(int) (*tf.Tensor, *tf.Tensor, error),
) (frozen *tf.Graph, err error)

BatchTrain trains a model

func Freeze

func Freeze(tfGraph *tf.Graph, sess *tf.Session, headNames []string) (frozenTfGraph *tf.Graph, err error)

Freeze replaces all variables with consts, and strip off non required nodes

func TrainableFreeze

func TrainableFreeze(tfGraph *tf.Graph, sess *tf.Session, headNames []string) (frozenTfGraph *tf.Graph, err error)

TrainableFreeze replaces all the assign consts for each variables with the trained value, and strip of non required nodes. The resulting graph will still require initialisation to use, and will be larger, but can be used just like a python genorated train graph.

Types

type OnlineSess

type OnlineSess struct {
	Graph *tf.Graph
	Sess  *tf.Session
	// contains filtered or unexported fields
}

OnlineSess stores a model in the process of training.

func NewOnlineSess

func NewOnlineSess(
	graph *tf.Graph,
	inputName, targetName, trainOpName, initOpName, lossOpName, outputOpName, inferInputName string,
	outputOpNames []string,
) (oSess OnlineSess, err error)

NewOnlineSess makes a new OnlineSess

func (OnlineSess) Infer

func (oSess OnlineSess) Infer(inputTensor *tf.Tensor) (outputTensor *tf.Tensor, err error)

Infer runs the graph

func (OnlineSess) Run

func (oSess OnlineSess) Run(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output) (results []*tf.Tensor, err error)

Run the graph. Just a wrapper around tf.Session.Run() For simple cases you probably want .Infer()

func (OnlineSess) Save

func (oSess OnlineSess) Save() (graph *tf.Graph, err error)

Save writes the variable values to their initialisers.

func (OnlineSess) Train

func (oSess OnlineSess) Train(inputTensor *tf.Tensor, targetTensor *tf.Tensor) (loss float32, err error)

Train trains one mini batch

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL