Documentation ¶
Overview ¶
Package model implements an LSTM-based network model for use in MCTS. As input, the model takes the public game history (one-hot encoded), as well as the player's current hand (one-hot encoded) and predicts the advantages for each possible action in this infoset.
Index ¶
- func EncodeHistory(h gamestate.History, result [][]float32)
- type LSTM
- type MCTSPSRO
- func (m *MCTSPSRO) AddCurrentExploiterToModel()
- func (m *MCTSPSRO) AddModel(policy mcts.Policy)
- func (m *MCTSPSRO) AddSample(s Sample)
- func (m *MCTSPSRO) AssignWeights(weights []float32)
- func (m *MCTSPSRO) Evaluate(rng *rand.Rand, node cfr.GameTreeNode, opponent mcts.Policy) ([]float32, float32)
- func (m *MCTSPSRO) GetPolicies() []mcts.Policy
- func (m *MCTSPSRO) Len() int
- func (m *MCTSPSRO) SamplePolicy() mcts.Policy
- func (m *MCTSPSRO) SaveTo(w io.Writer) error
- func (m *MCTSPSRO) TrainNetwork()
- type Params
- type PredictorPolicy
- type RefCountingTrainedLSTM
- type Sample
- type TrainedLSTM
- type UniformRandomPolicy
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func EncodeHistory ¶
Game history is encoded as: MaxNumActions (58) x
- One hot encoded player (2)
- One hot encoded action type (4)
- One hot encoded Card (10)
- One hot encoded position in draw pile (13)
- Concatenated one hot cards seen (3x10)
Types ¶
type LSTM ¶
type LSTM struct {
// contains filtered or unexported fields
}
LSTM is a model for AlphaCats to be used in MCTS.
func (*LSTM) MarshalBinary ¶
func (*LSTM) UnmarshalBinary ¶
type MCTSPSRO ¶
type MCTSPSRO struct {
// contains filtered or unexported fields
}
func NewMCTSPSRO ¶
func (*MCTSPSRO) AddCurrentExploiterToModel ¶
func (m *MCTSPSRO) AddCurrentExploiterToModel()
func (*MCTSPSRO) AssignWeights ¶
func (*MCTSPSRO) Evaluate ¶
func (m *MCTSPSRO) Evaluate(rng *rand.Rand, node cfr.GameTreeNode, opponent mcts.Policy) ([]float32, float32)
Evaluate implements mcts.Evaluator for one-sided IS-MCTS search rollouts when this policy is being trained as the exploiter.
func (*MCTSPSRO) GetPolicies ¶
func (*MCTSPSRO) SamplePolicy ¶
func (*MCTSPSRO) TrainNetwork ¶
func (m *MCTSPSRO) TrainNetwork()
type PredictorPolicy ¶
type PredictorPolicy struct {
// contains filtered or unexported fields
}
Adaptor to use TrainedLSTM as a mcts.Policy.
func NewPredictorPolicy ¶
func NewPredictorPolicy(model *TrainedLSTM, cacheSize int) *PredictorPolicy
func (*PredictorPolicy) GetPolicy ¶
func (pp *PredictorPolicy) GetPolicy(node cfr.GameTreeNode) []float32
func (*PredictorPolicy) GobDecode ¶
func (pp *PredictorPolicy) GobDecode(buf []byte) error
func (*PredictorPolicy) GobEncode ¶
func (pp *PredictorPolicy) GobEncode() ([]byte, error)
type RefCountingTrainedLSTM ¶
type RefCountingTrainedLSTM struct { TrainedLSTM *TrainedLSTM RefCount int64 }
type Sample ¶
type Sample struct { InfoSet alphacats.AbstractedInfoSet Policy []float32 Value float32 }
type TrainedLSTM ¶
type TrainedLSTM struct {
// contains filtered or unexported fields
}
func LoadTrainedLSTM ¶
func LoadTrainedLSTM(dir string, params Params) (*TrainedLSTM, error)
func (*TrainedLSTM) Close ¶
func (m *TrainedLSTM) Close()
func (*TrainedLSTM) GobDecode ¶
func (m *TrainedLSTM) GobDecode(buf []byte) error
When serializing, we just serialize the path to the model already on disk.
func (*TrainedLSTM) GobEncode ¶
func (m *TrainedLSTM) GobEncode() ([]byte, error)
func (*TrainedLSTM) KerasWeightsFile ¶
func (m *TrainedLSTM) KerasWeightsFile() string
func (*TrainedLSTM) Predict ¶
func (m *TrainedLSTM) Predict(is *alphacats.AbstractedInfoSet) ([]float32, float32)
type UniformRandomPolicy ¶
type UniformRandomPolicy struct{}
Policy that always plays randomly. Used to bootstrap fictitious play. Alternative to SmoothUCT for bootstrapping fictitious play.
func (*UniformRandomPolicy) GetPolicy ¶
func (u *UniformRandomPolicy) GetPolicy(node cfr.GameTreeNode) []float32
Directories ¶
Path | Synopsis |
---|---|
internal
|
|
npyio
Package npyio is a fork of github.com/sbinet/npyio that is hard-coded for []float32s to avoid reflection.
|
Package npyio is a fork of github.com/sbinet/npyio that is hard-coded for []float32s to avoid reflection. |
tffloats
Package tffloats constructs *tf.Tensors from []float32 slices, avoiding reflection.
|
Package tffloats constructs *tf.Tensors from []float32 slices, avoiding reflection. |