linear_reg_vl

package
v0.0.0-...-ebe581b Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Mar 22, 2024 License: Apache-2.0 Imports: 16 Imported by: 0

Documentation

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

This section is empty.

Types

type Learner

type Learner struct {
	// contains filtered or unexported fields
}

func NewLearner

func NewLearner(id string, address string, params *pbCom.TrainParams, samplesFile []byte,
	parties []string, rpc RpcHandler, rh ResultHandler, le LiveEvaluator) (*Learner, error)

NewLearner returns a VerticalLinearRegression Learner id is the assigned id for Learner address indicates local mpc-node parties are other learners who participates in MPC, assigned with mpc-node address usually rpc is used to request remote mpc-node rh handles final result which is successful or failed params are parameters for training model samplesFile contains samples for training model le is an LiveEvaluator, and LiveEvaluation will be performed by learner if it is assigned without nil

func NewLearnerWithoutSamples

func NewLearnerWithoutSamples(id string, address string, params *pbCom.TrainParams,
	parties []string, rpc RpcHandler, rh ResultHandler) (*Learner, error)

NewLearner returns a VerticalLinearRegression Learner but doesn't run it id is the assigned id for Learner address indicates local mpc-node parties are other learners who participates in MPC, assigned with mpc-node address usually rpc is used to request remote mpc-node rh handles final result which is successful or failed params are parameters for training model

func (*Learner) Advance

func (l *Learner) Advance(payload []byte) (*pb.TrainResponse, error)

type LiveEvaluator

type LiveEvaluator interface {
	// Trigger triggers model evaluation.
	// The parameter contains two types of messages.
	// One is to set the learner for evaluation with training set and start it.
	// The other is to drive the learner to continue training. When the conditions are met(reaching pause round),
	// stop training and instantiate the model for validation.
	Trigger(*pb.LiveEvaluationTriggerMsg) error
}

LiveEvaluator performs staged evaluation during training. The basic steps of LiveEvaluator:

Divide the dataset in the way of proportional random division.
Initiate a learner for evaluation with training part.
Train the model, and pause training when the pause round is reached,
and instantiate the staged model for validation,
then, calculate the evaluation metric scores with prediction result obtained on the validation set.
Repeat Train-Pause-validate until the stop signal is received.

type PSI

type PSI interface {
	// EncryptSampleIDSet to encrypt local IDs
	EncryptSampleIDSet() ([]byte, error)

	// SetReEncryptIDSet sets re-encrypted IDs from other party,
	// and tries to calculate final re-encrypted IDs
	// returns True if calculation is Done, otherwise False if still waiting for others' parts
	// returns Error if any mistake happens
	SetReEncryptIDSet(party string, reEncIDs []byte) (bool, error)

	// ReEncryptIDSet to encrypt encrypted IDs for other party
	ReEncryptIDSet(party string, encIDs []byte) ([]byte, error)

	// SetOtherFinalReEncryptIDSet sets final re-encrypted IDs of other party
	SetOtherFinalReEncryptIDSet(party string, reEncIDs []byte) error

	// IntersectParts tries to calculate intersection with all parties' samples
	// returns True with final result if calculation is Done, otherwise False if still waiting for others' samples
	// returns Error if any mistake happens
	// You'd better call it when SetReEncryptIDSet returns Done or SetOtherFinalReEncryptIDSet finishes
	IntersectParts() (bool, [][]string, []string, error)
}

PSI is for vertical learning, initialized at the beginning of training by Learner

type ResultHandler

type ResultHandler interface {
	SaveResult(*pbCom.TrainTaskResult)
}

ResultHandler handles final result which is successful or failed Should be called when learning finished

type RpcHandler

type RpcHandler interface {
	StepTrain(req *pb.TrainRequest, peerName string) (*pb.TrainResponse, error)

	// StepTrainWithRetry sends training message to remote mpc-node
	// retries 2 times at most
	// inteSec indicates the interval between retry requests, in seconds
	StepTrainWithRetry(req *pb.TrainRequest, peerName string, times int, inteSec int64) (*pb.TrainResponse, error)
}

RpcHandler used to request remote mpc-node

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL