tensorflow

package
v0.12.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 19, 2016 License: Apache-2.0 Imports: 10 Imported by: 0

README

TensorFlow Go API

Construct and execute TensorFlow graphs in Go.

GoDoc

WARNING: The API defined in this package is not stable and can change without notice. The same goes for the awkward package path (github.com/tensorflow/tensorflow/tensorflow/go).

Requirements

  • Go version 1.7+

  • bazel

  • Environment to build TensorFlow from source code (Linux or Mac OS X). If you'd like to skip reading those details and do not care about GPU support, try the following:

    # On Linux
    sudo apt-get install python swig python-numpy
    
    # On Mac OS X with homebrew
    brew install swig
    

Installation

  1. Download the TensorFlow source code:

    go get -d github.com/tensorflow/tensorflow/tensorflow/go
    
  2. Build the TensorFlow library (libtensorflow.so):

    cd ${GOPATH}/src/github.com/tensorflow/tensorflow
    ./configure
    bazel build -c opt //tensorflow:libtensorflow.so
    

    This can take a while (tens of minutes, more if also building for GPU).

  3. Make libtensorflow.so available to the linker. This can be done by either:

    a. Copying it to a system location, e.g.,

    cp ${GOPATH}/src/github.com/tensorflow/tensorflow/bazel-bin/tensorflow/libtensorflow.so /usr/local/lib
    

    OR

    b. Setting the LD_LIBRARY_PATH=${GOPATH}/src/github.com/tensorflow/tensorflow/bazel-bin/tensorflow environment variable (DYLD_LIBRARY_PATH on Mac OS X).

  4. Generate wrapper functions for TensorFlow ops:

    go generate github.com/tensorflow/tensorflow/tensorflow/go/op
    

After this, the go tool should be usable as normal. For example:

go test -v github.com/tensorflow/tensorflow/tensorflow/go

Contributions

This API has been built on top of the C API, which is intended for building language bindings for TensorFlow functionality. However, this is far from complete. Contributions are welcome. To monitor progress follow issue 10.

Documentation

Overview

Package tensorflow is a Go binding to TensorFlow.

The API is subject to change and may break at any time.

TensorFlow (www.tensorflow.org) is an open source software library for numerical computation using data flow graphs. This package provides functionality to build and execute such graphs and depends on TensorFlow being available. For installation instructions see https://www.tensorflow.org/code/tensorflow/go/README.md

Example
package main

import (
	"archive/zip"
	"bufio"
	"flag"
	"fmt"
	"image"
	_ "image/jpeg"
	"io"
	"io/ioutil"
	"log"
	"net/http"
	"os"
	"path/filepath"

	tf "github.com/tensorflow/tensorflow/tensorflow/go"
	"github.com/tensorflow/tensorflow/tensorflow/go/op"
)

func main() {
	// An example for using the TensorFlow Go API for image recognition
	// using a pre-trained inception model (http://arxiv.org/abs/1512.00567).
	//
	// The pre-trained model takes input in the form of a 4-dimensional
	// tensor with shape [ BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3 ],
	// where:
	// - BATCH_SIZE allows for inference of multiple images in one pass through the graph
	// - IMAGE_HEIGHT is the height of the images on which the model was trained
	// - IMAGE_WIDTH is the width of the images on which the model was trained
	// - 3 is the (R, G, B) values of the pixel colors represented as a float.
	//
	// And produces as output a vector with shape [ NUM_LABELS ].
	// output[i] is the probability that the input image was recognized as
	// having the i-th label.
	//
	// A separate file contains a list of string labels corresponding to the
	// integer indices of the output.
	//
	// This example:
	// - Loads the serialized representation of the pre-trained model into a Graph
	// - Creates a Session to execute operations on the Graph
	// - Converts an image file to a Tensor to provide as input to a Session run
	// - Executes the Session and prints out the label with the highest probability
	//
	// To convert an image file to a Tensor suitable for input to the Inception model,
	// this example:
	// - Constructs another TensorFlow graph to normalize the image into a
	//   form suitable for the model (for example, resizing the image)
	// - Creates an executes a Session to obtain a Tensor in this normalized form.
	modeldir := flag.String("dir", "", "Directory containing the trained model files. The directory will be created and the model downloaded into it if necessary")
	imagefile := flag.String("image", "", "Path of the image to extract labels for")
	flag.Parse()
	if *modeldir == "" || *imagefile == "" {
		flag.Usage()
		return
	}
	// Load the serialized GraphDef from a file.
	modelfile, labelsfile, err := modelFiles(*modeldir)
	if err != nil {
		log.Fatal(err)
	}
	model, err := ioutil.ReadFile(modelfile)
	if err != nil {
		log.Fatal(err)
	}

	// Construct an in-memory graph from the serialized form.
	graph := tf.NewGraph()
	if err := graph.Import(model, ""); err != nil {
		log.Fatal(err)
	}

	// Create a session for inference over graph.
	session, err := tf.NewSession(graph, nil)
	if err != nil {
		log.Fatal(err)
	}
	defer session.Close()

	// Run inference on thestImageFilename.
	// For multiple images, session.Run() can be called in a loop (and
	// concurrently). Furthermore, images can be batched together since the
	// model accepts batches of image data as input.
	tensor, err := makeTensorFromImage(*imagefile)
	if err != nil {
		log.Fatal(err)
	}
	output, err := session.Run(
		map[tf.Output]*tf.Tensor{
			graph.Operation("input").Output(0): tensor,
		},
		[]tf.Output{
			graph.Operation("output").Output(0),
		},
		nil)
	if err != nil {
		log.Fatal(err)
	}
	// output[0].Value() is a vector containing probabilities of
	// labels for each image in the "batch". The batch size was 1.
	// Find the most probably label index.
	probabilities := output[0].Value().([][]float32)[0]
	printBestLabel(probabilities, labelsfile)
}

func printBestLabel(probabilities []float32, labelsFile string) {
	bestIdx := 0
	for i, p := range probabilities {
		if p > probabilities[bestIdx] {
			bestIdx = i
		}
	}
	// Found a best match, now read the string from the labelsFile where
	// there is one line per label.
	file, err := os.Open(labelsFile)
	if err != nil {
		log.Fatal(err)
	}
	defer file.Close()
	scanner := bufio.NewScanner(file)
	var labels []string
	for scanner.Scan() {
		labels = append(labels, scanner.Text())
	}
	if err := scanner.Err(); err != nil {
		log.Printf("ERROR: failed to read %s: %v", labelsFile, err)
	}
	fmt.Printf("BEST MATCH: (%2.0f%% likely) %s\n", probabilities[bestIdx]*100.0, labels[bestIdx])
}

// Conver the image in filename to a Tensor suitable as input to the Inception model.
func makeTensorFromImage(filename string) (*tf.Tensor, error) {
	// Load the pixels from the file
	file, err := os.Open(filename)
	if err != nil {
		return nil, err
	}
	img, _, err := image.Decode(file)
	file.Close()
	if err != nil {
		return nil, err
	}
	// Represent the image as [H][W][B,G,R]byte
	contents := make([][][3]byte, img.Bounds().Size().Y)
	for y := 0; y < len(contents); y++ {
		contents[y] = make([][3]byte, img.Bounds().Size().X)
		for x := 0; x < len(contents[y]); x++ {
			px := x + img.Bounds().Min.X
			py := y + img.Bounds().Min.Y
			r, g, b, _ := img.At(px, py).RGBA()
			// image.Image uses 16-bits for each color.
			// We want 8-bits.
			contents[y][x][0] = byte(b >> 8)
			contents[y][x][1] = byte(g >> 8)
			contents[y][x][2] = byte(r >> 8)
		}
	}
	tensor, err := tf.NewTensor(contents)
	if err != nil {
		return nil, err
	}
	// Construct a graph to normalize the image
	graph, input, output, err := constructGraphToNormalizeImage()
	if err != nil {
		return nil, err
	}
	// Execute that graph to normalize this one image
	session, err := tf.NewSession(graph, nil)
	if err != nil {
		return nil, err
	}
	defer session.Close()
	normalized, err := session.Run(
		map[tf.Output]*tf.Tensor{input: tensor},
		[]tf.Output{output},
		nil)
	if err != nil {
		return nil, err
	}
	return normalized[0], nil
}

// The inception model takes as input the image described by a Tensor in a very
// specific normalized format (a particular image size, shape of the input tensor,
// normalized pixel values etc.).
//
// This function constructs a graph of TensorFlow operations which takes as input
// the raw pixel values of an image in the form of a Tensor of shape [Height, Width, 3]
// and returns a tensor suitable for input to the inception model.
//
// T[y][x] is the (Blue, Green, Red) values of the pixel at position (x, y) in the image,
// with each color value represented as a single byte.
func constructGraphToNormalizeImage() (graph *tf.Graph, input, output tf.Output, err error) {
	// Some constants specific to the pre-trained model at:
	// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
	//
	// - The model was trained after with images scaled to 224x224 pixels.
	// - The colors, represented as R, G, B in 1-byte each were converted to
	//   float using (value - Mean)/Scale.
	//
	// If using a different pre-trained model, the values will have to be adjusted.
	const (
		H, W  = 224, 224
		Mean  = float32(117)
		Scale = float32(1)
	)
	// - input is a 3D tensor of shape [Height, Width, Colors=3], where
	//   each pixel is represented as a triplet of 1-byte colors
	// - ResizeBilinear (and the inception model) takes a 4D tensor of shape
	//   [BatchSize, Height, Width, Colors=3], where each pixel is
	//   represented as a triplet of floats
	// - Apply normalization on each pixel and use ExpandDims to make
	//   this single image be a "batch" of size 1 for ResizeBilinear.
	s := op.NewScope()
	input = op.Placeholder(s, tf.Uint8)
	output = op.Div(s,
		op.Sub(s,
			op.ResizeBilinear(s,
				op.ExpandDims(s,
					op.Cast(s, input, tf.Float),
					op.Const(s.SubScope("make_batch"), int32(0))),
				op.Const(s.SubScope("size"), []int32{H, W})),
			op.Const(s.SubScope("mean"), Mean)),
		op.Const(s.SubScope("scale"), Scale))
	graph, err = s.Finalize()
	return graph, input, output, err
}

func modelFiles(dir string) (modelfile, labelsfile string, err error) {
	const URL = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"
	var (
		model   = filepath.Join(dir, "tensorflow_inception_graph.pb")
		labels  = filepath.Join(dir, "imagenet_comp_graph_label_strings.txt")
		zipfile = filepath.Join(dir, "inception5h.zip")
	)
	if filesExist(model, labels) == nil {
		return model, labels, nil
	}
	log.Println("Did not find model in", dir, "downloading from", URL)
	if err := os.MkdirAll(dir, 0755); err != nil {
		return "", "", err
	}
	if err := download(URL, zipfile); err != nil {
		return "", "", fmt.Errorf("failed to download %v - %v", URL, err)
	}
	if err := unzip(dir, zipfile); err != nil {
		return "", "", fmt.Errorf("failed to extract contents from model archive: %v", err)
	}
	os.Remove(zipfile)
	return model, labels, filesExist(model, labels)
}

func filesExist(files ...string) error {
	for _, f := range files {
		if _, err := os.Stat(f); err != nil {
			return fmt.Errorf("unable to stat %s: %v", f, err)
		}
	}
	return nil
}

func download(URL, filename string) error {
	resp, err := http.Get(URL)
	if err != nil {
		return err
	}
	defer resp.Body.Close()
	file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644)
	if err != nil {
		return err
	}
	defer file.Close()
	_, err = io.Copy(file, resp.Body)
	return err
}

func unzip(dir, zipfile string) error {
	r, err := zip.OpenReader(zipfile)
	if err != nil {
		return err
	}
	defer r.Close()
	for _, f := range r.File {
		src, err := f.Open()
		if err != nil {
			return err
		}
		log.Println("Extracting", f.Name)
		dst, err := os.OpenFile(filepath.Join(dir, f.Name), os.O_WRONLY|os.O_CREATE, 0644)
		if err != nil {
			return err
		}
		if _, err := io.Copy(dst, src); err != nil {
			return err
		}
		dst.Close()
	}
	return nil
}
Output:

Index

Examples

Constants

This section is empty.

Variables

This section is empty.

Functions

func Version added in v0.12.0

func Version() string

Version returns a string describing the version of the underlying TensorFlow runtime.

Types

type DataType

type DataType C.TF_DataType

DataType holds the type for a scalar value. E.g., one slot in a tensor.

const (
	Float      DataType = C.TF_FLOAT
	Double     DataType = C.TF_DOUBLE
	Int32      DataType = C.TF_INT32
	Uint8      DataType = C.TF_UINT8
	Int16      DataType = C.TF_INT16
	Int8       DataType = C.TF_INT8
	String     DataType = C.TF_STRING
	Complex64  DataType = C.TF_COMPLEX64
	Complex    DataType = C.TF_COMPLEX
	Int64      DataType = C.TF_INT64
	Bool       DataType = C.TF_BOOL
	Qint8      DataType = C.TF_QINT8
	Quint8     DataType = C.TF_QUINT8
	Qint32     DataType = C.TF_QINT32
	Bfloat16   DataType = C.TF_BFLOAT16
	Qint16     DataType = C.TF_QINT16
	Quint16    DataType = C.TF_QUINT16
	Uint16     DataType = C.TF_UINT16
	Complex128 DataType = C.TF_COMPLEX128
	Half       DataType = C.TF_HALF
)

Types of scalar values in the TensorFlow type system.

type Graph

type Graph struct {
	// contains filtered or unexported fields
}

Graph represents a computation graph. Graphs may be shared between sessions.

func NewGraph

func NewGraph() *Graph

NewGraph returns a new Graph.

func (*Graph) AddOperation added in v0.12.0

func (g *Graph) AddOperation(args OpSpec) (*Operation, error)

AddOperation adds an operation to g.

func (*Graph) Import added in v0.12.0

func (g *Graph) Import(def []byte, prefix string) error

Import imports the nodes and edges from a serialized representation of another Graph into g.

Names of imported nodes will be prefixed with prefix.

func (*Graph) Operation added in v0.12.0

func (g *Graph) Operation(name string) *Operation

Operation returns the Operation named name in the Graph, or nil if no such operation is present.

func (*Graph) WriteTo added in v0.12.0

func (g *Graph) WriteTo(w io.Writer) (int64, error)

WriteTo writes out a serialized representation of g to w.

Implements the io.WriterTo interface.

type Input added in v0.12.0

type Input interface {
	// contains filtered or unexported methods
}

Input is the interface for specifying inputs to an operation being added to a Graph.

Operations can have multiple inputs, each of which could be either a tensor produced by another operation (an Output object), or a list of tensors produced by other operations (an OutputList). Thus, this interface is implemented by both Output and OutputList.

See OpSpec.Input for more information.

type OpSpec added in v0.12.0

type OpSpec struct {
	// Type of the operation (e.g., "Add", "MatMul").
	Type string

	// Name by which the added operation will be referred to in the Graph.
	// If omitted, defaults to Type.
	Name string

	// Inputs to this operation, which in turn must be outputs
	// of other operations already added to the Graph.
	//
	// An operation may have multiple inputs with individual inputs being
	// either a single tensor produced by another operation or a list of
	// tensors produced by multiple operations. For example, the "Concat"
	// operation takes two inputs: (1) the dimension along which to
	// concatenate and (2) a list of tensors to concatenate. Thus, for
	// Concat, len(Input) must be 2, with the first element being an Output
	// and the second being an OutputList.
	Input []Input

	// Map from attribute name to its value that will be attached to this
	// operation.
	Attrs map[string]interface{}
}

OpSpec is the specification of an Operation to be added to a Graph (using Graph.AddOperation).

type Operation

type Operation struct {
	// contains filtered or unexported fields
}

Operation that has been added to the graph.

func (*Operation) Name added in v0.12.0

func (op *Operation) Name() string

Name returns the name of the operation.

func (*Operation) NumOutputs added in v0.12.0

func (op *Operation) NumOutputs() int

NumOutputs returns the number of outputs of op.

func (*Operation) Output added in v0.12.0

func (op *Operation) Output(i int) Output

Output returns the i-th output of op.

func (*Operation) OutputListSize added in v0.12.0

func (op *Operation) OutputListSize(output string) (int, error)

OutputListSize returns the size of the list of Outputs that is produced by a named output of op.

An Operation has multiple named outputs, each of which produces either a single tensor or a list of tensors. This method returns the size of the list of tensors for a specific output of the operation, identified by its name.

func (*Operation) Type added in v0.12.0

func (op *Operation) Type() string

Type returns the name of the operator used by this operation.

type Output

type Output struct {
	// Op is the Operation that produces this Output.
	Op *Operation

	// Index specifies the index of the output within the Operation.
	Index int
}

Output represents one of the outputs of an operation in the graph. Has a DataType (and eventually a Shape). May be passed as an input argument to a function for adding operations to a graph, or to a Session's Run() method to fetch that output as a tensor.

func (Output) Shape added in v0.12.0

func (p Output) Shape() (shape []int64, err error)

Shape returns the (possibly incomplete) shape of the tensor produced p.

Returns a slice of length 0 if the tensor is a scalar. Returns a slice where shape[i] is the size of the i-th dimension of the tensor, or -1 if the size of that dimension is not known.

Returns an error if the number of dimensions of the tensor is not known.

type OutputList added in v0.12.0

type OutputList []Output

OutputList represents a list of Outputs that can be provided as input to another operation.

type Session

type Session struct {
	// contains filtered or unexported fields
}

Session drives a TensorFlow graph computation.

When a Session is created with a given target, a new Session object is bound to the universe of resources specified by that target. Those resources are available to this session to perform computation described in the GraphDef. After creating the session with a graph, the caller uses the Run() API to perform the computation and potentially fetch outputs as Tensors. A Session allows concurrent calls to Run().

func NewSession

func NewSession(graph *Graph, options *SessionOptions) (*Session, error)

NewSession creates a new execution session with the associated graph. options may be nil to use the default options.

func (*Session) Close

func (s *Session) Close() error

Close a session. This contacts any other processes associated with this session, if applicable. Blocks until all previous calls to Run have returned.

func (*Session) Run

func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Operation) ([]*Tensor, error)

Run the graph with the associated session starting with the supplied inputs. inputs and outputs may be set to nil. Runs, but does not return Tensors for operations specified in targets.

On success, returns the Tensor outputs in the same order as supplied in the outputs argument. If outputs is set to nil, the returned Tensor outputs is empty.

type SessionOptions

type SessionOptions struct {
	// Target indicates the TensorFlow runtime to connect to.
	//
	// If 'target' is empty or unspecified, the local TensorFlow runtime
	// implementation will be used.  Otherwise, the TensorFlow engine
	// defined by 'target' will be used to perform all computations.
	//
	// "target" can be either a single entry or a comma separated list
	// of entries. Each entry is a resolvable address of one of the
	// following formats:
	//   local
	//   ip:port
	//   host:port
	//   ... other system-specific formats to identify tasks and jobs ...
	//
	// NOTE: at the moment 'local' maps to an in-process service-based
	// runtime.
	//
	// Upon creation, a single session affines itself to one of the
	// remote processes, with possible load balancing choices when the
	// "target" resolves to a list of possible processes.
	//
	// If the session disconnects from the remote process during its
	// lifetime, session calls may fail immediately.
	Target string
}

SessionOptions contains configuration information for a session.

type Tensor

type Tensor struct {
	// contains filtered or unexported fields
}

Tensor holds a multi-dimensional array of elements of a single data type.

func NewTensor

func NewTensor(value interface{}) (*Tensor, error)

NewTensor converts from a Go value to a Tensor. Valid values are scalars, slices, and arrays. Every element of a slice must have the same length so that the resulting Tensor has a valid shape.

func (*Tensor) DataType

func (t *Tensor) DataType() DataType

DataType returns the scalar datatype of the Tensor.

func (*Tensor) Shape

func (t *Tensor) Shape() []int64

Shape returns the shape of the Tensor.

func (*Tensor) Value

func (t *Tensor) Value() interface{}

Value converts the Tensor to a Go value. For now, not all Tensor types are supported, and this function may panic if it encounters an unsupported DataType.

The type of the output depends on the Tensor type and dimensions. For example: Tensor(int64, 0): int64 Tensor(float64, 3): [][][]float64

Directories

Path Synopsis
Command genop generates a Go source file with functions for TensorFlow ops.
Command genop generates a Go source file with functions for TensorFlow ops.
internal
Package internal generates Go source code with functions for TensorFlow operations.
Package internal generates Go source code with functions for TensorFlow operations.
Package op defines functions for adding TensorFlow operations to a Graph.
Package op defines functions for adding TensorFlow operations to a Graph.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL