Преглед изворни кода

opt(tflite): support all dimensions

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
Jiyong Huang пре 2 година
родитељ
комит
102808833a
1 измењених фајлова са 67 додато и 148 уклоњено
  1. 67 148
      extensions/functions/tfLite/tfLite.go

+ 67 - 148
extensions/functions/tfLite/tfLite.go

@@ -19,7 +19,6 @@ import (
 	"github.com/lf-edge/ekuiper/pkg/api"
 	"github.com/lf-edge/ekuiper/pkg/api"
 	"github.com/lf-edge/ekuiper/pkg/cast"
 	"github.com/lf-edge/ekuiper/pkg/cast"
 	"github.com/mattn/go-tflite"
 	"github.com/mattn/go-tflite"
-	"strconv"
 )
 )
 
 
 type Tffunc struct {
 type Tffunc struct {
@@ -57,12 +56,6 @@ func (f *Tffunc) Exec(args []interface{}, ctx api.FunctionContext) (interface{},
 	// Set input tensors
 	// Set input tensors
 	for i := 1; i < len(args); i++ {
 	for i := 1; i < len(args); i++ {
 		input := interpreter.GetInputTensor(i - 1)
 		input := interpreter.GetInputTensor(i - 1)
-		dims := "("
-		for j := 0; j < input.NumDims(); j++ {
-			dims += strconv.Itoa(input.Dim(j)) + ","
-		}
-		dims += ")"
-		ctx.GetLogger().Debugf("tensorflow function %s input %d shape %s", model, i, dims)
 		var arg []interface{}
 		var arg []interface{}
 		switch v := args[i].(type) {
 		switch v := args[i].(type) {
 		case []byte:
 		case []byte:
@@ -71,156 +64,82 @@ func (f *Tffunc) Exec(args []interface{}, ctx api.FunctionContext) (interface{},
 			}
 			}
 			input.CopyFromBuffer(v)
 			input.CopyFromBuffer(v)
 			continue
 			continue
-		case []interface{}:
+		case []interface{}: // only supports one dimensional arg. Even dim 0 must be an array of 1 element
 			arg = v
 			arg = v
 		default:
 		default:
 			return fmt.Errorf("tensorflow function parameter %d must be a bytea or array of bytea, but got %[1]T(%[1]v)", i), false
 			return fmt.Errorf("tensorflow function parameter %d must be a bytea or array of bytea, but got %[1]T(%[1]v)", i), false
 		}
 		}
-		t := input.Type()
-		switch input.NumDims() {
-		case 0, 1:
-			return fmt.Errorf("tensorflow function input tensor %d must have at least 2 dimensions but got 1", i-1), false
-		case 2:
-			if input.Dim(1) != len(arg) {
-				return fmt.Errorf("tensorflow function input tensor %d must have %d elements but got %d", i-1, input.Dim(1), len(arg)), false
+		paraLen := 1
+		for j := 0; j < input.NumDims(); j++ {
+			paraLen = paraLen * input.Dim(j)
+		}
+		ctx.GetLogger().Debugf("receive tensor %v, require %d length", arg, paraLen)
+		if paraLen != len(arg) {
+			return fmt.Errorf("tensorflow function input tensor %d must have %d elements but got %d", i-1, paraLen, len(arg)), false
+		}
+		switch input.Type() {
+		case tflite.Float32:
+			v, err := cast.ToFloat32Slice(arg, cast.CONVERT_SAMEKIND)
+			if err != nil {
+				return fmt.Errorf("invalid %d parameter, expect float32 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
 			}
 			}
-			switch t {
-			case tflite.Float32:
-				v, err := cast.ToFloat32Slice(arg, cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect float32 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
-				}
-				err = input.SetFloat32s(v)
-				if err != nil {
-					return nil, false
-				}
-			case tflite.Int64:
-				v, err := cast.ToInt64Slice(arg, cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect int64 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
-				}
-				err = input.SetInt64s(v)
-				if err != nil {
-					return nil, false
-				}
-			case tflite.Int32:
-				v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-					return cast.ToInt32(input, sn)
-				}, "int32", cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect int32 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
-				}
-				err = input.SetInt32s(v.([]int32))
-				if err != nil {
-					return nil, false
-				}
-			case tflite.Int16:
-				v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-					return cast.ToInt16(input, sn)
-				}, "int16", cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect int16 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
-				}
-				err = input.SetInt16s(v.([]int16))
-				if err != nil {
-					return nil, false
-				}
-			case tflite.Int8:
-				v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-					return cast.ToInt8(input, sn)
-				}, "int8", cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect int8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
-				}
-				err = input.SetInt8s(v.([]int8))
-				if err != nil {
-					return nil, false
-				}
-			case tflite.UInt8:
-				v, err := cast.ToBytes(args, cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect uint8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
-				}
-				err = input.SetUint8s(v)
-				if err != nil {
-					return nil, false
-				}
-			default:
-				return fmt.Errorf("invalid %d parameter, unsupported type %v in the model", i, t), false
+			err = input.SetFloat32s(v)
+			if err != nil {
+				return nil, false
 			}
 			}
-		default:
-			// support multiple dimensions. Here assume user passes a 1D array.
-			var paraLen int = 1
-			for j := 1; j < input.NumDims(); j++ {
-				paraLen = paraLen * input.Dim(j)
+		case tflite.Int64:
+			v, err := cast.ToInt64Slice(arg, cast.CONVERT_SAMEKIND)
+			if err != nil {
+				return fmt.Errorf("invalid %d parameter, expect int64 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
+			}
+			err = input.SetInt64s(v)
+			if err != nil {
+				return nil, false
+			}
+		case tflite.Int32:
+			v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
+				return cast.ToInt32(input, sn)
+			}, "int32", cast.CONVERT_SAMEKIND)
+			if err != nil {
+				return fmt.Errorf("invalid %d parameter, expect int32 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
+			}
+			err = input.SetInt32s(v.([]int32))
+			if err != nil {
+				return nil, false
+			}
+		case tflite.Int16:
+			v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
+				return cast.ToInt16(input, sn)
+			}, "int16", cast.CONVERT_SAMEKIND)
+			if err != nil {
+				return fmt.Errorf("invalid %d parameter, expect int16 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
+			}
+			err = input.SetInt16s(v.([]int16))
+			if err != nil {
+				return nil, false
 			}
 			}
-			if paraLen != len(arg) {
-				return fmt.Errorf("tensorflow function input tensor %d must have %d elements but got %d", i-1, paraLen, len(arg)), false
+		case tflite.Int8:
+			v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
+				return cast.ToInt8(input, sn)
+			}, "int8", cast.CONVERT_SAMEKIND)
+			if err != nil {
+				return fmt.Errorf("invalid %d parameter, expect int8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
+			}
+			err = input.SetInt8s(v.([]int8))
+			if err != nil {
+				return nil, false
 			}
 			}
-			switch t {
-			case tflite.Float32:
-				v, err := cast.ToFloat32Slice(args[i], cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect float32 but got %[2]T(%[2]v)", i, args[i]), false
-				}
-				err = input.SetFloat32s(v)
-				if err != nil {
-					return nil, false
-				}
-			case tflite.Int64:
-				v, err := cast.ToInt64Slice(arg, cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect int64 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
-				}
-				err = input.SetInt64s(v)
-				if err != nil {
-					return nil, false
-				}
-			case tflite.Int32:
-				v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-					return cast.ToInt32(input, sn)
-				}, "int32", cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect int32 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
-				}
-				err = input.SetInt32s(v.([]int32))
-				if err != nil {
-					return nil, false
-				}
-			case tflite.Int16:
-				v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-					return cast.ToInt16(input, sn)
-				}, "int16", cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect int16 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
-				}
-				err = input.SetInt16s(v.([]int16))
-				if err != nil {
-					return nil, false
-				}
-			case tflite.Int8:
-				v, err := cast.ToTypedSlice(args, func(input interface{}, sn cast.Strictness) (interface{}, error) {
-					return cast.ToInt8(input, sn)
-				}, "int8", cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect int8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
-				}
-				err = input.SetInt8s(v.([]int8))
-				if err != nil {
-					return nil, false
-				}
-			case tflite.UInt8:
-				v, err := cast.ToBytes(args, cast.CONVERT_SAMEKIND)
-				if err != nil {
-					return fmt.Errorf("invalid %d parameter, expect uint8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
-				}
-				err = input.SetUint8s(v)
-				if err != nil {
-					return nil, false
-				}
-			default:
-				return fmt.Errorf("invalid %d parameter, unsupported type %v in the model", i, t), false
+		case tflite.UInt8:
+			v, err := cast.ToBytes(args, cast.CONVERT_SAMEKIND)
+			if err != nil {
+				return fmt.Errorf("invalid %d parameter, expect uint8 but got %[2]T(%[2]v) with err %v", i, args[i], err), false
 			}
 			}
+			err = input.SetUint8s(v)
+			if err != nil {
+				return nil, false
+			}
+		default:
+			return fmt.Errorf("invalid %d parameter, unsupported type %v in the model", i, input.Type()), false
 		}
 		}
 	}
 	}
 	status := interpreter.Invoke()
 	status := interpreter.Invoke()