trainer

package
v0.0.0-...-ebe581b Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 22, 2024 License: Apache-2.0 Imports: 10 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Callback

type Callback interface {
	//SaveModel to persist a model
	SaveModel(*pbCom.TrainTaskResult) error

	// StartTask starts a specific task of training or prediction
	StartTask(*pbCom.StartTaskRequest) error

	//StopTask to stop a training task
	// You'd better use it asynchronously to avoid deadlock
	StopTask(*pbCom.StopTaskRequest) error

	// Train to train out a model
	Train(*pb.TrainRequest) (*pb.TrainResponse, error)
}

Callback contains some methods would be called when finish training, such as to save the trained models and to stop a training task. On the other hand, it also contains some other methods would be called during the evaluation phase, such as to start a specific task of training or prediction and to train out a model. It will be set into Trainer instance in initialization phase.

type Evaluator

type Evaluator interface {
	// Start starts model evaluation, segment the training set according to a certain strategy (cross validation, proportional random division),
	//  then starts the training-validation process.
	// fileRows is returned by psi.IntersectParts after sample alignment.
	Start(fileRows [][]string) error

	// Stop deletes all the leaners created by Evaluator as well as other objects
	Stop()

	// SaveModel collects the results of the training in the evaluation phase, that is, the model.
	// If the model is successfully trained,
	// it will trigger the local creation of a Model instance for validation.
	SaveModel(*pbCom.TrainTaskResult) error

	// SavePredictOut collects the prediction results in the evaluation phase.
	// If the prediction result is obtained, it will check how many prediction results have been obtained so far,
	//  and determine whether to start calculating the average scores for each metric.
	SavePredictOut(*pbCom.PredictTaskResult) error
}

Evaluator performs model evaluation, supports cross-validation, LOO, validation by proportional random division. The basic steps of evaluation:

Divide the dataset in some way
Train the model
Validate
Calculate the evaluation metric scores with prediction result obtained on the validation set
Calculate the average scores for each metric

type Learner

type Learner interface {
	// Advance does calculation with local data and communicates with other nodes in cluster to train a model step by step
	// payload could be resolved by Learner defined by specific algorithm
	// We'd better call the method asynchronously avoid blocking the main go-routine
	Advance(payload []byte) (*pb.TrainResponse, error)
}

Learner is assigned with a specific algorithm and data used for training a model,

and participates in the multi-parts-calculation during training process

type LiveEvaluator

type LiveEvaluator interface {
	// Trigger triggers model evaluation.
	// The parameter contains two types of messages.
	// One is to set the learner for evaluation with training set and start it.
	// The other is to drive the learner to continue training. When the conditions are met(reaching pause round),
	// stop training and instantiate the model for validation.
	Trigger(*pb.LiveEvaluationTriggerMsg) error

	// Stop deletes all the leaners created by LiveEvaluator as well as other objects
	Stop()

	// SaveModel collects the results of the training in the evaluation phase,
	// that is, the model, for LiveEvaluation of Model.
	// If the model is successfully trained,
	// it will trigger the local creation of a Model instance for validation.
	SaveModel(*pbCom.TrainTaskResult) error

	// SavePredictOut collects the prediction results in the evaluation phase.
	// If the prediction result is obtained, it will start calculating metric scores,
	// then report the results to visualization system.
	SavePredictOut(*pbCom.PredictTaskResult) error
}

LiveEvaluator performs staged evaluation during training. The basic steps of LiveEvaluator:

Divide the dataset in the way of proportional random division.
Initiate a learner for evaluation with training part.
Train the model, and pause training when the pause round is reached,
and instantiate the staged model for validation,
then, calculate the evaluation metric scores with prediction result obtained on the validation set.
Repeat Train-Pause-validate until the stop signal is received.

type RpcHandler

type RpcHandler interface {
	StepTrain(req *pb.TrainRequest, peerName string) (*pb.TrainResponse, error)

	// StepTrainWithRetry sends training message to remote mpc-node
	// retries 2 times at most
	// inteSec indicates the interval between retry requests, in seconds
	StepTrainWithRetry(req *pb.TrainRequest, peerName string, times int, inteSec int64) (*pb.TrainResponse, error)
}

RpcHandler performs remote procedure calls to remote cluster nodes. set into Trainer instance in initialization phase

type TrainResponse

type TrainResponse struct {
	Resp *pb.TrainResponse
	Err  error
}

type Trainer

type Trainer struct {
	// contains filtered or unexported fields
}

Trainer manages Learners, such as to create or to delete a learner dispatches requests to different Learners by taskId, keeps the number of Learners in the proper range in order to avoid high memory usage

func NewTrainer

func NewTrainer(address string, rh RpcHandler, cb Callback, learnerLimit int) *Trainer

NewTrainer creates a Trainer instance, address indicates local mpc-node address learnerLimit indicates the upper limit of the number of Learners rh indicates the handler for rpc request sending cb indicates the callback methods called when finish training

func (*Trainer) DeleteLearner

func (t *Trainer) DeleteLearner(req *pbCom.StopTaskRequest) error

DeleteLearner deletes a task from Memory Storage

func (*Trainer) NewLearner

func (t *Trainer) NewLearner(req *pbCom.StartTaskRequest) error

NewLearner creates a Learner instance related to TaskId and stores it into Memory Storage keeps the number of Learners in the proper range in order to avoid high memory usage

func (*Trainer) SavePredictAndEvaluatResult

func (t *Trainer) SavePredictAndEvaluatResult(result *pbCom.TrainTaskResult)

SavePredictAndEvaluatResult saves the training result and evaluation result for a Learner and stops related task. Called only by Evaluator.

func (*Trainer) SaveResult

func (t *Trainer) SaveResult(result *pbCom.TrainTaskResult)

SaveResult saves the training result (failed status or successful status) for a Learner and stops related task. Analyze the TaskID to determine whether the training task is a common task from user or a task from Evaluator. If the former, and user didn't ask for evaluation, persist the prediction results locally,

otherwise call Evaluator.Start() to start evaluation process.

If the latter, call Evaluator.SaveModel().

func (*Trainer) Train

func (t *Trainer) Train(req *pb.TrainRequest, resC chan *TrainResponse)

Train dispatches requests to different Learners by taskId resC returns the result, and couldn't be set with nil

func (*Trainer) Validate

func (t *Trainer) Validate(req *pb.ValidateRequest, resC chan *TrainResponse)

Validate saves the prediction results to the Evaluator or LiveEvaluator, then trigger the subsequent verification process.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL