Quellcode durchsuchen

feat(sql): support limit clause (#2066)

* support limit parser

Signed-off-by: yisaer <disxiaofei@163.com>

* support limit plan

Signed-off-by: yisaer <disxiaofei@163.com>

* support limit operator

Signed-off-by: yisaer <disxiaofei@163.com>

* fix lint

Signed-off-by: yisaer <disxiaofei@163.com>

* address the comment

Signed-off-by: yisaer <disxiaofei@163.com>

* add test

Signed-off-by: yisaer <disxiaofei@163.com>

* add doc

Signed-off-by: yisaer <disxiaofei@163.com>

* add doc

Signed-off-by: yisaer <disxiaofei@163.com>

---------

Signed-off-by: yisaer <disxiaofei@163.com>
Song Gao vor 1 Jahr
Ursprung
Commit
050a276b5d

+ 9 - 0
docs/en_US/sqls/query_language_elements.md

@@ -12,6 +12,7 @@ eKuiper provides a variety of elements for building queries. They are summarized
 | [GROUP BY](#group-by) | GROUP BY groups a selected set of rows into a set of summary rows grouped by the values of one or more columns or expressions. It must run within a [window](./windows.md).                                                                   |
 | [ORDER BY](#order-by) | Order the rows by values of one or more columns.                                                                                                                                                                                              |
 | [HAVING](#having)     | HAVING specifies a search condition for a group or an aggregate. HAVING can be used only with the SELECT expression.                                                                                                                          |
+| [LIMIT](#limit) | LIMIT will limit the number of output data. |
 
 ## SELECT
 
@@ -357,6 +358,14 @@ FROM table_name
 ORDER BY column1, column2, ... ASC|DESC;
 ```
 
+## LIMIT
+
+Limit the number of output data
+
+```sql
+LIMIT 1
+```
+
 ## Case Expression
 
 The case expression evaluates a list of conditions and returns one of multiple possible result expressions. It let you use IF ... THEN ... ELSE logic in SQL statements without having to invoke procedures.

+ 9 - 0
docs/zh_CN/sqls/query_language_elements.md

@@ -13,6 +13,7 @@ eKuiper 提供了用于构建查询的各种元素。 总结如下。
 | [ORDER BY](#order-by) | 按一列或多列的值对行进行排序。                                                                                                                |
 | [HAVING](#having)     | HAVING 为组或集合指定搜索条件。 HAVING 只能与 SELECT 表达式一起使用。                                                                                 |
 |                       |                                                                                                                                |
+| [LIMIT](#limit)       | LIMIT 将输出的数据条数进行数量上的限制 |
 
 ## SELECT
 
@@ -359,6 +360,14 @@ FROM table_name
 ORDER BY column1, column2, ... ASC|DESC;
 ```
 
+## LIMIT
+
+将输出的数据条数进行限制
+
+```sql
+LIMIT 1
+```
+
 ## Case 表达式
 
 Case 表达式评估一系列条件,并返回多个可能的结果表达式之一。它允许你在 SQL 语句中使用 IF ... THEN ... ELSE 逻辑,而无需调用过程。

+ 21 - 0
internal/topo/operator/project_operator.go

@@ -34,6 +34,8 @@ type ProjectOp struct {
 	AliasFields      ast.Fields
 	ExprFields       ast.Fields
 	IsAggregate      bool
+	EnableLimit      bool
+	LimitCount       int
 
 	SendMeta bool
 
@@ -49,6 +51,9 @@ type ProjectOp struct {
 func (pp *ProjectOp) Apply(ctx api.StreamContext, data interface{}, fv *xsql.FunctionValuer, afv *xsql.AggregateFunctionValuer) interface{} {
 	log := ctx.GetLogger()
 	log.Debugf("project plan receive %v", data)
+	if pp.LimitCount == 0 && pp.EnableLimit {
+		return []xsql.TupleRow{}
+	}
 	switch input := data.(type) {
 	case error:
 		return input
@@ -88,6 +93,14 @@ func (pp *ProjectOp) Apply(ctx api.StreamContext, data interface{}, fv *xsql.Fun
 		if err != nil {
 			return err
 		}
+		if pp.EnableLimit && pp.LimitCount > 0 && input.Len() > pp.LimitCount {
+			var sel []int
+			sel = make([]int, pp.LimitCount, pp.LimitCount)
+			for i := 0; i < pp.LimitCount; i++ {
+				sel[i] = i
+			}
+			return input.Filter(sel)
+		}
 	case xsql.GroupedCollection: // The order is important, because single collection usually is also a groupedCollection
 		err := input.GroupRange(func(_ int, aggRow xsql.CollectionRow) (bool, error) {
 			ve := pp.getVE(aggRow, aggRow, input.GetWindowRange(), fv, afv)
@@ -99,6 +112,14 @@ func (pp *ProjectOp) Apply(ctx api.StreamContext, data interface{}, fv *xsql.Fun
 		if err != nil {
 			return err
 		}
+		if pp.EnableLimit && pp.LimitCount > 0 && input.Len() > pp.LimitCount {
+			var sel []int
+			sel = make([]int, pp.LimitCount, pp.LimitCount)
+			for i := 0; i < pp.LimitCount; i++ {
+				sel[i] = i
+			}
+			return input.Filter(sel)
+		}
 	default:
 		return fmt.Errorf("run Select error: invalid input %[1]T(%[1]v)", input)
 	}

+ 17 - 1
internal/topo/operator/projectset_operator.go

@@ -22,7 +22,9 @@ import (
 )
 
 type ProjectSetOperator struct {
-	SrfMapping map[string]struct{}
+	SrfMapping  map[string]struct{}
+	EnableLimit bool
+	LimitCount  int
 }
 
 // Apply implement UnOperation
@@ -32,6 +34,9 @@ type ProjectSetOperator struct {
 // For Collection, ProjectSetOperator will do the following transform:
 // [{"a":[1,2],"b":3},{"a":[1,2],"b":4}] = > [{"a":"1","b":3},{"a":"2","b":3},{"a":"1","b":4},{"a":"2","b":4}]
 func (ps *ProjectSetOperator) Apply(_ api.StreamContext, data interface{}, _ *xsql.FunctionValuer, _ *xsql.AggregateFunctionValuer) interface{} {
+	if ps.LimitCount == 0 && ps.EnableLimit {
+		return []xsql.TupleRow{}
+	}
 	switch input := data.(type) {
 	case error:
 		return input
@@ -40,11 +45,22 @@ func (ps *ProjectSetOperator) Apply(_ api.StreamContext, data interface{}, _ *xs
 		if err != nil {
 			return err
 		}
+		if ps.EnableLimit && ps.LimitCount > 0 && len(results.rows) > ps.LimitCount {
+			return results.rows[:ps.LimitCount]
+		}
 		return results.rows
 	case xsql.Collection:
 		if err := ps.handleSRFRowForCollection(input); err != nil {
 			return err
 		}
+		if ps.EnableLimit && ps.LimitCount > 0 && input.Len() > ps.LimitCount {
+			var sel []int
+			sel = make([]int, ps.LimitCount, ps.LimitCount)
+			for i := 0; i < ps.LimitCount; i++ {
+				sel[i] = i
+			}
+			return input.Filter(sel)
+		}
 		return input
 	default:
 		return fmt.Errorf("run Select error: invalid input %[1]T(%[1]v)", input)

+ 20 - 4
internal/topo/planner/planner.go

@@ -177,9 +177,9 @@ func buildOps(lp LogicalPlan, tp *topo.Topo, options *api.RuleOption, sources []
 	case *OrderPlan:
 		op = Transform(&operator.OrderOp{SortFields: t.SortFields}, fmt.Sprintf("%d_order", newIndex), options)
 	case *ProjectPlan:
-		op = Transform(&operator.ProjectOp{ColNames: t.colNames, AliasNames: t.aliasNames, AliasFields: t.aliasFields, ExprFields: t.exprFields, ExceptNames: t.exceptNames, IsAggregate: t.isAggregate, AllWildcard: t.allWildcard, WildcardEmitters: t.wildcardEmitters, ExprNames: t.exprNames, SendMeta: t.sendMeta}, fmt.Sprintf("%d_project", newIndex), options)
+		op = Transform(&operator.ProjectOp{ColNames: t.colNames, AliasNames: t.aliasNames, AliasFields: t.aliasFields, ExprFields: t.exprFields, ExceptNames: t.exceptNames, IsAggregate: t.isAggregate, AllWildcard: t.allWildcard, WildcardEmitters: t.wildcardEmitters, ExprNames: t.exprNames, SendMeta: t.sendMeta, LimitCount: t.limitCount, EnableLimit: t.enableLimit}, fmt.Sprintf("%d_project", newIndex), options)
 	case *ProjectSetPlan:
-		op = Transform(&operator.ProjectSetOperator{SrfMapping: t.SrfMapping}, fmt.Sprintf("%d_projectset", newIndex), options)
+		op = Transform(&operator.ProjectSetOperator{SrfMapping: t.SrfMapping, LimitCount: t.limitCount, EnableLimit: t.enableLimit}, fmt.Sprintf("%d_projectset", newIndex), options)
 	default:
 		err = fmt.Errorf("unknown logical plan %v", t)
 	}
@@ -463,20 +463,36 @@ func createLogicalPlan(stmt *ast.SelectStatement, opt *api.RuleOption, store kv.
 		children = []LogicalPlan{p}
 	}
 
+	srfMapping := extractSRFMapping(stmt)
 	if stmt.Fields != nil {
+		enableLimit := false
+		limitCount := 0
+		if stmt.Limit != nil && len(srfMapping) == 0 {
+			enableLimit = true
+			limitCount = stmt.Limit.(*ast.LimitExpr).LimitCount.Val
+		}
 		p = ProjectPlan{
 			fields:      stmt.Fields,
 			isAggregate: xsql.WithAggFields(stmt),
 			sendMeta:    opt.SendMetaToSink,
+			enableLimit: enableLimit,
+			limitCount:  limitCount,
 		}.Init()
 		p.SetChildren(children)
 		children = []LogicalPlan{p}
 	}
 
-	srfMapping := extractSRFMapping(stmt)
 	if len(srfMapping) > 0 {
+		enableLimit := false
+		limitCount := 0
+		if stmt.Limit != nil {
+			enableLimit = true
+			limitCount = stmt.Limit.(*ast.LimitExpr).LimitCount.Val
+		}
 		p = ProjectSetPlan{
-			SrfMapping: srfMapping,
+			SrfMapping:  srfMapping,
+			enableLimit: enableLimit,
+			limitCount:  limitCount,
 		}.Init()
 		p.SetChildren(children)
 	}

+ 74 - 1
internal/topo/planner/planner_test.go

@@ -98,11 +98,84 @@ func Test_createLogicalPlan(t *testing.T) {
 		err string
 	}{
 		{
-			sql: "select unnest(myarray) as col from src1",
+			sql: "select name from src1 where true limit 1",
+			p: ProjectPlan{
+				baseLogicalPlan: baseLogicalPlan{
+					children: []LogicalPlan{
+						FilterPlan{
+							baseLogicalPlan: baseLogicalPlan{
+								children: []LogicalPlan{
+									DataSourcePlan{
+										baseLogicalPlan: baseLogicalPlan{},
+										name:            "src1",
+										streamFields: map[string]*ast.JsonStreamField{
+											"name": {
+												Type: "string",
+											},
+										},
+										streamStmt: streams["src1"],
+										metaFields: []string{},
+									}.Init(),
+								},
+							},
+							condition: &ast.BooleanLiteral{
+								Val: true,
+							},
+						}.Init(),
+					},
+				},
+				fields: []ast.Field{
+					{
+						Name: "name",
+						Expr: &ast.FieldRef{
+							StreamName: "src1",
+							Name:       "name",
+						},
+					},
+				},
+				limitCount:  1,
+				enableLimit: true,
+			}.Init(),
+		},
+		{
+			sql: "select name from src1 limit 1",
+			p: ProjectPlan{
+				baseLogicalPlan: baseLogicalPlan{
+					children: []LogicalPlan{
+						DataSourcePlan{
+							baseLogicalPlan: baseLogicalPlan{},
+							name:            "src1",
+							streamFields: map[string]*ast.JsonStreamField{
+								"name": {
+									Type: "string",
+								},
+							},
+							streamStmt: streams["src1"],
+							metaFields: []string{},
+						}.Init(),
+					},
+				},
+				fields: []ast.Field{
+					{
+						Name: "name",
+						Expr: &ast.FieldRef{
+							StreamName: "src1",
+							Name:       "name",
+						},
+					},
+				},
+				limitCount:  1,
+				enableLimit: true,
+			}.Init(),
+		},
+		{
+			sql: "select unnest(myarray) as col from src1 limit 1",
 			p: ProjectSetPlan{
 				SrfMapping: map[string]struct{}{
 					"col": {},
 				},
+				limitCount:  1,
+				enableLimit: true,
 				baseLogicalPlan: baseLogicalPlan{
 					children: []LogicalPlan{
 						ProjectPlan{

+ 2 - 0
internal/topo/planner/projectPlan.go

@@ -29,6 +29,8 @@ type ProjectPlan struct {
 	wildcardEmitters map[string]bool
 	aliasFields      ast.Fields
 	exprFields       ast.Fields
+	enableLimit      bool
+	limitCount       int
 }
 
 func (p ProjectPlan) Init() *ProjectPlan {

+ 3 - 1
internal/topo/planner/projectset_plan.go

@@ -16,7 +16,9 @@ package planner
 
 type ProjectSetPlan struct {
 	baseLogicalPlan
-	SrfMapping map[string]struct{}
+	SrfMapping  map[string]struct{}
+	enableLimit bool
+	limitCount  int
 }
 
 func (p ProjectSetPlan) Init() *ProjectSetPlan {

+ 65 - 0
internal/topo/topotest/rule_test.go

@@ -22,6 +22,71 @@ import (
 	"github.com/lf-edge/ekuiper/pkg/api"
 )
 
+func TestLimitSQL(t *testing.T) {
+	// Reset
+	streamList := []string{"demo", "demoArr"}
+	HandleStream(false, streamList, t)
+	var r [][]map[string]interface{}
+	tests := []RuleTest{
+		{
+			Name: "TestLimitSQL0",
+			Sql:  `SELECT unnest(demoArr.arr3) as col, demo.size FROM demo inner join demoArr on demo.size = demoArr.x group by SESSIONWINDOW(ss, 2, 1) limit 1;`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"col":  float64(1),
+						"size": float64(1),
+					},
+				},
+			},
+		},
+		{
+			Name: "TestLimitSQL1",
+			Sql:  `SELECT unnest(demoArr.arr3) as col, demo.size FROM demo inner join demoArr on demo.size = demoArr.x group by SESSIONWINDOW(ss, 2, 1) limit 0;`,
+			R:    r,
+		},
+		{
+			Name: "TestLimitSQL2",
+			Sql:  `SELECT demo.size FROM demo inner join demoArr on demo.size = demoArr.x group by SESSIONWINDOW(ss, 2, 1) limit 1;`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"size": float64(1),
+					},
+				},
+			},
+		},
+		{
+			Name: "TestLimitSQL3",
+			Sql:  `SELECT demo.size FROM demo inner join demoArr on demo.size = demoArr.x group by SESSIONWINDOW(ss, 2, 1) limit 0;`,
+			R:    r,
+		},
+	}
+	// Data setup
+	HandleStream(true, streamList, t)
+	options := []*api.RuleOption{
+		{
+			BufferLength: 100,
+			SendError:    true,
+		},
+		{
+			BufferLength:       100,
+			SendError:          true,
+			Qos:                api.AtLeastOnce,
+			CheckpointInterval: 5000,
+		},
+		{
+			BufferLength:       100,
+			SendError:          true,
+			Qos:                api.ExactlyOnce,
+			CheckpointInterval: 5000,
+		},
+	}
+	for j, opt := range options {
+		DoRuleTest(t, tests, j, opt, 0)
+	}
+}
+
 func TestSRFSQL(t *testing.T) {
 	// Reset
 	streamList := []string{"demo", "demoArr"}

+ 2 - 0
internal/xsql/lexical.go

@@ -245,6 +245,8 @@ func (s *Scanner) ScanIdent() (tok ast.Token, lit string) {
 		return ast.SS, lit
 	case "MS":
 		return ast.MS, lit
+	case "LIMIT":
+		return ast.LIMIT, lit
 	}
 	return ast.IDENT, word
 }

+ 25 - 0
internal/xsql/parser.go

@@ -57,6 +57,22 @@ func (p *Parser) ParseCondition() (ast.Expr, error) {
 	return expr, nil
 }
 
+func (p *Parser) ParseLimit() (ast.Expr, error) {
+	if tok, _ := p.scanIgnoreWhitespace(); tok != ast.LIMIT {
+		p.unscan()
+		return nil, nil
+	}
+	expr, err := p.ParseExpr()
+	if err != nil {
+		return nil, err
+	}
+	limitCount, ok := expr.(*ast.IntegerLiteral)
+	if !ok {
+		return nil, fmt.Errorf("limit should be integer")
+	}
+	return &ast.LimitExpr{LimitCount: limitCount}, nil
+}
+
 func (p *Parser) scan() (tok ast.Token, lit string) {
 	if p.n > 0 {
 		p.n--
@@ -167,6 +183,7 @@ func (p *Parser) Parse() (*ast.SelectStatement, error) {
 			selects.Condition = exp
 		}
 	}
+
 	p.clause = "groupby"
 	if dims, err := p.parseDimensions(); err != nil {
 		return nil, err
@@ -185,6 +202,14 @@ func (p *Parser) Parse() (*ast.SelectStatement, error) {
 	} else {
 		selects.SortFields = sorts
 	}
+	p.clause = "limit"
+	if expr, err := p.ParseLimit(); err != nil {
+		return nil, err
+	} else {
+		if expr != nil {
+			selects.Limit = expr
+		}
+	}
 	p.clause = ""
 	if tok, lit := p.scanIgnoreWhitespace(); tok == ast.SEMICOLON {
 		validateFields(selects, p.sourceNames)

+ 57 - 0
internal/xsql/parser_test.go

@@ -22,6 +22,8 @@ import (
 	"strings"
 	"testing"
 
+	"github.com/stretchr/testify/require"
+
 	"github.com/lf-edge/ekuiper/internal/testx"
 	"github.com/lf-edge/ekuiper/pkg/ast"
 )
@@ -4481,3 +4483,58 @@ func TestParser_ParseStatements(t *testing.T) {
 		}
 	}
 }
+
+func TestParser_ParseLimit(t *testing.T) {
+	tests := []struct {
+		s    string
+		stmt *ast.SelectStatement
+		err  string
+	}{
+		{
+			s: "SELECT name FROM tbl LIMIT 1;",
+			stmt: &ast.SelectStatement{
+				Fields: []ast.Field{
+					{
+						Expr:  &ast.FieldRef{Name: "name", StreamName: ast.DefaultStream},
+						Name:  "name",
+						AName: "",
+					},
+				},
+				Sources: []ast.Source{&ast.Table{Name: "tbl"}},
+				Limit: &ast.LimitExpr{
+					LimitCount: &ast.IntegerLiteral{
+						Val: 1,
+					},
+				},
+			},
+		},
+		{
+			s: "SELECT name FROM tbl where true LIMIT 1;",
+			stmt: &ast.SelectStatement{
+				Fields: []ast.Field{
+					{
+						Expr:  &ast.FieldRef{Name: "name", StreamName: ast.DefaultStream},
+						Name:  "name",
+						AName: "",
+					},
+				},
+				Sources: []ast.Source{&ast.Table{Name: "tbl"}},
+				Condition: &ast.BooleanLiteral{
+					Val: true,
+				},
+				Limit: &ast.LimitExpr{
+					LimitCount: &ast.IntegerLiteral{
+						Val: 1,
+					},
+				},
+			},
+		},
+	}
+
+	fmt.Printf("The test bucket size is %d.\n\n", len(tests))
+	for _, tt := range tests {
+		stmt, err := NewParser(strings.NewReader(tt.s)).Parse()
+		require.NoError(t, err)
+		require.Equal(t, tt.stmt, stmt)
+	}
+}

+ 7 - 0
pkg/ast/expr.go

@@ -198,6 +198,13 @@ type BetweenExpr struct {
 func (b *BetweenExpr) expr() {}
 func (b *BetweenExpr) node() {}
 
+type LimitExpr struct {
+	LimitCount *IntegerLiteral
+}
+
+func (l *LimitExpr) expr() {}
+func (l *LimitExpr) node() {}
+
 type StreamName string
 
 func (sn *StreamName) node() {}

+ 1 - 0
pkg/ast/statement.go

@@ -24,6 +24,7 @@ type SelectStatement struct {
 	Sources    Sources
 	Joins      Joins
 	Condition  Expr
+	Limit      Expr
 	Dimensions Dimensions
 	Having     Expr
 	SortFields SortFields

+ 2 - 0
pkg/ast/token.go

@@ -94,6 +94,7 @@ const (
 	CROSS
 	ON
 	WHERE
+	LIMIT
 	GROUP
 	ORDER
 	HAVING
@@ -170,6 +171,7 @@ var Tokens = []string{
 	INNER:     "INNER",
 	ON:        "ON",
 	WHERE:     "WHERE",
+	LIMIT:     "LIMIT",
 	GROUP:     "GROUP",
 	ORDER:     "ORDER",
 	HAVING:    "HAVING",