Browse Source

feat: optimize the limit processing (#2084)

* optimize limit

Signed-off-by: yisaer <disxiaofei@163.com>

* add test

Signed-off-by: yisaer <disxiaofei@163.com>

* address the comment

Signed-off-by: yisaer <disxiaofei@163.com>

* address the comment

Signed-off-by: yisaer <disxiaofei@163.com>

---------

Signed-off-by: yisaer <disxiaofei@163.com>
Song Gao 1 năm trước cách đây
mục cha
commit
73b9783679

+ 18 - 11
internal/topo/operator/project_operator.go

@@ -77,7 +77,23 @@ func (pp *ProjectOp) Apply(ctx api.StreamContext, data interface{}, fv *xsql.Fun
 				}
 				return true, nil
 			})
+			if pp.EnableLimit && pp.LimitCount > 0 && input.Len() > pp.LimitCount {
+				var sel []int
+				sel = make([]int, pp.LimitCount, pp.LimitCount)
+				for i := 0; i < pp.LimitCount; i++ {
+					sel[i] = i
+				}
+				input = input.Filter(sel).(xsql.SingleCollection)
+			}
 		} else {
+			if pp.EnableLimit && pp.LimitCount > 0 && input.Len() > pp.LimitCount {
+				var sel []int
+				sel = make([]int, pp.LimitCount, pp.LimitCount)
+				for i := 0; i < pp.LimitCount; i++ {
+					sel[i] = i
+				}
+				input = input.Filter(sel).(xsql.SingleCollection)
+			}
 			err = input.RangeSet(func(_ int, row xsql.Row) (bool, error) {
 				aggData, ok := input.(xsql.AggregateData)
 				if !ok {
@@ -93,15 +109,15 @@ func (pp *ProjectOp) Apply(ctx api.StreamContext, data interface{}, fv *xsql.Fun
 		if err != nil {
 			return err
 		}
+	case xsql.GroupedCollection: // The order is important, because single collection usually is also a groupedCollection
 		if pp.EnableLimit && pp.LimitCount > 0 && input.Len() > pp.LimitCount {
 			var sel []int
 			sel = make([]int, pp.LimitCount, pp.LimitCount)
 			for i := 0; i < pp.LimitCount; i++ {
 				sel[i] = i
 			}
-			return input.Filter(sel)
+			input = input.Filter(sel).(xsql.GroupedCollection)
 		}
-	case xsql.GroupedCollection: // The order is important, because single collection usually is also a groupedCollection
 		err := input.GroupRange(func(_ int, aggRow xsql.CollectionRow) (bool, error) {
 			ve := pp.getVE(aggRow, aggRow, input.GetWindowRange(), fv, afv)
 			if err := pp.project(aggRow, ve); err != nil {
@@ -112,18 +128,9 @@ func (pp *ProjectOp) Apply(ctx api.StreamContext, data interface{}, fv *xsql.Fun
 		if err != nil {
 			return err
 		}
-		if pp.EnableLimit && pp.LimitCount > 0 && input.Len() > pp.LimitCount {
-			var sel []int
-			sel = make([]int, pp.LimitCount, pp.LimitCount)
-			for i := 0; i < pp.LimitCount; i++ {
-				sel[i] = i
-			}
-			return input.Filter(sel)
-		}
 	default:
 		return fmt.Errorf("run Select error: invalid input %[1]T(%[1]v)", input)
 	}
-
 	return data
 }
 

+ 7 - 0
internal/topo/operator/projectset_operator.go

@@ -50,6 +50,13 @@ func (ps *ProjectSetOperator) Apply(_ api.StreamContext, data interface{}, _ *xs
 		}
 		return results.rows
 	case xsql.Collection:
+		if ps.EnableLimit && ps.LimitCount > 0 && input.Len() > ps.LimitCount {
+			sel := make([]int, 0, ps.LimitCount)
+			for i := 0; i < ps.LimitCount; i++ {
+				sel = append(sel, i)
+			}
+			input = input.Filter(sel)
+		}
 		if err := ps.handleSRFRowForCollection(input); err != nil {
 			return err
 		}

+ 2 - 0
internal/topo/topotest/mock_topo.go

@@ -272,6 +272,8 @@ func HandleStream(createOrDrop bool, names []string, t *testing.T) {
 			switch name {
 			case "demoE2":
 				sql = `CREATE STREAM demoE2 () WITH (DATASOURCE="demoE2", TYPE="mock", FORMAT="json", KEY="ts", TIMESTAMP="ts");`
+			case "demoArr2":
+				sql = `CREATE STREAM demoArr2 () WITH (DATASOURCE="demoArr2", TYPE="mock", FORMAT="json", KEY="ts");`
 			case "demoArr":
 				sql = `CREATE STREAM demoArr () WITH (DATASOURCE="demoArr", TYPE="mock", FORMAT="json", KEY="ts");`
 			case "demo":

+ 18 - 0
internal/topo/topotest/mocknode/mock_data.go

@@ -1025,6 +1025,24 @@ var TestData = map[string][]*xsql.Tuple{
 			Timestamp: 1541152489253,
 		},
 	},
+	"demoArr2": {
+		{
+			Emitter: "demoArr2",
+			Message: map[string]interface{}{
+				"arr": []interface{}{1, 2, 3, 4, 5},
+				"x":   1,
+			},
+			Timestamp: 1541152489253,
+		},
+		{
+			Emitter: "demoArr2",
+			Message: map[string]interface{}{
+				"arr": []interface{}{1, 2, 3, 4, 5},
+				"x":   1,
+			},
+			Timestamp: 1541152489254,
+		},
+	},
 	"demoE2": {
 		{
 			Emitter: "demoE2",

+ 13 - 1
internal/topo/topotest/rule_test.go

@@ -24,11 +24,23 @@ import (
 
 func TestLimitSQL(t *testing.T) {
 	// Reset
-	streamList := []string{"demo", "demoArr"}
+	streamList := []string{"demo", "demoArr", "demoArr2"}
 	HandleStream(false, streamList, t)
 	var r [][]map[string]interface{}
 	tests := []RuleTest{
 		{
+			Name: "TestLimitSQL01",
+			Sql:  `SELECT unnest(demoArr2.arr) as col, demo.size FROM demo inner join demoArr2 on demo.size = demoArr2.x group by SESSIONWINDOW(ss, 2, 1) limit 1;`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"col":  float64(1),
+						"size": float64(1),
+					},
+				},
+			},
+		},
+		{
 			Name: "TestLimitSQL0",
 			Sql:  `SELECT unnest(demoArr.arr3) as col, demo.size FROM demo inner join demoArr on demo.size = demoArr.x group by SESSIONWINDOW(ss, 2, 1) limit 1;`,
 			R: [][]map[string]interface{}{