瀏覽代碼

feat(func): lag func support ignoreNull (#2223)

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
ngjaying 1 年之前
父節點
當前提交
f00bbc5329

+ 2 - 2
docs/en_US/sqls/functions/analytic_functions.md

@@ -30,7 +30,7 @@ AnalyticFuncName(<arguments>...) OVER ([WHEN <Expression>])
 ## LAG
 ## LAG
 
 
 ```text
 ```text
-lag(expr, [offset], [default value])
+lag(expr, [offset], [default value], [ignore null])
 ```
 ```
 
 
 Return the former result of expression at offset, if not found, return the default value specified, if default value not
 Return the former result of expression at offset, if not found, return the default value specified, if default value not
@@ -52,7 +52,7 @@ Example function call to calculate duration of events: ts is timestamp, and stat
 status in the same event
 status in the same event
 
 
 ```text
 ```text
-select lag(Status) as Status, ts - lag(ts, 1, ts) OVER (WHEN had_changed(true, statusCode)) as duration from demo
+select lag(Status) as Status, ts - lag(ts, 1, ts, true) OVER (WHEN had_changed(true, statusCode)) as duration from demo
 ```
 ```
 
 
 ## LATEST
 ## LATEST

+ 2 - 2
docs/zh_CN/sqls/functions/analytic_functions.md

@@ -25,7 +25,7 @@ AnalyticFuncName(<arguments>...) OVER ([WHEN <Expression>])
 ## LAG
 ## LAG
 
 
 ```text
 ```text
-lag(expr, [offset], [default value])
+lag(expr, [offset], [default value], [ignore null])
 ```
 ```
 
 
 返回表达式前一个值在偏移 offset 处的结果,如果没有找到,则返回默认值,如果没有指定默认值则返回 nil。
 返回表达式前一个值在偏移 offset 处的结果,如果没有找到,则返回默认值,如果没有指定默认值则返回 nil。
@@ -45,7 +45,7 @@ lag(temperature) OVER (PARTITION BY deviceId)
 示例3:ts为时间戳,获取设备状态 statusCode1 和 statusCode2 不相等持续时间
 示例3:ts为时间戳,获取设备状态 statusCode1 和 statusCode2 不相等持续时间
 
 
 ```text
 ```text
-select lag(Status) as Status, ts - lag(ts, 1, ts) OVER (WHEN had_changed(true, statusCode)) as duration from demo
+select lag(Status) as Status, ts - lag(ts, 1, ts, true) OVER (WHEN had_changed(true, statusCode)) as duration from demo
 ```
 ```
 
 
 ## LATEST
 ## LATEST

+ 23 - 10
internal/binder/function/funcs_analytic.go

@@ -125,7 +125,7 @@ func registerAnalyticFunc() {
 		fType: ast.FuncTypeScalar,
 		fType: ast.FuncTypeScalar,
 		exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
 		exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
 			l := len(args) - 2
 			l := len(args) - 2
-			if l != 1 && l != 2 && l != 3 {
+			if l != 1 && l != 2 && l != 3 && l != 4 {
 				return fmt.Errorf("expect one two or three args but got %d", l), false
 				return fmt.Errorf("expect one two or three args but got %d", l), false
 			}
 			}
 			key := args[len(args)-1].(string)
 			key := args[len(args)-1].(string)
@@ -137,6 +137,13 @@ func registerAnalyticFunc() {
 			if !ok {
 			if !ok {
 				return fmt.Errorf("when arg is not a bool but got %v", args[len(args)-2]), false
 				return fmt.Errorf("when arg is not a bool but got %v", args[len(args)-2]), false
 			}
 			}
+			ignoreNull := false
+			if l == 4 {
+				ignoreNull, ok = args[3].(bool)
+				if !ok {
+					return fmt.Errorf("The fourth arg is not a bool but got %v", args[0]), false
+				}
+			}
 			paraLen := len(args) - 2
 			paraLen := len(args) - 2
 			var rq *ringqueue = nil
 			var rq *ringqueue = nil
 			var rtnVal interface{} = nil
 			var rtnVal interface{} = nil
@@ -157,6 +164,7 @@ func registerAnalyticFunc() {
 					}
 					}
 				}
 				}
 				rq = newRingqueue(size)
 				rq = newRingqueue(size)
+
 				rq.fill(dftVal)
 				rq.fill(dftVal)
 				err := ctx.PutState(key, rq)
 				err := ctx.PutState(key, rq)
 				if err != nil {
 				if err != nil {
@@ -165,22 +173,22 @@ func registerAnalyticFunc() {
 			} else {
 			} else {
 				rq, _ = v.(*ringqueue)
 				rq, _ = v.(*ringqueue)
 			}
 			}
-
+			rtnVal, _ = rq.peek()
 			if validData {
 			if validData {
-				rtnVal, _ = rq.fetch()
-				rq.append(args[0])
-				err := ctx.PutState(key, rq)
-				if err != nil {
-					return fmt.Errorf("error setting state for %s: %v", key, err), false
+				if !ignoreNull || args[0] != nil {
+					rtnVal, _ = rq.fetch()
+					rq.append(args[0])
+					err := ctx.PutState(key, rq)
+					if err != nil {
+						return fmt.Errorf("error setting state for %s: %v", key, err), false
+					}
 				}
 				}
-			} else {
-				rtnVal, _ = rq.peek()
 			}
 			}
 			return rtnVal, true
 			return rtnVal, true
 		},
 		},
 		val: func(_ api.FunctionContext, args []ast.Expr) error {
 		val: func(_ api.FunctionContext, args []ast.Expr) error {
 			l := len(args)
 			l := len(args)
-			if l != 1 && l != 2 && l != 3 {
+			if l != 1 && l != 2 && l != 3 && l != 4 {
 				return fmt.Errorf("expect one two or three args but got %d", l)
 				return fmt.Errorf("expect one two or three args but got %d", l)
 			}
 			}
 			if l >= 2 {
 			if l >= 2 {
@@ -193,6 +201,11 @@ func registerAnalyticFunc() {
 					}
 					}
 				}
 				}
 			}
 			}
+			if l == 4 {
+				if ast.IsNumericArg(args[3]) || ast.IsTimeArg(args[3]) || ast.IsStringArg(args[3]) {
+					return ProduceErrInfo(3, "bool")
+				}
+			}
 			return nil
 			return nil
 		},
 		},
 	}
 	}

+ 79 - 6
internal/binder/function/funcs_analytic_test.go

@@ -15,10 +15,12 @@
 package function
 package function
 
 
 import (
 import (
+	"errors"
 	"fmt"
 	"fmt"
 	"reflect"
 	"reflect"
 	"testing"
 	"testing"
 
 
+	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 	"github.com/stretchr/testify/require"
 
 
 	"github.com/lf-edge/ekuiper/internal/conf"
 	"github.com/lf-edge/ekuiper/internal/conf"
@@ -778,6 +780,51 @@ func TestHadChangedPartitionWithWhen(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestLagValidation(t *testing.T) {
+	f, ok := builtins["lag"]
+	if !ok {
+		t.Fatal("builtin not found")
+	}
+	tests := []struct {
+		args []ast.Expr
+		err  error
+	}{
+		{
+			args: []ast.Expr{
+				&ast.StringLiteral{Val: "foo"},
+			},
+			err: nil,
+		}, {
+			args: []ast.Expr{
+				&ast.StringLiteral{Val: "foo"},
+				&ast.StringLiteral{Val: "bar"},
+			},
+			err: fmt.Errorf("Expect int type for parameter 2"),
+		}, {
+			args: []ast.Expr{
+				&ast.StringLiteral{Val: "foo"},
+				&ast.StringLiteral{Val: "bar"},
+				&ast.StringLiteral{Val: "baz"},
+			},
+			err: fmt.Errorf("Expect int type for parameter 2"),
+		}, {
+			args: []ast.Expr{
+				&ast.BooleanLiteral{Val: true},
+				&ast.IntegerLiteral{Val: 23},
+				&ast.StringLiteral{Val: "baz"},
+				&ast.StringLiteral{Val: "baz"},
+			},
+			err: fmt.Errorf("Expect bool type for parameter 4"),
+		},
+	}
+	for i, tt := range tests {
+		err := f.val(nil, tt.args)
+		if !reflect.DeepEqual(err, tt.err) {
+			t.Errorf("%d result mismatch,\ngot:\t%v \nwant:\t%v", i, err, tt.err)
+		}
+	}
+}
+
 func TestLagExec(t *testing.T) {
 func TestLagExec(t *testing.T) {
 	f, ok := builtins["lag"]
 	f, ok := builtins["lag"]
 	if !ok {
 	if !ok {
@@ -794,6 +841,9 @@ func TestLagExec(t *testing.T) {
 		{ // 1
 		{ // 1
 			args: []interface{}{
 			args: []interface{}{
 				"foo",
 				"foo",
+				1,
+				"default",
+				true,
 				true,
 				true,
 				"self",
 				"self",
 			},
 			},
@@ -801,7 +851,10 @@ func TestLagExec(t *testing.T) {
 		},
 		},
 		{ // 2
 		{ // 2
 			args: []interface{}{
 			args: []interface{}{
-				"bar",
+				nil,
+				1,
+				"default",
+				true,
 				true,
 				true,
 				"self",
 				"self",
 			},
 			},
@@ -810,14 +863,20 @@ func TestLagExec(t *testing.T) {
 		{ // 3
 		{ // 3
 			args: []interface{}{
 			args: []interface{}{
 				"bar",
 				"bar",
+				1,
+				"default",
+				true,
 				true,
 				true,
 				"self",
 				"self",
 			},
 			},
-			result: "bar",
+			result: "foo",
 		},
 		},
 		{ // 4
 		{ // 4
 			args: []interface{}{
 			args: []interface{}{
 				"foo",
 				"foo",
+				1,
+				"default",
+				true,
 				true,
 				true,
 				"self",
 				"self",
 			},
 			},
@@ -826,17 +885,31 @@ func TestLagExec(t *testing.T) {
 		{ // 4
 		{ // 4
 			args: []interface{}{
 			args: []interface{}{
 				"foo",
 				"foo",
+				1,
+				"default",
+				true,
 				true,
 				true,
 				"self",
 				"self",
 			},
 			},
 			result: "foo",
 			result: "foo",
 		},
 		},
+		{ // 5
+			args: []interface{}{
+				"foo",
+				1,
+				"default",
+				23,
+				true,
+				"self",
+			},
+			result: errors.New("The fourth arg is not a bool but got foo"),
+		},
 	}
 	}
 	for i, tt := range tests {
 	for i, tt := range tests {
-		result, _ := f.exec(fctx, tt.args)
-		if !reflect.DeepEqual(result, tt.result) {
-			t.Errorf("%d result mismatch,\ngot:\t%v \nwant:\t%v", i, result, tt.result)
-		}
+		t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
+			result, _ := f.exec(fctx, tt.args)
+			assert.Equal(t, tt.result, result)
+		})
 	}
 	}
 }
 }