tf

package
v0.0.0-...-d9bce91 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jul 11, 2018 License: MIT Imports: 21 Imported by: 0

Documentation

Index

Constants

View Source
const (
	SaverDefName   = "saver_def.pb"
	GraphDefName   = "graph_def.pb"
	SavedModelName = "saved_model"
	ModelMetaName  = "model_meta.pb"
	// FilePerm is the file permission all the model files use.
	FilePerm = 0600
)
View Source
const (
	PokAssignOp    = "pok/update/assign"
	PokAssignAddOp = "pok/update/assign_add"
	PokVarPrefix   = "pok/update/var/"
	PokVarScaleOp  = "pok/update/scale"
)

Variables

This section is empty.

Functions

func DecodeDataType

func DecodeDataType(r io.Reader) (tensorflow.DataType, error)

DecodeDataType decodes the data type from the reader.

func DecodeInt64Array

func DecodeInt64Array(r io.Reader) ([]int64, error)

DecodeInt64Array decodes a int64 array from the bytes. Format is: int64(num) + repeated (int64)

func DecodeString

func DecodeString(r io.Reader) (string, error)

Decode string reads a string from an reader. Format is: int64(len) + body

func DecodeStringArray

func DecodeStringArray(r io.Reader) ([]string, error)

DecodeStringArray decodes a string array from the bytes. Format is: int64(num strings) + repeated (string)

func DecodeStringND

func DecodeStringND(t reflect.Type, shape []int64, r io.Reader) (reflect.Value, error)

func DecodeTensor

func DecodeTensor(r io.Reader) (*tensorflow.Tensor, error)

DecodeTensor decodes a tensor from a gob.Decoder and returns it. See EncodeTensor.

func DecodeTensorMap

func DecodeTensorMap(r io.Reader) (map[string]*tensorflow.Tensor, error)

DecodeTensorMap decodes a map[string]*tensorflow.Tensor from a gob.Decoder and returns it. See EncodeTensorMap.

func EncodeDataType

func EncodeDataType(w io.Writer, dt tensorflow.DataType) error

EncodeDataType writes the data type to the writer.

func EncodeInt64Array

func EncodeInt64Array(w io.Writer, arr []int64) error

EncodeInt64Array encodes a int64 array. Format is: int64(num) + repeated (int64)

func EncodeString

func EncodeString(w io.Writer, body string) error

EncodeString encodes a string into the writer.

func EncodeStringArray

func EncodeStringArray(w io.Writer, arr []string) error

EncodeStringArray decodes a string array from the bytes. Format is: int64(num strings) + repeated (string)

func EncodeStringND

func EncodeStringND(w io.Writer, val reflect.Value) error

func EncodeTensor

func EncodeTensor(w io.Writer, val *tensorflow.Tensor) error

EncodeTensor encodes a tensor into a gob.Encoder. See DecodeTensor.

func EncodeTensorMap

func EncodeTensorMap(w io.Writer, m map[string]*tensorflow.Tensor) error

EncodeTensorMap encodes a map[string]*tensorflow.Tensor into a gob.Encoder. See DecodeTensorMap.

func LoadWeights

func LoadWeights(r io.Reader) (map[string]*tensorflow.Tensor, error)

func ParseNodeOutput

func ParseNodeOutput(path string) (string, int, error)

ParseNodeOutput returns the node name when given a "<name>:<output #>" pair.

Types

type Batcher

type Batcher struct {
	// contains filtered or unexported fields
}

Batcher takes a fixed number of tensors and concatenates them into one larger tensor. This is mostly used for creating mini-batches of tensors for SGD.

func NewTensorBatcher

func NewTensorBatcher(n int, dtype tensorflow.DataType, shape []int64) (*Batcher, error)

NewTensorBatcher returns a new Batcher with the specified params. Shape should have exactly one dimension that is unspecified (-1) and the tensors will be concatenated along that axis.

func (*Batcher) Batch

func (m *Batcher) Batch(values []*tensorflow.Tensor) (*tensorflow.Tensor, error)

Batch takes in a session and values and returns a single output tensor that has all the values concatenated.

func (*Batcher) Close

func (m *Batcher) Close() error

type Model

type Model struct {
	Graph    *tensorflow.Graph
	Session  *tensorflow.Session
	SaverDef tensorflowpb.SaverDef
	Meta     clientpb.ModelMeta
	Prefix   string
}

func GetModel

func GetModel(url string) (*Model, error)

GetModel fetches a model from a URL and loads it.

func LoadModel

func LoadModel(reader io.Reader) (*Model, error)

LoadModel loads a model from a provided .tar.gz io stream. The returned session must be closed when done using.

The .tar.gz file should contain the following files: - saver_def.pb - graph_def.pb - checkpoint - saved_model-<iteration>.{index,meta,data-*}

func (*Model) AddWeights

func (model *Model) AddWeights(scale float64, weights map[string]*tensorflow.Tensor) error

AddWeights imports weights and then adds them to the current with a scaler.

func (Model) ApplyPrefix

func (model Model) ApplyPrefix(op string) string

func (*Model) Close

func (model *Model) Close() error

Close closes the model.

func (*Model) ExportWeights

func (model *Model) ExportWeights(w io.Writer) error

func (*Model) Load

func (model *Model) Load(reader io.Reader) error

func (*Model) NumWeights

func (model *Model) NumWeights() (int64, error)

NumWeights returns the total number of weights the model has.

func (*Model) Operation

func (m *Model) Operation(path string) (*tensorflow.Operation, error)

func (*Model) Output

func (m *Model) Output(path string) (tensorflow.Output, error)

func (*Model) Save

func (model *Model) Save(writer io.Writer) error

Save saves the model to the writer. See LoadModel.

func (*Model) SetWeights

func (model *Model) SetWeights(weights map[string]*tensorflow.Tensor) error

SetWeights sets the weights of the model.

func (*Model) WeightsMap

func (model *Model) WeightsMap() (map[string]*tensorflow.Tensor, error)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL