Jelajahi Sumber

feat: support accmulative agg funcs (#2125)

* support global agg

Signed-off-by: yisaer <disxiaofei@163.com>

* support global agg

Signed-off-by: yisaer <disxiaofei@163.com>

* add test

Signed-off-by: yisaer <disxiaofei@163.com>

* add test

Signed-off-by: yisaer <disxiaofei@163.com>

* add count

Signed-off-by: yisaer <disxiaofei@163.com>

* address the comment

Signed-off-by: yisaer <disxiaofei@163.com>

* add test

Signed-off-by: yisaer <disxiaofei@163.com>

* fix test

Signed-off-by: yisaer <disxiaofei@163.com>

* fix test

Signed-off-by: yisaer <disxiaofei@163.com>

* add doc

Signed-off-by: yisaer <disxiaofei@163.com>

* add doc

Signed-off-by: yisaer <disxiaofei@163.com>

---------

Signed-off-by: yisaer <disxiaofei@163.com>
Song Gao 1 tahun lalu
induk
melakukan
98df74a882

+ 92 - 0
docs/en_US/sqls/functions/analytic_functions.md

@@ -292,3 +292,95 @@ WHERE CHANGED_COL(true, temperature) > 24
 _________________________________________________________
 {"ts":4,temperature":25,"humidity":88}
 ```
+
+## ACC Functions
+
+The ACC Functions means the accumulate functions, which will perform cumulative calculations based on the obtained parameters, and the cumulative scope is the entire life cycle of the rule.
+
+For the next acc functions, we will simulate input and output with the following data:
+
+```text
+a
+```
+
+Enter 3 pieces of data in sequence, 1,2,3 respectively.
+
+### ACC_SUM
+
+```text
+acc_sum(expr)
+```
+
+The acc_sum function accumulates the expression results and returns the cumulative sum result.
+
+Example 1: Cumulative sums using acc_sum
+
+```text
+acc_sum(a)
+```
+
+The results are: 1 3 6
+
+### ACC_MAX
+
+```text
+acc_max(expr)
+```
+
+The acc_max function performs accumulative comparison on the result of the expression to take the larger value, and returns the result of the cumulative comparison to take the larger value.
+
+Example 1: Use acc_max for cumulative comparison to take the larger value
+
+```text
+acc_max(a)
+```
+
+The results are: 1 2 3
+
+### ACC_MIN
+
+```text
+acc_min(expr)
+```
+
+The acc_min function performs accumulative comparison on the result of the expression to take the smaller value, and returns the result of the cumulative comparison to take the smaller value.
+
+Example 1: Use acc_min for cumulative comparison to take the smaller value
+
+```text
+acc_min(a)
+```
+
+The results are: 1 1 1
+
+### ACC_COUNT
+
+```text
+acc_count(expr)
+```
+
+The acc_count function counts the cumulative number of expression results and returns the cumulative value.
+
+Example 1: Use acc_count for cumulative count statistics
+
+```text
+acc_count(a)
+```
+
+The results are: 1 2 3
+
+### ACC_AVG
+
+```text
+acc_avg(expr)
+```
+
+The acc_avg function performs cumulative average statistics on the expression result and returns the cumulative average.
+
+Example 1: Cumulative average statistics using acc_count
+
+```text
+acc_avg(a)
+```
+
+The results are: 1 1.5 2

+ 92 - 0
docs/zh_CN/sqls/functions/analytic_functions.md

@@ -275,3 +275,95 @@ WHERE CHANGED_COL(true, temperature) > 24
 _________________________________________________________
 {"ts":4,temperature":25,"humidity":88}
 ```
+
+## ACC 函数
+
+ACC 函数全称为 accumulate function,该函数将会根据所得的参数进行累计计算,累计范围为该规则的整个生命周期。
+
+对于接下来的 acc 函数,我们将用以下数据进行模拟输入输出:
+
+```text
+a
+```
+
+依次输入 3 条数据,分别为 1,2,3。
+
+### ACC_SUM
+
+```text
+acc_sum(expr)
+```
+
+acc_sum 函数对表达式结果进行累计加和,返回累计加和结果。
+
+示例1:使用 acc_sum 进行累计加和
+
+```text
+acc_sum(a)
+```
+
+结果为分别为: 1 3 6
+
+### ACC_MAX
+
+```text
+acc_max(expr)
+```
+
+acc_max 函数对表达式结果进行累计比较取较大值,返回累计比较取较大值的结果。
+
+示例1:使用 acc_max 进行累计比较取较大值
+
+```text
+acc_max(a)
+```
+
+结果为分别为: 1 2 3
+
+### ACC_MIN
+
+```text
+acc_min(expr)
+```
+
+acc_min 函数对表达式结果进行累计比较取较小值,返回累计比较取较小值的结果。
+
+示例1:使用 acc_min 进行累计比较取较小值
+
+```text
+acc_min(a)
+```
+
+结果为分别为: 1 1 1
+
+### ACC_COUNT
+
+```text
+acc_count(expr)
+```
+
+acc_count 函数对表达式结果进行累计个数统计,返回累计个数值。
+
+示例1:使用 acc_count 进行累计个数统计
+
+```text
+acc_count(a)
+```
+
+结果为分别为: 1 2 3
+
+### ACC_AVG
+
+```text
+acc_avg(expr)
+```
+
+acc_avg 函数对表达式结果进行累计平均值统计,返回累计平均值。
+
+示例1:使用 acc_count 进行累计平均值统计
+
+```text
+acc_avg(a)
+```
+
+结果为分别为: 1 1.5 2

+ 257 - 0
internal/binder/function/funcs_analytic.go

@@ -16,6 +16,7 @@ package function
 
 import (
 	"fmt"
+	"math"
 	"reflect"
 	"strconv"
 
@@ -238,3 +239,259 @@ func registerAnalyticFunc() {
 		},
 	}
 }
+
+func registerGlobalAggFunc() {
+	builtins["acc_avg"] = builtinFunc{
+		fType: ast.FuncTypeScalar,
+		exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
+			key := args[len(args)-1].(string)
+			keyCount := fmt.Sprintf("%s_count", key)
+			keySum := fmt.Sprintf("%s_sum", key)
+
+			v1, err := ctx.GetState(keyCount)
+			if err != nil {
+				return err, false
+			}
+			v2, err := ctx.GetState(keySum)
+			if err != nil {
+				return err, false
+			}
+			validData, ok := args[len(args)-2].(bool)
+			if !ok {
+				return fmt.Errorf("when arg is not a bool but got %v", args[len(args)-2]), false
+			}
+			if v1 == nil && v2 == nil {
+				if args[0] == nil || !validData {
+					return 0, true
+				}
+
+				v1 = float64(0)
+				v2 = float64(0)
+			} else {
+				if args[0] == nil || !validData {
+					count := v1.(float64)
+					sum := v2.(float64)
+					return sum / count, true
+				}
+			}
+			count := v1.(float64)
+			sum := v2.(float64)
+			count = count + 1
+			switch v := args[0].(type) {
+			case int:
+				sum += float64(v)
+			case int32:
+				sum += float64(v)
+			case int64:
+				sum += float64(v)
+			case float32:
+				sum += float64(v)
+			case float64:
+				sum += v
+			default:
+				return fmt.Errorf("the value should be number"), false
+			}
+			if err := ctx.PutState(keyCount, count); err != nil {
+				return err, false
+			}
+			if err := ctx.PutState(keySum, sum); err != nil {
+				return err, false
+			}
+			return sum / count, true
+		},
+		val: func(ctx api.FunctionContext, args []ast.Expr) error {
+			return nil
+		},
+	}
+	builtins["acc_max"] = builtinFunc{
+		fType: ast.FuncTypeScalar,
+		exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
+			key := args[len(args)-1].(string)
+			val, err := ctx.GetState(key)
+			if err != nil {
+				return err, false
+			}
+			validData, ok := args[len(args)-2].(bool)
+			if !ok {
+				return fmt.Errorf("when arg is not a bool but got %v", args[len(args)-2]), false
+			}
+			if val == nil {
+				if !validData {
+					return nil, false
+				}
+				val = float64(math.MinInt64)
+			}
+			m := val.(float64)
+			if !validData {
+				return m, true
+			}
+			switch v := args[0].(type) {
+			case int:
+				v1 := float64(v)
+				m = getMax(m, v1)
+			case int32:
+				v1 := float64(v)
+				m = getMax(m, v1)
+			case int64:
+				v1 := float64(v)
+				m = getMax(m, v1)
+			case float32:
+				v1 := float64(v)
+				m = getMax(m, v1)
+			case float64:
+				m = getMax(m, v)
+			default:
+				return fmt.Errorf("the value should be number"), false
+			}
+			if err := ctx.PutState(key, m); err != nil {
+				return err, false
+			}
+			return m, true
+		},
+		val: func(ctx api.FunctionContext, args []ast.Expr) error {
+			return ValidateLen(1, len(args))
+		},
+	}
+	builtins["acc_min"] = builtinFunc{
+		fType: ast.FuncTypeScalar,
+		exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
+			key := args[len(args)-1].(string)
+			val, err := ctx.GetState(key)
+			if err != nil {
+				return err, false
+			}
+			validData, ok := args[len(args)-2].(bool)
+			if !ok {
+				return fmt.Errorf("when arg is not a bool but got %v", args[len(args)-2]), false
+			}
+			if val == nil {
+				if !validData {
+					return nil, false
+				}
+				val = float64(math.MaxInt64)
+			}
+			m := val.(float64)
+			if !validData {
+				return m, true
+			}
+			switch v := args[0].(type) {
+			case int:
+				v1 := float64(v)
+				m = getMin(m, v1)
+			case int32:
+				v1 := float64(v)
+				m = getMin(m, v1)
+			case int64:
+				v1 := float64(v)
+				m = getMin(m, v1)
+			case float32:
+				v1 := float64(v)
+				m = getMin(m, v1)
+			case float64:
+				m = getMin(m, v)
+			default:
+				return fmt.Errorf("the value should be number"), false
+			}
+			if err := ctx.PutState(key, m); err != nil {
+				return err, false
+			}
+			return m, true
+		},
+		val: func(ctx api.FunctionContext, args []ast.Expr) error {
+			return ValidateLen(1, len(args))
+		},
+	}
+	builtins["acc_sum"] = builtinFunc{
+		fType: ast.FuncTypeScalar,
+		exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
+			key := args[len(args)-1].(string)
+			val, err := ctx.GetState(key)
+			if err != nil {
+				return err, false
+			}
+			validData, ok := args[len(args)-2].(bool)
+			if !ok {
+				return fmt.Errorf("when arg is not a bool but got %v", args[len(args)-2]), false
+			}
+			if val == nil {
+				if !validData {
+					return nil, false
+				}
+				val = float64(0)
+			}
+			accu := val.(float64)
+			if !validData {
+				return accu, true
+			}
+			switch sumValue := args[0].(type) {
+			case int:
+				accu += float64(sumValue)
+			case int32:
+				accu += float64(sumValue)
+			case int64:
+				accu += float64(sumValue)
+			case float32:
+				accu += float64(sumValue)
+			case float64:
+				accu += sumValue
+			default:
+				return fmt.Errorf("the value should be number"), false
+			}
+			if err := ctx.PutState(key, accu); err != nil {
+				return err, false
+			}
+			return accu, true
+		},
+		val: func(ctx api.FunctionContext, args []ast.Expr) error {
+			return ValidateLen(1, len(args))
+		},
+	}
+	builtins["acc_count"] = builtinFunc{
+		fType: ast.FuncTypeScalar,
+		exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
+			key := args[len(args)-1].(string)
+			val, err := ctx.GetState(key)
+			if err != nil {
+				return err, false
+			}
+			validData, ok := args[len(args)-2].(bool)
+			if !ok {
+				return fmt.Errorf("when arg is not a bool but got %v", args[len(args)-2]), false
+			}
+			if val == nil {
+				if !validData {
+					return nil, false
+				}
+				val = 0
+			}
+			cnt := val.(int)
+			if !validData {
+				return cnt, true
+			}
+			if args[0] != nil {
+				cnt = cnt + 1
+			}
+			if err := ctx.PutState(key, cnt); err != nil {
+				return err, false
+			}
+			return cnt, true
+		},
+		val: func(ctx api.FunctionContext, args []ast.Expr) error {
+			return ValidateLen(1, len(args))
+		},
+	}
+}
+
+func getMax(a, b float64) float64 {
+	if a > b {
+		return a
+	}
+	return b
+}
+
+func getMin(a, b float64) float64 {
+	if a < b {
+		return a
+	}
+	return b
+}

+ 113 - 0
internal/binder/function/funcs_analytic_test.go

@@ -19,6 +19,8 @@ import (
 	"reflect"
 	"testing"
 
+	"github.com/stretchr/testify/require"
+
 	"github.com/lf-edge/ekuiper/internal/conf"
 	kctx "github.com/lf-edge/ekuiper/internal/topo/context"
 	"github.com/lf-edge/ekuiper/internal/topo/state"
@@ -1349,3 +1351,114 @@ func TestLatestPartition(t *testing.T) {
 		}
 	}
 }
+
+func TestAccumulateAgg(t *testing.T) {
+	tests := []struct {
+		name     string
+		results  []interface{}
+		testargs []interface{}
+	}{
+		{
+			name: "acc_count",
+			testargs: []interface{}{
+				"1",
+				float64(1),
+				float32(1),
+				1,
+				int32(1),
+				int64(1),
+			},
+			results: []interface{}{
+				1, 2, 3, 4, 5, 6,
+			},
+		},
+		{
+			name: "acc_avg",
+			testargs: []interface{}{
+				"1",
+				float64(1),
+				float32(1),
+				1,
+				int32(1),
+				int64(1),
+			},
+			results: []interface{}{
+				fmt.Errorf("the value should be number"),
+				float64(1),
+				float64(1),
+				float64(1),
+				float64(1),
+				float64(1),
+			},
+		},
+		{
+			name: "acc_max",
+			testargs: []interface{}{
+				"1",
+				float64(1),
+				float32(2),
+				3,
+				int32(4),
+				int64(5),
+			},
+			results: []interface{}{
+				fmt.Errorf("the value should be number"),
+				float64(1),
+				float64(2),
+				float64(3),
+				float64(4),
+				float64(5),
+			},
+		},
+		{
+			name: "acc_min",
+			testargs: []interface{}{
+				"1",
+				float64(5),
+				float32(4),
+				3,
+				int32(2),
+				int64(1),
+			},
+			results: []interface{}{
+				fmt.Errorf("the value should be number"),
+				float64(5),
+				float64(4),
+				float64(3),
+				float64(2),
+				float64(1),
+			},
+		},
+		{
+			name: "acc_sum",
+			testargs: []interface{}{
+				"1",
+				float64(1),
+				float32(1),
+				1,
+				int32(1),
+				int64(1),
+			},
+			results: []interface{}{
+				fmt.Errorf("the value should be number"),
+				float64(1),
+				float64(2),
+				float64(3),
+				float64(4),
+				float64(5),
+			},
+		},
+	}
+	for _, test := range tests {
+		f, ok := builtins[test.name]
+		require.True(t, ok)
+		contextLogger := conf.Log.WithField("rule", "testExec")
+		ctx := kctx.WithValue(kctx.Background(), kctx.LoggerKey, contextLogger)
+		tempStore, _ := state.CreateStore("mockRule0", api.AtMostOnce)
+		fctx := kctx.NewDefaultFuncContext(ctx.WithMeta("mockRule0", "test", tempStore), 2)
+		for i, arg := range test.testargs {
+			result, _ := f.exec(fctx, []interface{}{arg, true, fmt.Sprintf("%s_key", test.name)})
+			require.Equal(t, test.results[i], result)
+		}
+	}
+}

+ 6 - 0
internal/binder/function/function.go

@@ -53,6 +53,7 @@ func init() {
 	registerArrayFunc()
 	registerObjectFunc()
 	registerGlobalStateFunc()
+	registerGlobalAggFunc()
 }
 
 //var funcWithAsteriskSupportMap = map[string]string{
@@ -65,6 +66,11 @@ var analyticFuncs = map[string]struct{}{
 	"changed_col": {},
 	"had_changed": {},
 	"latest":      {},
+	"acc_sum":     {},
+	"acc_min":     {},
+	"acc_max":     {},
+	"acc_avg":     {},
+	"acc_count":   {},
 }
 
 const AnalyticPrefix = "$$a"

+ 0 - 1
internal/topo/planner/planner.go

@@ -463,7 +463,6 @@ func createLogicalPlan(stmt *ast.SelectStatement, opt *api.RuleOption, store kv.
 		p.SetChildren(children)
 		children = []LogicalPlan{p}
 	}
-
 	srfMapping := extractSRFMapping(stmt)
 	if stmt.Fields != nil {
 		enableLimit := false

+ 247 - 0
internal/topo/topotest/rule_test.go

@@ -22,6 +22,253 @@ import (
 	"github.com/lf-edge/ekuiper/pkg/api"
 )
 
+func TestAccAggSQL(t *testing.T) {
+	// Reset
+	streamList := []string{"demo"}
+	HandleStream(false, streamList, t)
+	tests := []RuleTest{
+		{
+			Name: "TestAccAggSql1",
+			Sql:  `select acc_sum(size) over (partition by color), color from demo`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"acc_sum": float64(3),
+						"color":   "red",
+					},
+				},
+				{
+					{
+						"acc_sum": float64(6),
+						"color":   "blue",
+					},
+				},
+				{
+					{
+						"acc_sum": float64(8),
+						"color":   "blue",
+					},
+				},
+				{
+					{
+						"acc_sum": float64(4),
+						"color":   "yellow",
+					},
+				},
+				{
+					{
+						"acc_sum": float64(4),
+						"color":   "red",
+					},
+				},
+			},
+		},
+		{
+			Name: "TestAccAggSql2",
+			Sql:  `select acc_sum(size) over (when color = "red"), color from demo`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"acc_sum": float64(3),
+						"color":   "red",
+					},
+				},
+				{
+					{
+						"acc_sum": float64(3),
+						"color":   "blue",
+					},
+				},
+				{
+					{
+						"acc_sum": float64(3),
+						"color":   "blue",
+					},
+				},
+				{
+					{
+						"acc_sum": float64(3),
+						"color":   "yellow",
+					},
+				},
+				{
+					{
+						"acc_sum": float64(4),
+						"color":   "red",
+					},
+				},
+			},
+		},
+		{
+			Name: "TestAccAggSql3",
+			Sql:  `select acc_min(size) over (when color = "red"), color from demo`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"acc_min": float64(3),
+						"color":   "red",
+					},
+				},
+				{
+					{
+						"acc_min": float64(3),
+						"color":   "blue",
+					},
+				},
+				{
+					{
+						"acc_min": float64(3),
+						"color":   "blue",
+					},
+				},
+				{
+					{
+						"acc_min": float64(3),
+						"color":   "yellow",
+					},
+				},
+				{
+					{
+						"acc_min": float64(1),
+						"color":   "red",
+					},
+				},
+			},
+		},
+		{
+			Name: "TestAccAggSql4",
+			Sql:  `select acc_max(size) over (when color = "red"), color from demo`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"acc_max": float64(3),
+						"color":   "red",
+					},
+				},
+				{
+					{
+						"acc_max": float64(3),
+						"color":   "blue",
+					},
+				},
+				{
+					{
+						"acc_max": float64(3),
+						"color":   "blue",
+					},
+				},
+				{
+					{
+						"acc_max": float64(3),
+						"color":   "yellow",
+					},
+				},
+				{
+					{
+						"acc_max": float64(3),
+						"color":   "red",
+					},
+				},
+			},
+		},
+		{
+			Name: "TestAccAggSql5",
+			Sql:  `select acc_count(size) over (when color = "red"), color from demo`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"acc_count": float64(1),
+						"color":     "red",
+					},
+				},
+				{
+					{
+						"acc_count": float64(1),
+						"color":     "blue",
+					},
+				},
+				{
+					{
+						"acc_count": float64(1),
+						"color":     "blue",
+					},
+				},
+				{
+					{
+						"acc_count": float64(1),
+						"color":     "yellow",
+					},
+				},
+				{
+					{
+						"acc_count": float64(2),
+						"color":     "red",
+					},
+				},
+			},
+		},
+		{
+			Name: "TestAccAggSql6",
+			Sql:  `select acc_avg(size) over (when color = "red"), color from demo`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"acc_avg": float64(3),
+						"color":   "red",
+					},
+				},
+				{
+					{
+						"acc_avg": float64(3),
+						"color":   "blue",
+					},
+				},
+				{
+					{
+						"acc_avg": float64(3),
+						"color":   "blue",
+					},
+				},
+				{
+					{
+						"acc_avg": float64(3),
+						"color":   "yellow",
+					},
+				},
+				{
+					{
+						"acc_avg": float64(2),
+						"color":   "red",
+					},
+				},
+			},
+		},
+	}
+	// Data setup
+	HandleStream(true, streamList, t)
+	options := []*api.RuleOption{
+		{
+			BufferLength: 100,
+			SendError:    true,
+		},
+		{
+			BufferLength:       100,
+			SendError:          true,
+			Qos:                api.AtLeastOnce,
+			CheckpointInterval: 5000,
+		},
+		{
+			BufferLength:       100,
+			SendError:          true,
+			Qos:                api.ExactlyOnce,
+			CheckpointInterval: 5000,
+		},
+	}
+	for j, opt := range options {
+		DoRuleTest(t, tests, j, opt, 0)
+	}
+}
+
 func TestLimitSQL(t *testing.T) {
 	// Reset
 	streamList := []string{"demo", "demoArr", "demoArr2"}