tensorflow_service_apis

package module
v0.1.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Feb 8, 2022 License: MIT Imports: 22 Imported by: 0

README

tensorflow_service_apis

tensorflow_service的grpc客户端接口封装.

使用方法

  1. 创建并初始化SDK对象,有3种方式
    1. 可以使用tensorflow_service_apis.New()方式,然后调用sdk对象的Init方法传入参数tensorflow_service_apis.SDKConfig对象初始化
    2. 构造一个tensorflow_service_apis.SDKConfig对象,调用其NewSDK()方法创建一个初始化化好的SDK对象
    3. 直接使用默认SDK对象tensorflow_service_apis.DefaultSDK,调用sdk对象的Init方法传入参数tensorflow_service_apis.SDKConfig对象初始化
  2. 调用SDK对象的Get{Session|Model|Prediction}ServiceConn方法获取对应的连接
  3. 调用SDK对象的NewCtx方法获取请求时用的ctx对象
  4. 构造指定方法的请求体
  5. 请求指定方法并获得结果

例子

此处以调用ModelServiceGetModelStatus方法为例

import (
    tfsapis "github.com/Golang-Tools/tensorflow_service_apis"
    "google.golang.org/protobuf/types/known/wrapperspb"
    log "github.com/Golang-Tools/loggerhelper"
)

func main(){
    tfsapis.DefaultSDK.Init(&tensorflow_service_apis.SDKConfig{
        //你的配置
    })
    conn,err := tfsapis.DefaultSDK.GetModelServiceConn()
    if err != nil{
        panic(err)
    }
    // 获取模型元信息
    ctx,cancel := tfsapis.DefaultSDK.NewCtx()
    defer cancel()
    res, err := conn.GetModelStatus(ctx, &tfsapis.GetModelStatusRequest{
        ModelSpec:&tfsapis.ModelSpec{
            Name:          {modelName},//模型名
            VersionChoice: &tfsapis.ModelSpec_Version{Version: wrapperspb.Int64({version})},//指定版本号
        },
    })
    if err != nil{
        panic(err)
    }
    log.Info("get model status",log.Dict{"res":res})
}

注意事项

tensorflow.serving.PredictionService/GetModelMetadata常用来查看模型的元信息
  1. 请求这个方法必须填写参数MetadataField: []string{"signature_def"}

  2. 这个方法的返回中有any类型,其对应的是tensorflow_serving.SignatureDefMap使用如下方式获取:

    import (
        "github.com/golang/protobuf/ptypes"
        apispb "github.com/Golang-Tools/tensorflow_service_apis/tensorflow_serving"
    )
    
    func main(){
        sd := apispb.SignatureDefMap{}
        err = ptypes.UnmarshalAny(meta.Metadata["signature_def"], &sd)
    }
    

开发方式

  1. 下载指定版本的tensorflow和tfserving,将其中有用的文件夹(tensorflow/core和tensorflow_serving)留下其他都删除.

  2. 执行leave_proto.py文件

  3. 使用搜索工具,查找.go文件中的"tensorflow,找到import中的内容,前面加上github.com/Golang-Tools/tensorflow_service_apis/

  4. tensorflowtensorflow_serving两个文件夹下分别添加一个同名.go文件,在其中添加同名package声明

    package tensorflow|tensorflow_serving
    

一般情况下我们只需要修改项目根目录下的4个文件

  • tensorflow_service_apis.go,sdk对象的声明,包括各种设置项的处理等
  • modelserviceconn.go,predictionserviceconn.gosessionserviceconn.go,不同service的连接对象声明

Documentation

Index

Constants

This section is empty.

Variables

View Source
var DefaultCtxOpts = CtxOptions{}
View Source
var DefaultSDK = New()

DefaultSDK 默认的sdk客户端

Functions

This section is empty.

Types

type CtxOption added in v0.1.0

type CtxOption interface {
	Apply(*CtxOptions)
}

CtxOption configures how we set up the connection.

func UntilEnd added in v0.1.0

func UntilEnd() CtxOption

UntilEnd NewCtx方法的参数,用于设置ctx为不会超时

func WithTimeout added in v0.1.0

func WithTimeout(timeout time.Duration) CtxOption

WithTimeout NewCtx方法的参数,用于设置ctx为指定的超时时长

type CtxOptions added in v0.1.0

type CtxOptions struct {
	UntilEnd bool
	Timeout  time.Duration
}

CtxOptions 设置ctx行为的选项

type ModelServiceConn

type ModelServiceConn struct {
	tfserv.ModelServiceClient
	// contains filtered or unexported fields
}

ModelServiceConn ModelServiceClient客户端的连接类

func (*ModelServiceConn) Close

func (c *ModelServiceConn) Close() error

Close 断开连接

type PredictionServiceConn

type PredictionServiceConn struct {
	tfserv.PredictionServiceClient
	// contains filtered or unexported fields
}

PredictionServiceConn PredictionServiceClient客户端的连接类

func (*PredictionServiceConn) Close

func (c *PredictionServiceConn) Close() error

Close 断开连接

type SDK added in v0.0.2

type SDK struct {
	*SDKConfig
	// contains filtered or unexported fields
}

SDK 的客户端类型

func New added in v0.0.2

func New() *SDK

New 创建客户端对象

func (*SDK) GetModelServiceConn added in v0.1.0

func (c *SDK) GetModelServiceConn() (*ModelServiceConn, error)

GetModelServiceConn 获取ModelServiceClient客户端连接

func (*SDK) GetPredictionServiceConn added in v0.1.0

func (c *SDK) GetPredictionServiceConn() (*PredictionServiceConn, error)

GetPredictionServiceConn 获取PredictionServiceClient客户端连接

func (*SDK) GetSessionServiceConn added in v0.1.0

func (c *SDK) GetSessionServiceConn() (*SessionServiceConn, error)

GetPredictionServiceConn 获取PredictionServiceClient客户端连接

func (*SDK) Init added in v0.0.2

func (c *SDK) Init(conf *SDKConfig) error

Init 初始化sdk客户端的连接信息

func (*SDK) NewCtx added in v0.0.2

func (c *SDK) NewCtx(opts ...CtxOption) (ctx context.Context, cancel context.CancelFunc)

NewCtx 创建请求的上下文

func (*SDK) NewModelServiceConn added in v0.0.2

func (c *SDK) NewModelServiceConn() (*ModelServiceConn, error)

NewModelServiceConn 建立一个新的ModelServiceClient客户端的连接类

func (*SDK) NewPredictionServiceConn added in v0.0.2

func (c *SDK) NewPredictionServiceConn() (*PredictionServiceConn, error)

NewPredictionServiceConn 建立一个新的连接

func (*SDK) NewSessionServiceConn added in v0.1.0

func (c *SDK) NewSessionServiceConn() (*SessionServiceConn, error)

NewSessionServiceConn 建立一个新的连接

type SDKConfig

type SDKConfig struct {
	Query_Addresses       []string `json:"query_addresses" jsonschema:"required,description=连接服务的主机地址"`
	Requester_App_Name    string   `json:"requester_app_name,omitempty" jsonschema:"description=请求方服务名"`
	Requester_App_Version string   `json:"requester_app_version,omitempty" jsonschema:"description=请求方服务版本"`

	// 性能设置
	Initial_Window_Size                         int  `json:"initial_window_size,omitempty" jsonschema:"description=基于Stream的滑动窗口大小"`
	Initial_Conn_Window_Size                    int  `json:"initial_conn_window_size,omitempty" jsonschema:"description=基于Connection的滑动窗口大小"`
	Keepalive_Time                              int  `json:"keepalive_time,omitempty" jsonschema:"description=空闲连接每隔n秒ping一次客户端已确保连接存活"`
	Keepalive_Timeout                           int  `json:"keepalive_timeout,omitempty" jsonschema:"description=ping时长超过n则认为连接已死"`
	Keepalive_Enforcement_Permit_Without_Stream bool `` /* 134-byte string literal not displayed */
	Conn_With_Block                             bool `json:"conn_with_block,omitempty" jsonschema:"description=同步的连接建立"`
	Max_Recv_Msg_Size                           int  `json:"max_rec_msg_size,omitempty" jsonschema:"description=允许接收的最大消息长度"`
	Max_Send_Msg_Size                           int  `json:"max_send_msg_size,omitempty" jsonschema:"description=允许发送的最大消息长度"`

	//压缩设置,目前只支持gzip
	Compression string `json:"compression,omitempty" jsonschema:"description=使用哪种方式压缩发送的消息,enum=gzip"`

	// TLS设置
	Ca_Cert_Path     string `json:"ca_cert_path,omitempty" jsonschema:"description=如果要使用tls则需要指定根证书位置"`
	Client_Cert_Path string `json:"client_cert_path,omitempty" jsonschema:"description=客户端整数位置"`
	Client_Key_Path  string `json:"client_key_path,omitempty" jsonschema:"description=客户端证书对应的私钥位置"`

	// XDS设置
	XDS_CREDS bool `json:"xds_creds,omitempty" jsonschema:"description=当address的schema是xds时是否使用xds的令牌加密访问"`

	// 请求超时设置
	Query_Timeout int `json:"query_timeout,omitempty" jsonschema:"description=请求服务的最大超时时间单位ms"`
}

SDKConfig 的客户端类型

func (*SDKConfig) NewSDK added in v0.0.2

func (c *SDKConfig) NewSDK() *SDK

NewSDK 创建客户端对象

type SessionServiceConn added in v0.1.0

type SessionServiceConn struct {
	tfserv.SessionServiceClient
	// contains filtered or unexported fields
}

SessionServiceConn 客户端类

func (*SessionServiceConn) Close added in v0.1.0

func (c *SessionServiceConn) Close() error

Close 断开连接

Directories

Path Synopsis

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL