xgboost

package module
v0.1.4 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Oct 20, 2023 License: MIT Imports: 12 Imported by: 0

README

xgboost-go

Build Status GoDoc

XGBoost inference with Golang by means of exporting xgboost model into json format and load model from that json file. This repo only supports DMLC XGBoost model at the moment. For more information regarding how XGBoost inference works, you can refer to this medium article.

Features

Currently, this repo only supports a few core features such as:

  • Read models from json format file (via dump_model API call)
  • Support sigmoid and softmax transformation activation.
  • Support binary and multiclass predictions.
  • Support regressions predictions.
  • Support missing values.
  • Support libsvm data format.

NOTE: The result from DMLC XGBoost model may slightly differ from this model due to float number precision.

How to use:

To use this repo, first you need to get it:

go get github.com/Elvenson/xgboost-go

Basic example:

package main

import (
	"fmt"

	xgb "github.com/Elvenson/xgboost-go"
	"github.com/Elvenson/xgboost-go/activation"
	"github.com/Elvenson/xgboost-go/mat"
)

func main() {
	ensemble, err := xgb.LoadXGBoostFromJSON("your model path",
		"", 1, 4, &activation.Logistic{})
	if err != nil {
		panic(err)
	}

	input, err := mat.ReadLibsvmFileToSparseMatrix("your libsvm input path")
	if err != nil {
		panic(err)
	}
	predictions, err := ensemble.PredictProba(input)
	if err != nil {
		panic(err)
	}
	fmt.Printf("%+v\n", predictions)
}

Here LoadXGBoostFromJSON requires 5 parameters:

  • The json model path.
  • DMLC feature map format, if no feature map leave this blank.
  • The number of classes (if this is a binary classification, the number of classes should be 1)
  • The depth of the tree, if unable to get the tree depth can specify 0 (slightly slower model built time)
  • Activation function, for now binary is Logistic multiclass is Softmax and regression is Raw.

For more example, can take a look at xgbensemble_test.go or read this package documentation.

NOTE: This repo only got tested on Python xgboost package version 1.2.0.

Documentation

Overview

Package xgboost is a pure Golang implementation of loading DMLC XGBoost json model generated from dump_model python API. This package supports binary, multiclass and regression inference. Note that this package is just for inference purpose only, for training part please reference to https://github.com/dmlc/xgboost.

Training model

In order to have a json encoded model file, we need to train the model via Python first:

iris_xgboost.py:

import xgboost as xgb
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.datasets import dump_svmlight_file
import numpy as np

X, y = datasets.load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

dtrain = xgb.DMatrix(X_train, label=y_train)
param = {'max_depth': 4, 'eta': 1, 'objective': 'multi:softmax', 'nthread': 4,
		 'eval_metric': 'auc', 'num_class': 3}

num_round = 10
bst = xgb.train(param, dtrain, num_round)
y_pred = bst.predict(xgb.DMatrix(X_test))

clf = xgb.XGBClassifier(max_depth=4, objective='multi:softprob', n_estimators=10,
						num_classes=3)

clf.fit(X_train, y_train)
y_pred_proba = clf.predict_proba(X_test)

np.savetxt('../data/iris_xgboost_true_prediction.txt', y_pred, delimiter='\t')
np.savetxt('../data/iris_xgboost_true_prediction_proba.txt', y_pred_proba, delimiter='\t')
dump_svmlight_file(X_test, y_test, '../data/iris_test.libsvm')
bst.dump_model('../data/iris_xgboost_dump.json', dump_format='json')

Here is how to load the model exported from the above script:

package main

import (
	"fmt"

	"github.com/Elvenson/xgboost-go/activation"
	"github.com/Elvenson/xgboost-go/mat"
	"github.com/Elvenson/xgboost-go/models"
)

func main() {
	ensemble, err := models.LoadXGBoostFromJSON("your model path",
		"", 1, 4, &activation.Logistic{})
	if err != nil {
		panic(err)
	}

	input, err := mat.ReadLibsvmFileToSparseMatrix("your libsvm input path")
	if err != nil {
		panic(err)
	}
	predictions, err := ensemble.PredictProba(input)
	if err != nil {
		panic(err)
	}
	fmt.Printf("%+v\n", predictions)
}

For more information, please take a look at xgbensemble_test.go

Index

Constants

This section is empty.

Variables

This section is empty.

Functions

func LoadXGBoost added in v0.1.3

func LoadXGBoost(
	xgbEnsembleJSON []*xgboostJSON,
	featuresMapPath string,
	numClasses int,
	maxDepth int,
	activation activation.Activation) (*inference.Ensemble, error)

func LoadXGBoostFromJSON

func LoadXGBoostFromJSON(
	modelPath,
	featuresMapPath string,
	numClasses int,
	maxDepth int,
	activation activation.Activation) (*inference.Ensemble, error)

LoadXGBoostFromJSON loads xgboost model from json file.

func LoadXGBoostFromJSONBytes added in v0.1.3

func LoadXGBoostFromJSONBytes(
	jsonBytes []byte,
	featuresMapPath string,
	numClasses int,
	maxDepth int,
	activation activation.Activation) (*inference.Ensemble, error)

Types

This section is empty.

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL