Explorar o código

feat: add last_value function (#2108)

* add last_value func

Signed-off-by: Rui-Gan <1171530954@qq.com>

* fix ut

Signed-off-by: Rui-Gan <1171530954@qq.com>

* add ut

Signed-off-by: Rui-Gan <1171530954@qq.com>

---------

Signed-off-by: Rui-Gan <1171530954@qq.com>
Regina hai 1 ano
pai
achega
f9ca9b4a8c

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 9 - 0
docs/en_US/sqls/functions/aggregate_functions.md


+ 9 - 0
docs/zh_CN/sqls/functions/aggregate_functions.md

@@ -75,6 +75,15 @@ collect(col)
     SELECT collect(*)[1]->a as r1 FROM test GROUP BY TumblingWindow(ss, 10)
     ```
 
+## LAST_VALUE
+
+```text
+last_value(*, true)
+last_value(col, false)
+```
+
+用于返回在组中指定列或整个消息中最后一行的值。该函数有两个参数,第一个参数用于指定列或整个消息,第二个参数用于指定是否需要忽略空值;如果第二个参数为 true,则该函数仅返回最后的非空值,如果没有非空值,则返回空值;如果第二个参数为 false,则函数将返回最后的值,无论它是否为空。
+
 ## MERGE_AGG
 
 ```text

+ 38 - 0
internal/binder/function/funcs_agg.go

@@ -404,6 +404,44 @@ func registerAggFunc() {
 		val:   ValidateTwoNumberArg,
 		check: returnNilIfHasAnyNil,
 	}
+	builtins["last_value"] = builtinFunc{
+		fType: ast.FuncTypeAgg,
+		exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
+			arg0, ok := args[0].([]interface{})
+			if !ok {
+				return fmt.Errorf("Invalid argument type found."), false
+			}
+			args1, ok := args[1].([]interface{})
+			if !ok {
+				return fmt.Errorf("Invalid argument type found."), false
+			}
+			arg1, ok := getFirstValidArg(args1).(bool)
+			if !ok {
+				return fmt.Errorf("Invalid argument type found."), false
+			}
+			if len(arg0) == 0 {
+				return nil, true
+			}
+			if arg1 {
+				for i := len(arg0) - 1; i >= 0; i-- {
+					if arg0[i] != nil {
+						return arg0[i], true
+					}
+				}
+			}
+			return arg0[len(arg0)-1], true
+		},
+		val: func(_ api.FunctionContext, args []ast.Expr) error {
+			if err := ValidateLen(2, len(args)); err != nil {
+				return err
+			}
+			if !ast.IsBooleanArg(args[1]) {
+				return ProduceErrInfo(1, "bool")
+			}
+			return nil
+		},
+		check: returnNilIfHasAnyNil,
+	}
 }
 
 func getCount(s []interface{}) int {

+ 220 - 0
internal/binder/function/funcs_agg_test.go

@@ -26,6 +26,7 @@ import (
 	kctx "github.com/lf-edge/ekuiper/internal/topo/context"
 	"github.com/lf-edge/ekuiper/internal/topo/state"
 	"github.com/lf-edge/ekuiper/pkg/api"
+	"github.com/lf-edge/ekuiper/pkg/ast"
 )
 
 func TestAggExec(t *testing.T) {
@@ -400,6 +401,19 @@ func TestAggFuncNil(t *testing.T) {
 			r, b := function.exec(fctx, []interface{}{nil})
 			require.True(t, b, fmt.Sprintf("%v failed", name))
 			require.Nil(t, r, fmt.Sprintf("%v failed", name))
+		case "last_value":
+			r, b := function.exec(fctx, []interface{}{[]interface{}{nil}, []interface{}{false}})
+			require.True(t, b, fmt.Sprintf("%v failed", name))
+			require.Nil(t, r, fmt.Sprintf("%v failed", name))
+			r, b = function.exec(fctx, []interface{}{[]interface{}{1, 2, nil}, []interface{}{true}})
+			require.True(t, b, fmt.Sprintf("%v failed", name))
+			require.Equal(t, r, 2, fmt.Sprintf("%v failed", name))
+			r, b = function.exec(fctx, []interface{}{[]interface{}{1, 2, nil}, []interface{}{false}})
+			require.True(t, b, fmt.Sprintf("%v failed", name))
+			require.Equal(t, r, nil, fmt.Sprintf("%v failed", name))
+			r, b = function.check([]interface{}{nil})
+			require.True(t, b, fmt.Sprintf("%v failed", name))
+			require.Nil(t, r, fmt.Sprintf("%v failed", name))
 		default:
 			r, b := function.check([]interface{}{nil})
 			require.True(t, b, fmt.Sprintf("%v failed", name))
@@ -407,3 +421,209 @@ func TestAggFuncNil(t *testing.T) {
 		}
 	}
 }
+
+func TestLastValue(t *testing.T) {
+	f, ok := builtins["last_value"]
+	if !ok {
+		t.Fatal("builtin not found")
+	}
+	contextLogger := conf.Log.WithField("rule", "testExec")
+	ctx := kctx.WithValue(kctx.Background(), kctx.LoggerKey, contextLogger)
+	tempStore, _ := state.CreateStore("mockRule0", api.AtMostOnce)
+	fctx := kctx.NewDefaultFuncContext(ctx.WithMeta("mockRule0", "test", tempStore), 2)
+	tests := []struct {
+		args   []interface{}
+		result interface{}
+	}{
+		{
+			args: []interface{}{
+				[]interface{}{
+					"foo",
+					"bar",
+					"self",
+				},
+				[]interface{}{
+					true,
+					true,
+					true,
+				},
+			},
+			result: "self",
+		},
+		{
+			args: []interface{}{
+				[]interface{}{
+					"foo",
+					"bar",
+					"self",
+				},
+				[]interface{}{
+					false,
+					false,
+					false,
+				},
+			},
+			result: "self",
+		},
+		{
+			args: []interface{}{
+				[]interface{}{
+					int64(100),
+					float64(3.14),
+					1,
+				},
+				[]interface{}{
+					true,
+					true,
+					true,
+				},
+			},
+			result: int(1),
+		},
+		{
+			args: []interface{}{
+				[]interface{}{
+					int64(100),
+					float64(3.14),
+					1,
+				},
+				[]interface{}{
+					false,
+					false,
+					false,
+				},
+			},
+			result: 1,
+		},
+		{
+			args: []interface{}{
+				[]interface{}{
+					int64(100),
+					float64(3.14),
+					nil,
+				},
+				[]interface{}{
+					true,
+					true,
+					true,
+				},
+			},
+			result: float64(3.14),
+		},
+		{
+			args: []interface{}{
+				[]interface{}{
+					int64(100),
+					float64(3.14),
+					nil,
+				},
+				[]interface{}{
+					false,
+					false,
+					false,
+				},
+			},
+			result: nil,
+		},
+		{
+			args: []interface{}{
+				[]interface{}{
+					nil,
+					nil,
+					nil,
+				},
+				[]interface{}{
+					true,
+					true,
+					true,
+				},
+			},
+			result: nil,
+		},
+		{
+			args: []interface{}{
+				[]interface{}{
+					nil,
+					nil,
+					nil,
+				},
+				[]interface{}{
+					false,
+					false,
+					false,
+				},
+			},
+			result: nil,
+		},
+		{
+			args: []interface{}{
+				1,
+				true,
+			},
+			result: fmt.Errorf("Invalid argument type found."),
+		},
+		{
+			args: []interface{}{
+				[]interface{}{1},
+				true,
+			},
+			result: fmt.Errorf("Invalid argument type found."),
+		},
+		{
+			args: []interface{}{
+				[]interface{}{1},
+				[]interface{}{1},
+			},
+			result: fmt.Errorf("Invalid argument type found."),
+		},
+		{
+			args: []interface{}{
+				[]interface{}{},
+				[]interface{}{true},
+			},
+			result: nil,
+		},
+	}
+
+	for i, tt := range tests {
+		r, _ := f.exec(fctx, tt.args)
+		if !reflect.DeepEqual(r, tt.result) {
+			t.Errorf("%d result mismatch,\ngot:\t%v \nwant:\t%v", i, r, tt.result)
+		}
+	}
+}
+
+func TestLastValueValidation(t *testing.T) {
+	f, ok := builtins["last_value"]
+	if !ok {
+		t.Fatal("builtin not found")
+	}
+	tests := []struct {
+		args []ast.Expr
+		err  error
+	}{
+		{
+			args: []ast.Expr{
+				&ast.BooleanLiteral{Val: true},
+			},
+			err: fmt.Errorf("Expect 2 arguments but found 1."),
+		}, {
+			args: []ast.Expr{
+				&ast.FieldRef{Name: "foo"},
+				&ast.FieldRef{Name: "bar"},
+			},
+			err: fmt.Errorf("Expect bool type for parameter 2"),
+		}, {
+			args: []ast.Expr{
+				&ast.StringLiteral{Val: "foo"},
+				&ast.BooleanLiteral{Val: true},
+			},
+		},
+	}
+	for i, tt := range tests {
+		err := f.val(nil, tt.args)
+		if !reflect.DeepEqual(err, tt.err) {
+			t.Errorf("%d result mismatch,\ngot:\t%v \nwant:\t%v", i, err, tt.err)
+		}
+	}
+}