Explorar o código

feat(func): lag func support ignoreNull (#2223)

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
ngjaying hai 1 ano
pai
achega
f00bbc5329

+ 2 - 2
docs/en_US/sqls/functions/analytic_functions.md

@@ -30,7 +30,7 @@ AnalyticFuncName(<arguments>...) OVER ([WHEN <Expression>])
 ## LAG
 
 ```text
-lag(expr, [offset], [default value])
+lag(expr, [offset], [default value], [ignore null])
 ```
 
 Return the former result of expression at offset, if not found, return the default value specified, if default value not
@@ -52,7 +52,7 @@ Example function call to calculate duration of events: ts is timestamp, and stat
 status in the same event
 
 ```text
-select lag(Status) as Status, ts - lag(ts, 1, ts) OVER (WHEN had_changed(true, statusCode)) as duration from demo
+select lag(Status) as Status, ts - lag(ts, 1, ts, true) OVER (WHEN had_changed(true, statusCode)) as duration from demo
 ```
 
 ## LATEST

+ 2 - 2
docs/zh_CN/sqls/functions/analytic_functions.md

@@ -25,7 +25,7 @@ AnalyticFuncName(<arguments>...) OVER ([WHEN <Expression>])
 ## LAG
 
 ```text
-lag(expr, [offset], [default value])
+lag(expr, [offset], [default value], [ignore null])
 ```
 
 返回表达式前一个值在偏移 offset 处的结果,如果没有找到,则返回默认值,如果没有指定默认值则返回 nil。
@@ -45,7 +45,7 @@ lag(temperature) OVER (PARTITION BY deviceId)
 示例3:ts为时间戳,获取设备状态 statusCode1 和 statusCode2 不相等持续时间
 
 ```text
-select lag(Status) as Status, ts - lag(ts, 1, ts) OVER (WHEN had_changed(true, statusCode)) as duration from demo
+select lag(Status) as Status, ts - lag(ts, 1, ts, true) OVER (WHEN had_changed(true, statusCode)) as duration from demo
 ```
 
 ## LATEST

+ 23 - 10
internal/binder/function/funcs_analytic.go

@@ -125,7 +125,7 @@ func registerAnalyticFunc() {
 		fType: ast.FuncTypeScalar,
 		exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
 			l := len(args) - 2
-			if l != 1 && l != 2 && l != 3 {
+			if l != 1 && l != 2 && l != 3 && l != 4 {
 				return fmt.Errorf("expect one two or three args but got %d", l), false
 			}
 			key := args[len(args)-1].(string)
@@ -137,6 +137,13 @@ func registerAnalyticFunc() {
 			if !ok {
 				return fmt.Errorf("when arg is not a bool but got %v", args[len(args)-2]), false
 			}
+			ignoreNull := false
+			if l == 4 {
+				ignoreNull, ok = args[3].(bool)
+				if !ok {
+					return fmt.Errorf("The fourth arg is not a bool but got %v", args[0]), false
+				}
+			}
 			paraLen := len(args) - 2
 			var rq *ringqueue = nil
 			var rtnVal interface{} = nil
@@ -157,6 +164,7 @@ func registerAnalyticFunc() {
 					}
 				}
 				rq = newRingqueue(size)
+
 				rq.fill(dftVal)
 				err := ctx.PutState(key, rq)
 				if err != nil {
@@ -165,22 +173,22 @@ func registerAnalyticFunc() {
 			} else {
 				rq, _ = v.(*ringqueue)
 			}
-
+			rtnVal, _ = rq.peek()
 			if validData {
-				rtnVal, _ = rq.fetch()
-				rq.append(args[0])
-				err := ctx.PutState(key, rq)
-				if err != nil {
-					return fmt.Errorf("error setting state for %s: %v", key, err), false
+				if !ignoreNull || args[0] != nil {
+					rtnVal, _ = rq.fetch()
+					rq.append(args[0])
+					err := ctx.PutState(key, rq)
+					if err != nil {
+						return fmt.Errorf("error setting state for %s: %v", key, err), false
+					}
 				}
-			} else {
-				rtnVal, _ = rq.peek()
 			}
 			return rtnVal, true
 		},
 		val: func(_ api.FunctionContext, args []ast.Expr) error {
 			l := len(args)
-			if l != 1 && l != 2 && l != 3 {
+			if l != 1 && l != 2 && l != 3 && l != 4 {
 				return fmt.Errorf("expect one two or three args but got %d", l)
 			}
 			if l >= 2 {
@@ -193,6 +201,11 @@ func registerAnalyticFunc() {
 					}
 				}
 			}
+			if l == 4 {
+				if ast.IsNumericArg(args[3]) || ast.IsTimeArg(args[3]) || ast.IsStringArg(args[3]) {
+					return ProduceErrInfo(3, "bool")
+				}
+			}
 			return nil
 		},
 	}

+ 79 - 6
internal/binder/function/funcs_analytic_test.go

@@ -15,10 +15,12 @@
 package function
 
 import (
+	"errors"
 	"fmt"
 	"reflect"
 	"testing"
 
+	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 
 	"github.com/lf-edge/ekuiper/internal/conf"
@@ -778,6 +780,51 @@ func TestHadChangedPartitionWithWhen(t *testing.T) {
 	}
 }
 
+func TestLagValidation(t *testing.T) {
+	f, ok := builtins["lag"]
+	if !ok {
+		t.Fatal("builtin not found")
+	}
+	tests := []struct {
+		args []ast.Expr
+		err  error
+	}{
+		{
+			args: []ast.Expr{
+				&ast.StringLiteral{Val: "foo"},
+			},
+			err: nil,
+		}, {
+			args: []ast.Expr{
+				&ast.StringLiteral{Val: "foo"},
+				&ast.StringLiteral{Val: "bar"},
+			},
+			err: fmt.Errorf("Expect int type for parameter 2"),
+		}, {
+			args: []ast.Expr{
+				&ast.StringLiteral{Val: "foo"},
+				&ast.StringLiteral{Val: "bar"},
+				&ast.StringLiteral{Val: "baz"},
+			},
+			err: fmt.Errorf("Expect int type for parameter 2"),
+		}, {
+			args: []ast.Expr{
+				&ast.BooleanLiteral{Val: true},
+				&ast.IntegerLiteral{Val: 23},
+				&ast.StringLiteral{Val: "baz"},
+				&ast.StringLiteral{Val: "baz"},
+			},
+			err: fmt.Errorf("Expect bool type for parameter 4"),
+		},
+	}
+	for i, tt := range tests {
+		err := f.val(nil, tt.args)
+		if !reflect.DeepEqual(err, tt.err) {
+			t.Errorf("%d result mismatch,\ngot:\t%v \nwant:\t%v", i, err, tt.err)
+		}
+	}
+}
+
 func TestLagExec(t *testing.T) {
 	f, ok := builtins["lag"]
 	if !ok {
@@ -794,6 +841,9 @@ func TestLagExec(t *testing.T) {
 		{ // 1
 			args: []interface{}{
 				"foo",
+				1,
+				"default",
+				true,
 				true,
 				"self",
 			},
@@ -801,7 +851,10 @@ func TestLagExec(t *testing.T) {
 		},
 		{ // 2
 			args: []interface{}{
-				"bar",
+				nil,
+				1,
+				"default",
+				true,
 				true,
 				"self",
 			},
@@ -810,14 +863,20 @@ func TestLagExec(t *testing.T) {
 		{ // 3
 			args: []interface{}{
 				"bar",
+				1,
+				"default",
+				true,
 				true,
 				"self",
 			},
-			result: "bar",
+			result: "foo",
 		},
 		{ // 4
 			args: []interface{}{
 				"foo",
+				1,
+				"default",
+				true,
 				true,
 				"self",
 			},
@@ -826,17 +885,31 @@ func TestLagExec(t *testing.T) {
 		{ // 4
 			args: []interface{}{
 				"foo",
+				1,
+				"default",
+				true,
 				true,
 				"self",
 			},
 			result: "foo",
 		},
+		{ // 5
+			args: []interface{}{
+				"foo",
+				1,
+				"default",
+				23,
+				true,
+				"self",
+			},
+			result: errors.New("The fourth arg is not a bool but got foo"),
+		},
 	}
 	for i, tt := range tests {
-		result, _ := f.exec(fctx, tt.args)
-		if !reflect.DeepEqual(result, tt.result) {
-			t.Errorf("%d result mismatch,\ngot:\t%v \nwant:\t%v", i, result, tt.result)
-		}
+		t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
+			result, _ := f.exec(fctx, tt.args)
+			assert.Equal(t, tt.result, result)
+		})
 	}
 }