mcts

package
v0.1.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jan 18, 2021 License: MIT Imports: 14 Imported by: 2

Documentation

Overview

Example
package main

import (
	"bytes"
	"fmt"
	"io/ioutil"
	"log"
	"math/rand"
	"time"

	"github.com/gorgonia/agogo/game"
	"github.com/gorgonia/agogo/game/mnk"
	"github.com/gorgonia/agogo/mcts"
)

var (
	Cross  = game.Player(game.Black)
	Nought = game.Player(game.White)
)

func opponent(p game.Player) game.Player {
	switch p {
	case Cross:
		return Nought
	case Nought:
		return Cross
	}
	panic("Unreachable")
}

type dummyNN struct{}

func (dummyNN) Infer(state game.State) (policy []float32, value float32) {
	policy = make([]float32, 10)
	switch state.MoveNumber() {
	case 0:
		policy[4] = 0.9
		value = 0.5
	case 1:
		policy[0] = 0.1
		value = 0.5
	case 2:
		policy[2] = 0.9
		value = 8 / 9
	case 3:
		policy[6] = 0.1
		value = 8 / 9
	case 4:
		policy[3] = 0.9
		value = 8 / 9
	case 5:
		policy[5] = 0.1
		value = 0.5
	case 6:
		policy[1] = 0.9
		value = 8 / 9
	case 7:
		policy[7] = 0.1
		value = 0
	case 8:
		policy[8] = 0.9
		value = 0
	}
	return
}

func main() {
	g := mnk.TicTacToe()
	conf := mcts.Config{
		PUCT:           1.0,
		M:              3,
		N:              3,
		Timeout:        500 * time.Millisecond,
		PassPreference: mcts.DontPreferPass,
		Budget:         10000,
		DumbPass:       true,
		RandomCount:    0, // this is a deterministic example
	}
	nn := dummyNN{}
	t := mcts.New(g, conf, nn)
	player := Cross

	var buf bytes.Buffer
	var ended bool
	var winner game.Player
	for ended, winner = g.Ended(); !ended; ended, winner = g.Ended() {
		moveNum := g.MoveNumber()
		best := t.Search(player)
		g = g.Apply(game.PlayerMove{player, best}).(*mnk.MNK)
		fmt.Fprintf(&buf, "Turn %d\n%v---\n", moveNum, g)
		if moveNum == 2 {
			ioutil.WriteFile("fullGraph_tictactoe.dot", []byte(t.ToDot()), 0644)
		}
		player = opponent(player)
	}

	log.Printf("Playout:\n%v", buf.String())
	fmt.Printf("WINNER %v\n", winner)

	// the outputs should look something like this (may dfiffer due to random numbers)
	// Turn 0
	// ⎢ · · · ⎥
	// ⎢ · X · ⎥
	// ⎢ · · · ⎥
	// ---
	// Turn 1
	// ⎢ O · · ⎥
	// ⎢ · X · ⎥
	// ⎢ · · · ⎥
	// ---
	// Turn 2
	// ⎢ O · X ⎥
	// ⎢ · X · ⎥
	// ⎢ · · · ⎥
	// ---
	// Turn 3
	// ⎢ O · X ⎥
	// ⎢ · X · ⎥
	// ⎢ O · · ⎥
	// ---
	// Turn 4
	// ⎢ O · X ⎥
	// ⎢ X X · ⎥
	// ⎢ O · · ⎥
	// ---
	// Turn 5
	// ⎢ O · X ⎥
	// ⎢ X X O ⎥
	// ⎢ O · · ⎥
	// ---
	// Turn 6
	// ⎢ O X X ⎥
	// ⎢ X X O ⎥
	// ⎢ O · · ⎥
	// ---
	// Turn 7
	// ⎢ O X X ⎥
	// ⎢ X X O ⎥
	// ⎢ O O · ⎥
	// ---
	// Turn 8
	// ⎢ O X X ⎥
	// ⎢ X X O ⎥
	// ⎢ O O X ⎥
	// ---

}
Output:

WINNER None

Index

Examples

Constants

View Source
const (
	Pass   game.Single = -1
	Resign game.Single = -2

	White game.Player = game.Player(game.White)
	Black game.Player = game.Player(game.Black)
)
View Source
const (
	MAXTREESIZE = 25000000 // a tree is at max allowed this many nodes - at about 56 bytes per node that is 1.2GB of memory required
)

Variables

This section is empty.

Functions

This section is empty.

Types

type Config

type Config struct {
	// PUCT is the proportion of polynomial upper confidence trees to keep. Between 1 and 0
	PUCT    float32
	Timeout time.Duration

	// M, N represents the height and width.
	M, N              int
	RandomCount       int   // if the move number is less than this, we should randomize
	Budget            int32 // iteration budget
	RandomMinVisits   uint32
	RandomTemperature float32
	DumbPass          bool
	ResignPercentage  float32
	PassPreference    PassPreference
}

Config is the structure to configure the MCTS multitree (poorly named Tree)

func DefaultConfig

func DefaultConfig(boardSize int) Config

func (Config) IsValid

func (c Config) IsValid() bool

type Inferencer

type Inferencer interface {
	Infer(state game.State) (policy []float32, value float32)
}

Inferencer is essentially the neural network

type MCTS

type MCTS struct {
	sync.RWMutex
	Config
	// contains filtered or unexported fields
}

MCTS is essentially a "global" manager of sorts for the memories. The goal is to build MCTS without much pointer chasing.

func New

func New(game game.State, conf Config, nn Inferencer) *MCTS

func (*MCTS) Children

func (t *MCTS) Children(of naughty) []naughty

Children returns a list of children

func (MCTS) Log

func (l MCTS) Log() string

func (*MCTS) New

func (t *MCTS) New(move game.Single, score, value float32) (retVal naughty)

New creates a new node

func (*MCTS) Nodes

func (t *MCTS) Nodes() int

func (*MCTS) Policies

func (t *MCTS) Policies(g game.State) []float32

func (*MCTS) Reset

func (t *MCTS) Reset()

func (*MCTS) Search

func (t *MCTS) Search(player game.Player) (retVal game.Single)

func (*MCTS) SetGame

func (t *MCTS) SetGame(g game.State)

SetGame sets the game

func (*MCTS) ToDot

func (t *MCTS) ToDot() string

type Node

type Node struct {
	// contains filtered or unexported fields
}

func (*Node) Activate

func (n *Node) Activate()

Activate activates the node

func (*Node) AddChild

func (n *Node) AddChild(child naughty)

AddChild adds a child to the node

func (*Node) BestChild

func (n *Node) BestChild(player game.Player) naughty

BestChild returns the best scoring child. Note that fancySort has all sorts of heuristics

func (*Node) BlackScores

func (n *Node) BlackScores() float32

BlackScores returns the scores for black

func (*Node) Evaluate

func (n *Node) Evaluate(player game.Player) float32

Evaluate evaluates a move made by a player

func (*Node) Format

func (n *Node) Format(s fmt.State, c rune)

func (*Node) HasChildren

func (n *Node) HasChildren() bool

HasChildren returns true if the node has children

func (*Node) ID

func (n *Node) ID() int

func (*Node) Invalidate

func (n *Node) Invalidate()

Invalidate invalidates the node

func (*Node) IsActive

func (n *Node) IsActive() bool

IsActive returns true if the node is active

func (*Node) IsExpandable

func (n *Node) IsExpandable(minPsaRatio float32) bool

IsExpandable returns true if the node is exandable. It may not be for memory reasons.

func (*Node) IsNotVisited

func (n *Node) IsNotVisited() bool

IsFirstVisit returns true if this node hasn't ever been visited

func (*Node) IsPruned

func (n *Node) IsPruned() bool

IsPruned returns true if the node has been pruned.

func (*Node) IsValid

func (n *Node) IsValid() bool

IsValid returns true if it's valid

func (*Node) MinPsaRatio

func (n *Node) MinPsaRatio() float32

func (*Node) Move

func (n *Node) Move() game.Single

Move gets the move associated with the node

func (*Node) NNEvaluate

func (n *Node) NNEvaluate(player game.Player) float32

NNEvaluate returns the result of the NN evaluation of the colour.

func (*Node) Prune

func (n *Node) Prune()

Prune prunes the node

func (*Node) Score

func (n *Node) Score() float32

Score returns the score

func (*Node) Select

func (n *Node) Select(of game.Player) naughty

Select selects the child of the given Colour

func (*Node) Update

func (n *Node) Update(score float32)

Update updates the accumulated score

func (*Node) Value

func (n *Node) Value() float32

Value returns the predicted value (probability of winning from the NN) of the given node

func (*Node) VirtualLoss

func (n *Node) VirtualLoss() float32

func (*Node) Visits

func (n *Node) Visits() uint32

type PassPreference

type PassPreference int

PassPreference

const (
	DontPreferPass PassPreference = iota
	PreferPass
	DontResign
	MAXPASSPREFERENCE
)

type Result

type Result float32

Result is a NaN tagged floating point, used to represent the reuslts.

type Status

type Status uint32
const (
	Invalid Status = iota
	Active
	Pruned
)

func (Status) String

func (a Status) String() string

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL