ops

package
v1.1.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jan 30, 2024 License: MPL-2.0 Imports: 6 Imported by: 0

Documentation

Index

Constants

View Source
const (
	ActivationAlphaAttr = "activation_alpha"
	ActivationBetaAttr  = "activation_beta"
	ActivationsAttr     = "activations"
	ClipAttr            = "clip"
	DirectionAttr       = "direction"
	HiddenSizeAttr      = "hidden_size"
)

These constants define attributes that are applicable to GRU, LSTM and RNN operators.

View Source
const InvalidInputCountErrTemplate = "%v: expected %d input tensors, got %d"

InvalidInputCountErrTemplate is used to format an error when an operator got the wrong amount of input tensors.

View Source
const InvalidInputErrTemplate = "invalid input tensor for %v: %v"

InvalidInputErrTemplate is used to format an error when an operator got an invalid input tensor as input.

View Source
const InvalidOptionalInputCountErrTemplate = "%v: expected %d-%d input tensors, got %d"

InvalidOptionalInputCountErrTemplate is used to format an error when an operator got the wrong amount of input tensors when optional inputs are present.

View Source
const UnsupportedInputErrTemplate = "unsupported input for %v: %v"

UnsupportedInputErrTemplate is used to format an error when an operator got the wrong amount of input tensors when optional inputs are present.

Variables

View Source
var (
	ErrCast         = errors.New("cast error")
	ErrInvalidShape = errors.New("invalid shape error")
)

AllTypes is a type constraint which allows all types.

View Source
var ErrActivationNotImplementedBase = errors.New("the given activation function is not implemented")
View Source
var ErrAxisNotInRange = errors.New("axis out of range")
View Source
var ErrConversion = errors.New("unable to convert")
View Source
var ErrUnsupportedOperator = errors.New("unsupported operator")
View Source
var ErrUnsupportedOpsetVersion = errors.New("unsupported opset version")

Functions

func Abs

func Abs(x int) int

Abs returns the absolute value of an int.

func Add

func Add(A, B tensor.Tensor) (tensor.Tensor, error)

Add adds 2 tensors to each other.

func AllInRange

func AllInRange(arr []int, min, max int) bool

AllInRange checks if all the entries in `arr` are in the inclusive range min <= x <= max.

func And

func And(A, B tensor.Tensor) (tensor.Tensor, error)

And applies the boolean 'and' operation on 2 tensors.

func AnyToIntSlice

func AnyToIntSlice(value interface{}) ([]int, error)

AnyToIntSlice casts the data of a node to an int list. This will only be done if the data is of some sort of int type.

func ApplyBinaryOperation

func ApplyBinaryOperation(A, B tensor.Tensor, op BinaryOp, broadcastOption BroadcastType) ([]tensor.Tensor, error)

ApplyBinaryOperation applies a binary operation (an operation of arity 2) to 2 tensors. It returns a list of tensors with only 1 output tensor in order for this function to be easily used in operators.

func Arange

func Arange(size int, step float32) []float32

Arange fills a slice with float32 ranging from 0, to size with step, step.

func ConvertTensorDtype

func ConvertTensorDtype(t tensor.Tensor, newType int32) (tensor.Tensor, error)

ConvertTensorDtype converts an interface of a specific dtype to a new dtype.

func Div

func Div(A, B tensor.Tensor) (tensor.Tensor, error)

Div divides 1 tensor by the other.

func EmptyNodeProto

func EmptyNodeProto() *onnx.NodeProto

EmptyNodeProto returns a node proto with no attributes.

func Equal

func Equal(A, B tensor.Tensor) (tensor.Tensor, error)

Equal applies the equal operator (=) operator on 2 tensors.

func ErrActivationNotImplemented

func ErrActivationNotImplemented(activation string) error

func ErrAxisOutOfRange

func ErrAxisOutOfRange(min, max, actual int) error

func ErrConversionInvalidType

func ErrConversionInvalidType(dType tensor.Dtype, newType int32) error

func ErrConversionNotSupported

func ErrConversionNotSupported(dType int32) error

func ErrDimension

func ErrDimension(reason string) error

func ErrIncompatibleDimensions

func ErrIncompatibleDimensions() error

func ErrInvalidAttributeCount

func ErrInvalidAttributeCount(expected, actual int, operator Operator) error

func ErrInvalidInput

func ErrInvalidInput(reason string, operator Operator) error

func ErrInvalidInputCount

func ErrInvalidInputCount(actual int, operator Operator) error

func ErrInvalidInputType

func ErrInvalidInputType(inputNumber int, dType string, operator Operator) error

func ErrInvalidOptionalInputCount

func ErrInvalidOptionalInputCount(actual int, operator Operator) error

func ErrInvalidTensor

func ErrInvalidTensor(reason string, operator Operator) error

func ErrMultidirBroadcast

func ErrMultidirBroadcast(shapeA, shapeB tensor.Shape, err error) error

func ErrNotAllAxesInRange

func ErrNotAllAxesInRange(min, max int) error

func ErrTypeAssert

func ErrTypeAssert(expected string, actual any) error

func ErrUnidirBroadcast

func ErrUnidirBroadcast(shapeA, shapeB tensor.Shape) error

func ErrUnknownOperatorType

func ErrUnknownOperatorType(operatorType string) error

func ErrUnsupportedAttribute

func ErrUnsupportedAttribute(attributeName string, operator Operator) error

func ErrUnsupportedInput

func ErrUnsupportedInput(inputName string, operator Operator) error

func ExtractMatrices

func ExtractMatrices(M tensor.Tensor, nMatrices, nDimensions, hiddenSize int) ([]tensor.Tensor, error)

ExtractMatrices extracts a given number of matrices from tensor M. M contains concatenated matrices along a certain dimension. M is assumed to have a shape of (num_directions, nMatrices * hidden_size, ...) and we extract the by slicing over the 'nMatrices * hidden_size' dimension. This method is specific for recurrent operators RNN, GRU and LSTM.

func Float32TensorFixture

func Float32TensorFixture(shp ...int) tensor.Tensor

Float32TensorFixture returns a float32 backed gorgonia node. It initializes all its values using tensor.Range.

func Full

func Full(size int, value float32) []float32

Full fills a slice with value, named after np.full.

func GetValueAsTensorType

func GetValueAsTensorType(value float64, dtype tensor.Dtype) (interface{}, error)

GetValueAsTensorType returns the given value as the given tensor type.

func Gt

func Gt(A, B tensor.Tensor) (tensor.Tensor, error)

Gt applies the greater than (>) operator on 2 tensors.

func Gte

func Gte(A, B tensor.Tensor) (tensor.Tensor, error)

Gte applies the greater or equal than (>=) operator on 2 tensors.

func HasDuplicates

func HasDuplicates(arr []int) bool

HasDuplicates checks if there are duplicates in the sorted array `arr`.

func IfScalarToSlice

func IfScalarToSlice(value any) any

IfScalarToSlice will wrap the value in a slice if it is a scalar in a slice with that value, otherwise will return itself.

func Int64ToBool

func Int64ToBool(v int64) bool

Int64ToBool converts a int64 to a boolean.

func Lt

func Lt(A, B tensor.Tensor) (tensor.Tensor, error)

Lt applies the less than (<) operator on 2 tensors.

func Lte

func Lte(A, B tensor.Tensor) (tensor.Tensor, error)

Lte applies the less or equal than (<=) operator on 2 tensors.

func Mul

func Mul(A, B tensor.Tensor) (tensor.Tensor, error)

Mul multiplies 2 tensors with each other.

func MultidirectionalBroadcast

func MultidirectionalBroadcast(A, B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error)

MultidirectionalBroadcast broadcasts two tensors for a binary operator according to the ONNX standards.

func NElements

func NElements(shp ...int) int

NElements calculates the amount of elements in a tensor based on its shape.

func NewSlicer

func NewSlicer(start int, options ...int) tensor.Slice

NewSlicer creates a new Slicer object. By default, end will be set to start + 1 and step will be set to 1. If options are given, it is assumed that the first element will be the value for the end index and the second element the value for the step of the Slicer.

func OffsetArrayIfNegative

func OffsetArrayIfNegative(arr []int, offset int)

OffsetArrayIfNegative adds `offset` to negative elements in the array `arr`. `arr` is modified in place.

func OffsetTensorIfNegative

func OffsetTensorIfNegative(t tensor.Tensor, offset int) error

OffsetTensorIfNegative adds an offset to every negative element in tensor t. Works only for tensors with Dtype int (same as offset).

func Ones

func Ones(size int) []float32

Ones fills a slice with float32 ones.

func OnesTensor

func OnesTensor(t tensor.Tensor) tensor.Tensor

OnesTensor returns a new tensor with the same shape as the given tensor intialized with all ones.

func Or

func Or(A, B tensor.Tensor) (tensor.Tensor, error)

Or applies the boolean 'or' operation on 2 tensors.

func PairwiseAssign

func PairwiseAssign(t1, t2 tensor.Tensor) (err error)

PairwiseAssign essentially does pairwise t1 = t2 in place!.

func RandomFloat32TensorFixture

func RandomFloat32TensorFixture(shp ...int) tensor.Tensor

func ReLU

func ReLU(X tensor.Tensor) (tensor.Tensor, error)

ReLU performs the ReLU operation on a tensor.

func ReshapeTensorsForMultidirBroadcast

func ReshapeTensorsForMultidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error)

ReshapeTensorsForMultidirBroadcast reshapes the 2 tensors such that they have the same number of dimensions. This means that when the number of dimensions do not correspond, the shape of the tensor with the smaller number of dimensions gets padded with 1's such that it matches the number of dimensions of the other tensor. One of the tensors (or both) will always remain the same. Example: shapeA=(3, 4) and shapeB=(1, 3, 5, 6) yields shapeNewA=(1, 1, 3, 4).

func Sigmoid

func Sigmoid(X tensor.Tensor) (tensor.Tensor, error)

Sigmoid performs the sigmoid operation on a tensor.

func Sub

func Sub(A, B tensor.Tensor) (tensor.Tensor, error)

Sub subtracts 1 tensor from the other.

func Tanh

func Tanh(X tensor.Tensor) (tensor.Tensor, error)

Tanh performs the tanh operation on a tensor.

func TensorInputsFixture

func TensorInputsFixture(nTensors int) []tensor.Tensor

TensorInputsFixture returns a list with a given number of tensors.

func TensorWithBackingFixture

func TensorWithBackingFixture(b interface{}, shp ...int) tensor.Tensor

TensorWithBackingFixture returns a gorgonia node with a tensor using the given backing.

func UnidirectionalBroadcast

func UnidirectionalBroadcast(A, B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error)

UnidirectionalBroadcast tries to broadcast tensor B to tensor A according to the ONNX standards.

func ValidateInputs

func ValidateInputs(op Operator, inputs []tensor.Tensor) ([]tensor.Tensor, error)

ValidateInputs validates if a list of nodes has enough (not too few or too many) nodes. When there are fewer input nodes then the given max, the list is padded with nils. Expects either 1 requirement ==> the expected number of inputs, or 2 requirements, the minimum and the maximum number of inputs.

func Xor

func Xor(A, B tensor.Tensor) (tensor.Tensor, error)

Xor applies the boolean 'xor' operation on 2 tensors.

func ZeroTensor

func ZeroTensor(shape ...int) tensor.Tensor

ZeroTensor returns a tensor filled with zeros with the given shape.

func Zeros

func Zeros(size int) []float32

Zeros fills a float32 slice with 0's.

Types

type Activation

type Activation func(n tensor.Tensor) (tensor.Tensor, error)

Activation is an activation function.

func GetActivation

func GetActivation(activation string) (Activation, error)

type AttributeError

type AttributeError struct {
	// contains filtered or unexported fields
}

func ErrInvalidAttribute

func ErrInvalidAttribute(attributeName string, operator Operator) *AttributeError

func (*AttributeError) Error

func (t *AttributeError) Error() string

type AttributeErrorKind

type AttributeErrorKind string
const (
	AttributeErrorCount       AttributeErrorKind = "count"
	AttributeErrorInvalid     AttributeErrorKind = "invalid"
	AttributeErrorUnsupported AttributeErrorKind = "unsupported"
)

type BinaryOp

type BinaryOp func(A, B tensor.Tensor) (tensor.Tensor, error)

BinaryOp describes a general operation between 2 tensors with 1 tensor as result.

type BooleanOp

type BooleanOp func(a, b bool) bool

BooleanOp describes a binary operation between two booleans that also returns a boolean.

type BroadcastError

type BroadcastError struct {
	// contains filtered or unexported fields
}

func (*BroadcastError) Error

func (b *BroadcastError) Error() string

type BroadcastType

type BroadcastType int
const (
	NoBroadcasting               BroadcastType = 0
	UnidirectionalBroadcasting   BroadcastType = 1
	MultidirectionalBroadcasting BroadcastType = 2
)

type DimensionError

type DimensionError struct {
	// contains filtered or unexported fields
}

func (*DimensionError) Error

func (d *DimensionError) Error() string

type DimensionErrorKind

type DimensionErrorKind string
const (
	DimensionErrorIncompatible DimensionErrorKind = "incompatible"
)

type FloatType

type FloatType interface {
	float32 | float64
}

FloatType is a type that describes a float value. Can be either float32 or float64.

type InputError

type InputError struct {
	// contains filtered or unexported fields
}

func (*InputError) Error

func (i *InputError) Error() string

type InputErrorKind

type InputErrorKind string
const (
	InputErrorType        InputErrorKind = "type"
	InputErrorCount       InputErrorKind = "count"
	InputErrorUnsupported InputErrorKind = "unsupported"
	InputErrorInvalid     InputErrorKind = "invalid"
)

type InputFixture

type InputFixture func() []tensor.Tensor

InputFixture is a function that generates inputs for ops. Useful in testing.

type InvalidTensorError

type InvalidTensorError struct {
	// contains filtered or unexported fields
}

func (*InvalidTensorError) Error

func (i *InvalidTensorError) Error() string

type Number

type Number interface {
	float32 | float64 | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64
}

Number is a type which represents a number.

type Operator

type Operator interface {
	// String should return a simple string describing the operator
	String() string

	// Init should initialize the operator based on the given node.
	// This node contains attributes, which outputs are expected and more. How these
	// attributes influence the operator is defined by the ONNX standard, and can be
	// found in https://github.com/onnx/onnx/blob/main/docs/Operators.md
	Init(*onnx.NodeProto) error

	// Apply should apply the operator to the list of input tensors. It should return a
	// list with output tensors, the result of the operator.
	Apply([]tensor.Tensor) ([]tensor.Tensor, error)

	// GetMinInputs should return the minimum number of inputs this operator expects.
	GetMinInputs() int

	// GetMaxInputs should return the maximum number of inputs this operator expects.
	GetMaxInputs() int

	// GetInputTypeConstraints should return a list. Every element represents a set of
	// allowed tensor dtypes for the corresponding input tensor.
	GetInputTypeConstraints() [][]tensor.Dtype

	// ValidateInputs should validate the list of input tensors. It should check for both
	// the right amount of inputs and the correct dtypes of the tensors.
	ValidateInputs([]tensor.Tensor) ([]tensor.Tensor, error)
}

Operator is the base interface for all operators.

type SequenceProcessDirection

type SequenceProcessDirection string

SequenceProcessDirection is the direction in which a sequential input is processed. We can process sequential inputs forward (from first to last), in reverse (from last to first) or bidirectional (which is both forward and reverse added together).

const (
	Forward       SequenceProcessDirection = "forward"
	Reverse       SequenceProcessDirection = "reverse"
	Bidirectional SequenceProcessDirection = "bidirectional"
)

type Slicer

type Slicer struct {
	// contains filtered or unexported fields
}

Slicer implements the tensor.Slice interface. It is able to slice the dimension of a tensor.

func (*Slicer) End

func (s *Slicer) End() int

End returns the start of the slice.

func (*Slicer) Start

func (s *Slicer) Start() int

Start returns the start of the slice.

func (*Slicer) Step

func (s *Slicer) Step() int

Step returns the step of the slice.

type TypeAssertError

type TypeAssertError struct {
	// contains filtered or unexported fields
}

func (*TypeAssertError) Error

func (t *TypeAssertError) Error() string

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL