Просмотр исходного кода

fix(func): avg/max/min should support int64 inputs

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
Jiyong Huang 2 лет назад
Родитель
Сommit
15627dfbd5

+ 21 - 17
internal/binder/function/funcs_agg.go

@@ -18,6 +18,7 @@ import (
 	"fmt"
 	"github.com/lf-edge/ekuiper/pkg/api"
 	"github.com/lf-edge/ekuiper/pkg/ast"
+	"github.com/lf-edge/ekuiper/pkg/cast"
 )
 
 func registerAggFunc() {
@@ -33,7 +34,7 @@ func registerAggFunc() {
 					if r, err := sliceIntTotal(arg0); err != nil {
 						return err, false
 					} else {
-						return r / c, true
+						return r / int64(c), true
 					}
 				case float64:
 					if r, err := sliceFloatTotal(arg0); err != nil {
@@ -67,13 +68,13 @@ func registerAggFunc() {
 				v := getFirstValidArg(arg0)
 				switch t := v.(type) {
 				case int:
-					if r, err := sliceIntMax(arg0, t); err != nil {
+					if r, err := sliceIntMax(arg0, int64(t)); err != nil {
 						return err, false
 					} else {
 						return r, true
 					}
 				case int64:
-					if r, err := sliceIntMax(arg0, int(t)); err != nil {
+					if r, err := sliceIntMax(arg0, t); err != nil {
 						return err, false
 					} else {
 						return r, true
@@ -108,13 +109,13 @@ func registerAggFunc() {
 				v := getFirstValidArg(arg0)
 				switch t := v.(type) {
 				case int:
-					if r, err := sliceIntMin(arg0, t); err != nil {
+					if r, err := sliceIntMin(arg0, int64(t)); err != nil {
 						return err, false
 					} else {
 						return r, true
 					}
 				case int64:
-					if r, err := sliceIntMin(arg0, int(t)); err != nil {
+					if r, err := sliceIntMin(arg0, t); err != nil {
 						return err, false
 					} else {
 						return r, true
@@ -227,10 +228,11 @@ func getFirstValidArg(s []interface{}) interface{} {
 	return nil
 }
 
-func sliceIntTotal(s []interface{}) (int, error) {
-	var total int
+func sliceIntTotal(s []interface{}) (int64, error) {
+	var total int64
 	for _, v := range s {
-		if vi, ok := v.(int); ok {
+		vi, err := cast.ToInt64(v, cast.CONVERT_SAMEKIND)
+		if err == nil {
 			total += vi
 		} else if v != nil {
 			return 0, fmt.Errorf("requires int but found %[1]T(%[1]v)", v)
@@ -250,14 +252,15 @@ func sliceFloatTotal(s []interface{}) (float64, error) {
 	}
 	return total, nil
 }
-func sliceIntMax(s []interface{}, max int) (int, error) {
+func sliceIntMax(s []interface{}, max int64) (int64, error) {
 	for _, v := range s {
-		if vi, ok := v.(int); ok {
-			if max < vi {
+		vi, err := cast.ToInt64(v, cast.CONVERT_SAMEKIND)
+		if err == nil {
+			if vi > max {
 				max = vi
 			}
 		} else if v != nil {
-			return 0, fmt.Errorf("requires int but found %[1]T(%[1]v)", v)
+			return 0, fmt.Errorf("requires int64 but found %[1]T(%[1]v)", v)
 		}
 	}
 	return max, nil
@@ -287,14 +290,15 @@ func sliceStringMax(s []interface{}, max string) (string, error) {
 	}
 	return max, nil
 }
-func sliceIntMin(s []interface{}, min int) (int, error) {
+func sliceIntMin(s []interface{}, min int64) (int64, error) {
 	for _, v := range s {
-		if vi, ok := v.(int); ok {
-			if min > vi {
+		vi, err := cast.ToInt64(v, cast.CONVERT_SAMEKIND)
+		if err == nil {
+			if vi < min {
 				min = vi
 			}
 		} else if v != nil {
-			return 0, fmt.Errorf("requires int but found %[1]T(%[1]v)", v)
+			return 0, fmt.Errorf("requires int64 but found %[1]T(%[1]v)", v)
 		}
 	}
 	return min, nil
@@ -315,7 +319,7 @@ func sliceFloatMin(s []interface{}, min float64) (float64, error) {
 func sliceStringMin(s []interface{}, min string) (string, error) {
 	for _, v := range s {
 		if vs, ok := v.(string); ok {
-			if min < vs {
+			if vs < min {
 				min = vs
 			}
 		} else if v != nil {

+ 108 - 0
internal/binder/function/funcs_agg_test.go

@@ -0,0 +1,108 @@
+// Copyright 2022 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package function
+
+import (
+	"fmt"
+	"github.com/lf-edge/ekuiper/internal/conf"
+	kctx "github.com/lf-edge/ekuiper/internal/topo/context"
+	"github.com/lf-edge/ekuiper/internal/topo/state"
+	"github.com/lf-edge/ekuiper/pkg/api"
+	"reflect"
+	"testing"
+)
+
+func TestAggExec(t *testing.T) {
+	fAvg, ok := builtins["avg"]
+	if !ok {
+		t.Fatal("builtin not found")
+	}
+	fMax, ok := builtins["max"]
+	if !ok {
+		t.Fatal("builtin not found")
+	}
+	fMin, ok := builtins["min"]
+	if !ok {
+		t.Fatal("builtin not found")
+	}
+	contextLogger := conf.Log.WithField("rule", "testExec")
+	ctx := kctx.WithValue(kctx.Background(), kctx.LoggerKey, contextLogger)
+	tempStore, _ := state.CreateStore("mockRule0", api.AtMostOnce)
+	fctx := kctx.NewDefaultFuncContext(ctx.WithMeta("mockRule0", "test", tempStore), 2)
+	var tests = []struct {
+		args []interface{}
+		avg  interface{}
+		max  interface{}
+		min  interface{}
+	}{
+		{ // 0
+			args: []interface{}{
+				[]interface{}{
+					"foo",
+					"bar",
+					"self",
+				},
+			},
+			avg: fmt.Errorf("run avg function error: found invalid arg string(foo)"),
+			max: "self",
+			min: "bar",
+		}, { // 1
+			args: []interface{}{
+				[]interface{}{
+					int64(100),
+					int64(150),
+					int64(200),
+				},
+			},
+			avg: int64(150),
+			max: int64(200),
+			min: int64(100),
+		}, { // 2
+			args: []interface{}{
+				[]interface{}{
+					float64(100),
+					float64(150),
+					float64(200),
+				},
+			},
+			avg: float64(150),
+			max: float64(200),
+			min: float64(100),
+		}, { // 3
+			args: []interface{}{
+				[]interface{}{
+					100, 150, 200,
+				},
+			},
+			avg: int64(150),
+			max: int64(200),
+			min: int64(100),
+		},
+	}
+	for i, tt := range tests {
+		rAvg, _ := fAvg.exec(fctx, tt.args)
+		if !reflect.DeepEqual(rAvg, tt.avg) {
+			t.Errorf("%d result mismatch,\ngot:\t%v \nwant:\t%v", i, rAvg, tt.avg)
+		}
+		rMax, _ := fMax.exec(fctx, tt.args)
+		if !reflect.DeepEqual(rMax, tt.max) {
+			t.Errorf("%d result mismatch,\ngot:\t%v \nwant:\t%v", i, rMax, tt.max)
+		}
+		rMin, _ := fMin.exec(fctx, tt.args)
+		if !reflect.DeepEqual(rMin, tt.min) {
+			t.Errorf("%d result mismatch,\ngot:\t%v \nwant:\t%v", i, rMin, tt.min)
+		}
+	}
+}

+ 9 - 9
internal/topo/operator/project_test.go

@@ -1714,7 +1714,7 @@ func TestProjectPlan_AggFuncs(t *testing.T) {
 				WindowRange: xsql.NewWindowRange(1541152486013, 1541152487013),
 			},
 			result: []map[string]interface{}{{
-				"sum":        123203,
+				"sum":        int64(123203),
 				"ws":         int64(1541152486013),
 				"window_end": int64(1541152487013),
 			}},
@@ -1735,7 +1735,7 @@ func TestProjectPlan_AggFuncs(t *testing.T) {
 			},
 
 			result: []map[string]interface{}{{
-				"s": 123203,
+				"s": int64(123203),
 			}},
 		},
 		//8
@@ -1753,7 +1753,7 @@ func TestProjectPlan_AggFuncs(t *testing.T) {
 				},
 			},
 			result: []map[string]interface{}{{
-				"sum": 123203,
+				"sum": int64(123203),
 			}},
 		},
 		//9
@@ -1773,10 +1773,10 @@ func TestProjectPlan_AggFuncs(t *testing.T) {
 			result: []map[string]interface{}{{
 				"all": 3,
 				"c":   2,
-				"a":   40,
-				"s":   80,
-				"min": 27,
-				"max": 53,
+				"a":   int64(40),
+				"s":   int64(80),
+				"min": int64(27),
+				"max": int64(53),
 			}},
 		},
 		//10
@@ -2158,8 +2158,8 @@ func TestProjectPlan_AggFuncs(t *testing.T) {
 			},
 			result: []map[string]interface{}{{
 				"var2": "moduleB topic",
-				"max2": 1,
-				"max3": 100,
+				"max2": int64(1),
+				"max3": int64(100),
 			}},
 		},
 	}