shapes

package module
v0.0.0-...-8faea76 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Feb 26, 2024 License: MIT Imports: 11 Imported by: 0

README

About

Package shapes provides the algebra and machinery for dealing with the metainformation of shapes of a tensor.

Why a shape package?

The shape package defines a syntax and semantics that describes shapes of values.

The goal is:

  • to describe the notion of shape;
  • to implement various automation to work with shapes (example a a tool that verify if shapes of different inputs are compatible to be used with one operator)

What is a Shape?

A shape is a piece of meta-information that tells you something about a tensor.

A rank-2 tensor is also commonly known as a matrix (mathematicians and physicists, please bear with the inaccuracies and informalities).

Example

Let's consider the following matrix:

⎡1  2  3⎤
⎣4  5  6⎦

We can describe this matrix by saying "it's a matrix with 2 rows and 3 columns". We can write this as (2, 3).

(2, 3) is the shape of the matrix.

This is a very convenient way of describing N-dimensiional tensors. We don't have to give the dimensions names like "layer", "row" or "column". We would rapidly run out of names! Instead, we just index them by their number. So in a 3 dimensional shape, instead of saying "layer", we say "dimension 0". In a 2 dimensional shape, "row" would be "dimension 0".

Components of a Shape

Given a shape, let's explore the components of the shape:

 +--+--+--------- size
 v  v  v
(2, 3, 4) <------ shape
 ^̄  ^̄  ^̄
 +--+--+--------- dimension/axis

Each "slot" in a shape is called a dimension/axis. The number of dimensions in a shape is called a rank (though confusingly enough, in the Gorgonia family of libraries, it's called .Dims()). There are 3 slots in the example, so it's a 3 dimensional shape. Each number in a slot is called the size of the dimension. When refering to them by their number, the preferred term is to use "axis". So, axis 1 has a size of 3, therefore, the first dimension is of size 3.

To use the traditional named dimensions - recall in this instance, that dimension 1 is "rows" - We say there are 3 rows.

Shape Expr: Syntax

The primary data structure that the package provides is the shape expression. A shape expression is given by the following BNF

<shape>    ::= <unit>| "("<integer>",)" | "("<integer>","<shape>")" |
               "("<shape>","<integer>")" | "("<variable>",)" |
               <binaryOperationExpression> | <nameOfT>

<integer>  ::= <digit> [ <integer> ]
<digit>    ::= 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
<unit>     ::= "()"
<variable>  ::= ...
<binaryOperationExpression>  ::= "("<variable> <binop> <integer>",)" |
                       "("<variable> <binop> <variable>",) |
                       "("<integer> <binop> <variable>",)"
<binop>   ::= + | *

<expression> ::= <shape> | I n <expression> | D <expression |
                 K <expression> | Σ <expression> | Π <expression> |
                 <variable> | <expression → <expression> |
                 (<expression> → <expression>) @ <expression> // TODO
T ::= (I n E,) | (D E,) | (Σ E,) | (Π E,) // TODO
A compact BNF

The original design for the language of shapes was written in a compact BNF. This was expanded by Olivier Wulveryck into something that traditional computer scientists are more familiar with.

Here, the original compact BNF is preserved.

E ::= S | I n E | D E | K E | Σ E | Π E | a | E -> E | (E -> E) @ E
S ::= () | (n,) | (n, S) | (S, n) | (a,) | B | T
T ::= (I n E,) | (D E,) | (Σ E,) | (Π E,)
B ::=  (a O n,) | (a O b,) | (n O a,)
O ::= + | ×
n,m ::= ℕ
a,b ::= variables

The BNF might be brief, but it is dense, so let's break it down.

Primitive Shape Construction

<shape>    ::= <unit>| "("<integer>",)" | "("<integer>","<shape>")" |
               "("<shape>","<integer>")" | "("<variable>",)" |
               <binaryOperationExpr> | <nameOfT>

What this snippet says is: a shape (<shape>) is one of: the four possible definitions of a shape:

  1. <unit> - This is the shape of scalar values. () is pronounced "unit".
  2. "("<integer>",)" - <integer> is any positive number. Example of a shape: (10,).
  3. "("<integer>","<shape>")" - an <interer> folowwed by another <shape>. Example: (8, (10,)).
  4. "("<shape>","<integer>")" - a <shape> followed by another <integer>. Example: ((8, (10,)), 20).

From the snippet, we have generated 4 examples of shapes: (), (10,), (8, (10,)) and ((8, (10,)), 20). This isn't how we normally write shapes. We would normally write them as (10,), (8, 10) and (8, 10, 20). So what's the difference?

In fact, the following are equivalent;

  • (8, (10,)) = (8, 10)
  • ((8, (10,)), 20) = ((8, 10), 20) = (8, (10, 20)) = (8, 10, 20)

What this says is that the primitive shape construction is associative. This is useful as we can now omit the additional parentheses.

On <unit>

The unit () is particularly interesting. In a while we'll see why it's called a "unit".

For now, given the associativity rules, what would (10, ()) be equivalent to? One possible answer is (10, ) - afterall, we're just removing the internal parentheses. Another possible answer is (10, 1).

TODO write more.

Introducing Variables

Now, let us focus, once again on line 2 of the BNF, but on the latter parts:

<shape> ::= ... | (<variable>,) | ...

What this says is that you can create a <shape> using a variable. I'll use x and y for real variables.

Example: (x,) is a shape. Combining this rule with the rules from above, we can see that (x, y) is also a valid shape. So is (10, x) or (x, 10)

To recap using the examples we have seen so far, these are exammples of valid shapes:

  • (x,)
  • (x, 10)
  • (10, x)
  • (10, 20, x)
  • (10, x, 20)

Introducing Binary Operations

In the following snippet (still on line 2 of the BNF), this is introduced:

<shape> ::= ... | <binaryOperationExpression> | ...

And <binaryOperationExpression> is defined in the following line as:

<binaryOperationExpr>  ::= "("<variable> <binop> <integer>",)" |
                       "("<variable> <binop> <variable>",) |
                       "("<integer> <binop> <variable>",)"

What this says is any valid <binaryOperationExpr> is also a valid <shape>.

<binop> is defined as:

<binop>   ::= + | *

That is, a valid mathematical operation is a addition or a multiplication.

So what line 3 of the BNF says is that these are valid shapes:

  • (x + 10,)
  • (x + y,)
  • (10 + x,)

An astute reader will observe that "("<integer> <binop> <integer>")" (example: (10 × 20,)) isn't allowed. The reasoning is clear - if you know ahead of time what the resulting shape will be, don't put a mathematical expression in. Note that though there is a restriction in the fomal language, this is not enforced by the package machinery.

Recap 1

To recap, these are all examples valid shapes:

  • ()
  • (10,)
  • (10, 20)
  • (x,)
  • (x, 10)
  • (10, x)
  • (x, y)
  • (x+20,)
  • (x×y,)
  • (x+20, 10)
  • (10, x×y)

We wil leave the last part of the definition of S (S ::= ... | T) to after we've introduced the notion of expressions.

Shape Expr: Semantics

TODO: write more

Why So Complicated?

Why introduce all these complicated things? We introduce these complicated things because we want to do things with shape expressions.

It is a System of Constraints

Ideally, the shape expression should be enough to tell you what an operation does to the shape of its inputs.

Consider for example, a shape expression that is the following:

(a, b) → (b, c) → (a, c)

What does this expression say? It says the operation takes a matrix with shape (a, b), and then takes another matrix with the shape (b, c), finally, it returns a matrix of shape (a, c).

This simple expression contains a lot of information:

  • The inputs are matrices only.
  • The inner dimension of the first input must be the same as the outer dimension of the second input.
  • The output is a matrix only, not of any other rank.
  • The matching dimensions disappear.

In fact, there is precisely one operation that is described by this expression: Matrix Multiplication.

Here we see one of the functions of the shape expression: it's to provide constraints to the inputs. e.g. the inputs must be matrices; the matching dimensions; etc.

A Second Example

Here's another example. Consider this shape expression, can you guess what it does?

(a, b) → (a, c) → (a, b+c)
Amswer

It's a concatenation on axis 1. A concrete example is given:

t :=
shape: (2, 3)
⎡0  1  2⎤
⎣3  4  5⎦

u :=
shape: (2, 2)
⎡100  200⎤
⎣300  400⎦

Concat(1, t, u) =
shape: (2, 5)
⎡  0    1    2  100  200⎤
⎣  3    4    5  300  400⎦

It is a System of Evolving Constraints

The shape expression system can not only define constraints, it can also evolve them.

Going back to the matrix multiplication example, now let's extend it.

(a, b) → (b, c) → (a, c) @ (2, 3)

When this is run through the interpreter of the expression, the result will be the following shape expression:

(3, c) → (2, c)

The constraints have now evolved.

Recall that the @ symbol is an application of a function to an input. So when given an input matrix of a known size - (2, 3), we can evolve the constraints, so the next input will only require one check instead of two.

Documentation

Overview

package shapes provides an algebra for dealing with shapes.

Example (Add)

This examples shows how to describe a particular operation (addition), and how to use the infer functions to provide inference for the following types.

// Consider the idea of adding two tensors - A and B - together.
// If it's a matrix, then both A and B must have a same shape.
// Thus we can use the following shape expression to describe addition:
//	Add: a → a → a
//
// if A has shape `a`, then B also has to have shape `a`. The result is also shaped `a`.

add := Arrow{
	Var('a'),
	Arrow{
		Var('a'),
		Var('a'),
	},
}
fmt.Printf("Add: %v\n", add)

// pass in the first input
fst := Shape{5, 2, 3, 1, 10}
retExpr, err := InferApp(add, fst)
if err != nil {
	fmt.Printf("Error %v\n", err)
}
fmt.Printf("Applying %v to Add:\n", fst)
fmt.Printf("%v @ %v ↠ %v\n", add, fst, retExpr)

// pass in the second input
snd := Shape{5, 2, 3, 1, 10}
retExpr2, err := InferApp(retExpr, snd)
if err != nil {
	fmt.Printf("Error %v\n", err)
}
fmt.Printf("Applying %v to the result\n", snd)
fmt.Printf("%v @ %v ↠ %v\n", retExpr, snd, retExpr2)

// bad example:
bad2nd := Shape{2, 3}
_, err = InferApp(retExpr, bad2nd)

fmt.Printf("Passing in a bad second input\n")
fmt.Printf("%v @ %v ↠ %v", retExpr, bad2nd, err)
Output:

Add: a → a → a
Applying (5, 2, 3, 1, 10) to Add:
a → a → a @ (5, 2, 3, 1, 10) ↠ (5, 2, 3, 1, 10) → (5, 2, 3, 1, 10)
Applying (5, 2, 3, 1, 10) to the result
(5, 2, 3, 1, 10) → (5, 2, 3, 1, 10) @ (5, 2, 3, 1, 10) ↠ (5, 2, 3, 1, 10)
Passing in a bad second input
(5, 2, 3, 1, 10) → (5, 2, 3, 1, 10) @ (2, 3) ↠ Failed to solve [{(5, 2, 3, 1, 10) → (5, 2, 3, 1, 10) = (2, 3) → a}] | a: Unification Fail. (5, 2, 3, 1, 10) ~ (2, 3) cannot proceed as they do not contain the same amount of sub-expressions. (5, 2, 3, 1, 10) has 5 subexpressions while (2, 3) has 2 subexpressions
Example (Broadcast)
add := Arrow{Var('a'), Arrow{Var('b'), BroadcastOf{Var('a'), Var('b')}}}
expr := Compound{add, SubjectTo{
	Bc,
	UnaryOp{Const, Var('a')},
	UnaryOp{Const, Var('b')},
}}

fst := Shape{2, 3, 4}
snd := Shape{2, 1, 4}

retExpr1, err := InferApp(expr, fst)
if err != nil {
	fmt.Println(err)
}

retExpr2, err := InferApp(retExpr1, snd)
if err != nil {
	fmt.Println(err)
}

fmt.Printf("Add: %v\n", expr)
fmt.Printf("Applying %v to %v\n", fst, expr)
fmt.Printf("\t%v @ %v → %v\n", expr, fst, retExpr1)

fmt.Printf("Applying %v to %v\n", snd, retExpr1)
fmt.Printf("\t%v @ %v → %v", retExpr1, snd, retExpr2)
Output:

Add: { a → b → (a||b) | (K a ⚟ K b) }
Applying (2, 3, 4) to { a → b → (a||b) | (K a ⚟ K b) }
	{ a → b → (a||b) | (K a ⚟ K b) } @ (2, 3, 4) → { b → ((2, 3, 4)||b) | (K (2, 3, 4) ⚟ K b) }
Applying (2, 1, 4) to { b → ((2, 3, 4)||b) | (K (2, 3, 4) ⚟ K b) }
	{ b → ((2, 3, 4)||b) | (K (2, 3, 4) ⚟ K b) } @ (2, 1, 4) → (2, 3, 4)
Example (ColwiseSumMatrix)

The following shape expressions describe a columnwise summing of a matrix.

// Given A:
// 	1 2 3
//	4 5 6
//
// The columnwise sum is:
// 	5 7 9
//
// The basic description can be explained as such:
//	(r, c) → (1, c)
//
// Here a matrix is given as (r, c). After a columnwise sum, the result is 1 row of c columns.
// However, to keep compatibility with Numpy, a colwise sum would look like this:
// 	(r, c) → (c, )
//
// Lastly a generalized single-axis sum that would work across all tensors would be:
// 	a → b | (D b = D a - 1)
//
// Here, it says that the sum is a function that takes a tensor of any shape called `a`, and returns a tensor with a different shape, called `b`.
// The constraints however is that the dimensions of `b` must be the dimensions of `a` minus 1.

basic := MakeArrow(
	Abstract{Var('r'), Var('c')},
	Abstract{Size(1), Var('c')},
)
fmt.Printf("Basic: %v\n", basic)

compat := MakeArrow(
	Abstract{Var('r'), Var('c')},
	Abstract{Var('c')},
)
fmt.Printf("Numpy Compatible: %v\n", compat)

general := Arrow{Var('a'), ReductOf{Var('a'), Axis(1)}}

fmt.Printf("General: %v\n", general)

fst := Shape{2, 3}
retExpr, err := InferApp(general, fst)
if err != nil {
	fmt.Println(err)
}
fmt.Printf("Applying %v to %v:\n", fst, general)
fmt.Printf("\t%v @ %v → %v", general, fst, retExpr)
Output:

Basic: (r, c) → (1, c)
Numpy Compatible: (r, c) → (c)
General: a → /¹a
Applying (2, 3) to a → /¹a:
	a → /¹a @ (2, 3) → (2)
Example (Im2col)
type im2col struct {
	h, w                 int // kernel height and  width
	padH, padW           int
	strideH, strideW     int
	dilationH, dilationW int
}
op := im2col{
	3, 3,
	1, 1,
	1, 1,
	1, 1,
}

b := Var('b')
c := Var('c')
h := Var('h')
w := Var('w')
in := Abstract{b, c, h, w}

// h2 = (h+2*op.padH-(op.dilationH*(op.h-1)+1))/op.strideH + 1
h2 := BinOp{
	Op: Add,
	A: E2{BinOp{
		Op: Div,
		A: E2{BinOp{
			Op: Sub,
			A: E2{BinOp{
				Op: Add,
				A:  h,
				B: E2{BinOp{
					Op: Mul,
					A:  Size(2),
					B:  Size(op.padH),
				}},
			}},
			B: Size(op.dilationH*(op.h-1) + 1),
		}},
		B: Size(op.strideH),
	}},
	B: Size(1),
}

//  w2 = (w+2*op.padW-(op.dilationW*(op.w-1)+1))/op.strideW + 1
w2 := BinOp{
	Op: Add,
	A: E2{BinOp{
		Op: Div,
		A: E2{BinOp{
			Op: Sub,
			A: E2{BinOp{
				Op: Add,
				A:  w,
				B: E2{BinOp{
					Op: Mul,
					A:  Size(2),
					B:  Size(op.padW),
				}},
			}},
			B: Size(op.dilationW*(op.w-1) + 1),
		}},
		B: Size(op.strideW),
	}},
	B: Size(1),
}

c2 := BinOp{
	Op: Mul,
	A:  c,
	B:  Size(op.w * op.h),
}

out := Abstract{b, h2, w2, c2}

im2colExpr := MakeArrow(in, out)
s := Shape{100, 3, 90, 120}
s2, err := InferApp(im2colExpr, s)
if err != nil {
	fmt.Printf("Unable to infer %v @ %v", im2colExpr, s)
}

fmt.Printf("im2col: %v\n", im2colExpr)
fmt.Printf("Applying %v to %v:\n", s, im2colExpr)
fmt.Printf("\t%v @ %v → %v", im2colExpr, s, s2)
Output:

im2col: (b, c, h, w) → (b, (((h + 2 × 1) - 3) ÷ 1) + 1, (((w + 2 × 1) - 3) ÷ 1) + 1, c × 9)
Applying (100, 3, 90, 120) to (b, c, h, w) → (b, (((h + 2 × 1) - 3) ÷ 1) + 1, (((w + 2 × 1) - 3) ÷ 1) + 1, c × 9):
	(b, c, h, w) → (b, (((h + 2 × 1) - 3) ÷ 1) + 1, (((w + 2 × 1) - 3) ÷ 1) + 1, c × 9) @ (100, 3, 90, 120) → (100, 90, 120, 27)
Example (Index)
sizes := Sizes{0, 0, 1, 0}
simple := Arrow{
	Var('a'),
	Arrow{
		Var('b'),
		Abstract{},
	},
}
fmt.Printf("Unconstrained Indexing: %v\n", simple)

st := SubjectTo{
	And,
	SubjectTo{
		Eq,
		UnaryOp{Dims, Var('a')},
		UnaryOp{Dims, Var('b')},
	},
	SubjectTo{
		Lt,
		UnaryOp{ForAll, Var('b')},
		UnaryOp{ForAll, Var('a')},
	},
}
index := Compound{Expr: simple, SubjectTo: st}
fmt.Printf("Indexing: %v\n", index)

fst := Shape{1, 2, 3, 4}
retExpr, err := InferApp(index, fst)
if err != nil {
	fmt.Printf("Error: %v\n", err)
}
fmt.Printf("Applying %v to %v:\n", fst, index)
fmt.Printf("\t%v @ %v ↠ %v\n", index, fst, retExpr)

snd := sizes
retExpr2, err := InferApp(retExpr, snd)
if err != nil {
	fmt.Printf("Error: %v\n", err)
}
fmt.Printf("Applying %v to %v:\n", snd, retExpr)
fmt.Printf("\t%v @ %v ↠ %v\n", retExpr, snd, retExpr2)
Output:

Unconstrained Indexing: a → b → ()
Indexing: { a → b → () | ((D a = D b) ∧ (∀ b < ∀ a)) }
Applying (1, 2, 3, 4) to { a → b → () | ((D a = D b) ∧ (∀ b < ∀ a)) }:
	{ a → b → () | ((D a = D b) ∧ (∀ b < ∀ a)) } @ (1, 2, 3, 4) ↠ { b → () | ((D (1, 2, 3, 4) = D b) ∧ (∀ b < ∀ (1, 2, 3, 4))) }
Applying Sz[0 0 1 0] to { b → () | ((D (1, 2, 3, 4) = D b) ∧ (∀ b < ∀ (1, 2, 3, 4))) }:
	{ b → () | ((D (1, 2, 3, 4) = D b) ∧ (∀ b < ∀ (1, 2, 3, 4))) } @ Sz[0 0 1 0] ↠ ()
Example (Index_scalar)
sizes := Sizes{}
simple := Arrow{
	Var('a'),
	Arrow{
		Var('b'),
		Abstract{},
	},
}
fmt.Printf("Unconstrained Indexing: %v\n", simple)

st := SubjectTo{
	And,
	SubjectTo{
		Eq,
		UnaryOp{Dims, Var('a')},
		UnaryOp{Dims, Var('b')},
	},
	SubjectTo{
		Lt,
		UnaryOp{ForAll, Var('b')},
		UnaryOp{ForAll, Var('a')},
	},
}
index := Compound{Expr: simple, SubjectTo: st}
fmt.Printf("Indexing: %v\n", index)

fst := Shape{}
retExpr, err := InferApp(index, fst)
if err != nil {
	fmt.Printf("Error: %v\n", err)
}
fmt.Printf("Applying %v to %v:\n", fst, index)
fmt.Printf("\t%v @ %v ↠ %v\n", index, fst, retExpr)

snd := sizes
retExpr2, err := InferApp(retExpr, snd)
if err != nil {
	fmt.Printf("Error: %v\n", err)
}
fmt.Printf("Applying %v to %v:\n", snd, retExpr)
fmt.Printf("\t%v @ %v ↠ %v\n", retExpr, snd, retExpr2)
Output:

Unconstrained Indexing: a → b → ()
Indexing: { a → b → () | ((D a = D b) ∧ (∀ b < ∀ a)) }
Applying () to { a → b → () | ((D a = D b) ∧ (∀ b < ∀ a)) }:
	{ a → b → () | ((D a = D b) ∧ (∀ b < ∀ a)) } @ () ↠ { b → () | ((D () = D b) ∧ (∀ b < ∀ ())) }
Applying Sz[] to { b → () | ((D () = D b) ∧ (∀ b < ∀ ())) }:
	{ b → () | ((D () = D b) ∧ (∀ b < ∀ ())) } @ Sz[] ↠ ()
Example (Index_unidimensional)
sizes := Sizes{0}
simple := Arrow{
	Var('a'),
	Arrow{
		Var('b'),
		Abstract{},
	},
}
fmt.Printf("Unconstrained Indexing: %v\n", simple)

st := SubjectTo{
	And,
	SubjectTo{
		Eq,
		UnaryOp{Dims, Var('a')},
		UnaryOp{Dims, Var('b')},
	},
	SubjectTo{
		Lt,
		UnaryOp{ForAll, Var('b')},
		UnaryOp{ForAll, Var('a')},
	},
}
index := Compound{Expr: simple, SubjectTo: st}
fmt.Printf("Indexing: %v\n", index)

fst := Shape{5}
retExpr, err := InferApp(index, fst)
if err != nil {
	fmt.Printf("Error: %v\n", err)
}
fmt.Printf("Applying %v to %v:\n", fst, index)
fmt.Printf("\t%v @ %v ↠ %v\n", index, fst, retExpr)

snd := sizes
retExpr2, err := InferApp(retExpr, snd)
if err != nil {
	fmt.Printf("Error: %v\n", err)
}
fmt.Printf("Applying %v to %v:\n", snd, retExpr)
fmt.Printf("\t%v @ %v ↠ %v\n", retExpr, snd, retExpr2)
Output:

Unconstrained Indexing: a → b → ()
Indexing: { a → b → () | ((D a = D b) ∧ (∀ b < ∀ a)) }
Applying (5) to { a → b → () | ((D a = D b) ∧ (∀ b < ∀ a)) }:
	{ a → b → () | ((D a = D b) ∧ (∀ b < ∀ a)) } @ (5) ↠ { b → () | ((D (5) = D b) ∧ (∀ b < ∀ (5))) }
Applying Sz[0] to { b → () | ((D (5) = D b) ∧ (∀ b < ∀ (5))) }:
	{ b → () | ((D (5) = D b) ∧ (∀ b < ∀ (5))) } @ Sz[0] ↠ ()
Example (KeepDims)
keepDims1 := MakeArrow(
	MakeArrow(Var('a'), Var('b')),
	Var('a'),
	Var('c'),
)

expr := Compound{keepDims1, SubjectTo{
	Eq,
	UnaryOp{Dims, Var('a')},
	UnaryOp{Dims, Var('c')},
}}

fmt.Printf("KeepDims1: %v", expr)
Output:

KeepDims1: { (a → b) → a → c | (D a = D c) }
Example (MatMul)
matmul := Arrow{
	Abstract{Var('a'), Var('b')},
	Arrow{
		Abstract{Var('b'), Var('c')},
		Abstract{Var('a'), Var('c')},
	},
}
fmt.Printf("MatMul: %v\n", matmul)

// Apply the first input to MatMul
fst := Shape{2, 3}
expr2, err := InferApp(matmul, fst)
if err != nil {
	fmt.Printf("Error: %v\n", err)
	return
}
fmt.Printf("Applying %v to MatMul:\n", fst)
fmt.Printf("%v @ %v ↠ %v\n", matmul, fst, expr2)

// Apply the second input
snd := Shape{3, 4}
expr3, err := InferApp(expr2, snd)
if err != nil {
	fmt.Printf("Error: %v\n", err)
}
fmt.Printf("Applying %v to the result:\n", snd)
fmt.Printf("%v @ %v ↠ %v\n", expr2, snd, expr3)

// Bad example:
bad2nd := Shape{4, 5}
_, err = InferApp(expr2, bad2nd)
fmt.Printf("What happens when you pass in a bad value (e.g. %v instead of %v):\n", bad2nd, snd)
fmt.Printf("%v @ %v ↠ %v", expr2, bad2nd, err)
Output:

MatMul: (a, b) → (b, c) → (a, c)
Applying (2, 3) to MatMul:
(a, b) → (b, c) → (a, c) @ (2, 3) ↠ (3, c) → (2, c)
Applying (3, 4) to the result:
(3, c) → (2, c) @ (3, 4) ↠ (2, 4)
What happens when you pass in a bad value (e.g. (4, 5) instead of (3, 4)):
(3, c) → (2, c) @ (4, 5) ↠ Failed to solve [{(3, c) → (2, c) = (4, 5) → d}] | d: Unification Fail. 3 ~ 4 cannot proceed
Example (Package)
package main

import (
	"fmt"

	"github.com/pkg/errors"
	"gorgonia.org/shapes"
)

type Dense struct {
	data  []float64
	shape shapes.Shape
}

func (d *Dense) Shape() shapes.Shape { return d.shape }

type elemwiseFn func(a, b *Dense) (*Dense, error)

func (f elemwiseFn) Shape() shapes.Expr {
	return shapes.Arrow{
		shapes.Var('a'),
		shapes.Arrow{shapes.Var('a'), shapes.Var('a')},
	}
}

type matmulFn func(a, b *Dense) (*Dense, error)

func (f matmulFn) Shape() shapes.Expr {
	return shapes.Arrow{
		shapes.Abstract{shapes.Var('a'), shapes.Var('b')},
		shapes.Arrow{
			shapes.Abstract{shapes.Var('b'), shapes.Var('c')},
			shapes.Abstract{shapes.Var('a'), shapes.Var('c')},
		},
	}
}

type applyFn func(*Dense, func(float64) float64) (*Dense, error)

func (f applyFn) Shape() shapes.Expr {
	return shapes.Arrow{
		shapes.Var('a'),
		shapes.Arrow{
			shapes.Arrow{
				shapes.Var('b'),
				shapes.Var('b'),
			},
			shapes.Var('a'),
		},
	}
}

func (d *Dense) MatMul(other *Dense) (*Dense, error) {
	expr := matmulFn((*Dense).MatMul).Shape()
	sh, err := infer(expr, d.Shape(), other.Shape())
	if err != nil {
		return nil, err
	}

	retVal := &Dense{
		data:  make([]float64, sh.TotalSize()),
		shape: sh,
	}
	return retVal, nil
}

func (d *Dense) Add(other *Dense) (*Dense, error) {
	expr := elemwiseFn((*Dense).Add).Shape()
	sh, err := infer(expr, d.Shape(), other.Shape())
	if err != nil {
		return nil, err
	}
	retVal := &Dense{
		data:  make([]float64, sh.TotalSize()),
		shape: sh,
	}
	return retVal, nil
}

func (d *Dense) Apply(fn func(float64) float64) (*Dense, error) {
	expr := applyFn((*Dense).Apply).Shape()
	fnShape := shapes.ShapeOf(fn)
	sh, err := infer(expr, d.Shape(), fnShape)
	if err != nil {
		return nil, err
	}
	return &Dense{shape: sh}, nil
}

func infer(fn shapes.Expr, others ...shapes.Expr) (shapes.Shape, error) {
	retShape, err := shapes.InferApp(fn, others...)
	if err != nil {
		return nil, err
	}
	sh, err := shapes.ToShape(retShape)
	if err != nil {
		return nil, errors.Wrapf(err, "Expected a Shape in retShape. Got %v of %T instead", retShape, retShape)
	}
	return sh, nil
}

func main() {
	A := &Dense{
		data:  []float64{1, 2, 3, 4, 5, 6},
		shape: shapes.Shape{2, 3},
	}
	B := &Dense{
		data:  []float64{10, 20, 30, 40, 50, 60},
		shape: shapes.Shape{3, 2},
	}
	var fn shapes.Exprer = matmulFn((*Dense).MatMul)
	C, err := A.MatMul(B)
	if err != nil {
		fmt.Println(err)
		return
	}
	fmt.Printf("×: %v\n", fn.Shape())
	fmt.Printf("\t   A   ×    B   =    C\n\t%v × %v = %v\n", A.Shape(), B.Shape(), C.Shape())
	fmt.Println("---")

	fn = elemwiseFn((*Dense).Add)
	D := &Dense{
		data:  []float64{0, 0, 0, 0},
		shape: shapes.Shape{2, 2},
	}
	E, err := C.Add(D)
	if err != nil {
		fmt.Println(err)
		return
	}
	fmt.Printf("+: %v\n", fn.Shape())
	fmt.Printf("\t   C   +    D   =    E\n")
	fmt.Printf("\t%v + %v = %v\n", C.Shape(), D.Shape(), E.Shape())
	fmt.Println("---")

	square := func(a float64) float64 { return a * a }
	squareShape := shapes.ShapeOf(square)
	fn = applyFn((*Dense).Apply)
	F, err := E.Apply(square)
	if err != nil {
		fmt.Println(err)
		return
	}
	fmt.Printf("@: %v\n", fn.Shape())
	fmt.Printf("square: %v\n", squareShape)
	fmt.Printf("\t   E   @  square =    F\n")
	fmt.Printf("\t%v @ (%v) = %v\n", E.Shape(), squareShape, F.Shape())
	fmt.Println("---")

	// trying to do a bad add (e.g. adding two matrices with different shapes) will yield an error
	fn = elemwiseFn((*Dense).Add)
	_, err = E.Add(A)
	fmt.Printf("+: %v\n", fn.Shape())
	fmt.Printf("\t   E   +   A    =\n")
	fmt.Printf("\t%v + %v = ", E.Shape(), A.Shape())
	fmt.Println(err)

}
Output:

×: (a, b) → (b, c) → (a, c)
	   A   ×    B   =    C
	(2, 3) × (3, 2) = (2, 2)
---
+: a → a → a
	   C   +    D   =    E
	(2, 2) + (2, 2) = (2, 2)
---
@: a → (b → b) → a
square: () → ()
	   E   @  square =    F
	(2, 2) @ (() → ()) = (2, 2)
---
+: a → a → a
	   E   +   A    =
	(2, 2) + (2, 3) = Failed to solve [{(2, 2) → (2, 2) = (2, 3) → a}] | a: Unification Fail. 2 ~ 3 cannot proceed
Example (Ravel)
ravel := Arrow{
	Var('a'),
	Abstract{UnaryOp{Prod, Var('a')}},
}
fmt.Printf("Ravel: %v\n", ravel)

fst := Shape{2, 3, 4}
retExpr, err := InferApp(ravel, fst)
if err != nil {
	fmt.Printf("Error %v\n", err)
}
fmt.Printf("Applying %v to Ravel:\n", fst)
fmt.Printf("%v @ %v ↠ %v", ravel, fst, retExpr)
Output:

Ravel: a → (Π a)
Applying (2, 3, 4) to Ravel:
a → (Π a) @ (2, 3, 4) ↠ (24)
Example (Reshape)
expr := Compound{
	Arrow{
		Var('a'),
		Arrow{
			Var('b'),
			Var('b'),
		},
	},
	SubjectTo{
		Eq,
		UnaryOp{Prod, Var('a')},
		UnaryOp{Prod, Var('b')},
	},
}

fmt.Printf("Reshape: %v\n", expr)

fst := Shape{2, 3}
snd := Shape{3, 2}

retExpr, err := InferApp(expr, fst)
if err != nil {
	fmt.Println(err)
}
fmt.Printf("Applying %v to %v:\n", fst, expr)
fmt.Printf("\t%v @ %v ↠ %v\n", expr, fst, retExpr)

retExpr2, err := InferApp(retExpr, snd)
if err != nil {
	fmt.Println(err)
}

fmt.Printf("Applying %v to %v:\n", snd, retExpr)
fmt.Printf("\t%v @ %v ↠ %v\n", retExpr, snd, retExpr2)

bad := Shape{6, 2}
_, err = InferApp(retExpr, bad)
fmt.Printf("Applying a bad shape %v to %v:\n", bad, retExpr)
fmt.Printf("\t%v\n", err)
Output:

Reshape: { a → b → b | (Π a = Π b) }
Applying (2, 3) to { a → b → b | (Π a = Π b) }:
	{ a → b → b | (Π a = Π b) } @ (2, 3) ↠ { b → b | (Π (2, 3) = Π b) }
Applying (3, 2) to { b → b | (Π (2, 3) = Π b) }:
	{ b → b | (Π (2, 3) = Π b) } @ (3, 2) ↠ (3, 2)
Applying a bad shape (6, 2) to { b → b | (Π (2, 3) = Π b) }:
	SubjectTo (Π (2, 3) = Π (6, 2)) resolved to false. Cannot continue
Example (Slice)
sli := Range{0, 2, 1}
simple := Arrow{
	Var('a'),
	Arrow{
		Var('b'),
		SliceOf{
			Var('b'),
			Var('a'),
		},
	},
}
slice := Compound{
	Expr: simple,
	SubjectTo: SubjectTo{
		OpType: Gte,
		A:      IndexOf{I: 0, A: Var('a')},
		B:      Size(2),
	},
}

fmt.Printf("slice: %v\n", slice)

fst := Shape{5, 3, 4}
retExpr, err := InferApp(slice, fst)
if err != nil {
	fmt.Println(err)
}
fmt.Printf("Applying %v to %v:\n", fst, slice)
fmt.Printf("\t%v @ %v ↠ %v\n", slice, fst, retExpr)

snd := sli
retExpr2, err := InferApp(retExpr, snd)
if err != nil {
	fmt.Println(err)
}
fmt.Printf("Applying %v to %v:\n", snd, retExpr)
fmt.Printf("\t%v @ %v ↠ %v\n", retExpr, snd, retExpr2)
Output:

slice: { a → b → a[b] | (a[0] ≥ 2) }
Applying (5, 3, 4) to { a → b → a[b] | (a[0] ≥ 2) }:
	{ a → b → a[b] | (a[0] ≥ 2) } @ (5, 3, 4) ↠ b → (5, 3, 4)[b]
Applying [0:2] to b → (5, 3, 4)[b]:
	b → (5, 3, 4)[b] @ [0:2] ↠ (2, 3, 4)
Example (Slice_basic)
sli := Range{0, 2, 1}
last := Range{3, 4, 1}
so := SliceOf{
	Slices{sli, sli, last},
	Var('a'),
}
simple := Arrow{
	Var('a'),
	so,
}

fmt.Printf("Simple Slice: %v\n", simple)
fst := Shape{5, 3, 4}
retExpr, err := InferApp(simple, fst)
if err != nil {
	fmt.Println(err)
}

fmt.Printf("Applying %v to %v:\n\t%v @ %v ↠ %v\n", fst, simple, simple, fst, retExpr)
Output:

Simple Slice: a → a[0:2, 0:2, 3]
Applying (5, 3, 4) to a → a[0:2, 0:2, 3]:
	a → a[0:2, 0:2, 3] @ (5, 3, 4) ↠ (2, 2)
Example (Sum_allAxes)
sum := Arrow{Var('a'), Reduce(Var('a'), Axes{0, 1})}
fmt.Printf("Sum: %v\n", sum)
fst := Shape{2, 3}
retExpr, err := InferApp(sum, fst)
if err != nil {
	fmt.Println(err)
}

fmt.Printf("Applying %v to %v\n", fst, sum)
fmt.Printf("\t%v @ %v → %v\n", sum, fst, retExpr)

sum1 := Arrow{Var('a'), ReductOf{Var('a'), AllAxes}}
retExpr1, err := InferApp(sum1, fst)
if err != nil {
	fmt.Println(err)
}
fmt.Printf("Applying %v to %v\n", fst, sum1)
fmt.Printf("\t%v @ %v → %v", sum1, fst, retExpr1)
Output:

Sum: a → /⁰/¹a
Applying (2, 3) to a → /⁰/¹a
	a → /⁰/¹a @ (2, 3) → ()
Applying (2, 3) to a → /⁼a
	a → /⁼a @ (2, 3) → ()
Example (Trace)
expr := Arrow{
	Abstract{Var('a'), Var('a')},
	Shape{},
}
fmt.Printf("Trace: %v\n", expr)
Output:

Trace: (a, a) → ()
Example (Transpose)
axes := Axes{0, 1, 3, 2}
simple := Arrow{
	Var('a'),
	Arrow{
		axes,
		TransposeOf{
			axes,
			Var('a'),
		},
	},
}
fmt.Printf("Unconstrained Transpose: %v\n", simple)

st := SubjectTo{
	Eq,
	UnaryOp{Dims, axes},
	UnaryOp{Dims, Var('a')},
}
transpose := Compound{
	Expr:      simple,
	SubjectTo: st,
}
fmt.Printf("Transpose: %v\n", transpose)

fst := Shape{1, 2, 3, 4}
retExpr, err := InferApp(transpose, fst)
if err != nil {
	fmt.Printf("Error: %v\n", err)
}
fmt.Printf("Applying %v to %v:\n", fst, transpose)
fmt.Printf("\t%v @ %v ↠ %v\n", transpose, fst, retExpr)
snd := axes
retExpr2, err := InferApp(retExpr, snd)
if err != nil {
	fmt.Printf("Error: %v\n", err)
}
fmt.Printf("Applying %v to %v:\n", snd, retExpr)
fmt.Printf("\t%v @ %v ↠ %v\n", retExpr, snd, retExpr2)

// bad axes
bad2nd := Axes{0, 2, 1, 3} // not the original axes {0,1,3,2}
_, err = InferApp(retExpr, bad2nd)
fmt.Printf("Bad Axes causes error: %v\n", err)

// bad first input
bad1st := Shape{2, 3, 4}
_, err = InferApp(transpose, bad1st)
fmt.Printf("Bad first input causes error: %v", err)
Output:

Unconstrained Transpose: a → X[0 1 3 2] → T⁽⁰ ¹ ³ ²⁾ a
Transpose: { a → X[0 1 3 2] → T⁽⁰ ¹ ³ ²⁾ a | (D X[0 1 3 2] = D a) }
Applying (1, 2, 3, 4) to { a → X[0 1 3 2] → T⁽⁰ ¹ ³ ²⁾ a | (D X[0 1 3 2] = D a) }:
	{ a → X[0 1 3 2] → T⁽⁰ ¹ ³ ²⁾ a | (D X[0 1 3 2] = D a) } @ (1, 2, 3, 4) ↠ X[0 1 3 2] → (1, 2, 4, 3)
Applying X[0 1 3 2] to X[0 1 3 2] → (1, 2, 4, 3):
	X[0 1 3 2] → (1, 2, 4, 3) @ X[0 1 3 2] ↠ (1, 2, 4, 3)
Bad Axes causes error: Failed to solve [{X[0 1 3 2] → (1, 2, 4, 3) = X[0 2 1 3] → a}] | a: Unification Fail. X[0 1 3 2] ~ X[0 2 1 3] cannot proceed
Bad first input causes error: SubjectTo (D X[0 1 3 2] = D (2, 3, 4)) resolved to false. Cannot continue
Example (Transpose_NoOp)
axes := Axes{0, 1, 2, 3} // when the axes are monotonic, there is NoOp.
transpose := Arrow{
	Var('a'),
	TransposeOf{
		axes,
		Var('a'),
	},
}
fst := Shape{1, 2, 3, 4}
retExpr, err := InferApp(transpose, fst)
if err != nil {
	fmt.Printf("Error %v\n", err)
}
fmt.Printf("Applying %v to %v\n", fst, transpose)
fmt.Printf("\t%v @ %v ↠ %v", transpose, fst, retExpr)
Output:

Applying (1, 2, 3, 4) to a → T⁽⁰ ¹ ² ³⁾ a
	a → T⁽⁰ ¹ ² ³⁾ a @ (1, 2, 3, 4) ↠ (1, 2, 3, 4)
Example (UnsafePermute)
pattern := []int{2, 1, 3, 4, 0}
x1 := []int{1, 2, 3, 4, 5}
x2 := []int{5, 4, 3, 2, 1}
fmt.Printf("Before:\nx1: %v\nx2: %v\n", x1, x2)

err := UnsafePermute(pattern, x1, x2)
if err != nil {
	fmt.Println(err)
}
fmt.Printf("After:\nx1: %v\nx2: %v\n\n", x1, x2)

// when patterns are monotonic and increasing, it is noop

pattern = []int{0, 1, 2, 3, 4}
err = UnsafePermute(pattern, x1, x2)
if _, ok := err.(NoOpError); ok {
	fmt.Printf("NoOp with %v:\nx1: %v\nx2: %v\n\n", pattern, x1, x2)
}

// special cases for 2 dimensions

x1 = x1[:2]
x2 = x2[:2]
fmt.Printf("Before:\nx1: %v\nx2: %v\n", x1, x2)
pattern = []int{1, 0} // the only valid pattern with 2 dimensions
if err := UnsafePermute(pattern, x1, x2); err != nil {
	fmt.Println(err)
}
fmt.Printf("After:\nx1: %v\nx2: %v\n\n", x1, x2)

// Bad patterns

// invalid axis
pattern = []int{2, 1}
err = UnsafePermute(pattern, x1, x2)
fmt.Printf("Invalid axis in pattern %v: %v\n", pattern, err)

// repeated axes
pattern = []int{1, 1}
err = UnsafePermute(pattern, x1, x2)
fmt.Printf("Repeated axes in pattern %v: %v\n", pattern, err)

// dimension mismatches
pattern = []int{1}
err = UnsafePermute(pattern, x1, x2)
fmt.Printf("Pattern %v has a smaller dimension than xs: %v\n", pattern, err)
Output:

Before:
x1: [1 2 3 4 5]
x2: [5 4 3 2 1]
After:
x1: [3 2 4 5 1]
x2: [3 4 2 1 5]

NoOp with [0 1 2 3 4]:
x1: [3 2 4 5 1]
x2: [3 4 2 1 5]

Before:
x1: [3 2]
x2: [3 4]
After:
x1: [2 3]
x2: [4 3]

Invalid axis in pattern [2 1]: Invalid axis 2 for ndarray with 2 dimensions.
Repeated axes in pattern [1 1]: repeated axis 1 in permutation pattern.
Pattern [1] has a smaller dimension than xs: Dimension mismatch. Expected 2. Got  1 instead.

Index

Examples

Constants

This section is empty.

Variables

This section is empty.

Functions

func AreBroadcastable

func AreBroadcastable(a, b Shape) (err error)

AreBroadcastable checks that two shapes are mutually broadcastable.

func CheckSlice

func CheckSlice(s Slice, size int) error

CheckSlice checks a slice to see if it's sane

func IsMonotonicInts

func IsMonotonicInts(a []int) (monotonic bool, incr1 bool)

IsMonotonicInts returns true if the slice of ints is monotonically increasing. It also returns true for incr1 if every succession is a succession of 1

func SliceDetails

func SliceDetails(s Slice, size int) (start, end, step int, err error)

SliceDetails is a function that takes a slice and spits out its details. The whole reason for this is to handle the nil Slice, which is this: a[:]

func UnsafePermute

func UnsafePermute(pattern []int, xs ...[]int) (err error)

UnsafePermute permutes the xs according to the pattern. Each x in xs must have the same length as the pattern's length.

func Verify

func Verify(any interface{}) error

Verify checks that the values of a data structure conforms to the expected shape, given in struct tags.

Example - consider the following struct:

type Foo struct {
	A tensor.Tensor `shape:"(a, b)"`
	B tensor.Tensor `shape:"(b, c)"`
}

At runtime, the fields A and B would be populated with a Tensor of arbitrary shape. Verify can verify that A and B have the expected patterns of shapes.

So, if A has a shape of (2, 3), B's shape cannot be (4, 5). It can be (3, 5).

Example
package main

import (
	"fmt"

	. "gorgonia.org/shapes"
)

type T int

func (t T) Shape() Shape {
	switch t {
	case 0:
		return Shape{5, 4, 2}
	case 1:
		return Shape{4, 2}
	case 2:
		return Shape{2, 10}
	default:
		return Shape{10, 10, 10, 10} // bad shape
	}

}

func main() {
	// the actual field type doesn't matter for now
	type A struct {
		A T `shape:"(a, b, c)"`
		B T `shape:"(b, c)"`
	}

	type B struct {
		A
		C T `shape:"(c, d)"`
	}

	a1 := A{0, 1}
	a2 := A{0, 100}

	if err := Verify(a1); err != nil {
		fmt.Printf("a1 is a correct value. No errors expected. Got %v instead\n", err)
	}
	err := Verify(a2)
	if err == nil {
		fmt.Printf("a2 is an incorrect value. Errors expected but none was returned\n")
	}

}
Output:

Types

type Abstract

type Abstract []Sizelike

Abstract is an abstract shape

func Gen

func Gen(d int) (retVal Abstract)

Gen creates an abstract with the provided dims. This is particularly useful for generating abstracts for higher order functions.

Example

Gen is a generator for Abstracts

a := Gen(2)
fmt.Printf("Gen(2): %v\n", a)

// Gen is not a stateful generator
b := Gen(2)
fmt.Printf("Gen(2): %v\n", b)

// Gen handles a maximum of 50 characters(so far)
c := Gen(50)
fmt.Printf("Gen(50): %v\n", c)

defer func() {
	if r := recover(); r != nil {
		fmt.Println("Gen will panic if a `d` >= 51 is passed in.")
	}
}()
Gen(51)
Output:

Gen(2): (a, b)
Gen(2): (a, b)
Gen(50): (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, α, β, γ, δ, ε, ζ, η, θ, ι, κ, λ, μ, ν, ξ, ο, π, ρ, ς, σ, τ, υ, φ, χ, ψ)
Gen will panic if a `d` >= 51 is passed in.

func (Abstract) Clone

func (a Abstract) Clone() Abstract

func (Abstract) Concat

func (a Abstract) Concat(axis Axis, others ...Shapelike) (newShape Shapelike, err error)

func (Abstract) Cons

func (a Abstract) Cons(other Conser) (retVal Conser)

func (Abstract) DimSize

func (a Abstract) DimSize(dim int) (Sizelike, error)

func (Abstract) Dims

func (a Abstract) Dims() int

Dims returns the number of dimensions in the shape

func (Abstract) Format

func (s Abstract) Format(st fmt.State, r rune)

Format implements fmt.Formatter, and formats a shape nicely

func (Abstract) Repeat

func (a Abstract) Repeat(axis Axis, repeats ...int) (retVal Shapelike, finalRepeats []int, size int, err error)

func (Abstract) S

func (a Abstract) S(slices ...Slice) (newShape Shapelike, err error)

func (Abstract) T

func (a Abstract) T(axes ...Axis) (newShape Shapelike, err error)

func (Abstract) ToShape

func (a Abstract) ToShape() (s Shape, ok bool)

func (Abstract) TotalSize

func (a Abstract) TotalSize() int

type Arrow

type Arrow struct {
	A, B Expr
}

Arrow represents a function of shapes, from A → B. Arrows are right associative.

func MakeArrow

func MakeArrow(exprs ...Expr) Arrow

MakeArrow is a utility function for writing correct Arrow expressions.

Consider for example, matrix multiplication. It is written plainly as follows:

MatMul: (a, b) → (b, c) → (a, c)

However, note that because arrows are right associative and they're a binary operator, it actually is more correctly written like this:

MatMul: (a, b) → ((b, c) → (a, c))

This makes writing plain Arrow{} expressions a bit fraught with errors (see example). Thus, the MakeArrow function is created to help write more correct expressions.

Example

MakeArrow is a utility function

matmul := MakeArrow(
	Abstract{Var('a'), Var('b')},
	Abstract{Var('b'), Var('c')},
	Abstract{Var('a'), Var('c')},
)
fmt.Printf("Correct MatMul: %v\n", matmul)

wrong := Arrow{
	Arrow{Abstract{Var('a'), Var('b')},
		Abstract{Var('b'), Var('c')},
	},
	Abstract{Var('a'), Var('c')},
}
fmt.Printf("Wrong MatMul: %v\n", wrong)

// it doesn't mean that you should use MakeArrow mindlessly.
// Consider the higher order function Map: (a → a) → b → b
// p.s equiv Go function signature:
// 	func Map(f func(a int) int, b Tensor) Tensor
Map := MakeArrow(
	Arrow{Var('a'), Var('a')}, // you can also use MakeArrow here
	Var('b'),
	Var('b'),
)
fmt.Printf("Correct Map: %v\n", Map)

wrong = MakeArrow(
	Var('a'), Var('a'),
	Var('b'),
	Var('b'),
)
fmt.Printf("Wrong Map: %v\n", wrong)
Output:

Correct MatMul: (a, b) → (b, c) → (a, c)
Wrong MatMul: ((a, b) → (b, c)) → (a, c)
Correct Map: (a → a) → b → b
Wrong Map: a → a → b → b

func (Arrow) Format

func (a Arrow) Format(s fmt.State, r rune)

type Axes

type Axes []Axis

Axes represents a list of axes. Despite being a container type (i.e. an Axis is an Expr), it returns nil for Exprs(). This is because we want to treat Axes as a monolithic entity.

func (Axes) AsInts

func (a Axes) AsInts() []int

func (Axes) Dims

func (a Axes) Dims() int

func (Axes) Eq

func (a Axes) Eq(other Axes) bool

func (Axes) Format

func (a Axes) Format(s fmt.State, r rune)

type Axis

type Axis int

Axis represents an axis in doing shape stuff.

const (
	AllAxes Axis = -65535
)

func ResolveAxis

func ResolveAxis(a Axis, s Shapelike) Axis

func (Axis) Format

func (a Axis) Format(s fmt.State, c rune)

type BinOp

type BinOp struct {
	Op OpType
	A  Expr
	B  Expr
}

BinOp represents a binary operation. It is only an Expr for the purposes of being a value in a shape. On the toplevel, BinOp on is meaningless. This is enforced in the `unify` function.

func (BinOp) Format

func (op BinOp) Format(s fmt.State, r rune)

Format formats the BinOp into a nice string.

type BroadcastOf

type BroadcastOf struct {
	A, B Expr
}

BroadcastOf represents the results of mutually broadcasting A and B expr.

func (BroadcastOf) Format

func (b BroadcastOf) Format(s fmt.State, r rune)

type Compound

type Compound struct {
	Expr
	SubjectTo
}

func (Compound) Format

func (c Compound) Format(s fmt.State, r rune)

type ConcatOf

type ConcatOf struct {
	Along Axis
	A, B  Expr
}

func (ConcatOf) Format

func (c ConcatOf) Format(s fmt.State, r rune)

type Conser

type Conser interface {
	Cons(Conser) Conser
	// contains filtered or unexported methods
}

Conser is anything that can be used to construct another Conser. The following types are Conser:

Shape | Abstract

type ConstraintsExpr

type ConstraintsExpr struct {
	// contains filtered or unexported fields
}

ConstraintExpr is a tuple of a list of constraints and an expression.

func App

func App(ar Expr, b Expr) ConstraintsExpr

App applys an expression to a function/Arrow expression. This function will aggressively perform alpha renaming on the expression.

Example. Given an application of the following:

((a,b) → (b, c) → (a, c)) @ (2, a)

The variables in the latter will be aggressively renamed, to become:

((a,b) → (b, c) → (a, c)) @ (2, d)

Normally this wouldn't be a concern, as you would be passing in concrete shapes, something like:

((a,b) → (b, c) → (a, c)) @ (2, 3)

which will then yield:

(3, c) → (2, c)

func (ConstraintsExpr) Format

func (ce ConstraintsExpr) Format(f fmt.State, r rune)

type E2

type E2 struct{ BinOp }

type Expr

type Expr interface {
	// contains filtered or unexported methods
}

Expr represents an expression. The following types are Expr:

Shape | Abstract | Arrow | Compound
Var | Size | UnaryOp
IndexOf | TransposeOf | SliceOf | RepeatOf | ConcatOf
Sli | Axis | Axes

A compact BNF is as follows:

E := S | A | E → E | (E s.t. X)
a | Sz | Π E | Σ E | D E
I n E | T []Ax E | L : E | R Ax n E | C Ax E E
: | Ax | []Ax

func Infer

func Infer(ce ConstraintsExpr) (Expr, error)

func InferApp

func InferApp(a Expr, others ...Expr) (retVal Expr, err error)

func Parse

func Parse(a string) (retVal Expr, err error)

Parse parses a string and returns a shape expression.

func ShapeOf

func ShapeOf(a interface{}) Expr

ShapeOf returns the shape of a given datum.

type Exprer

type Exprer interface {
	Shape() Expr
}

Exprer is anything that can return a Shape Expr.

type IndexOf

type IndexOf struct {
	I Size
	A Expr
}

IndexOf gets the size of a given shape (expression) at the given index.

IndexOf is the symbolic version of doing s[i], where s is a Shape.

func (IndexOf) Format

func (i IndexOf) Format(s fmt.State, r rune)

type NoOpError

type NoOpError interface {
	NoOp()
}

NoOpError is a useful for operations that have no op.

type OpType

type OpType byte

OpType represents the type of operation that is being performed

const (
	// Unary: a → b
	Const OpType = iota // K
	Dims
	Prod
	Sum
	ForAll // special use
	// Binary: a → a → a
	Add
	Sub
	Mul
	Div
	// Cmp: a → a → Bool
	Eq
	Ne
	Lt
	Gt
	Lte
	Gte
	// Logic: bool → bool → bool
	And
	Or

	// Broadcast
	Bc
)

func (OpType) String

func (o OpType) String() string

String returns the string representation

type Operation

type Operation interface {
	// contains filtered or unexported methods
}

Operation represents an operation (BinOp or UnaryOp)

type Range

type Range struct {
	// contains filtered or unexported fields
}

Range is a shape expression representing a slicing range. Coincidentally, Range also implements Slice.

A Range is a shape expression but it doesn't stand alone - resolving it will yield an error.

func S

func S(start int, opt ...int) *Range

S creates a Slice. Internally it uses the Range type provided.

func (Range) End

func (s Range) End() int

End returns the end of the slicing range

func (Range) Format

func (s Range) Format(st fmt.State, r rune)

Format allows Sli to implement fmt.Formmatter

func (Range) Start

func (s Range) Start() int

Start returns the start of the slicing range

func (Range) Step

func (s Range) Step() int

Step returns the steps/jumps to make in the slicing range.

type ReductOf

type ReductOf struct {
	A     Expr
	Along Axis
}

ReductOf represents the results of reducing A along the given axis.

func Reduce

func Reduce(a Expr, along Axes) ReductOf

func (ReductOf) Format

func (r ReductOf) Format(s fmt.State, c rune)

type RepeatOf

type RepeatOf struct {
	Along   Axis
	Repeats []Size
	A       Expr
}

func (RepeatOf) Format

func (r RepeatOf) Format(s fmt.State, ru rune)

type Shape

type Shape []int

Shape represents the shape of a multidimensional array.

func CalcBroadcastShape

func CalcBroadcastShape(a, b Shape) Shape

CalcBroadcastShape computes the final shape of two mutually broadcastable shapes. This function does not check that the shapes are mutually broadcastable. Use `AreBroadcastable` for that functionality.

func ScalarShape

func ScalarShape() Shape

ScalarShape returns a shape that represents a scalar shape.

Usually `nil` will also be considered a scalar shape (because a `nil` of type `Shape` has a length of 0 and will return true when `.IsScalar` is called)

func ToShape

func ToShape(a Expr) (Shape, error)

func (Shape) AsInts

func (s Shape) AsInts() []int

func (Shape) Clone

func (s Shape) Clone() Shape

func (Shape) Concat

func (s Shape) Concat(axis Axis, ss ...Shapelike) (retVal Shapelike, err error)

func (Shape) Cons

func (s Shape) Cons(other Conser) Conser

Cons is an associative construction of shapes

func (Shape) Dim

func (s Shape) Dim(d int) (retVal int, err error)

Dim returns the dimension wanted,

func (Shape) DimSize

func (s Shape) DimSize(d int) (retVal Sizelike, err error)

DimSize returns the dimension wanted.

func (Shape) Dims

func (s Shape) Dims() int

Dims returns the number of dimensions in the shape

func (Shape) Eq

func (s Shape) Eq(other Shape) bool

Eq indicates if a shape is equal with another. There is a soft concept of equality when it comes to vectors.

If s is a column vector and other is a vanilla vector, they're considered equal if the size of the column dimension is the same as the vector size; if s is a row vector and other is a vanilla vector, they're considered equal if the size of the row dimension is the same as the vector size

func (Shape) Format

func (s Shape) Format(st fmt.State, r rune)

Format implements fmt.Formatter, and formats a shape nicely

func (Shape) IsColVec

func (s Shape) IsColVec() bool

IsColVec returns true when the access pattern has the shape (x, 1)

func (Shape) IsMatrix

func (s Shape) IsMatrix() bool

IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices

func (Shape) IsRowVec

func (s Shape) IsRowVec() bool

IsRowVec returns true when the access pattern has the shape (1, x)

func (Shape) IsScalar

func (s Shape) IsScalar() bool

IsScalar returns true if the access pattern indicates it's a scalar value

func (Shape) IsScalarEquiv

func (s Shape) IsScalarEquiv() bool

IsScalarEquiv returns true if the access pattern indicates it's a scalar-like value

Example
s := Shape{1, 1, 1, 1, 1, 1}
fmt.Printf("%v is scalar equiv: %t\n", s, s.IsScalarEquiv())

s = Shape{}
fmt.Printf("%v is scalar equiv: %t\n", s, s.IsScalarEquiv())

s = Shape{2, 3}
fmt.Printf("%v is scalar equiv: %t\n", s, s.IsScalarEquiv())

s = Shape{0, 0, 0}
fmt.Printf("%v is scalar equiv: %t\n", s, s.IsScalarEquiv())

s = Shape{1, 2, 0, 3}
fmt.Printf("%v is scalar equiv: %t\n", s, s.IsScalarEquiv())
Output:

(1, 1, 1, 1, 1, 1) is scalar equiv: true
() is scalar equiv: true
(2, 3) is scalar equiv: false
(0, 0, 0) is scalar equiv: true
(1, 2, 0, 3) is scalar equiv: false

func (Shape) IsVector

func (s Shape) IsVector() bool

IsVector returns whether the access pattern falls into one of three possible definitions of vectors:

vanilla vector (not a row or a col)
column vector
row vector

func (Shape) IsVectorLike

func (s Shape) IsVectorLike() bool

IsVectorLike returns true when the shape looks like a vector e.g. a number that is surrounded by 1s:

(1, 1, ... 1, 10, 1, 1... 1)

func (Shape) Repeat

func (s Shape) Repeat(axis Axis, repeats ...int) (retVal Shapelike, finalRepeats []int, size int, err error)

Repeat returns the expected new shape given the repetition parameters

func (Shape) S

func (s Shape) S(slices ...Slice) (newShape Shapelike, err error)

S gives the new shape after a shape has been sliced.

func (Shape) Shape

func (s Shape) Shape() Shape

Shape implements Shaper - it returns itself

func (Shape) T

func (s Shape) T(axes ...Axis) (newShape Shapelike, err error)

func (Shape) TotalSize

func (s Shape) TotalSize() int

type Shapelike

type Shapelike interface {
	Dims() int
	TotalSize() int // needed?
	DimSize(dim int) (Sizelike, error)
	T(axes ...Axis) (newShape Shapelike, err error)
	S(slices ...Slice) (newShape Shapelike, err error)
	Repeat(axis Axis, repeats ...int) (newShape Shapelike, finalRepeats []int, size int, err error)
	Concat(axis Axis, others ...Shapelike) (newShape Shapelike, err error)
}

Shapelike is anything that performs all the things you can do with a Shape. The following types provided by this library are Shapelike:

Shape | Abstract

func ShapesToShapelikes

func ShapesToShapelikes(ss []Shape) []Shapelike

ShapesToShapelikes is a utility function that retuns a []Shapelike given a []Shape.

type Shaper

type Shaper interface {
	Shape() Shape
}

Shaper is anything that can return a Shape.

type Size

type Size int

Size represents a size of a dimension/axis

func (Size) Format

func (s Size) Format(f fmt.State, r rune)

type Sizelike

type Sizelike interface {
	// contains filtered or unexported methods
}

Sizelike represents something that can go into a Abstract. The following types are Sizelike:

Size | Var | BinOp | UnaryOp

type Sizes

type Sizes []Size

Sizes are a list of sizes.

func (Sizes) AsInts

func (s Sizes) AsInts() []int

func (Sizes) Format

func (s Sizes) Format(f fmt.State, r rune)

type Slice

type Slice interface {
	Start() int
	End() int
	Step() int
}

Slice represents a slicing range.

Example (S)
param0 := Abstract{Var('a'), Var('b')}
param1 := Abstract{Var('a'), Var('b'), BinOp{Add, Var('a'), Var('b')}, UnaryOp{Const, Var('b')}}
expected, err := param1.S(S(1, 5), S(1, 5), S(1, 5), S(2, 5))
if err != nil {
	fmt.Printf("Err %v\n", err)
	return
}
expr := MakeArrow(param0, param1, expected.(Expr))
fmt.Printf("expr: %v\n", expr)

fst := Shape{10, 20}
result, err := InferApp(expr, fst)
if err != nil {
	fmt.Printf("Err %v\n", err)
	return
}
fmt.Printf("%v @ %v ↠ %v\n", expr, fst, result)

snd := Shape{10, 20, 30, 20}
result2, err := InferApp(result, snd)
if err != nil {
	fmt.Printf("Err %v\n", err)
	return
}
fmt.Printf("%v @ %v ↠ %v", result, snd, result2)
Output:

expr: (a, b) → (a, b, a + b, K b) → (a[1:5], b[1:5], a + b[1:5], K b[2:5])
(a, b) → (a, b, a + b, K b) → (a[1:5], b[1:5], a + b[1:5], K b[2:5]) @ (10, 20) ↠ (10, 20, 30, 20) → (4, 4, 4, 3)
(10, 20, 30, 20) → (4, 4, 4, 3) @ (10, 20, 30, 20) ↠ (4, 4, 4, 3)

type SliceOf

type SliceOf struct {
	Slice Slicelike
	A     Expr
}

SliceOf is an intrinsic operation, symbolically representing a slicing operation.

func (SliceOf) Format

func (s SliceOf) Format(st fmt.State, r rune)

type Slicelike

type Slicelike interface {
	// contains filtered or unexported methods
}

Slicelike is anything like a slice. The following types implement Slicelike:

Range | Var

func ToSlicelike

func ToSlicelike(s Slice) Slicelike

ToSlicelike is a utility function for turning a slice into a Slicelike.

type Slices

type Slices []Slice

Slices is a list of slices.

func (Slices) Format

func (ss Slices) Format(st fmt.State, r rune)

type SubjectTo

type SubjectTo struct {
	OpType
	A, B Operation
}

SubjectTo describes a constraint

func (SubjectTo) Format

func (s SubjectTo) Format(st fmt.State, r rune)

type TransposeOf

type TransposeOf struct {
	Axes Axes
	A    Expr
}

TransposeOf is the symbolic version of doing s.T(axes...)

func (TransposeOf) Format

func (t TransposeOf) Format(s fmt.State, r rune)

type UnaryOp

type UnaryOp struct {
	Op OpType
	A  Expr
}

UnaryOp represetns a unary operation on a shape expression. Unlike BinaryOp, UnaryOp is an expression.

func (UnaryOp) Format

func (op UnaryOp) Format(s fmt.State, r rune)

Format makes UnaryOp implement fmt.Formatter.

type Var

type Var rune

Var represents a variable. A variable can represent:

  • a Shape (e.g. a in the expression a → b)
  • a Size (e.g. (a, b))
  • a Slice (in Arrow expressions)

func (Var) Format

func (v Var) Format(s fmt.State, r rune)

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL