Преглед изворни кода

fix: fix acc agg function for invalid data (#2182)

Signed-off-by: yisaer <disxiaofei@163.com>
Song Gao пре 1 година
родитељ
комит
93a018f87c

+ 21 - 25
internal/binder/function/funcs_analytic.go

@@ -247,12 +247,17 @@ func registerGlobalAggFunc() {
 			key := args[len(args)-1].(string)
 			keyCount := fmt.Sprintf("%s_count", key)
 			keySum := fmt.Sprintf("%s_sum", key)
+			keyAvg := fmt.Sprintf("%s_avg", key)
 
-			v1, err := ctx.GetState(keyCount)
+			vCount, err := ctx.GetState(keyCount)
 			if err != nil {
 				return err, false
 			}
-			v2, err := ctx.GetState(keySum)
+			vSum, err := ctx.GetState(keySum)
+			if err != nil {
+				return err, false
+			}
+			vAvg, err := ctx.GetState(keyAvg)
 			if err != nil {
 				return err, false
 			}
@@ -260,22 +265,16 @@ func registerGlobalAggFunc() {
 			if !ok {
 				return fmt.Errorf("when arg is not a bool but got %v", args[len(args)-2]), false
 			}
-			if v1 == nil && v2 == nil {
-				if args[0] == nil || !validData {
-					return 0, true
-				}
-
-				v1 = float64(0)
-				v2 = float64(0)
-			} else {
-				if args[0] == nil || !validData {
-					count := v1.(float64)
-					sum := v2.(float64)
-					return sum / count, true
-				}
+			if vSum == nil || vCount == nil || vAvg == nil {
+				vSum = float64(0)
+				vCount = float64(0)
+				vAvg = float64(0)
+			}
+			if args[0] == nil || !validData {
+				return vAvg.(float64), true
 			}
-			count := v1.(float64)
-			sum := v2.(float64)
+			count := vCount.(float64)
+			sum := vSum.(float64)
 			count = count + 1
 			switch v := args[0].(type) {
 			case int:
@@ -297,6 +296,9 @@ func registerGlobalAggFunc() {
 			if err := ctx.PutState(keySum, sum); err != nil {
 				return err, false
 			}
+			if err := ctx.PutState(keyAvg, sum/count); err != nil {
+				return err, false
+			}
 			return sum / count, true
 		},
 		val: func(ctx api.FunctionContext, args []ast.Expr) error {
@@ -317,7 +319,7 @@ func registerGlobalAggFunc() {
 			}
 			if val == nil {
 				if !validData {
-					return nil, false
+					return 0, true
 				}
 				val = float64(math.MinInt64)
 			}
@@ -366,7 +368,7 @@ func registerGlobalAggFunc() {
 			}
 			if val == nil {
 				if !validData {
-					return nil, false
+					return 0, true
 				}
 				val = float64(math.MaxInt64)
 			}
@@ -414,9 +416,6 @@ func registerGlobalAggFunc() {
 				return fmt.Errorf("when arg is not a bool but got %v", args[len(args)-2]), false
 			}
 			if val == nil {
-				if !validData {
-					return nil, false
-				}
 				val = float64(0)
 			}
 			accu := val.(float64)
@@ -459,9 +458,6 @@ func registerGlobalAggFunc() {
 				return fmt.Errorf("when arg is not a bool but got %v", args[len(args)-2]), false
 			}
 			if val == nil {
-				if !validData {
-					return nil, false
-				}
 				val = 0
 			}
 			cnt := val.(int)

+ 37 - 0
internal/binder/function/funcs_analytic_test.go

@@ -1461,4 +1461,41 @@ func TestAccumulateAgg(t *testing.T) {
 			require.Equal(t, test.results[i], result)
 		}
 	}
+
+	tests2 := []struct {
+		name   string
+		result interface{}
+	}{
+		{
+			"acc_sum",
+			float64(0),
+		},
+		{
+			"acc_max",
+			0,
+		},
+		{
+			"acc_min",
+			0,
+		},
+		{
+			"acc_avg",
+			float64(0),
+		},
+		{
+			"acc_count",
+			0,
+		},
+	}
+	for _, test := range tests2 {
+		f, ok := builtins[test.name]
+		require.True(t, ok)
+		contextLogger := conf.Log.WithField("rule", "testExec")
+		ctx := kctx.WithValue(kctx.Background(), kctx.LoggerKey, contextLogger)
+		tempStore, _ := state.CreateStore("mockRule0", api.AtMostOnce)
+		fctx := kctx.NewDefaultFuncContext(ctx.WithMeta("mockRule0", "test", tempStore), 2)
+		result, b := f.exec(fctx, []interface{}{1, false, fmt.Sprintf("%s_key", test.name)})
+		require.True(t, b)
+		require.Equal(t, test.result, result)
+	}
 }