Documentation ¶
Overview ¶
Package tftrain runs training graphs
Index ¶
- func BatchTrain(graph *tf.Graph, numBatches int, ...) (frozen *tf.Graph, err error)
- func Freeze(tfGraph *tf.Graph, sess *tf.Session, headNames []string) (frozenTfGraph *tf.Graph, err error)
- func TrainableFreeze(tfGraph *tf.Graph, sess *tf.Session, headNames []string) (frozenTfGraph *tf.Graph, err error)
- type OnlineSess
- func (oSess OnlineSess) Infer(inputTensor *tf.Tensor) (outputTensor *tf.Tensor, err error)
- func (oSess OnlineSess) Run(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output) (results []*tf.Tensor, err error)
- func (oSess OnlineSess) Save() (graph *tf.Graph, err error)
- func (oSess OnlineSess) Train(inputTensor *tf.Tensor, targetTensor *tf.Tensor) (loss float32, err error)
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func BatchTrain ¶
func BatchTrain( graph *tf.Graph, numBatches int, inputName, targetName, trainOpName, initOpName, lossOpName string, outputOpNames []string, lossHandlerFunction func(int, float32), getTrainingData func(int) (*tf.Tensor, *tf.Tensor, error), ) (frozen *tf.Graph, err error)
BatchTrain trains a model
func Freeze ¶
func Freeze(tfGraph *tf.Graph, sess *tf.Session, headNames []string) (frozenTfGraph *tf.Graph, err error)
Freeze replaces all variables with consts, and strip off non required nodes
func TrainableFreeze ¶
func TrainableFreeze(tfGraph *tf.Graph, sess *tf.Session, headNames []string) (frozenTfGraph *tf.Graph, err error)
TrainableFreeze replaces all the assign consts for each variables with the trained value, and strip of non required nodes. The resulting graph will still require initialisation to use, and will be larger, but can be used just like a python genorated train graph.
Types ¶
type OnlineSess ¶
type OnlineSess struct { Graph *tf.Graph Sess *tf.Session // contains filtered or unexported fields }
OnlineSess stores a model in the process of training.
func NewOnlineSess ¶
func NewOnlineSess( graph *tf.Graph, inputName, targetName, trainOpName, initOpName, lossOpName, outputOpName, inferInputName string, outputOpNames []string, ) (oSess OnlineSess, err error)
NewOnlineSess makes a new OnlineSess
func (OnlineSess) Run ¶
func (oSess OnlineSess) Run(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output) (results []*tf.Tensor, err error)
Run the graph. Just a wrapper around tf.Session.Run() For simple cases you probably want .Infer()
Click to show internal directories.
Click to hide internal directories.