tnn

module
v1.0.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jun 17, 2023 License: MIT

README

tnn

go版本神经网络框架,支持模型训练和推理,libgotorch库的安装方式请看libgotorch安装

工具

  • minfo: 这是tnn框架中的一个工具,用于查看保存模型的定义信息

示例

  • xor: 该示例是一个四层的神经网络,用于训练异或运算(回归)
  • sin: 该示例使用rnn网络来进行sin函数的时序任务训练(回归)
  • sin_attention: 使用transformer来实现sin曲线的预测(回归)
  • couplet: 使用GPT模型来对对联(回归)

构造网络

首先定义网络的每个层

hiddenLayer := layer.NewDense(10)
relu := activation.NewReLU()
outputLayer := layer.NewDense(1)

选定一个优化器

optimizer := optimizer.NewAdam()

定义forward函数

func forward(x *tensor.Tensor) *tensor.Tensor {
    y := hiddenLayer.Forward(x)
    y = relu.Forward(y)
    y = outputLayer.Forward(y)
    return y
}

最后构造出模型并进行模型训练

for i := 0; i < 10; i++ {
    loss := loss.NewMse(forward(input), output)
    loss.Backward()
    optimizer.Step(net.Params())
    fmt.Printf("Epoch: %d, Loss: %.05f\n", i, loss.Value())
}

模型推理方法如下

pred := forward(input)
fmt.Println(pred.Float32Value())

感谢

Directories

Path Synopsis
example
sin
xor
internal
pb
nn
net

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL