Browse Source

fix: fix alias cycle reference (#2184)

Signed-off-by: yisaer <disxiaofei@163.com>
Song Gao 1 year ago
parent
commit
9a176fd651
2 changed files with 28 additions and 7 deletions
  1. 24 7
      internal/topo/planner/analyzer.go
  2. 4 0
      internal/topo/planner/analyzer_test.go

+ 24 - 7
internal/topo/planner/analyzer.go

@@ -256,13 +256,9 @@ func checkAliasReferenceCycle(s *ast.SelectStatement) bool {
 						_, ok := aliasRef[f.Name]
 						if ok {
 							aliasRef[field.AName][f.Name] = struct{}{}
-							v, ok1 := aliasRef[f.Name]
-							if ok1 {
-								_, ok2 := v[field.AName]
-								if ok2 {
-									hasCycleAlias = true
-									return false
-								}
+							if dfsRef(aliasRef, map[string]struct{}{}, f.Name, field.AName) {
+								hasCycleAlias = true
+								return false
 							}
 						}
 					}
@@ -277,6 +273,27 @@ func checkAliasReferenceCycle(s *ast.SelectStatement) bool {
 	return false
 }
 
+func dfsRef(aliasRef map[string]map[string]struct{}, walked map[string]struct{}, currentName, targetName string) bool {
+	defer func() {
+		walked[currentName] = struct{}{}
+	}()
+	for refName := range aliasRef[currentName] {
+		if refName == targetName {
+			return true
+		}
+	}
+	for name := range aliasRef[currentName] {
+		_, ok := walked[name]
+		if ok {
+			continue
+		}
+		if dfsRef(aliasRef, walked, name, targetName) {
+			return true
+		}
+	}
+	return false
+}
+
 func aliasFieldTopoSort(s *ast.SelectStatement, streamStmts []*streamInfo) {
 	nonAliasFields := make([]ast.Field, 0)
 	aliasDegreeMap := make(map[string]*aliasTopoDegree)

+ 4 - 0
internal/topo/planner/analyzer_test.go

@@ -142,6 +142,10 @@ var tests = []struct {
 		sql: "select a + 1 as b, b + 1 as a from src1",
 		r:   newErrorStruct("select fields have cycled alias"),
 	},
+	{
+		sql: "select a + 1 as b, b * 2 as c, c + 1 as a from src1",
+		r:   newErrorStruct("select fields have cycled alias"),
+	},
 	//{ // 19 already captured in parser
 	//	sql: `SELECT * FROM src1 GROUP BY SlidingWindow(ss,5) Over (WHEN abs(sum(a)) > 1) HAVING last_agg_hit_count() < 3`,
 	//	r:   newErrorStruct("error compile sql: Not allowed to call aggregate functions in GROUP BY clause."),