Selaa lähdekoodia

fix: fix loop error when checking schema (#2225)

Signed-off-by: Song Gao <disxiaofei@163.com>
Song Gao 1 vuosi sitten
vanhempi
commit
1657b4cd48
2 muutettua tiedostoa jossa 73 lisäystä ja 6 poistoa
  1. 24 6
      internal/topo/planner/analyzer.go
  2. 49 0
      internal/topo/planner/analyzer_test.go

+ 24 - 6
internal/topo/planner/analyzer.go

@@ -55,7 +55,9 @@ func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*
 		return nil, nil, nil, fmt.Errorf("select fields have cycled alias")
 	}
 	if !isSchemaless {
-		aliasFieldTopoSort(s, streamStmts)
+		if err := aliasFieldTopoSort(s, streamStmts); err != nil {
+			return nil, nil, nil, err
+		}
 	}
 	dsn := ast.DefaultStream
 	if len(streamsFromStmt) == 1 {
@@ -294,7 +296,7 @@ func dfsRef(aliasRef map[string]map[string]struct{}, walked map[string]struct{},
 	return false
 }
 
-func aliasFieldTopoSort(s *ast.SelectStatement, streamStmts []*streamInfo) {
+func aliasFieldTopoSort(s *ast.SelectStatement, streamStmts []*streamInfo) error {
 	nonAliasFields := make([]ast.Field, 0)
 	aliasDegreeMap := make(map[string]*aliasTopoDegree)
 	for _, field := range s.Fields {
@@ -311,7 +313,7 @@ func aliasFieldTopoSort(s *ast.SelectStatement, streamStmts []*streamInfo) {
 	for !isAliasFieldTopoSortFinish(aliasDegreeMap) {
 		for _, field := range s.Fields {
 			if field.AName != "" && aliasDegreeMap[field.AName].degree < 0 {
-				skip := false
+				unknownFieldRefName := ""
 				degree := 0
 				ast.WalkFunc(field.Expr, func(node ast.Node) bool {
 					switch f := node.(type) {
@@ -323,15 +325,30 @@ func aliasFieldTopoSort(s *ast.SelectStatement, streamStmts []*streamInfo) {
 							return true
 						}
 						if !isFieldRefNameExists(f.Name, streamStmts) {
-							skip = true
+							unknownFieldRefName = f.Name
 							return false
 						}
 					}
 					return true
 				})
-				if !skip {
-					aliasDegreeMap[field.AName].degree = degree
+
+				if len(unknownFieldRefName) > 0 {
+					unknownField := true
+					for _, otherField := range s.Fields {
+						if field == otherField {
+							continue
+						}
+						// the unknownFieldRef name belongs to a alias
+						if otherField.AName == unknownFieldRefName {
+							unknownField = false
+							break
+						}
+					}
+					if unknownField {
+						return fmt.Errorf("unknown field %s", unknownFieldRefName)
+					}
 				}
+				aliasDegreeMap[field.AName].degree = degree
 			}
 		}
 	}
@@ -345,6 +362,7 @@ func aliasFieldTopoSort(s *ast.SelectStatement, streamStmts []*streamInfo) {
 		s.Fields = append(s.Fields, d.field)
 	}
 	s.Fields = append(s.Fields, nonAliasFields...)
+	return nil
 }
 
 func isFieldRefNameExists(name string, streamStmts []*streamInfo) bool {

+ 49 - 0
internal/topo/planner/analyzer_test.go

@@ -16,12 +16,14 @@ package planner
 
 import (
 	"encoding/json"
+	"errors"
 	"fmt"
 	"reflect"
 	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 
 	"github.com/lf-edge/ekuiper/internal/pkg/store"
 	"github.com/lf-edge/ekuiper/internal/testx"
@@ -152,6 +154,53 @@ var tests = []struct {
 	//},
 }
 
+func TestCheckTopoSort(t *testing.T) {
+	store, err := store.GetKV("stream")
+	require.NoError(t, err)
+	streamSqls := map[string]string{
+		"src1": `CREATE STREAM src1 (
+					id1 BIGINT,
+					temp BIGINT,
+					name string,
+					next STRUCT(NAME STRING, NID BIGINT)
+				) WITH (DATASOURCE="src1", FORMAT="json", KEY="ts");`,
+	}
+	types := map[string]ast.StreamType{
+		"src1": ast.TypeStream,
+	}
+	for name, sql := range streamSqls {
+		s, err := json.Marshal(&xsql.StreamInfo{
+			StreamType: types[name],
+			Statement:  sql,
+		})
+		require.NoError(t, err)
+		store.Set(name, string(s))
+	}
+	streams := make(map[string]*ast.StreamStmt)
+	for n := range streamSqls {
+		streamStmt, err := xsql.GetDataSource(store, n)
+		if err != nil {
+			t.Errorf("fail to get stream %s, please check if stream is created", n)
+			return
+		}
+		streams[n] = streamStmt
+	}
+	sql := "select latest(a) as a from src1"
+	stmt, err := xsql.NewParser(strings.NewReader(sql)).Parse()
+	require.NoError(t, err)
+	_, err = createLogicalPlan(stmt, &api.RuleOption{
+		IsEventTime:        false,
+		LateTol:            0,
+		Concurrency:        0,
+		BufferLength:       0,
+		SendMetaToSink:     false,
+		Qos:                0,
+		CheckpointInterval: 0,
+		SendError:          true,
+	}, store)
+	require.Equal(t, errors.New("unknown field a"), err)
+}
+
 func Test_validation(t *testing.T) {
 	tests[10].r = newErrorStruct("invalid argument for func sum: aggregate argument is not allowed")
 	store, err := store.GetKV("stream")