astw

package module
v0.3.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jun 28, 2021 License: MIT Imports: 5 Imported by: 0

README

Astw - Enhanced abstract-syntax-tree walker for Go

In the Go standard library is the package go/ast, which defines functions and types for understanding parsed Go code. This includes the function Walk, which calls a visitor’s Visit method on each Node in a syntax tree.

For many applications, the visitor will need more context than just the node it’s currently visiting. So implementations will typically do such things as maintaining a stack of nodes being visited, and distinguishing between the recursive visits of one subnode versus another.

This package, astw, provides an enhanced Walk function with a Visitor based on callbacks, one for each syntax-tree node type. It tracks the extra details that syntax-tree walkers typically need.

Each callback for a given node is called twice: once before visiting its children (the “pre” visit) and once after (the “post” visit). This is true even for nodes that have no children.

Each callback receives as arguments:

  • The node being visited;
  • An enumerated constant describing which child of its parent this is;
  • A slice index, in case this node is in a slice of its parent’s children;
  • The stack of nodes above this node in the syntax tree;
  • A boolean telling whether this is a “pre” (true) visit or a “post” (false) visit;
  • The error, if any, produced by visiting the node’s children (always nil during “pre” visits).

Callbacks are never invoked on nil nodes.

In a “post” visit, when the received error is non-nil, a callback may decide whether and how to propagate the error to the caller. A node callback should typically begin with

if err != nil {
  return err
}

unless it wants to ignore, decorate, or otherwise alter errors in its subtree.

Nodes in a Go syntax tree have concrete types like *ast.IfStmt and *ast.BinaryExpr. Each concrete type has its own callback in Visitor: Visitor.IfStmt, Visitor.BinaryExpr, and so on.

Most node types also implement one of these abstract interfaces: ast.Expr, ast.Stmt, ast.Decl, and ast.Spec. Visitor has callbacks for these types too. If the concrete-type callback for a given node is set, then that’s used when the node is visited. On the other hand, if it’s not set but the abstract-type callback is, then that callback is used when the node is visited.

For example, if your program sets Visitor.BinaryExpr and Visitor.Expr, then your BinaryExpr callback will be called for every *ast.BinaryExpr node you visit, and your Expr callback will be called for every other ast.Expr node you visit that’s not a *ast.BinaryExpr.

All syntax node types implement the interface ast.Node. Visitor.Node is a callback for this catch-all type. If you define a Node callback then it is called for every node that doesn’t have a more-specific callback.

Detailed example

Consider this Go program fragment:

if x == 7 {
  y++
  return z
}

After parsing, this is represented by a *ast.IfStmt.

Now imagine this IfStmt is passed to this package’s Walk function, together with a Visitor with suitable callbacks defined. Here is the sequence of events that will occur, assuming all the relevant callbacks in the Visitor, v, are non-nil. (Callbacks that are nil are simply skipped.)

  1. v.IfStmt( the IfStmt, Top, 0, nil, true, nil )

This is the “pre” visit of the top IfStmt node. Because this node was reached directly via the Walk function, there is no information about its parent. So the Which value is Top and the stack is empty.

  1. v.BinaryExpr( the x==7 node, IfStmt_Cond, 0, [ the IfStmt ], true, nil )

Now the IfStmt’s condition subnode is visited, via its abstract Expr type.

The Which value, IfStmt_Cond, tells which child of the IfStmt this is. (It’s the Cond field of the IfStmt type.)

Its parent, the IfStmt itself, is in the stack passed to v.Expr.

Note also that the IfStmt type includes an optional Init sub-statement, but this IfStmt doesn’t use one, so that callback is skipped.

  1. v.Ident( the x node, BinaryExpr_X, 0, [ the IfStmt, the x==7 node ], true, nil )

The x part of x==7 is visited.

  1. v.Ident( the x node, BinaryExpr_X, 0, [ the IfStmt, the x==7 node ], false, err )

A *Ident has no children, so now it’s time for the “post” visit of the same node on the way out of this subtree.

The value of err in step 4 is whatever error was returned in step 3.

  1. v.BasicLit( the 7 node, BinaryExpr_Y, 0, [ the IfStmt, the x==7 node ], true, nil )
  2. v.BasicLit( the 7 node, BinaryExpr_Y, 0, [ the IfStmt, the x==7 node ], false, err )

The 7 is pre-visited and post-visited. It’s a “basic literal.”

  1. v.BinaryExpr( the x==7 node, IfStmt_Cond, 0, [ the IfStmt ], false, err )

Continuing to unwind the call stack, the x==7 node is now post-visited.

  1. v.BlockStmt( the { ... } node, IfStmt_Body, 0, [ the IfStmt ], true, nil )

The next child of the IfStmt, the body, is now pre-visited.

  1. v.IncDecStmt( the y++ node, BlockStmt_List, 0, [ the IfStmt, the { ... } node ], true, nil )

A BlockStmt contains a field, List, whose value is a slice of ast.Stmt. So now we visit each statement in the BlockStmt’s list, starting with the y++ statement.

The value of the index parameter, 0, which is irrelevant for child nodes that aren’t part of a slice, now tells us that this is the first element in the BlockStmt’s list.

  1. v.Ident( the y node, IncDecStmt_X, 0, [ the IfStmt, the { ... } node, the y++ node ], true, nil )
  2. v.Ident( the y node, IncDecStmt_X, 0, [ the IfStmt, the { ... } node, the y++ node ], false, err )

We descend into and then out of the sole child of the y++ node.

  1. v.IncDecStmt( the y++ node, BlockStmt_List, 0, [ the IfStmt, the { ... } node ], false, err )

Post-visiting the y++ node.

  1. v.ReturnStmt( the return z node, BlockStmt_List, 1, [ the IfStmt, the { ... } node ], true, nil )

We now visit the second child of the BlockStmt: the return z node.

The value of the index parameter, 1, tells us that this is the second element in the BlockStmt’s list.

  1. v.Ident( the z node, ReturnStmt_Results, 0, [ the IfStmt, the { ... } node, the return z node ], true, nil )
  2. v.Ident( the z node, ReturnStmt_Results, 0, [ the IfStmt, the { ... } node, the return z node ], false, err )

Descending into and out of the sole child of the return z node.

  1. v.ReturnStmt( the return z node, BlockStmt_List, 1, [ the IfStmt, the { ... } node ], false, err )
  2. v.BlockStmt( the { ... } node, IfStmt_Body, 0, [ the IfStmt ], false, err )
  3. v.IfStmt( the IfStmt, Top, 0, nil, false, err )

Post-visiting everything on the way out of the tree, all the way back to the top.

Documentation

Overview

Package astw implements an enhanced walker/visitor for Go abstract syntax trees.

Example
package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"log"
	"os"
	"sort"

	"github.com/bobg/astw"
)

func main() {
	// In this example,
	// a Go file is parsed from the standard input.
	// Its functions are then scanned for parameters that are used only for their method calls.
	// These are candidates for being declared with (possibly smaller) interface types,
	// to improve composability and simplify testing.
	fset := token.NewFileSet()
	tree, err := parser.ParseFile(fset, "", os.Stdin, 0)
	if err != nil {
		log.Fatal(err)
	}

	// Callbacks in the *astw.Visitor below are methods on this object,
	// which allows extra state
	// (namely, the fset object)
	// to be carried through the syntax-tree walk.
	mv := &myVisitor{fset: fset}

	// This visitor inspects both top-level (named) function declarations
	// and inline (anonymous) function literals.
	v := &astw.Visitor{
		FuncDecl: mv.onFuncDecl,
		FuncLit:  mv.onFuncLit,
	}

	err = astw.Walk(v, tree)
	if err != nil {
		log.Fatal(err)
	}
}

type myVisitor struct {
	fset *token.FileSet
}

func (mv *myVisitor) onFuncDecl(funcDecl *ast.FuncDecl, which astw.Which, index int, stack []astw.StackItem, pre bool, err error) error {
	if err != nil {
		return err
	}
	if !pre {
		return nil
	}
	return mv.walkFunc(funcDecl.Type, funcDecl.Body)
}

func (mv *myVisitor) onFuncLit(funcLit *ast.FuncLit, which astw.Which, index int, stack []astw.StackItem, pre bool, err error) error {
	if err != nil {
		return err
	}
	if !pre {
		return nil
	}
	return mv.walkFunc(funcLit.Type, funcLit.Body)
}

func (mv *myVisitor) walkFunc(typ *ast.FuncType, body *ast.BlockStmt) error {
	for _, field := range typ.Params.List {
		for _, ident := range field.Names {
			// Start a separate, new walk of the function body
			// looking for uses of this parameter.

			var (
				methodsOnly = true
				methodNames = make(map[string]bool)
			)

			iv := &identVisitor{
				paramIdent:  ident,
				methodsOnly: &methodsOnly,
				methodNames: methodNames,
			}

			v := &astw.Visitor{Ident: iv.onIdent}
			err := astw.Walk(v, body)
			if err != nil {
				return err
			}

			if methodsOnly && len(methodNames) > 0 {
				var sorted []string
				for method := range methodNames {
					sorted = append(sorted, method)
				}
				sort.Strings(sorted)
				fmt.Printf("%s: Parameter %s is used only for these methods: %v\n", mv.fset.PositionFor(typ.Pos(), false), ident.Name, sorted)
			}
		}
	}
	return nil
}

type identVisitor struct {
	paramIdent  *ast.Ident
	methodsOnly *bool
	methodNames map[string]bool
}

func (iv *identVisitor) onIdent(ident *ast.Ident, which astw.Which, index int, stack []astw.StackItem, pre bool, err error) error {
	if err != nil {
		return err
	}
	if !pre {
		return nil
	}
	if !*iv.methodsOnly {
		// If we've already determined that the parameter is used for other than it's methods, return early.
		return nil
	}
	if ident.Name != iv.paramIdent.Name {
		// Identifier must match the parameter we're searching for.
		return nil
	}
	if ident.Obj.Pos() != iv.paramIdent.Obj.Pos() {
		// Check not only that ident and paramIdent have the same name,
		// but that they refer to the same thing.
		return nil
	}

	// The following logic checks that the identifier X appears in an expression of the form X.method(...)
	// (i.e., it's the left-hand side of a SelectorExpr, which is the left-hand part of a CallExpr).
	// Anything else sets methodsOnly to false.

	if len(stack) < 2 {
		*iv.methodsOnly = false
		return nil
	}
	parent := stack[len(stack)-1]
	parentNode := parent.N
	sel, ok := parentNode.(*ast.SelectorExpr)
	if !ok {
		*iv.methodsOnly = false
		return nil
	}
	if which != astw.SelectorExpr_X {
		*iv.methodsOnly = false
		return nil
	}
	grandparentNode := stack[len(stack)-2].N
	if _, ok := grandparentNode.(*ast.CallExpr); !ok {
		*iv.methodsOnly = false
		return nil
	}
	if parent.W != astw.CallExpr_Fun {
		*iv.methodsOnly = false
		return nil
	}

	iv.methodNames[sel.Sel.Name] = true
	return nil
}
Output:

Index

Examples

Constants

This section is empty.

Variables

View Source
var ErrSkip = errors.New("skip")

ErrSkip is an error that a pre-visit callback can return to cause Walk to skip its children. The post-visit of the same callback is also skipped. Unlike other errors, it is not propagated up the call stack (i.e., the post-visit of the parent callback will receive a value of nil for its err argument).

Functions

func Walk

func Walk(v *Visitor, n ast.Node) error

Walk walks the syntax tree rooted at n using the Visitor v.

Types

type StackItem

type StackItem struct {
	// N is a node in the syntax tree.
	N ast.Node

	// W is the Which value for node N.
	W Which

	// I is the index value for node N.
	I int
}

StackItem is an item in the stack that callbacks receive.

type Visitor

type Visitor struct {
	// Node is the catch-all callback for any node that has no more-specific callback defined
	// (i.e., a concrete-type callback, or one of Expr, Stmt, Decl, or Spec).
	Node func(node ast.Node, which Which, index int, stack []StackItem, pre bool, err error) error

	Expr func(expr ast.Expr, which Which, index int, stack []StackItem, pre bool, err error) error
	Stmt func(stmt ast.Stmt, which Which, index int, stack []StackItem, pre bool, err error) error
	Decl func(decl ast.Decl, which Which, index int, stack []StackItem, pre bool, err error) error
	Spec func(spec ast.Spec, which Which, index int, stack []StackItem, pre bool, err error) error

	// Filename is the name of the current file,
	// when the File callback is invoked as a child of a Package node.
	Filename string

	Package      func(pkg *ast.Package, which Which, index int, stack []StackItem, pre bool, err error) error
	File         func(file *ast.File, which Which, index int, stack []StackItem, pre bool, err error) error
	Comment      func(comment *ast.Comment, which Which, index int, stack []StackItem, pre bool, err error) error
	CommentGroup func(commentGroup *ast.CommentGroup, which Which, index int, stack []StackItem, pre bool, err error) error
	FieldList    func(fieldList *ast.FieldList, which Which, index int, stack []StackItem, pre bool, err error) error
	Field        func(field *ast.Field, which Which, index int, stack []StackItem, pre bool, err error) error

	BadExpr        func(badExpr *ast.BadExpr, which Which, index int, stack []StackItem, pre bool, err error) error
	Ident          func(ident *ast.Ident, which Which, index int, stack []StackItem, pre bool, err error) error
	Ellipsis       func(ellipsis *ast.Ellipsis, which Which, index int, stack []StackItem, pre bool, err error) error
	BasicLit       func(basicLit *ast.BasicLit, which Which, index int, stack []StackItem, pre bool, err error) error
	FuncLit        func(funcLit *ast.FuncLit, which Which, index int, stack []StackItem, pre bool, err error) error
	CompositeLit   func(compositeLit *ast.CompositeLit, which Which, index int, stack []StackItem, pre bool, err error) error
	ParenExpr      func(parenExpr *ast.ParenExpr, which Which, index int, stack []StackItem, pre bool, err error) error
	SelectorExpr   func(selectorExpr *ast.SelectorExpr, which Which, index int, stack []StackItem, pre bool, err error) error
	IndexExpr      func(indexExpr *ast.IndexExpr, which Which, index int, stack []StackItem, pre bool, err error) error
	SliceExpr      func(sliceExpr *ast.SliceExpr, which Which, index int, stack []StackItem, pre bool, err error) error
	TypeAssertExpr func(typeAssertExpr *ast.TypeAssertExpr, which Which, index int, stack []StackItem, pre bool, err error) error
	CallExpr       func(callExpr *ast.CallExpr, which Which, index int, stack []StackItem, pre bool, err error) error
	StarExpr       func(starExpr *ast.StarExpr, which Which, index int, stack []StackItem, pre bool, err error) error
	UnaryExpr      func(unaryExpr *ast.UnaryExpr, which Which, index int, stack []StackItem, pre bool, err error) error
	BinaryExpr     func(binaryExpr *ast.BinaryExpr, which Which, index int, stack []StackItem, pre bool, err error) error
	KeyValueExpr   func(keyValueExpr *ast.KeyValueExpr, which Which, index int, stack []StackItem, pre bool, err error) error
	ArrayType      func(arrayType *ast.ArrayType, which Which, index int, stack []StackItem, pre bool, err error) error
	StructType     func(structType *ast.StructType, which Which, index int, stack []StackItem, pre bool, err error) error
	FuncType       func(funcType *ast.FuncType, which Which, index int, stack []StackItem, pre bool, err error) error
	InterfaceType  func(interfaceType *ast.InterfaceType, which Which, index int, stack []StackItem, pre bool, err error) error
	MapType        func(mapType *ast.MapType, which Which, index int, stack []StackItem, pre bool, err error) error
	ChanType       func(chanType *ast.ChanType, which Which, index int, stack []StackItem, pre bool, err error) error

	BadStmt        func(badStmt *ast.BadStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	DeclStmt       func(declStmt *ast.DeclStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	EmptyStmt      func(emptyStmt *ast.EmptyStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	LabeledStmt    func(labeledStmt *ast.LabeledStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	ExprStmt       func(exprStmt *ast.ExprStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	SendStmt       func(sendStmt *ast.SendStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	IncDecStmt     func(incDecStmt *ast.IncDecStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	AssignStmt     func(assignStmt *ast.AssignStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	GoStmt         func(goStmt *ast.GoStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	DeferStmt      func(deferStmt *ast.DeferStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	ReturnStmt     func(returnStmt *ast.ReturnStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	BranchStmt     func(branchStmt *ast.BranchStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	BlockStmt      func(blockStmt *ast.BlockStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	IfStmt         func(ifStmt *ast.IfStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	CaseClause     func(caseClause *ast.CaseClause, which Which, index int, stack []StackItem, pre bool, err error) error
	SwitchStmt     func(switchStmt *ast.SwitchStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	TypeSwitchStmt func(typeSwitchStmt *ast.TypeSwitchStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	CommClause     func(commClause *ast.CommClause, which Which, index int, stack []StackItem, pre bool, err error) error
	SelectStmt     func(selectStmt *ast.SelectStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	ForStmt        func(forStmt *ast.ForStmt, which Which, index int, stack []StackItem, pre bool, err error) error
	RangeStmt      func(rangeStmt *ast.RangeStmt, which Which, index int, stack []StackItem, pre bool, err error) error

	BadDecl  func(badDecl *ast.BadDecl, which Which, index int, stack []StackItem, pre bool, err error) error
	GenDecl  func(genDecl *ast.GenDecl, which Which, index int, stack []StackItem, pre bool, err error) error
	FuncDecl func(funcDecl *ast.FuncDecl, which Which, index int, stack []StackItem, pre bool, err error) error

	ImportSpec func(importSpec *ast.ImportSpec, which Which, index int, stack []StackItem, pre bool, err error) error
	ValueSpec  func(valueSpec *ast.ValueSpec, which Which, index int, stack []StackItem, pre bool, err error) error
	TypeSpec   func(typeSpec *ast.TypeSpec, which Which, index int, stack []StackItem, pre bool, err error) error
}

Visitor is a structure full of callbacks, for walking a syntax tree via the Walk function.

Depending on the application, it is usually necessary to set values only for a few callbacks. A zero Visitor is usable, but won't do anything interesting.

The name of each callback corresponds to a type in the go/ast package. That callback is called twice for each node of that type to be encountered in the walk: once when descending into the tree, before child nodes are visited (the "pre" visit), and once when traversing back out of the tree, after child nodes are visited (the "post" visit).

Each callback takes the following arguments:

  • the node (of the appropriate type);
  • a Which value, telling which child of the parent this node is;
  • an index value, for when this node is one member of a slice in its parent, telling which member it is;
  • a stack of nodes above this one in the tree, with stack[0] the root of the tree and stack[len(stack)-1] the node's immediate parent;
  • a boolean telling whether this is the pre visit (true) or the post visit (false);
  • the error value, if any, produced by visiting this node's children. This is always nil in a "pre" visit.

Each node in a syntax tree has a concrete type like *ast.IfStmt or *ast.BinaryExpr. Visitor contains a callback for each possible concrete type. Many of these types also implement one of these abstract interfaces: ast.Expr, ast.Stmt, ast.Decl, and ast.Spec. Visitor additionally contains callbacks for these abstract types. If the concrete-type callback for a given node is not set, but the abstract-type callback is, then that will be the callback for the node.

All node types implement the abstract interface ast.Node. Visitor contains a callback for that catch-all type too. If it is defined, it is used for all nodes that do not have a more-specific callback defined.

Package nodes are handled specially. Unique among go/ast Node types, a Package node's children exist not as individual fields or in a slice, but in a map, mapping filenames to *ast.File nodes. When File nodes are visited via a Package node, they are visited in lexical filename order. The index value passed to the File callback reflects this ordering. When the File callback is invoked via a Package node parent (i.e., the Which value is Package_Files), the filename from the map can be found in the Visitor's Filename field.

type Which

type Which int

Which describes a specific child of a specific go/ast Node type.

const (
	// Top is the Which value that Walk uses for the node at the top of the syntax tree.
	Top Which = iota

	Package_Files

	File_Doc
	File_Name
	File_Decls
	File_Imports
	File_Unresolved
	File_Comments

	CommentGroup_List

	Ellipsis_Elt
	FuncLit_Type
	FuncLit_Body
	CompositeLit_Type
	CompositeLit_Elts
	ParenExpr_X
	SelectorExpr_X
	SelectorExpr_Sel
	IndexExpr_X
	IndexExpr_Index
	SliceExpr_X
	SliceExpr_Low
	SliceExpr_High
	SliceExpr_Max
	TypeAssertExpr_X
	TypeAssertExpr_Type
	CallExpr_Fun
	CallExpr_Args
	StarExpr_X
	UnaryExpr_X
	BinaryExpr_X
	BinaryExpr_Y
	KeyValueExpr_Key
	KeyValueExpr_Value
	ArrayType_Len
	ArrayType_Elt
	StructType_Fields
	FuncType_Params
	FuncType_Results
	InterfaceType_Methods
	MapType_Key
	MapType_Value
	ChanType_Value

	FieldList_List

	Field_Doc
	Field_Names
	Field_Type
	Field_Tag
	Field_Comment

	DeclStmt_Decl
	LabeledStmt_Label
	LabeledStmt_Stmt
	ExprStmt_Expr
	SendStmt_Chan
	SendStmt_Value
	IncDecStmt_X
	AssignStmt_Lhs
	AssignStmt_Rhs
	GoStmt_Call
	DeferStmt_Call
	ReturnStmt_Results
	BranchStmt_Label
	BlockStmt_List
	IfStmt_Init
	IfStmt_Cond
	IfStmt_Body
	IfStmt_Else
	CaseClause_List
	CaseClause_Body
	SwitchStmt_Init
	SwitchStmt_Tag
	SwitchStmt_Body
	TypeSwitchStmt_Init
	TypeSwitchStmt_Assign
	TypeSwitchStmt_Body
	CommClause_Comm
	CommClause_Body
	SelectStmt_Body
	ForStmt_Init
	ForStmt_Cond
	ForStmt_Post
	ForStmt_Body
	RangeStmt_Key
	RangeStmt_Value
	RangeStmt_X
	RangeStmt_Body

	GenDecl_Doc
	GenDecl_Specs

	FuncDecl_Doc
	FuncDecl_Recv
	FuncDecl_Name
	FuncDecl_Type
	FuncDecl_Body

	ImportSpec_Doc
	ImportSpec_Name
	ImportSpec_Path
	ImportSpec_Comment

	ValueSpec_Doc
	ValueSpec_Names
	ValueSpec_Type
	ValueSpec_Values
	ValueSpec_Comment

	TypeSpec_Doc
	TypeSpec_Name
	TypeSpec_Type
	TypeSpec_Comment
)

Values for Which. The names of these values have the form x_y, where x is the base name of a concrete go/ast struct type (e.g., ForStmt) and y is the name of a field in that struct (e.g., Init, Cond, Post, and Body in the case of ForStmt).

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL