|
@@ -19,7 +19,6 @@ import (
|
|
"github.com/lf-edge/ekuiper/pkg/api"
|
|
"github.com/lf-edge/ekuiper/pkg/api"
|
|
"github.com/lf-edge/ekuiper/pkg/cast"
|
|
"github.com/lf-edge/ekuiper/pkg/cast"
|
|
"github.com/mattn/go-tflite"
|
|
"github.com/mattn/go-tflite"
|
|
- "strconv"
|
|
|
|
)
|
|
)
|
|
|
|
|
|
type Tffunc struct {
|
|
type Tffunc struct {
|
|
@@ -57,12 +56,6 @@ func (f *Tffunc) Exec(args []interface{}, ctx api.FunctionContext) (interface{},
|
|
// Set input tensors
|
|
// Set input tensors
|
|
for i := 1; i < len(args); i++ {
|
|
for i := 1; i < len(args); i++ {
|
|
input := interpreter.GetInputTensor(i - 1)
|
|
input := interpreter.GetInputTensor(i - 1)
|
|
- dims := "("
|
|
|
|
- for j := 0; j < input.NumDims(); j++ {
|
|
|
|
- dims += strconv.Itoa(input.Dim(j)) + ","
|
|
|
|
- }
|
|
|
|
- dims += ")"
|
|
|
|
- ctx.GetLogger().Debugf("tensorflow function %s input %d shape %s", model, i, dims)
|
|
|
|
var arg []interface{}
|
|
var arg []interface{}
|
|
switch v := args[i].(type) {
|
|
switch v := args[i].(type) {
|
|
case []byte:
|
|
case []byte:
|
|
@@ -71,156 +64,82 @@ func (f *Tffunc) Exec(args []interface{}, ctx api.FunctionContext) (interface{},
|
|
}
|
|
}
|
|
input.CopyFromBuffer(v)
|
|
input.CopyFromBuffer(v)
|
|
continue
|
|
continue
|
|
- case []interface{}:
|
|
|
|
|
|
+ case []interface{}: // only supports one dimensional arg. Even dim 0 must be an array of 1 element
|
|
arg = v
|
|
arg = v
|
|
default:
|
|
default:
|
|
return fmt.Errorf("tensorflow function parameter %d must be a bytea or array of bytea, but got %[1]T(%[1]v)", i), false
|
|
return fmt.Errorf("tensorflow function parameter %d must be a bytea or array of bytea, but got %[1]T(%[1]v)", i), false
|
|
}
|
|
}
|
|
- t := input.Type()
|
|
|
|
- switch input.NumDims() {
|
|
|
|
- case 0, 1:
|
|
|
|
- return fmt.Errorf("tensorflow function input tensor %d must have at least 2 dimensions but got 1", i-1), false
|
|
|
|
- case 2:
|
|
|
|
- if input.Dim(1) != len(arg) {
|
|
|
|
- return fmt.Errorf("tensorflow function input tensor %d must have %d elements but got %d", i-1, input.Dim(1), len(arg)), false
|
|
|
|
|
|
+ paraLen := 1
|
|
|
|
+ for j := 0; j < input.NumDims(); j++ {
|
|
|
|
+ paraLen = paraLen * input.Dim(j)
|
|
|
|
+ }
|
|
|
|
+ ctx.GetLogger().Debugf("receive tensor %v, require %d length", arg, paraLen)
|
|
|
|
+ if paraLen != len(arg) {
|
|
|
|
+ return fmt.Errorf("tensorflow function input tensor %d must have %d elements but got %d", i-1, paraLen, len(arg)), false
|
|
|
|
+ }
|
|
|
|
+ switch input.Type() {
|
|
|
|
+ case tflite.Float32:
|
|
|
|
+ v, err := cast.ToFloat32Slice(arg, cast.CONVERT_SAMEKIND)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return fmt.Errorf("invalid %d parameter, expect float32 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
}
|
|
}
|
|
- switch t {
|
|
|
|
- case tflite.Float32:
|
|
|
|
- v, err := cast.ToFloat32Slice(arg, cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect float32 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
- }
|
|
|
|
- err = input.SetFloat32s(v)
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- case tflite.Int64:
|
|
|
|
- v, err := cast.ToInt64Slice(arg, cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect int64 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
- }
|
|
|
|
- err = input.SetInt64s(v)
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- case tflite.Int32:
|
|
|
|
- v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
|
|
|
|
- return cast.ToInt32(input, sn)
|
|
|
|
- }, "int32", cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect int32 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
- }
|
|
|
|
- err = input.SetInt32s(v.([]int32))
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- case tflite.Int16:
|
|
|
|
- v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
|
|
|
|
- return cast.ToInt16(input, sn)
|
|
|
|
- }, "int16", cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect int16 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
- }
|
|
|
|
- err = input.SetInt16s(v.([]int16))
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- case tflite.Int8:
|
|
|
|
- v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
|
|
|
|
- return cast.ToInt8(input, sn)
|
|
|
|
- }, "int8", cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect int8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
- }
|
|
|
|
- err = input.SetInt8s(v.([]int8))
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- case tflite.UInt8:
|
|
|
|
- v, err := cast.ToBytes(args, cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect uint8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
- }
|
|
|
|
- err = input.SetUint8s(v)
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- default:
|
|
|
|
- return fmt.Errorf("invalid %d parameter, unsupported type %v in the model", i, t), false
|
|
|
|
|
|
+ err = input.SetFloat32s(v)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, false
|
|
}
|
|
}
|
|
- default:
|
|
|
|
- // support multiple dimensions. Here assume user passes a 1D array.
|
|
|
|
- var paraLen int = 1
|
|
|
|
- for j := 1; j < input.NumDims(); j++ {
|
|
|
|
- paraLen = paraLen * input.Dim(j)
|
|
|
|
|
|
+ case tflite.Int64:
|
|
|
|
+ v, err := cast.ToInt64Slice(arg, cast.CONVERT_SAMEKIND)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return fmt.Errorf("invalid %d parameter, expect int64 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
+ }
|
|
|
|
+ err = input.SetInt64s(v)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, false
|
|
|
|
+ }
|
|
|
|
+ case tflite.Int32:
|
|
|
|
+ v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
|
|
|
|
+ return cast.ToInt32(input, sn)
|
|
|
|
+ }, "int32", cast.CONVERT_SAMEKIND)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return fmt.Errorf("invalid %d parameter, expect int32 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
+ }
|
|
|
|
+ err = input.SetInt32s(v.([]int32))
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, false
|
|
|
|
+ }
|
|
|
|
+ case tflite.Int16:
|
|
|
|
+ v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
|
|
|
|
+ return cast.ToInt16(input, sn)
|
|
|
|
+ }, "int16", cast.CONVERT_SAMEKIND)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return fmt.Errorf("invalid %d parameter, expect int16 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
+ }
|
|
|
|
+ err = input.SetInt16s(v.([]int16))
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, false
|
|
}
|
|
}
|
|
- if paraLen != len(arg) {
|
|
|
|
- return fmt.Errorf("tensorflow function input tensor %d must have %d elements but got %d", i-1, paraLen, len(arg)), false
|
|
|
|
|
|
+ case tflite.Int8:
|
|
|
|
+ v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
|
|
|
|
+ return cast.ToInt8(input, sn)
|
|
|
|
+ }, "int8", cast.CONVERT_SAMEKIND)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return fmt.Errorf("invalid %d parameter, expect int8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
+ }
|
|
|
|
+ err = input.SetInt8s(v.([]int8))
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, false
|
|
}
|
|
}
|
|
- switch t {
|
|
|
|
- case tflite.Float32:
|
|
|
|
- v, err := cast.ToFloat32Slice(args[i], cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect float32 but got %[2]T(%[2]v)", i, args[i]), false
|
|
|
|
- }
|
|
|
|
- err = input.SetFloat32s(v)
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- case tflite.Int64:
|
|
|
|
- v, err := cast.ToInt64Slice(arg, cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect int64 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
- }
|
|
|
|
- err = input.SetInt64s(v)
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- case tflite.Int32:
|
|
|
|
- v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
|
|
|
|
- return cast.ToInt32(input, sn)
|
|
|
|
- }, "int32", cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect int32 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
- }
|
|
|
|
- err = input.SetInt32s(v.([]int32))
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- case tflite.Int16:
|
|
|
|
- v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
|
|
|
|
- return cast.ToInt16(input, sn)
|
|
|
|
- }, "int16", cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect int16 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
- }
|
|
|
|
- err = input.SetInt16s(v.([]int16))
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- case tflite.Int8:
|
|
|
|
- v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
|
|
|
|
- return cast.ToInt8(input, sn)
|
|
|
|
- }, "int8", cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect int8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
- }
|
|
|
|
- err = input.SetInt8s(v.([]int8))
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- case tflite.UInt8:
|
|
|
|
- v, err := cast.ToBytes(args, cast.CONVERT_SAMEKIND)
|
|
|
|
- if err != nil {
|
|
|
|
- return fmt.Errorf("invalid %d parameter, expect uint8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
|
|
- }
|
|
|
|
- err = input.SetUint8s(v)
|
|
|
|
- if err != nil {
|
|
|
|
- return nil, false
|
|
|
|
- }
|
|
|
|
- default:
|
|
|
|
- return fmt.Errorf("invalid %d parameter, unsupported type %v in the model", i, t), false
|
|
|
|
|
|
+ case tflite.UInt8:
|
|
|
|
+ v, err := cast.ToBytes(args, cast.CONVERT_SAMEKIND)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return fmt.Errorf("invalid %d parameter, expect uint8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
|
|
}
|
|
}
|
|
|
|
+ err = input.SetUint8s(v)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, false
|
|
|
|
+ }
|
|
|
|
+ default:
|
|
|
|
+ return fmt.Errorf("invalid %d parameter, unsupported type %v in the model", i, input.Type()), false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
status := interpreter.Invoke()
|
|
status := interpreter.Invoke()
|