torch

package module
v0.0.0-...-8e0162c Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jan 8, 2023 License: MIT Imports: 6 Imported by: 0

README

GoTorch

A Golang front-end for libtorch.

Development

Initialize

To initialize the go environment (e.g., download requirements, cleanup, etc.)

./main.sh init
Compile

To build the libcgotorch bridging library for libtorch and compile go code:

./main.sh build
Unit Test

To run test cases (after compiling code with ./main.sh build):

./main.sh test
Continuous Integration

For convenience, the last three commands can be executed in series using:

./main.sh ci

Docker Container

By default, development code is compiled directly for the host operating system (assumed to be either MacOS or Linux.) Development functions can optionally be executed within a Docker container using the -d flag. First, to build the development Docker image:

./main.sh builddocker

For example, to run CI within the Docker container, use:

./main.sh -d ci

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func AllClose

func AllClose(tensor, other Tensor, rtol, atol float64) bool

func Equal

func Equal(tensor, other Tensor) bool

func IsComplex

func IsComplex(tensor Tensor) bool

func IsConj

func IsConj(tensor Tensor) bool

func IsFloatingPoint

func IsFloatingPoint(tensor Tensor) bool

func IsGradEnabled

func IsGradEnabled() bool

/ Return the global gradient generation state.

func IsNonzero

func IsNonzero(tensor Tensor) bool

func ManualSeed

func ManualSeed(seed int64)

Set the random number generator seed.

func Numel

func Numel(tensor Tensor) int64

func SetGradEnabled

func SetGradEnabled(value bool)

/ Set the global gradient generation state.

func SetNumThreads

func SetNumThreads(numThreads int32)

/ Set the number of threads used for intraop parallelism on CPU.

func StdMean

func StdMean(tensor Tensor) (Tensor, Tensor)

Reduce a tensor to its mean value and standard deviation.

func StdMeanByDim

func StdMeanByDim(tensor Tensor, dim int, unbiased bool, keep_dims bool) (Tensor, Tensor)

Reduce a tensor to its mean value and standard deviation along given dimension.

func ToSlice

func ToSlice(tensor Tensor) interface{}

Convert a torch Tensor to a Go slice. This function implies a flattening of the tensor to return 1-dimensional vectors.

func VarMean

func VarMean(tensor Tensor) (Tensor, Tensor)

Reduce a tensor to its mean value and variance.

func VarMeanByDim

func VarMeanByDim(tensor Tensor, dim int, unbiased bool, keep_dims bool) (Tensor, Tensor)

Reduce a tensor to its mean value and variance along given dimension.

Types

type Device

type Device struct {
	T C.Device
}

Device wrapper a pointer to C.Device

func NewDevice

func NewDevice(deviceType string) Device

NewDevice returns a Device

type Dtype

type Dtype int8

An enumeration of the data-types in libtorch.

const (
	Byte          Dtype = iota // Byte Dtype 0
	Char                       // Char Dtype 1
	Short                      // Short Dtype 2
	Int                        // Int Dtype 3
	Long                       // Long Dtype 4
	Half                       // Half Dtype 5
	Float                      // Float Dtype 6
	Double                     // Double Dtype 7
	ComplexHalf                // ComplexHalf Dtype 8
	ComplexFloat               // ComplexFloat Dtype 9
	ComplexDouble              // ComplexDouble Dtype 10
	Bool                       // Bool Dtype 11
	QInt8                      // QInt8 Dtype 12
	QUInt8                     // QUInt8 Dtype 13
	QInt32                     // QInt32 Dtype 14
	BFloat16                   // BFloat16 Dtype 15
	Invalid       Dtype = -1   // Invalid Dtype
)

func GetDtypeOfKind

func GetDtypeOfKind(kind reflect.Kind) Dtype

Map an element type kind to its associated Dtype.

type IValue

type IValue struct {
	// A pointer to a C.IValue.
	T *unsafe.Pointer
}

IValue wraps a pointer to a C.IValue as an unsafe Pointer.

func NewIValue

func NewIValue(data interface{}) IValue

Create a new IValue from arbitrary data.

func (IValue) IsBool

func (ivalue IValue) IsBool() bool

func (IValue) IsBoolList

func (ivalue IValue) IsBoolList() bool

func (IValue) IsCapsule

func (ivalue IValue) IsCapsule() bool

func (IValue) IsComplexDouble

func (ivalue IValue) IsComplexDouble() bool

func (IValue) IsComplexDoubleList

func (ivalue IValue) IsComplexDoubleList() bool

func (IValue) IsCustomClass

func (ivalue IValue) IsCustomClass() bool

func (IValue) IsDevice

func (ivalue IValue) IsDevice() bool

func (IValue) IsDouble

func (ivalue IValue) IsDouble() bool

func (IValue) IsDoubleList

func (ivalue IValue) IsDoubleList() bool

func (IValue) IsEnum

func (ivalue IValue) IsEnum() bool

func (IValue) IsFuture

func (ivalue IValue) IsFuture() bool

func (IValue) IsGenerator

func (ivalue IValue) IsGenerator() bool

func (IValue) IsGenericDict

func (ivalue IValue) IsGenericDict() bool

func (IValue) IsInt

func (ivalue IValue) IsInt() bool

func (IValue) IsIntList

func (ivalue IValue) IsIntList() bool

func (IValue) IsList

func (ivalue IValue) IsList() bool

func (IValue) IsModule

func (ivalue IValue) IsModule() bool

func (IValue) IsNil

func (ivalue IValue) IsNil() bool

func (IValue) IsObject

func (ivalue IValue) IsObject() bool

func (IValue) IsPtrType

func (ivalue IValue) IsPtrType() bool

func (IValue) IsPyObject

func (ivalue IValue) IsPyObject() bool

func (IValue) IsQuantizer

func (ivalue IValue) IsQuantizer() bool

func (IValue) IsRRef

func (ivalue IValue) IsRRef() bool

func (IValue) IsScalar

func (ivalue IValue) IsScalar() bool

func (IValue) IsStorage

func (ivalue IValue) IsStorage() bool

func (IValue) IsStream

func (ivalue IValue) IsStream() bool

func (IValue) IsString

func (ivalue IValue) IsString() bool

func (IValue) IsTensor

func (ivalue IValue) IsTensor() bool

func (IValue) IsTensorList

func (ivalue IValue) IsTensorList() bool

func (IValue) IsTuple

func (ivalue IValue) IsTuple() bool

func (IValue) LengthDict

func (ivalue IValue) LengthDict() int64

func (IValue) LengthList

func (ivalue IValue) LengthList() int64

func (IValue) LengthTuple

func (ivalue IValue) LengthTuple() int64

func (IValue) ToBool

func (ivalue IValue) ToBool() bool

func (IValue) ToComplexDouble

func (ivalue IValue) ToComplexDouble() complex128

func (IValue) ToDevice

func (ivalue IValue) ToDevice() Device

func (IValue) ToDouble

func (ivalue IValue) ToDouble() float64

func (IValue) ToGenericDict

func (ivalue IValue) ToGenericDict() map[interface{}]IValue

func (IValue) ToInt

func (ivalue IValue) ToInt() int

func (IValue) ToList

func (ivalue IValue) ToList() []IValue

func (IValue) ToNone

func (ivalue IValue) ToNone() string

func (IValue) ToString

func (ivalue IValue) ToString() string

func (IValue) ToTensor

func (ivalue IValue) ToTensor() Tensor

func (IValue) ToTensorList

func (ivalue IValue) ToTensorList() []Tensor

func (IValue) ToTuple

func (ivalue IValue) ToTuple() []IValue

type Tensor

type Tensor struct {
	T *unsafe.Pointer
}

Tensor wraps a pointer to a C.Tensor as an unsafe Pointer.

func Abs

func Abs(tensor Tensor) Tensor

func Abs_

func Abs_(tensor Tensor) Tensor

func Add

func Add(tensor, other Tensor, alpha float32) Tensor

func All

func All(tensor Tensor) Tensor

Check if all values in the tensor evaluate to true.

func AllByDim

func AllByDim(tensor Tensor, dim int, keep_dims bool) Tensor

Check if all values in the tensor evaluate to true along the given dimension.

func Any

func Any(tensor Tensor) Tensor

Check if any values in the tensor evaluate to true.

func AnyByDim

func AnyByDim(tensor Tensor, dim int, keep_dims bool) Tensor

Check if any values in the tensor evaluate to true along the given dimension.

func Arange

func Arange(begin, end, step float32, options TensorOptions) Tensor

func Argmax

func Argmax(tensor Tensor) Tensor

Reduce a tensor to its maximum index.

func ArgmaxByDim

func ArgmaxByDim(tensor Tensor, dim int, keep_dims bool) Tensor

Reduce a tensor to its maximum index along the given dimension.

func Argmin

func Argmin(tensor Tensor) Tensor

Reduce a tensor to its minimum index.

func ArgminByDim

func ArgminByDim(tensor Tensor, dim int, keep_dims bool) Tensor

Reduce a tensor to its minimum index along the given dimension.

func Cat

func Cat(tensors []Tensor, dim int64) Tensor

func Decode

func Decode(buffer []byte) (Tensor, error)

Decode a pickled tensor back into a structured numerical format.

func Div

func Div(tensor, other Tensor) Tensor

func Empty

func Empty(size []int64, options TensorOptions) Tensor

func EmptyLike

func EmptyLike(reference Tensor) Tensor

func Eq

func Eq(tensor, other Tensor) Tensor

func Eye

func Eye(n, m int64, options TensorOptions) Tensor

func Flatten

func Flatten(tensor Tensor, startDim, endDim int64) Tensor

func Full

func Full(size []int64, value float32, options TensorOptions) Tensor

func FullLike

func FullLike(reference Tensor, value float32) Tensor

func Greater

func Greater(tensor, other Tensor) Tensor

func GreaterEqual

func GreaterEqual(tensor, other Tensor) Tensor

func IndexSelect

func IndexSelect(tensor Tensor, dim int64, index Tensor) Tensor

func IsClose

func IsClose(tensor, other Tensor, rtol, atol float64) Tensor

func IsFinite

func IsFinite(tensor Tensor) Tensor

func IsIn

func IsIn(tensor, other Tensor) Tensor

func IsInf

func IsInf(tensor Tensor) Tensor

func IsNaN

func IsNaN(tensor Tensor) Tensor

func IsNegInf

func IsNegInf(tensor Tensor) Tensor

func IsPosInf

func IsPosInf(tensor Tensor) Tensor

func IsReal

func IsReal(tensor Tensor) Tensor

func Less

func Less(tensor, other Tensor) Tensor

func LessEqual

func LessEqual(tensor, other Tensor) Tensor

func Linspace

func Linspace(begin, end float32, steps int64, options TensorOptions) Tensor

func Load

func Load(path string) (Tensor, error)

Load a tensor from the given path.

func LogSoftmax

func LogSoftmax(t Tensor, dim int64) Tensor

func Logspace

func Logspace(begin, end float32, steps int64, base float64, options TensorOptions) Tensor

func MM

func MM(a, b Tensor) Tensor

func Max

func Max(tensor Tensor) Tensor

Reduce a tensor to its maximum value.

func Maximum

func Maximum(tensor, other Tensor) Tensor

func Mean

func Mean(tensor Tensor) Tensor

Reduce a tensor to its mean value.

func MeanByDim

func MeanByDim(tensor Tensor, dim int, keep_dims bool) Tensor

Reduce a tensor to its mean value along the given dimension.

func Median

func Median(tensor Tensor) Tensor

Reduce a tensor to its median value.

func Min

func Min(tensor Tensor) Tensor

Reduce a tensor to its minimum value.

func Minimum

func Minimum(tensor, other Tensor) Tensor

func Mul

func Mul(tensor, other Tensor) Tensor

func NewTensor

func NewTensor(data interface{}) Tensor

Create a tensor from a Go slice.

func NewTensorFromBlob

func NewTensorFromBlob(data unsafe.Pointer, dtype Dtype, sizes []int64) Tensor

Create a new tensor that clones existing contiguous memory pointed to by data, of given data-type, and with given size. This function copies the input data to subsequent in-place operations performed on the tensor will not mutate the input data.

func NewTorchTensor

func NewTorchTensor(tensor *unsafe.Pointer) Tensor

Create a new tensor and configure garbage collection.

func NotEqual

func NotEqual(tensor, other Tensor) Tensor

func Ones

func Ones(size []int64, options TensorOptions) Tensor

func OnesLike

func OnesLike(reference Tensor) Tensor

func Permute

func Permute(tensor Tensor, dims ...int64) Tensor

func Pow

func Pow(tensor Tensor, exponent float64) Tensor

func Rand

func Rand(size []int64, options TensorOptions) Tensor

func RandInt

func RandInt(size []int64, low int64, high int64, options TensorOptions) Tensor

func RandIntLike

func RandIntLike(reference Tensor, low int64, high int64) Tensor

func RandLike

func RandLike(reference Tensor) Tensor

func RandN

func RandN(size []int64, options TensorOptions) Tensor

func RandNLike

func RandNLike(reference Tensor) Tensor

func Range

func Range(begin, end, step float32, options TensorOptions) Tensor

func Reshape

func Reshape(tensor Tensor, shape ...int64) Tensor

func Sigmoid

func Sigmoid(t Tensor) Tensor

Sigmoid returns sigmoid of the current tensor

func Slice

func Slice(tensor Tensor, dim, start, stop, step int64) Tensor

func Sqrt

func Sqrt(tensor Tensor) Tensor

func Sqrt_

func Sqrt_(tensor Tensor) Tensor

func Square

func Square(tensor Tensor) Tensor

func Square_

func Square_(tensor Tensor) Tensor

func Squeeze

func Squeeze(tensor Tensor, dim ...int64) Tensor

func Stack

func Stack(tensors []Tensor, dim int64) Tensor

func Std

func Std(tensor Tensor) Tensor

Reduce a tensor to its standard deviation.

func StdByDim

func StdByDim(tensor Tensor, dim int, unbiased bool, keep_dims bool) Tensor

Reduce a tensor to its standard deviation along the given dimension.

func Sub

func Sub(tensor, other Tensor, alpha float32) Tensor

func Sum

func Sum(tensor Tensor) Tensor

Reduce a tensor to its sum.

func SumByDim

func SumByDim(tensor Tensor, dim int, keep_dims bool) Tensor

Reduce a tensor to its sum along the given dimension.

func Tanh

func Tanh(t Tensor) Tensor

Tanh returns tanh of the current tensor

func TensorFromBlob

func TensorFromBlob(data unsafe.Pointer, dtype Dtype, sizes []int64) Tensor

Create a tensor view that wraps around existing contiguous memory pointed to by data, of given data-type, and with given size. This function does not copy the data buffer so in-place operations performed on the tensor will mutate the input data.

func Transpose

func Transpose(tensor Tensor, dim0, dim1 int64) Tensor

func Unsqueeze

func Unsqueeze(tensor Tensor, dim int64) Tensor

func Var

func Var(tensor Tensor) Tensor

Reduce a tensor to its variance.

func VarByDim

func VarByDim(tensor Tensor, dim int, unbiased bool, keep_dims bool) Tensor

Reduce a tensor to its variance along the given dimension.

func Zeros

func Zeros(size []int64, options TensorOptions) Tensor

func ZerosLike

func ZerosLike(reference Tensor) Tensor

func (Tensor) Abs

func (tensor Tensor) Abs() Tensor

func (Tensor) Abs_

func (tensor Tensor) Abs_() Tensor

func (Tensor) Add

func (tensor Tensor) Add(other Tensor, alpha float32) Tensor

func (Tensor) Add_

func (tensor Tensor) Add_(other Tensor, alpha float32) Tensor

func (Tensor) All

func (tensor Tensor) All() Tensor

Check if all values in the tensor evaluate to true.

func (Tensor) AllByDim

func (tensor Tensor) AllByDim(dim int, keep_dims bool) Tensor

Check if all values in the tensor evaluate to true along the given dimension.

func (Tensor) AllClose

func (tensor Tensor) AllClose(other Tensor, rtol, atol float64) bool

func (Tensor) Any

func (tensor Tensor) Any() Tensor

Check if any values in the tensor evaluate to true.

func (Tensor) AnyByDim

func (tensor Tensor) AnyByDim(dim int, keep_dims bool) Tensor

Check if any values in the tensor evaluate to true along the given dimension.

func (Tensor) Argmax

func (tensor Tensor) Argmax() Tensor

Reduce a tensor to its maximum index.

func (Tensor) ArgmaxByDim

func (tensor Tensor) ArgmaxByDim(dim int, keep_dims bool) Tensor

Reduce a tensor to its maximum index along the given dimension.

func (Tensor) Argmin

func (tensor Tensor) Argmin() Tensor

Reduce a tensor to its minimum index.

func (Tensor) ArgminByDim

func (tensor Tensor) ArgminByDim(dim int, keep_dims bool) Tensor

Reduce a tensor to its minimum index along the given dimension.

func (Tensor) Backward

func (tensor Tensor) Backward()

func (Tensor) CastTo

func (tensor Tensor) CastTo(dtype Dtype) Tensor

func (Tensor) Clone

func (tensor Tensor) Clone() Tensor

func (Tensor) CopyTo

func (tensor Tensor) CopyTo(device Device) Tensor

func (Tensor) Copy_

func (tensor Tensor) Copy_(b Tensor)

SetData sets the tensor data held by b to a

func (Tensor) Detach

func (tensor Tensor) Detach() Tensor

func (Tensor) Dim

func (tensor Tensor) Dim() int64

Return the number of dimensions the tensor occupies.

func (Tensor) Div

func (tensor Tensor) Div(other Tensor) Tensor

func (Tensor) Div_

func (tensor Tensor) Div_(other Tensor) Tensor

func (Tensor) Dtype

func (tensor Tensor) Dtype() Dtype

Return the data-type of the tensor.

func (Tensor) Encode

func (tensor Tensor) Encode() ([]byte, error)

Encode a tensor into a pickled representation. Tensors are copied to the CPU before encoding.

func (Tensor) Eq

func (tensor Tensor) Eq(other Tensor) Tensor

func (Tensor) Equal

func (tensor Tensor) Equal(other Tensor) bool

func (Tensor) Expand

func (tensor Tensor) Expand(shape ...int64) Tensor

func (Tensor) ExpandAs

func (tensor Tensor) ExpandAs(other Tensor) Tensor

func (Tensor) Flatten

func (tensor Tensor) Flatten(startDim, endDim int64) Tensor

func (Tensor) Grad

func (tensor Tensor) Grad() Tensor

func (Tensor) Greater

func (tensor Tensor) Greater(other Tensor) Tensor

func (Tensor) GreaterEqual

func (tensor Tensor) GreaterEqual(other Tensor) Tensor

func (Tensor) Index

func (tensor Tensor) Index(index Tensor) Tensor

func (Tensor) IndexSelect

func (tensor Tensor) IndexSelect(dim int64, index Tensor) Tensor

func (Tensor) IsClose

func (tensor Tensor) IsClose(other Tensor, rtol, atol float64) Tensor

func (Tensor) IsComplex

func (tensor Tensor) IsComplex() bool

func (Tensor) IsConj

func (tensor Tensor) IsConj() bool

func (Tensor) IsFinite

func (tensor Tensor) IsFinite() Tensor

func (Tensor) IsFloatingPoint

func (tensor Tensor) IsFloatingPoint() bool

func (Tensor) IsIn

func (tensor Tensor) IsIn(other Tensor) Tensor

func (Tensor) IsInf

func (tensor Tensor) IsInf() Tensor

func (Tensor) IsNaN

func (tensor Tensor) IsNaN() Tensor

func (Tensor) IsNegInf

func (tensor Tensor) IsNegInf() Tensor

func (Tensor) IsNonzero

func (tensor Tensor) IsNonzero() bool

func (Tensor) IsPosInf

func (tensor Tensor) IsPosInf() Tensor

func (Tensor) IsReal

func (tensor Tensor) IsReal() Tensor

func (Tensor) Item

func (tensor Tensor) Item() interface{}

Return the value of this tensor as a standard Go number. This only works for tensors with one element. @details users should do type assertion and get the value like:

```

if value, ok := a.Item().(float64); ok {
    // process the value
}

```

This function currently only supports signed data types.

func (Tensor) Less

func (tensor Tensor) Less(other Tensor) Tensor

func (Tensor) LessEqual

func (tensor Tensor) LessEqual(other Tensor) Tensor

func (Tensor) LogSoftmax

func (a Tensor) LogSoftmax(dim int64) Tensor

func (Tensor) MM

func (tensor Tensor) MM(other Tensor) Tensor

func (Tensor) Max

func (tensor Tensor) Max() Tensor

Reduce a tensor to its maximum value.

func (Tensor) MaxByDim

func (tensor Tensor) MaxByDim(dim int, keep_dims bool) ValueIndexPair

Reduce a tensor to its maximum value along the given dimension.

func (Tensor) Maximum

func (tensor Tensor) Maximum(other Tensor) Tensor

func (Tensor) Mean

func (tensor Tensor) Mean() Tensor

Reduce a tensor to its mean value.

func (Tensor) MeanByDim

func (tensor Tensor) MeanByDim(dim int, keep_dims bool) Tensor

Reduce a tensor to its mean value along the given dimension.

func (Tensor) Median

func (tensor Tensor) Median() Tensor

Reduce a tensor to its median value.

func (Tensor) MedianByDim

func (tensor Tensor) MedianByDim(dim int, keep_dims bool) ValueIndexPair

Reduce a tensor to its median value along the given dimension.

func (Tensor) Min

func (tensor Tensor) Min() Tensor

Reduce a tensor to its minimum value.

func (Tensor) MinByDim

func (tensor Tensor) MinByDim(dim int, keep_dims bool) ValueIndexPair

Reduce a tensor to its minimum value along the given dimension.

func (Tensor) Minimum

func (tensor Tensor) Minimum(other Tensor) Tensor

func (Tensor) Mul

func (tensor Tensor) Mul(other Tensor) Tensor

func (Tensor) Mul_

func (tensor Tensor) Mul_(other Tensor) Tensor

func (Tensor) NotEqual

func (tensor Tensor) NotEqual(other Tensor) Tensor

func (Tensor) Numel

func (tensor Tensor) Numel() int64

func (Tensor) Permute

func (tensor Tensor) Permute(dims ...int64) Tensor

func (Tensor) Pow

func (tensor Tensor) Pow(exponent float64) Tensor

func (Tensor) RequiresGrad

func (tensor Tensor) RequiresGrad() bool

func (Tensor) Reshape

func (tensor Tensor) Reshape(shape ...int64) Tensor

func (Tensor) ReshapeAs

func (tensor Tensor) ReshapeAs(other Tensor) Tensor

func (Tensor) Save

func (tensor Tensor) Save(path string) error

Save the tensor to the given path.

func (Tensor) SetData

func (tensor Tensor) SetData(b Tensor)

SetData sets the tensor data held by b to a

func (Tensor) SetRequiresGrad

func (tensor Tensor) SetRequiresGrad(requiresGrad bool)

func (Tensor) Shape

func (tensor Tensor) Shape() []int64

Return the shape of the tensor data.

func (Tensor) Sigmoid

func (a Tensor) Sigmoid() Tensor

Sigmoid returns sigmoid of the current tensor

func (Tensor) Slice

func (tensor Tensor) Slice(dim, start, stop, step int64) Tensor

func (Tensor) Sort

func (tensor Tensor) Sort(dim int64, descending bool) ValueIndexPair

func (Tensor) Sqrt

func (tensor Tensor) Sqrt() Tensor

func (Tensor) Sqrt_

func (tensor Tensor) Sqrt_() Tensor

func (Tensor) Square

func (tensor Tensor) Square() Tensor

func (Tensor) Square_

func (tensor Tensor) Square_() Tensor

func (Tensor) Squeeze

func (tensor Tensor) Squeeze(dim ...int64) Tensor

func (Tensor) Std

func (tensor Tensor) Std() Tensor

Reduce a tensor to its standard deviation.

func (Tensor) StdByDim

func (tensor Tensor) StdByDim(dim int, unbiased bool, keep_dims bool) Tensor

Reduce a tensor to its standard deviation along the given dimension.

func (Tensor) StdMean

func (tensor Tensor) StdMean() (Tensor, Tensor)

Reduce a tensor to its mean value and standard deviation.

func (Tensor) StdMeanByDim

func (tensor Tensor) StdMeanByDim(dim int, unbiased bool, keep_dims bool) (Tensor, Tensor)

Reduce a tensor to its mean value and standard deviation along given dimension.

func (Tensor) String

func (tensor Tensor) String() string

Convert the tensor to a string representation.

func (Tensor) Sub

func (tensor Tensor) Sub(other Tensor, alpha float32) Tensor

func (Tensor) Sub_

func (tensor Tensor) Sub_(other Tensor, alpha float32) Tensor

func (Tensor) Sum

func (tensor Tensor) Sum() Tensor

Reduce a tensor to its sum.

func (Tensor) SumByDim

func (tensor Tensor) SumByDim(dim int, keep_dims bool) Tensor

Reduce a tensor to its sum along the given dimension.

func (Tensor) Tanh

func (a Tensor) Tanh() Tensor

Tanh returns tanh of the current tensor

func (Tensor) To

func (tensor Tensor) To(device Device, dtype Dtype) Tensor

func (Tensor) ToSlice

func (tensor Tensor) ToSlice() interface{}

func (Tensor) TopK

func (tensor Tensor) TopK(k, dim int64, largest, sorted bool) ValueIndexPair

func (Tensor) Transpose

func (tensor Tensor) Transpose(dim0, dim1 int64) Tensor

func (Tensor) Unsqueeze

func (tensor Tensor) Unsqueeze(dim int64) Tensor

func (Tensor) Var

func (tensor Tensor) Var() Tensor

Reduce a tensor to its variance.

func (Tensor) VarByDim

func (tensor Tensor) VarByDim(dim int, unbiased bool, keep_dims bool) Tensor

Reduce a tensor to its variance along the given dimension.

func (Tensor) VarMean

func (tensor Tensor) VarMean() (Tensor, Tensor)

Reduce a tensor to its mean value and variance.

func (Tensor) VarMeanByDim

func (tensor Tensor) VarMeanByDim(dim int, unbiased bool, keep_dims bool) (Tensor, Tensor)

Reduce a tensor to its mean value and variance along given dimension.

func (Tensor) View

func (tensor Tensor) View(shape ...int64) Tensor

func (Tensor) ViewAs

func (tensor Tensor) ViewAs(other Tensor) Tensor

type TensorOptions

type TensorOptions struct {
	// A pointer to a C.TensorOptions.
	T *unsafe.Pointer
}

TensorOptions wraps a pointer to a C.TensorOptions as an unsafe Pointer.

func NewTensorOptions

func NewTensorOptions() TensorOptions

Create a new TensorOptions.

func (TensorOptions) Device

func (options TensorOptions) Device(device Device) TensorOptions

Create a new TensorOptions with the given compute device.

func (TensorOptions) Dtype

func (options TensorOptions) Dtype(value Dtype) TensorOptions

Create a new TensorOptions with the given data type.

func (TensorOptions) PinnedMemory

func (options TensorOptions) PinnedMemory(pinnedMemory bool) TensorOptions

Create a new TensorOptions with the given memory pinning state.

func (TensorOptions) RequiresGrad

func (options TensorOptions) RequiresGrad(requiresGrad bool) TensorOptions

Create a new TensorOptions with the given gradient taping state.

type ValueIndexPair

type ValueIndexPair struct {
	Values, Indices Tensor
}

A representation of the return type for a paired value/index selection call.

func MaxByDim

func MaxByDim(tensor Tensor, dim int, keep_dims bool) ValueIndexPair

Reduce a tensor to its maximum value along the given dimension.

func MedianByDim

func MedianByDim(tensor Tensor, dim int, keep_dims bool) ValueIndexPair

Reduce a tensor to its median value along the given dimension.

func MinByDim

func MinByDim(tensor Tensor, dim int, keep_dims bool) ValueIndexPair

Reduce a tensor to its minimum value along the given dimension.

func Sort

func Sort(tensor Tensor, dim int64, descending bool) ValueIndexPair

func TopK

func TopK(tensor Tensor, k, dim int64, largest, sorted bool) ValueIndexPair

Directories

Path Synopsis
cmd
nn
vision

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL