tfgo

package module
v0.0.0-...-1611311 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jul 15, 2023 License: Apache-2.0 Imports: 7 Imported by: 16

README

tfgo: TensorFlow in Go

GoDoc Build Status

TensorFlow's Go bindings are hard to use: tfgo makes it easy!

No more problems like:

  • Scoping: each new node will have a new and unique name
  • Typing: attributes are automatically converted to a supported type instead of throwing errors at runtime

Also, it uses Method chaining making possible to write pleasant Go code.

Dependencies

  1. TensorFlow-2.9.1 lib. How to install tensorflow.
  2. TensorFlow bindings github.com/galeone/tensorflow. In order to correctly work with TensorFlow 2.9.1 in Go, we have to use a fork I created with some fix for the Go bindings. Bindings can be too large for go mod proxy, so you may want to switch off proxy usage by executing go env -w GONOSUMDB="github.com/galeone/tensorflow" to pull code directly using system installed git. It changes nothing in the user interface -- you can use go modules as usual.

Installation

go get github.com/galeone/tfgo

Getting started

The core data structure of the TensorFlow's Go bindings is the op.Scope struct. tfgo allows creating new *op.Scope that solves the scoping issue mentioned above.

Since we're defining a graph, let's start from its root (empty graph)

root := tg.NewRoot()

We can now place nodes into this graphs and connect them. Let's say we want to multiply a matrix for a column vector and then add another column vector to the result.

Here's the complete source code.

package main

import (
        "fmt"
        tg "github.com/galeone/tfgo"
        tf "github.com/galeone/tensorflow/tensorflow/go"
)

func main() {
        root := tg.NewRoot()
        A := tg.NewTensor(root, tg.Const(root, [2][2]int32{{1, 2}, {-1, -2}}))
        x := tg.NewTensor(root, tg.Const(root, [2][1]int64{{10}, {100}}))
        b := tg.NewTensor(root, tg.Const(root, [2][1]int32{{-10}, {10}}))
        Y := A.MatMul(x.Output).Add(b.Output)
        // Please note that Y is just a pointer to A!

        // If we want to create a different node in the graph, we have to clone Y
        // or equivalently A
        Z := A.Clone()
        results := tg.Exec(root, []tf.Output{Y.Output, Z.Output}, nil, &tf.SessionOptions{})
        fmt.Println("Y: ", results[0].Value(), "Z: ", results[1].Value())
        fmt.Println("Y == A", Y == A) // ==> true
        fmt.Println("Z == A", Z == A) // ==> false
}

that produces

Y:  [[200] [-200]] Z:  [[200] [-200]]
Y == A true
Z == A false

The list of the available methods is available on GoDoc: http://godoc.org/github.com/galeone/tfgo

Computer Vision using data flow graph

TensorFlow is rich of methods for performing operations on images. tfgo provides the image package that allows using the Go bindings to perform computer vision tasks in an elegant way.

For instance, it's possible to read an image, compute its directional derivative along the horizontal and vertical directions, compute the gradient and save it.

The code below does that, showing the different results achieved using correlation and convolution operations.

package main

import (
        tg "github.com/galeone/tfgo"
        "github.com/galeone/tfgo/image"
        "github.com/galeone/tfgo/image/filter"
        "github.com/galeone/tfgo/image/padding"
        tf "github.com/galeone/tensorflow/tensorflow/go"
        "os"
)

func main() {
        root := tg.NewRoot()
        grayImg := image.Read(root, "/home/pgaleone/airplane.png", 1)
        grayImg = grayImg.Scale(0, 255)

        // Edge detection using sobel filter: convolution
        Gx := grayImg.Clone().Convolve(filter.SobelX(root), image.Stride{X: 1, Y: 1}, padding.SAME)
        Gy := grayImg.Clone().Convolve(filter.SobelY(root), image.Stride{X: 1, Y: 1}, padding.SAME)
        convoluteEdges := image.NewImage(root.SubScope("edge"), Gx.Square().Add(Gy.Square().Value()).Sqrt().Value()).EncodeJPEG()

        Gx = grayImg.Clone().Correlate(filter.SobelX(root), image.Stride{X: 1, Y: 1}, padding.SAME)
        Gy = grayImg.Clone().Correlate(filter.SobelY(root), image.Stride{X: 1, Y: 1}, padding.SAME)
        correlateEdges := image.NewImage(root.SubScope("edge"), Gx.Square().Add(Gy.Square().Value()).Sqrt().Value()).EncodeJPEG()

        results := tg.Exec(root, []tf.Output{convoluteEdges, correlateEdges}, nil, &tf.SessionOptions{})

        file, _ := os.Create("convolved.png")
        file.WriteString(results[0].Value().(string))
        file.Close()

        file, _ = os.Create("correlated.png")
        file.WriteString(results[1].Value().(string))
        file.Close()
}

airplane.png

airplane

convolved.png

convolved

correlated.png

correlated

the list of the available methods is available on GoDoc: http://godoc.org/github.com/galeone/tfgo/image

Train in Python, Serve in Go

TensorFlow 2 comes with a lot of easy way to export a computational graph (e.g. Keras model, or a function decorated with @tf.function) to the SavedModel serialization format (that's the only one officially supported).

saved model

Using TensorFlow 2 (with Keras or tf.function) + tfgo, exporting a trained model (or a generic computational graph) and use it in Go is straightforward.

Just dig into the example to understand how to serve a trained model with tfgo.

Python code
import tensorflow as tf

model = tf.keras.Sequential(
    [
        tf.keras.layers.Conv2D(
            8,
            (3, 3),
            strides=(2, 2),
            padding="valid",
            input_shape=(28, 28, 1),
            activation=tf.nn.relu,
            name="inputs",
        ),  # 14x14x8
        tf.keras.layers.Conv2D(
            16, (3, 3), strides=(2, 2), padding="valid", activation=tf.nn.relu
        ),  # 7x716
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10, name="logits"),  # linear
    ]
)

tf.saved_model.save(model, "output/keras")

Go code
package main

import (
        "fmt"
        tg "github.com/galeone/tfgo"
        tf "github.com/galeone/tensorflow/tensorflow/go"
)

func main() {
        // A model exported with tf.saved_model.save()
        // automatically comes with the "serve" tag because the SavedModel
        // file format is designed for serving.
        // This tag contains the various functions exported. Among these, there is
        // always present the "serving_default" signature_def. This signature def
        // works exactly like the TF 1.x graph. Get the input tensor and the output tensor,
        // and use them as placeholder to feed and output to get, respectively.

        // To get info inside a SavedModel the best tool is saved_model_cli
        // that comes with the TensorFlow Python package.

        // e.g. saved_model_cli show --all --dir output/keras
        // gives, among the others, this info:

        // signature_def['serving_default']:
        // The given SavedModel SignatureDef contains the following input(s):
        //   inputs['inputs_input'] tensor_info:
        //       dtype: DT_FLOAT
        //       shape: (-1, 28, 28, 1)
        //       name: serving_default_inputs_input:0
        // The given SavedModel SignatureDef contains the following output(s):
        //   outputs['logits'] tensor_info:
        //       dtype: DT_FLOAT
        //       shape: (-1, 10)
        //       name: StatefulPartitionedCall:0
        // Method name is: tensorflow/serving/predict

        model := tg.LoadModel("test_models/output/keras", []string{"serve"}, nil)

        fakeInput, _ := tf.NewTensor([1][28][28][1]float32{})
        results := model.Exec([]tf.Output{
                model.Op("StatefulPartitionedCall", 0),
        }, map[tf.Output]*tf.Tensor{
                model.Op("serving_default_inputs_input", 0): fakeInput,
        })

        predictions := results[0]
        fmt.Println(predictions.Value())
}

Why?

Thinking about computation represented using graphs, describing computing in this way is, in one word, challenging.

Also, tfgo brings GPU computations to Go and allows writing parallel code without worrying about the device that executes it (just place the graph into the device you desire: that's it!)

Contribute

I love contributions. Seriously. Having people that share your same interests and want to face your same challenges it's something awesome.

If you'd like to contribute, just dig in the code and see what can be added or improved. Start a discussion opening an issue and let's talk about it.

Just follow the same design I use into the image package ("override" the same Tensor methods, document the methods, test your changes, ...)

There are a lot of packages that can be added, like the image package. Feel free to work on a brand new package: I'd love to see this kind of contributions!

TensorFlow installation

Manual

On MacOS you can brew install libtensorflow (assuming you have brew installed. Brew is a package manager. If you need help installing brew follow instructions here: https://docs.brew.sh/Installation )

Download and install the C library from https://www.tensorflow.org/install/lang_c

curl -L "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.9.1.tar.gz" | sudo tar -C /usr/local -xz
sudo ldconfig
Docker
docker pull tensorflow/tensorflow:2.9.1

Or you can use system package manager.

Documentation

Overview

Package tfgo simplifies the usage of the Tensorflow's go bindings wrapping the most common methods as methods of new and logically separated objects. These objects handle the naming issues (that could happen when describing a tf.Graph) in a transparent way. Also, additional features are added. Why this package is required is explained in this blog post: https://pgaleone.eu/tensorflow/go/2017/05/29/understanding-tensorflow-using-go/

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func Batchify

func Batchify(scope *op.Scope, tensors []tf.Output) tf.Output

Batchify creates a batch of tensors, concatenating them along the first dimension

func Cast

func Cast(scope *op.Scope, value tf.Output, dtype tf.DataType) tf.Output

Cast casts value to the specified dtype

func Const

func Const(scope *op.Scope, value interface{}) tf.Output

Const creates a constant value within the specified scope

func Exec

func Exec(scope *op.Scope, tensors []tf.Output, feedDict map[tf.Output]*tf.Tensor, options *tf.SessionOptions) []*tf.Tensor

Exec creates the computation graph from the scope, then executes the operations required to compute each element of tensors. Node in the graph can be overwritten with feedDict. The session options can be specified using the session parameter. Returns the evaluated tensors. Panics on error.

func IsClose

func IsClose(scope *op.Scope, a, b tf.Output, relTol, absTol tf.Output) tf.Output

IsClose defines the isclose operation between a and b. Returns a conditional node that is true when a is close to b. relTol is the relative tolerance absTol is the absolute tolerance

func IsFloat

func IsFloat(dtype tf.DataType) bool

IsFloat returns true if dtype is a tensorfow float type

func IsInteger

func IsInteger(dtype tf.DataType) bool

IsInteger returns true if dtype is a tensorflow integer type

func MaxValue

func MaxValue(dtype tf.DataType) float64

MaxValue returns the maximum value accepted for the dtype

func MinValue

func MinValue(dtype tf.DataType) float64

MinValue returns the minimum representable value for the specified dtype

func NewRoot

func NewRoot() *op.Scope

NewRoot creates a new *op.Scope, empty

func NewScope

func NewScope(root *op.Scope) *op.Scope

NewScope returns a unique scope in the format input_<suffix> where suffix is a counter This function isthread safe can be called in parallel for DIFFERENT scopes.

Types

type Model

type Model struct {
	// contains filtered or unexported fields
}

Model represents a trained model

func ImportModel

func ImportModel(serializedModel, prefix string, options *tf.SessionOptions) (model *Model)

ImportModel creates a new *Model, loading the graph from the serialized representation. This operation creates a session with specified `options` Panics if the model can't be loaded

func LoadModel

func LoadModel(exportDir string, tags []string, options *tf.SessionOptions) (model *Model)

LoadModel creates a new *Model, loading it from the exportDir. The graph loaded is identified by the set of tags specified when exporting it. This operation creates a session with specified `options` Panics if the model can't be loaded

func (*Model) Exec

func (model *Model) Exec(tensors []tf.Output, feedDict map[tf.Output]*tf.Tensor) (results []*tf.Tensor)

Exec executes the nodes/tensors that must be present in the loaded model feedDict values to feed to placeholders (that must have been saved in the model definition) panics on error

func (*Model) Op

func (model *Model) Op(name string, idx int) tf.Output

Op extracts the output in position idx of the tensor with the specified name from the model graph

type Tensor

type Tensor struct {
	// Root: Each tensor maintains a pointer to the graph root
	Root *op.Scope
	// Path is the current Tensor full path
	Path *op.Scope
	// Output is the Tensor content
	Output tf.Output
}

Tensor is an high level abstraction for the tf.Output structure, associating a scope to the Tensor

func NewTensor

func NewTensor(scope *op.Scope, tfout tf.Output) (tensor *Tensor)

NewTensor creates a *Tensor from a tf.Output Place the cloned tensor within the specified scope

func (*Tensor) Add

func (tensor *Tensor) Add(tfout tf.Output) *Tensor

Add defines the add operation between the tensor and tfout `tfout` dtype is converted to tensor.Dtype() before adding

func (*Tensor) Cast

func (tensor *Tensor) Cast(dtype tf.DataType) *Tensor

Cast casts the current tensor to the requested dtype

func (*Tensor) Check

func (tensor *Tensor) Check()

Check checks if the previous operation caused an error and thus tensor.Path.Err is not nil. If it's not, panics because we're defining the graph in a wrong way

func (*Tensor) Clone

func (tensor *Tensor) Clone() *Tensor

Clone returns a copy of the current tensor in a new scope Clone is used to create a different tensor from the output of an operation. The new node is placed at the same level of the current tensor it can be seen as a twin tensor

func (*Tensor) Dtype

func (tensor *Tensor) Dtype() tf.DataType

Dtype returns the tensor dtype

func (*Tensor) MatMul

func (tensor *Tensor) MatMul(tfout tf.Output) *Tensor

MatMul defines the matrix multiplication operation between the tensor and `tfout`. `tfout` dtype is converted to tensor.Dtype() before multiplying

func (*Tensor) Mul

func (tensor *Tensor) Mul(tfout tf.Output) *Tensor

Mul defines the multiplication operation between the tensor and `tfout`. It's the multiplication element-wise with broadcasting support. `tfout` dtype is converted to tensor.Dtype() before multiplying

func (*Tensor) Pow

func (tensor *Tensor) Pow(y tf.Output) *Tensor

Pow defines the pow operation x^y, where x are the tensor values y dtype is converted to tensor.Dtype() before executing Pow

func (*Tensor) Scope

func (tensor *Tensor) Scope() *op.Scope

Scope returns the scope associated to the tensor

func (*Tensor) Shape32

func (tensor *Tensor) Shape32(firstDimension bool) []int32

Shape32 returns the shape of the tensor as []int32. If firstDimension is true a 4 elements slice is returned. Otherwise a 3 elements slice is returned.

func (*Tensor) Shape64

func (tensor *Tensor) Shape64(firstDimension bool) []int64

Shape64 returns the shape of the tensor as []int64. If firstDimension is true a 4 elements slice is returned. Otherwise a 3 elements slice is returned.

func (*Tensor) Sqrt

func (tensor *Tensor) Sqrt() *Tensor

Sqrt defines the square root operation for the tensor values

func (*Tensor) Square

func (tensor *Tensor) Square() *Tensor

Square defines the square operation for the tensor values

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL