ml

package
v0.0.0-...-7b461c4 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 12, 2024 License: Apache-2.0 Imports: 10 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

View Source
var (
	DT_BF16    = newDataType("BF16", dtype.BFloat16(0), DataTypeFuncSet_BF16{})
	DT_F32     = newDataType("Float32", float32(0), DataTypeFuncSet_F32{})
	DT_UINT16  = newDataType("UInt16", uint16(0), DataTypeFuncSet_UINT16{})
	DT_INT32   = newDataType("Int32", int32(0), DataTypeFuncSet_INT32{})
	DT_COMPLEX = newDataType("Complex", complex64(complex(0.0, 0.0)), DataTypeFuncSet_COMPLEX{})
)
View Source
var (
	TABLE_SILU [1 << 16]float32
)

Functions

func CheckBroadcastable

func CheckBroadcastable(t1 *Tensor, t2 *Tensor, isCommutative bool) (refTensor *Tensor, expandingTensor *Tensor, err error)

func CheckBroadcastableOnce

func CheckBroadcastableOnce(size1 []int, size2 []int) bool

func CompareTestTensor

func CompareTestTensor(expected interface{}, expectedSize []int, actual *Tensor, floatThreshold float64, shorten bool) error

func CompareTestTensorDimension

func CompareTestTensorDimension(expected interface{}, actual *Tensor, currentDimension int, loc []int, floatThreshold float64, shorten bool) error

func CompareTestTensorSkippable

func CompareTestTensorSkippable(skip bool, expected interface{}, expectedSize []int, actual *Tensor, floatThreshold float64, shorten bool) error

Types

type DataType

type DataType struct {
	Name     string
	GoType   reflect.Type
	ItemSize int
	FuncSet  DataTypeFuncSet
}

func (DataType) String

func (dt DataType) String() string

type DataTypeFuncSet

type DataTypeFuncSet interface {
	IsCompatible(val any) bool
	FromFloat32(val float32) any
	ToFloat32(val any) float32

	ToString(val any) string

	ReadItem(rawDataPtr unsafe.Pointer) any
	WriteItem(rawDataPtr unsafe.Pointer, val any) error

	ReadItem_AsFloat32(rawDataPtr unsafe.Pointer) float32
	WriteItem_FromFloat32(rawDataPtr unsafe.Pointer, val float32)
}

type DataTypeFuncSet_BF16

type DataTypeFuncSet_BF16 struct{}

func (DataTypeFuncSet_BF16) FromFloat32

func (dtfs DataTypeFuncSet_BF16) FromFloat32(val float32) any

func (DataTypeFuncSet_BF16) IsCompatible

func (dtfs DataTypeFuncSet_BF16) IsCompatible(val any) bool

func (DataTypeFuncSet_BF16) ReadItem

func (dtfs DataTypeFuncSet_BF16) ReadItem(rawDataPtr unsafe.Pointer) any

func (DataTypeFuncSet_BF16) ReadItem_AsFloat32

func (dtfs DataTypeFuncSet_BF16) ReadItem_AsFloat32(rawDataPtr unsafe.Pointer) float32

func (DataTypeFuncSet_BF16) ToFloat32

func (dtfs DataTypeFuncSet_BF16) ToFloat32(val any) float32

func (DataTypeFuncSet_BF16) ToString

func (dtfs DataTypeFuncSet_BF16) ToString(val any) string

func (DataTypeFuncSet_BF16) WriteItem

func (dtfs DataTypeFuncSet_BF16) WriteItem(rawDataPtr unsafe.Pointer, val any) error

func (DataTypeFuncSet_BF16) WriteItem_FromFloat32

func (dtfs DataTypeFuncSet_BF16) WriteItem_FromFloat32(rawDataPtr unsafe.Pointer, val float32)

type DataTypeFuncSet_COMPLEX

type DataTypeFuncSet_COMPLEX struct{}

func (DataTypeFuncSet_COMPLEX) FromFloat32

func (dtfs DataTypeFuncSet_COMPLEX) FromFloat32(val float32) any

func (DataTypeFuncSet_COMPLEX) IsCompatible

func (dtfs DataTypeFuncSet_COMPLEX) IsCompatible(val any) bool

func (DataTypeFuncSet_COMPLEX) ReadItem

func (dtfs DataTypeFuncSet_COMPLEX) ReadItem(rawDataPtr unsafe.Pointer) any

func (DataTypeFuncSet_COMPLEX) ReadItem_AsFloat32

func (dtfs DataTypeFuncSet_COMPLEX) ReadItem_AsFloat32(rawDataPtr unsafe.Pointer) float32

func (DataTypeFuncSet_COMPLEX) ToFloat32

func (dtfs DataTypeFuncSet_COMPLEX) ToFloat32(val any) float32

func (DataTypeFuncSet_COMPLEX) ToString

func (dtfs DataTypeFuncSet_COMPLEX) ToString(val any) string

func (DataTypeFuncSet_COMPLEX) WriteItem

func (dtfs DataTypeFuncSet_COMPLEX) WriteItem(rawDataPtr unsafe.Pointer, val any) error

func (DataTypeFuncSet_COMPLEX) WriteItem_FromFloat32

func (dtfs DataTypeFuncSet_COMPLEX) WriteItem_FromFloat32(rawDataPtr unsafe.Pointer, val float32)

type DataTypeFuncSet_F32

type DataTypeFuncSet_F32 struct{}

func (DataTypeFuncSet_F32) FromFloat32

func (dtfs DataTypeFuncSet_F32) FromFloat32(val float32) any

func (DataTypeFuncSet_F32) IsCompatible

func (dtfs DataTypeFuncSet_F32) IsCompatible(val any) bool

func (DataTypeFuncSet_F32) ReadItem

func (dtfs DataTypeFuncSet_F32) ReadItem(rawDataPtr unsafe.Pointer) any

func (DataTypeFuncSet_F32) ReadItem_AsFloat32

func (dtfs DataTypeFuncSet_F32) ReadItem_AsFloat32(rawDataPtr unsafe.Pointer) float32

func (DataTypeFuncSet_F32) ToFloat32

func (dtfs DataTypeFuncSet_F32) ToFloat32(val any) float32

func (DataTypeFuncSet_F32) ToString

func (dtfs DataTypeFuncSet_F32) ToString(val any) string

func (DataTypeFuncSet_F32) WriteItem

func (dtfs DataTypeFuncSet_F32) WriteItem(rawDataPtr unsafe.Pointer, val any) error

func (DataTypeFuncSet_F32) WriteItem_FromFloat32

func (dtfs DataTypeFuncSet_F32) WriteItem_FromFloat32(rawDataPtr unsafe.Pointer, val float32)

type DataTypeFuncSet_INT32

type DataTypeFuncSet_INT32 struct{}

func (DataTypeFuncSet_INT32) FromFloat32

func (dtfs DataTypeFuncSet_INT32) FromFloat32(val float32) any

func (DataTypeFuncSet_INT32) IsCompatible

func (dtfs DataTypeFuncSet_INT32) IsCompatible(val any) bool

func (DataTypeFuncSet_INT32) ReadItem

func (dtfs DataTypeFuncSet_INT32) ReadItem(rawDataPtr unsafe.Pointer) any

func (DataTypeFuncSet_INT32) ReadItem_AsFloat32

func (dtfs DataTypeFuncSet_INT32) ReadItem_AsFloat32(rawDataPtr unsafe.Pointer) float32

func (DataTypeFuncSet_INT32) ToFloat32

func (dtfs DataTypeFuncSet_INT32) ToFloat32(val any) float32

func (DataTypeFuncSet_INT32) ToString

func (dtfs DataTypeFuncSet_INT32) ToString(val any) string

func (DataTypeFuncSet_INT32) WriteItem

func (dtfs DataTypeFuncSet_INT32) WriteItem(rawDataPtr unsafe.Pointer, val any) error

func (DataTypeFuncSet_INT32) WriteItem_FromFloat32

func (dtfs DataTypeFuncSet_INT32) WriteItem_FromFloat32(rawDataPtr unsafe.Pointer, val float32)

type DataTypeFuncSet_UINT16

type DataTypeFuncSet_UINT16 struct{}

func (DataTypeFuncSet_UINT16) FromFloat32

func (dtfs DataTypeFuncSet_UINT16) FromFloat32(val float32) any

func (DataTypeFuncSet_UINT16) IsCompatible

func (dtfs DataTypeFuncSet_UINT16) IsCompatible(val any) bool

func (DataTypeFuncSet_UINT16) ReadItem

func (dtfs DataTypeFuncSet_UINT16) ReadItem(rawDataPtr unsafe.Pointer) any

func (DataTypeFuncSet_UINT16) ReadItem_AsFloat32

func (dtfs DataTypeFuncSet_UINT16) ReadItem_AsFloat32(rawDataPtr unsafe.Pointer) float32

func (DataTypeFuncSet_UINT16) ToFloat32

func (dtfs DataTypeFuncSet_UINT16) ToFloat32(val any) float32

func (DataTypeFuncSet_UINT16) ToString

func (dtfs DataTypeFuncSet_UINT16) ToString(val any) string

func (DataTypeFuncSet_UINT16) WriteItem

func (dtfs DataTypeFuncSet_UINT16) WriteItem(rawDataPtr unsafe.Pointer, val any) error

func (DataTypeFuncSet_UINT16) WriteItem_FromFloat32

func (dtfs DataTypeFuncSet_UINT16) WriteItem_FromFloat32(rawDataPtr unsafe.Pointer, val float32)

type DstRow

type DstRow struct {
	// contains filtered or unexported fields
}

type DstVal

type DstVal struct {
	// contains filtered or unexported fields
}

type OneTensorIterator

type OneTensorIterator struct {
	// contains filtered or unexported fields
}

func IterateOver

func IterateOver(tensor *Tensor, ignoreTrailingDimensions int) *OneTensorIterator

func IterateOverSize

func IterateOverSize(size []int, ignoreTrailingDimensions int) *OneTensorIterator

func (*OneTensorIterator) HasNext

func (it *OneTensorIterator) HasNext() bool

func (*OneTensorIterator) Next

func (it *OneTensorIterator) Next() (loc []int)

type Tensor

type Tensor struct {
	Name     string
	Size     []int
	Stride   []int
	DataType DataType
	RawData  []byte

	ByteStride []int
}

func ARange

func ARange(start int, end int, step int, dataType DataType) (*Tensor, error)

func Add

func Add(input *Tensor, other *Tensor) (*Tensor, error)

func AddScalar

func AddScalar(input *Tensor, scalar any) (*Tensor, error)

func Argmax

func Argmax(input *Tensor, dim int) (*Tensor, error)

func DivToScalar

func DivToScalar(input *Tensor, scalar any) (*Tensor, error)

func DuplicateTensor

func DuplicateTensor(input *Tensor) *Tensor

func Full

func Full(size []int, dataType DataType, fillValue any) (*Tensor, error)

func Fwd_Get_Rows

func Fwd_Get_Rows(embedding *Tensor, tokens *Tensor) (*Tensor, error)

func LinearTransformation

func LinearTransformation(input *Tensor, weights *Tensor) (*Tensor, error)

func MatMul

func MatMul(input *Tensor, other *Tensor) (*Tensor, error)

func Mean

func Mean(input *Tensor, dim int, keepdim bool) (*Tensor, error)

func MultiplyElementwise

func MultiplyElementwise(input *Tensor, other *Tensor) (*Tensor, error)

func NewEmptyTensor

func NewEmptyTensor(size []int, dataType DataType) *Tensor

func NewEmptyTensorEx

func NewEmptyTensorEx(name string, size []int, dataType DataType, allocateRawData bool) *Tensor

func NewEmptyTensorLike

func NewEmptyTensorLike(input *Tensor, allocateRawData bool) *Tensor

func NewTensor

func NewTensor(name string, size []int, stride []int, dataType DataType, RawData []byte) *Tensor

func Ones

func Ones(size []int, dataType DataType) (*Tensor, error)

func OnesLike

func OnesLike(input *Tensor) (*Tensor, error)

func Outer

func Outer(vec1 *Tensor, vec2 *Tensor) (*Tensor, error)

func Polar

func Polar(abs *Tensor, angle *Tensor) (*Tensor, error)

func Pow

func Pow(input *Tensor, power float64) (*Tensor, error)

func RSqrt

func RSqrt(input *Tensor) (*Tensor, error)

func Silu

func Silu(input *Tensor) (*Tensor, error)

func Softmax

func Softmax(input *Tensor, dim int) (*Tensor, error)

func TriangularUpper

func TriangularUpper(input *Tensor, diagonal int) (*Tensor, error)

func Zeros

func Zeros(size []int, dataType DataType) (*Tensor, error)

func ZerosLike

func ZerosLike(input *Tensor) (*Tensor, error)

func (*Tensor) Apply

func (t *Tensor) Apply(fn func(val any) any) error

func (*Tensor) Apply_AsFloat32

func (t *Tensor) Apply_AsFloat32(fn func(val float32) float32) error

func (*Tensor) GetBytesCount

func (t *Tensor) GetBytesCount() int

func (*Tensor) GetElementCount

func (t *Tensor) GetElementCount() int

func (*Tensor) GetItem

func (t *Tensor) GetItem(loc []int) (any, error)

func (*Tensor) GetItemByOffset

func (t *Tensor) GetItemByOffset(offset int) any

func (*Tensor) GetItemByOffset_AsFloat32

func (t *Tensor) GetItemByOffset_AsFloat32(offset int) (float32, error)

func (*Tensor) GetItem_AsFloat32

func (t *Tensor) GetItem_AsFloat32(loc []int) (float32, error)

func (*Tensor) IsMatrix

func (t *Tensor) IsMatrix() bool

func (*Tensor) IsVector

func (t *Tensor) IsVector() bool

func (*Tensor) Item

func (t *Tensor) Item() any

func (*Tensor) Reshape

func (t *Tensor) Reshape(newSize []int) (*Tensor, error)

func (*Tensor) SetItem

func (t *Tensor) SetItem(loc []int, val any) error

func (*Tensor) SetItemByOffset

func (t *Tensor) SetItemByOffset(offset int, val any) error

func (*Tensor) SetItemByOffset_FromFloat32

func (t *Tensor) SetItemByOffset_FromFloat32(offset int, val float32) error

func (*Tensor) SetItem_FromFloat32

func (t *Tensor) SetItem_FromFloat32(loc []int, val float32) error

func (*Tensor) SetSlice

func (t *Tensor) SetSlice(locStart []int, locEnd []int, val *Tensor) error

func (*Tensor) Slice

func (t *Tensor) Slice(locStart []int, locEnd []int) (*Tensor, error)

func (*Tensor) String

func (t *Tensor) String() string

func (*Tensor) StringLong

func (t *Tensor) StringLong() string

func (*Tensor) ToBFloat16

func (t *Tensor) ToBFloat16() (*Tensor, error)

func (*Tensor) ToFloat32

func (t *Tensor) ToFloat32() (*Tensor, error)

func (*Tensor) Transpose

func (t *Tensor) Transpose(dim1 int, dim2 int) (*Tensor, error)

func (*Tensor) ViewAsComplex64

func (t *Tensor) ViewAsComplex64() (*Tensor, error)

func (*Tensor) ViewAsComplex64WithReshape

func (t *Tensor) ViewAsComplex64WithReshape() (*Tensor, error)

func (*Tensor) ViewAsFloat32

func (t *Tensor) ViewAsFloat32() (*Tensor, error)

func (*Tensor) ViewAsFloat32WithReshape

func (t *Tensor) ViewAsFloat32WithReshape() (*Tensor, error)

type TwoTensorIterator

type TwoTensorIterator struct {
	// contains filtered or unexported fields
}

func IterateOverTwo

func IterateOverTwo(refTensor *Tensor, expandingTensor *Tensor, ignoreTrailingDimensions int) *TwoTensorIterator

func IterateOverTwoSize

func IterateOverTwoSize(refSize []int, expandingSize []int, ignoreTrailingDimensions int) *TwoTensorIterator

func (*TwoTensorIterator) HasNext

func (it *TwoTensorIterator) HasNext() bool

func (*TwoTensorIterator) Next

func (it *TwoTensorIterator) Next() (loc1 []int, loc2 []int)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL