model

package
v0.0.0-...-014f423 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Nov 6, 2020 License: GPL-3.0 Imports: 24 Imported by: 0

Documentation

Overview

Package model implements an LSTM-based network model for use in MCTS. As input, the model takes the public game history (one-hot encoded), as well as the player's current hand (one-hot encoded) and predicts the advantages for each possible action in this infoset.

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func EncodeHistory

func EncodeHistory(h gamestate.History, result [][]float32)

Game history is encoded as: MaxNumActions (58) x

  • One hot encoded player (2)
  • One hot encoded action type (4)
  • One hot encoded Card (10)
  • One hot encoded position in draw pile (13)
  • Concatenated one hot cards seen (3x10)

Types

type LSTM

type LSTM struct {
	// contains filtered or unexported fields
}

LSTM is a model for AlphaCats to be used in MCTS.

func NewLSTM

func NewLSTM(p Params) *LSTM

func (*LSTM) MarshalBinary

func (m *LSTM) MarshalBinary() ([]byte, error)

func (*LSTM) Train

func (m *LSTM) Train(initialWeightsFile string, samples []Sample) *TrainedLSTM

func (*LSTM) UnmarshalBinary

func (m *LSTM) UnmarshalBinary(buf []byte) error

type MCTSPSRO

type MCTSPSRO struct {
	// contains filtered or unexported fields
}

func LoadMCTSPSRO

func LoadMCTSPSRO(r io.Reader) (*MCTSPSRO, error)

func NewMCTSPSRO

func NewMCTSPSRO(model *LSTM, maxSamples, predictionCacheSize int) *MCTSPSRO

func (*MCTSPSRO) AddCurrentExploiterToModel

func (m *MCTSPSRO) AddCurrentExploiterToModel()

func (*MCTSPSRO) AddModel

func (m *MCTSPSRO) AddModel(policy mcts.Policy)

func (*MCTSPSRO) AddSample

func (m *MCTSPSRO) AddSample(s Sample)

func (*MCTSPSRO) AssignWeights

func (m *MCTSPSRO) AssignWeights(weights []float32)

func (*MCTSPSRO) Evaluate

func (m *MCTSPSRO) Evaluate(rng *rand.Rand, node cfr.GameTreeNode, opponent mcts.Policy) ([]float32, float32)

Evaluate implements mcts.Evaluator for one-sided IS-MCTS search rollouts when this policy is being trained as the exploiter.

func (*MCTSPSRO) GetPolicies

func (m *MCTSPSRO) GetPolicies() []mcts.Policy

func (*MCTSPSRO) Len

func (m *MCTSPSRO) Len() int

func (*MCTSPSRO) SamplePolicy

func (m *MCTSPSRO) SamplePolicy() mcts.Policy

func (*MCTSPSRO) SaveTo

func (m *MCTSPSRO) SaveTo(w io.Writer) error

func (*MCTSPSRO) TrainNetwork

func (m *MCTSPSRO) TrainNetwork()

type Params

type Params struct {
	OutputDir             string
	NumEncodingWorkers    int
	MaxInferenceBatchSize int
	NumPredictionWorkers  int
}

type PredictorPolicy

type PredictorPolicy struct {
	// contains filtered or unexported fields
}

Adaptor to use TrainedLSTM as a mcts.Policy.

func NewPredictorPolicy

func NewPredictorPolicy(model *TrainedLSTM, cacheSize int) *PredictorPolicy

func (*PredictorPolicy) GetPolicy

func (pp *PredictorPolicy) GetPolicy(node cfr.GameTreeNode) []float32

func (*PredictorPolicy) GobDecode

func (pp *PredictorPolicy) GobDecode(buf []byte) error

func (*PredictorPolicy) GobEncode

func (pp *PredictorPolicy) GobEncode() ([]byte, error)

type RefCountingTrainedLSTM

type RefCountingTrainedLSTM struct {
	TrainedLSTM *TrainedLSTM
	RefCount    int64
}

type Sample

type Sample struct {
	InfoSet alphacats.AbstractedInfoSet
	Policy  []float32
	Value   float32
}

func (Sample) String

func (s Sample) String() string

type TrainedLSTM

type TrainedLSTM struct {
	// contains filtered or unexported fields
}

func LoadTrainedLSTM

func LoadTrainedLSTM(dir string, params Params) (*TrainedLSTM, error)

func (*TrainedLSTM) Close

func (m *TrainedLSTM) Close()

func (*TrainedLSTM) GobDecode

func (m *TrainedLSTM) GobDecode(buf []byte) error

When serializing, we just serialize the path to the model already on disk.

func (*TrainedLSTM) GobEncode

func (m *TrainedLSTM) GobEncode() ([]byte, error)

func (*TrainedLSTM) KerasWeightsFile

func (m *TrainedLSTM) KerasWeightsFile() string

func (*TrainedLSTM) Predict

func (m *TrainedLSTM) Predict(is *alphacats.AbstractedInfoSet) ([]float32, float32)

type UniformRandomPolicy

type UniformRandomPolicy struct{}

Policy that always plays randomly. Used to bootstrap fictitious play. Alternative to SmoothUCT for bootstrapping fictitious play.

func (*UniformRandomPolicy) GetPolicy

func (u *UniformRandomPolicy) GetPolicy(node cfr.GameTreeNode) []float32

Directories

Path Synopsis
internal
npyio
Package npyio is a fork of github.com/sbinet/npyio that is hard-coded for []float32s to avoid reflection.
Package npyio is a fork of github.com/sbinet/npyio that is hard-coded for []float32s to avoid reflection.
tffloats
Package tffloats constructs *tf.Tensors from []float32 slices, avoiding reflection.
Package tffloats constructs *tf.Tensors from []float32 slices, avoiding reflection.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL