Pārlūkot izejas kodu

feat(func): add collect function (#442)

ngjaying 4 gadi atpakaļ
vecāks
revīzija
8dffe32e21

+ 1 - 0
docs/en_US/sqls/built-in_functions.md

@@ -14,6 +14,7 @@ Aggregate functions perform a calculation on a set of values and return a single
 | max      | max(col1)   | The maximum value in a group. The null values will be ignored.                  |
 | min      | min(col1)   | The minimum value in a group. The null values will be ignored.                   |
 | sum      | sum(col1)   | The sum of all the values in a group. The null values will be ignored.           |
+| collect   | collect(*), collect(col1)   | Returns an array with all column or the whole record (when the parameter is *) values from the group.    |
 
 ## Mathematical Functions
 | Function | Example     | Description                                    |

+ 1 - 0
docs/zh_CN/sqls/built-in_functions.md

@@ -14,6 +14,7 @@ Kuiper 具有许多内置函数,可以对数据执行计算。
 | max      | max(col1) | 组中的最大值。空值不参与计算。     |
 | min      | min(col1) | 组中的最小值。空值不参与计算。     |
 | sum      | sum(col1) | 组中所有值的总和。空值不参与计算。 |
+| collect   | collect(*), collect(col1)   | 返回组中指定的列或整个消息(参数为*时)的值组成的数组。    |
 
 ## 数学函数
 | 函数 | 示例   | 说明                                  |

+ 17 - 13
xsql/ast.go

@@ -54,6 +54,8 @@ const (
 	CROSS_JOIN
 )
 
+var AsteriskExpr = StringLiteral{Val: "*"}
+
 var COLUMN_SEPARATOR = tokens[COLSEP]
 
 type Join struct {
@@ -622,7 +624,7 @@ func (t *Tuple) All(stream string) (interface{}, bool) {
 }
 
 func (t *Tuple) AggregateEval(expr Expr, v CallValuer) []interface{} {
-	return []interface{}{Eval(expr, t, v)}
+	return []interface{}{Eval(expr, MultiValuer(t, v, &WildcardValuer{t}))}
 }
 
 func (t *Tuple) GetTimestamp() int64 {
@@ -709,7 +711,7 @@ func (w WindowTuplesSet) AggregateEval(expr Expr, v CallValuer) []interface{} {
 		return nil
 	}
 	for _, t := range w[0].Tuples {
-		result = append(result, Eval(expr, &t, v))
+		result = append(result, Eval(expr, MultiValuer(&t, v, &WildcardValuer{&t})))
 	}
 	return result
 }
@@ -810,7 +812,7 @@ func (s JoinTupleSets) Index(i int) Valuer { return &(s[i]) }
 func (s JoinTupleSets) AggregateEval(expr Expr, v CallValuer) []interface{} {
 	var result []interface{}
 	for _, t := range s {
-		result = append(result, Eval(expr, &t, v))
+		result = append(result, Eval(expr, MultiValuer(&t, v, &WildcardValuer{&t})))
 	}
 	return result
 }
@@ -820,7 +822,7 @@ type GroupedTuples []DataValuer
 func (s GroupedTuples) AggregateEval(expr Expr, v CallValuer) []interface{} {
 	var result []interface{}
 	for _, t := range s {
-		result = append(result, Eval(expr, t, v))
+		result = append(result, Eval(expr, MultiValuer(t, v, &WildcardValuer{t})))
 	}
 	return result
 }
@@ -970,8 +972,8 @@ type EvalResultMessage struct {
 type ResultsAndMessages []EvalResultMessage
 
 // Eval evaluates expr against a map.
-func Eval(expr Expr, m Valuer, v CallValuer) interface{} {
-	eval := ValuerEval{Valuer: MultiValuer(m, v)}
+func Eval(expr Expr, m Valuer) interface{} {
+	eval := ValuerEval{Valuer: m}
 	return eval.Eval(expr)
 }
 
@@ -1109,16 +1111,16 @@ func (v *ValuerEval) Eval(expr Expr) interface{} {
 	case *Call:
 		if valuer, ok := v.Valuer.(CallValuer); ok {
 			var args []interface{}
-
 			if len(expr.Args) > 0 {
 				args = make([]interface{}, len(expr.Args))
-				if aggreValuer, ok := valuer.(AggregateCallValuer); ok {
-					for i := range expr.Args {
-						args[i] = aggreValuer.GetAllTuples().AggregateEval(expr.Args[i], aggreValuer.GetSingleCallValuer())
+				for i, arg := range expr.Args {
+					if expr.Name == "collect" && reflect.DeepEqual(arg, &AsteriskExpr) {
+						arg = &Wildcard{Token: ASTERISK}
 					}
-				} else {
-					for i := range expr.Args {
-						args[i] = v.Eval(expr.Args[i])
+					if aggreValuer, ok := valuer.(AggregateCallValuer); ok {
+						args[i] = aggreValuer.GetAllTuples().AggregateEval(arg, aggreValuer.GetSingleCallValuer())
+					} else {
+						args[i] = v.Eval(arg)
 						if _, ok := args[i].(error); ok {
 							return args[i]
 						}
@@ -1160,6 +1162,8 @@ func (v *ValuerEval) evalBinaryExpr(expr *BinaryExpr) interface{} {
 	switch val := lhs.(type) {
 	case map[string]interface{}:
 		return v.evalJsonExpr(val, expr.OP, expr.RHS)
+	case Message:
+		return v.evalJsonExpr(map[string]interface{}(val), expr.OP, expr.RHS)
 	case error:
 		return val
 	}

+ 2 - 0
xsql/funcs_aggregate.go

@@ -170,6 +170,8 @@ func (v *AggregateFunctionValuer) Call(name string, args []interface{}) (interfa
 			}
 		}
 		return 0, true
+	case "collect":
+		return args[0], true
 	default:
 		common.Log.Debugf("run aggregate func %s", name)
 		nf, fctx, err := v.funcPlugins.GetFuncFromPlugin(name)

+ 4 - 0
xsql/funcs_ast_validator.go

@@ -347,6 +347,10 @@ func validateAggFunc(name string, args []Expr) error {
 		if err := validateLen(name, 1, len); err != nil {
 			return err
 		}
+	case "collect":
+		if err := validateLen(name, 1, len); err != nil {
+			return err
+		}
 	}
 	return nil
 }

+ 5 - 0
xsql/funcs_ast_validator_test.go

@@ -450,6 +450,11 @@ func TestFuncValidator(t *testing.T) {
 			stmt: nil,
 			err:  "Expect string type for 2 parameter of function json_path_query.",
 		},
+		{
+			s:    `SELECT collect() from tbl`,
+			stmt: nil,
+			err:  "The arguments for collect should be 1.",
+		},
 	}
 
 	fmt.Printf("The test bucket size is %d.\n\n", len(tests))

+ 2 - 1
xsql/functions.go

@@ -28,7 +28,8 @@ func (*FunctionValuer) Meta(_ string) (interface{}, bool) {
 var aggFuncMap = map[string]string{"avg": "",
 	"count": "",
 	"max":   "", "min": "",
-	"sum": "",
+	"sum":     "",
+	"collect": "",
 }
 
 var mathFuncMap = map[string]string{"abs": "", "acos": "", "asin": "", "atan": "", "atan2": "",

+ 3 - 0
xsql/parser.go

@@ -604,6 +604,9 @@ func (p *Parser) parseCall(name string) (Expr, error) {
 	var args []Expr
 	for {
 		if tok, _ := p.scanIgnoreWhitespace(); tok == RPAREN {
+			if valErr := validateFuncs(name, nil); valErr != nil {
+				return nil, valErr
+			}
 			return &Call{Name: name, Args: args}, nil
 		} else if tok == ASTERISK {
 			if tok2, lit2 := p.scanIgnoreWhitespace(); tok2 != RPAREN {

+ 3 - 15
xsql/parser_test.go

@@ -1205,21 +1205,9 @@ func TestParser_ParseStatement(t *testing.T) {
 		},
 
 		{
-			s: `SELECT sample(-.3,) FROM tbl`,
-			stmt: &SelectStatement{
-				Fields: []Field{
-					{
-						Expr: &Call{
-							Name: "sample",
-							Args: []Expr{
-								&NumberLiteral{Val: -0.3},
-							},
-						},
-						Name:  "sample",
-						AName: ""},
-				},
-				Sources: []Source{&Table{Name: "tbl"}},
-			},
+			s:    `SELECT sample(-.3,) FROM tbl`,
+			stmt: nil,
+			err:  "cannot get the plugin file name: invalid name sample: not exist",
 		},
 
 		{

+ 148 - 0
xsql/plans/project_test.go

@@ -1707,6 +1707,154 @@ func TestProjectPlan_AggFuncs(t *testing.T) {
 				"c": float64(2),
 				"d": "devicec",
 			}},
+		}, {
+			sql: "SELECT * FROM test Inner Join test1 on test.id = test1.id GROUP BY TumblingWindow(ss, 10), test1.color",
+			data: xsql.GroupedTuplesSet{
+				{
+					&xsql.JoinTuple{
+						Tuples: []xsql.Tuple{
+							{Emitter: "test", Message: xsql.Message{"id": 1, "a": 122.33, "c": 2, "r": 122}},
+							{Emitter: "src2", Message: xsql.Message{"id": 1, "color": "w2"}},
+						},
+					},
+					&xsql.JoinTuple{
+						Tuples: []xsql.Tuple{
+							{Emitter: "test", Message: xsql.Message{"id": 5, "a": 177.51}},
+							{Emitter: "src2", Message: xsql.Message{"id": 5, "color": "w2"}},
+						},
+					},
+				},
+				{
+					&xsql.JoinTuple{
+						Tuples: []xsql.Tuple{
+							{Emitter: "test", Message: xsql.Message{"id": 2, "a": 89.03, "c": 2, "r": 89}},
+							{Emitter: "src2", Message: xsql.Message{"id": 2, "color": "w1"}},
+						},
+					},
+					&xsql.JoinTuple{
+						Tuples: []xsql.Tuple{
+							{Emitter: "test", Message: xsql.Message{"id": 4, "a": 14.6}},
+							{Emitter: "src2", Message: xsql.Message{"id": 4, "color": "w1"}},
+						},
+					},
+				},
+			},
+			result: []map[string]interface{}{{
+				"a":     float64(122.33),
+				"c":     float64(2),
+				"color": "w2",
+				"id":    float64(1),
+				"r":     float64(122),
+			}, {
+				"a":     float64(89.03),
+				"c":     float64(2),
+				"color": "w1",
+				"id":    float64(2),
+				"r":     float64(89),
+			}},
+		}, {
+			sql: "SELECT collect(a) as r1 FROM test Inner Join test1 on test.id = test1.id GROUP BY TumblingWindow(ss, 10), test1.color",
+			data: xsql.GroupedTuplesSet{
+				{
+					&xsql.JoinTuple{
+						Tuples: []xsql.Tuple{
+							{Emitter: "test", Message: xsql.Message{"id": 1, "a": 122.33, "c": 2, "r": 122}},
+							{Emitter: "src2", Message: xsql.Message{"id": 1, "color": "w2"}},
+						},
+					},
+					&xsql.JoinTuple{
+						Tuples: []xsql.Tuple{
+							{Emitter: "test", Message: xsql.Message{"id": 5, "a": 177.51}},
+							{Emitter: "src2", Message: xsql.Message{"id": 5, "color": "w2"}},
+						},
+					},
+				},
+				{
+					&xsql.JoinTuple{
+						Tuples: []xsql.Tuple{
+							{Emitter: "test", Message: xsql.Message{"id": 2, "a": 89.03, "c": 2, "r": 89}},
+							{Emitter: "src2", Message: xsql.Message{"id": 2, "color": "w1"}},
+						},
+					},
+					&xsql.JoinTuple{
+						Tuples: []xsql.Tuple{
+							{Emitter: "test", Message: xsql.Message{"id": 4, "a": 14.6}},
+							{Emitter: "src2", Message: xsql.Message{"id": 4, "color": "w1"}},
+						},
+					},
+				},
+			},
+			result: []map[string]interface{}{{
+				"r1": []interface{}{122.33, 177.51},
+			}, {"r1": []interface{}{89.03, 14.6}}},
+		}, {
+			sql: "SELECT collect(*)[1] as c1 FROM test GROUP BY TumblingWindow(ss, 10)",
+			data: xsql.WindowTuplesSet{
+				xsql.WindowTuples{
+					Emitter: "test",
+					Tuples: []xsql.Tuple{
+						{
+							Emitter: "src1",
+							Message: xsql.Message{"a": 53, "s": 123203},
+						}, {
+							Emitter: "src1",
+							Message: xsql.Message{"a": 27},
+						}, {
+							Emitter: "src1",
+							Message: xsql.Message{"a": 123123},
+						},
+					},
+				},
+			},
+			result: []map[string]interface{}{{
+				"c1": map[string]interface{}{
+					"a": float64(27),
+				},
+			}},
+		}, {
+			sql: "SELECT collect(*)[1]->a as c1 FROM test GROUP BY TumblingWindow(ss, 10)",
+			data: xsql.WindowTuplesSet{
+				xsql.WindowTuples{
+					Emitter: "test",
+					Tuples: []xsql.Tuple{
+						{
+							Emitter: "src1",
+							Message: xsql.Message{"a": 53, "s": 123203},
+						}, {
+							Emitter: "src1",
+							Message: xsql.Message{"a": 27},
+						}, {
+							Emitter: "src1",
+							Message: xsql.Message{"a": 123123},
+						},
+					},
+				},
+			},
+			result: []map[string]interface{}{{
+				"c1": float64(27),
+			}},
+		}, {
+			sql: "SELECT collect(*)[1]->sl[0] as c1 FROM test GROUP BY TumblingWindow(ss, 10)",
+			data: xsql.WindowTuplesSet{
+				xsql.WindowTuples{
+					Emitter: "test",
+					Tuples: []xsql.Tuple{
+						{
+							Emitter: "src1",
+							Message: xsql.Message{"a": 53, "sl": []string{"hello", "world"}},
+						}, {
+							Emitter: "src1",
+							Message: xsql.Message{"a": 27, "sl": []string{"new", "horizon"}},
+						}, {
+							Emitter: "src1",
+							Message: xsql.Message{"a": 123123, "sl": []string{"south", "africa"}},
+						},
+					},
+				},
+			},
+			result: []map[string]interface{}{{
+				"c1": "new",
+			}},
 		},
 	}
 

+ 32 - 0
xsql/processors/window_rule_test.go

@@ -564,6 +564,38 @@ func TestWindow(t *testing.T) {
 				"op_window_0_records_in_total":   int64(3),
 				"op_window_0_records_out_total":  int64(4),
 			},
+		}, {
+			name: `TestCountWindowRule1`,
+			sql:  `SELECT collect(*)[0]->color as c FROM demo GROUP BY COUNTWINDOW(3)`,
+			r: [][]map[string]interface{}{
+				{{
+					"c": "red",
+				}},
+			},
+			m: map[string]interface{}{
+				"op_preprocessor_demo_0_exceptions_total":   int64(0),
+				"op_preprocessor_demo_0_process_latency_ms": int64(0),
+				"op_preprocessor_demo_0_records_in_total":   int64(5),
+				"op_preprocessor_demo_0_records_out_total":  int64(5),
+
+				"op_project_0_exceptions_total":   int64(0),
+				"op_project_0_process_latency_ms": int64(0),
+				"op_project_0_records_in_total":   int64(1),
+				"op_project_0_records_out_total":  int64(1),
+
+				"sink_mockSink_0_exceptions_total":  int64(0),
+				"sink_mockSink_0_records_in_total":  int64(1),
+				"sink_mockSink_0_records_out_total": int64(1),
+
+				"source_demo_0_exceptions_total":  int64(0),
+				"source_demo_0_records_in_total":  int64(5),
+				"source_demo_0_records_out_total": int64(5),
+
+				"op_window_0_exceptions_total":   int64(0),
+				"op_window_0_process_latency_ms": int64(0),
+				"op_window_0_records_in_total":   int64(5),
+				"op_window_0_records_out_total":  int64(1),
+			},
 		},
 	}
 	handleStream(true, streamList, t)