Procházet zdrojové kódy

fix: analyze funcs reference alias analyze (#2113)

* separte analyze funcs

Signed-off-by: yisaer <disxiaofei@163.com>

* separte analyze funcs

Signed-off-by: yisaer <disxiaofei@163.com>

* fix

Signed-off-by: yisaer <disxiaofei@163.com>

* fix

Signed-off-by: yisaer <disxiaofei@163.com>

* add test

Signed-off-by: yisaer <disxiaofei@163.com>

* rebase

Signed-off-by: yisaer <disxiaofei@163.com>

* fix lint

Signed-off-by: yisaer <disxiaofei@163.com>

* add test

Signed-off-by: yisaer <disxiaofei@163.com>

---------

Signed-off-by: yisaer <disxiaofei@163.com>
Song Gao před 1 rokem
rodič
revize
39d3a3c74b

+ 48 - 21
internal/topo/operator/analyticfuncs_operator.go

@@ -23,41 +23,68 @@ import (
 )
 
 type AnalyticFuncsOp struct {
-	Funcs []*ast.Call // Must range from end to start, because the later one may use the result of the former one
+	Funcs      []*ast.Call
+	FieldFuncs []*ast.Call
+}
+
+func (p *AnalyticFuncsOp) evalTupleFunc(calls []*ast.Call, ve *xsql.ValuerEval, input xsql.TupleRow) (xsql.TupleRow, error) {
+	for _, call := range calls {
+		f := call
+		result := ve.Eval(f)
+		if e, ok := result.(error); ok {
+			return nil, e
+		}
+		input.Set(f.CachedField, result)
+	}
+	return input, nil
+}
+
+func (p *AnalyticFuncsOp) evalCollectionFunc(calls []*ast.Call, fv *xsql.FunctionValuer, input xsql.SingleCollection) (xsql.SingleCollection, error) {
+	err := input.RangeSet(func(_ int, row xsql.Row) (bool, error) {
+		ve := &xsql.ValuerEval{Valuer: xsql.MultiValuer(row, &xsql.WindowRangeValuer{WindowRange: input.GetWindowRange()}, fv, &xsql.WildcardValuer{Data: row})}
+		for _, call := range calls {
+			f := call
+			result := ve.Eval(f)
+			if e, ok := result.(error); ok {
+				return false, e
+			}
+			row.Set(f.CachedField, result)
+		}
+		return true, nil
+	})
+	if err != nil {
+		return nil, err
+	}
+	return input, nil
 }
 
 func (p *AnalyticFuncsOp) Apply(ctx api.StreamContext, data interface{}, fv *xsql.FunctionValuer, _ *xsql.AggregateFunctionValuer) interface{} {
 	ctx.GetLogger().Debugf("AnalyticFuncsOp receive: %v", data)
+	var err error
 	switch input := data.(type) {
 	case error:
 		return input
 	case xsql.TupleRow:
 		ve := &xsql.ValuerEval{Valuer: xsql.MultiValuer(input, fv)}
-		// Must range from end to start, because the later one may use the result of the former one
-		for i := len(p.Funcs) - 1; i >= 0; i-- {
-			f := p.Funcs[i]
-			result := ve.Eval(f)
-			if e, ok := result.(error); ok {
-				return e
-			}
-			input.Set(f.CachedField, result)
+		input, err = p.evalTupleFunc(p.FieldFuncs, ve, input)
+		if err != nil {
+			return err
 		}
+		input, err = p.evalTupleFunc(p.Funcs, ve, input)
+		if err != nil {
+			return err
+		}
+		data = input
 	case xsql.SingleCollection:
-		err := input.RangeSet(func(_ int, row xsql.Row) (bool, error) {
-			ve := &xsql.ValuerEval{Valuer: xsql.MultiValuer(row, &xsql.WindowRangeValuer{WindowRange: input.GetWindowRange()}, fv, &xsql.WildcardValuer{Data: row})}
-			for i := len(p.Funcs) - 1; i >= 0; i-- {
-				f := p.Funcs[i]
-				result := ve.Eval(f)
-				if e, ok := result.(error); ok {
-					return false, e
-				}
-				row.Set(f.CachedField, result)
-			}
-			return true, nil
-		})
+		input, err = p.evalCollectionFunc(p.FieldFuncs, fv, input)
+		if err != nil {
+			return err
+		}
+		input, err = p.evalCollectionFunc(p.Funcs, fv, input)
 		if err != nil {
 			return err
 		}
+		data = input
 	default:
 		return fmt.Errorf("run analytic funcs op error: invalid input %[1]T(%[1]v)", input)
 	}

+ 2 - 1
internal/topo/planner/analyticFuncsPlan.go

@@ -18,7 +18,8 @@ import "github.com/lf-edge/ekuiper/pkg/ast"
 
 type AnalyticFuncsPlan struct {
 	baseLogicalPlan
-	funcs []*ast.Call
+	funcs      []*ast.Call
+	fieldFuncs []*ast.Call
 }
 
 func (p AnalyticFuncsPlan) Init() *AnalyticFuncsPlan {

+ 37 - 30
internal/topo/planner/analyzer.go

@@ -33,18 +33,18 @@ type streamInfo struct {
 
 // Analyze the select statement by decorating the info from stream statement.
 // Typically, set the correct stream name for fieldRefs
-func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*ast.Call, error) {
+func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*ast.Call, []*ast.Call, error) {
 	streamsFromStmt := xsql.GetStreams(s)
 	streamStmts := make([]*streamInfo, len(streamsFromStmt))
 	isSchemaless := false
 	for i, s := range streamsFromStmt {
 		streamStmt, err := xsql.GetDataSource(store, s)
 		if err != nil {
-			return nil, nil, fmt.Errorf("fail to get stream %s, please check if stream is created", s)
+			return nil, nil, nil, fmt.Errorf("fail to get stream %s, please check if stream is created", s)
 		}
 		si, err := convertStreamInfo(streamStmt)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, nil, err
 		}
 		streamStmts[i] = si
 		if si.schema == nil {
@@ -52,7 +52,7 @@ func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*
 		}
 	}
 	if checkAliasReferenceCycle(s) {
-		return nil, nil, fmt.Errorf("select fields have cycled alias")
+		return nil, nil, nil, fmt.Errorf("select fields have cycled alias")
 	}
 	if !isSchemaless {
 		aliasFieldTopoSort(s, streamStmts)
@@ -71,9 +71,10 @@ func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*
 		}
 	}
 	var (
-		walkErr       error
-		aliasFields   []*ast.Field
-		analyticFuncs []*ast.Call
+		walkErr            error
+		aliasFields        []*ast.Field
+		analyticFieldFuncs []*ast.Call
+		analyticFuncs      []*ast.Call
 	)
 	// Scan columns fields: bind all field refs, collect alias
 	for i, f := range s.Fields {
@@ -85,7 +86,7 @@ func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*
 			return true
 		})
 		if walkErr != nil {
-			return nil, nil, walkErr
+			return nil, nil, nil, walkErr
 		}
 		if f.AName != "" {
 			aliasFields = append(aliasFields, &s.Fields[i])
@@ -144,7 +145,7 @@ func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*
 		return true
 	})
 	if walkErr != nil {
-		return nil, nil, walkErr
+		return nil, nil, nil, walkErr
 	}
 	walkErr = validate(s)
 	// Collect all analytic function calls so that we can let them run firstly
@@ -170,33 +171,39 @@ func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*
 		return true
 	})
 	if walkErr != nil {
-		return nil, nil, walkErr
+		return nil, nil, nil, walkErr
 	}
 	// walk sources at last to let them run firstly
 	// because another clause may depend on the alias defined here
-	ast.WalkFunc(s.Fields, func(n ast.Node) bool {
-		switch f := n.(type) {
-		case *ast.Call:
-			if function.IsAnalyticFunc(f.Name) {
-				f.CachedField = fmt.Sprintf("%s_%s_%d", function.AnalyticPrefix, f.Name, f.FuncId)
-				f.Cached = true
-				analyticFuncs = append(analyticFuncs, &ast.Call{
-					Name:        f.Name,
-					FuncId:      f.FuncId,
-					FuncType:    f.FuncType,
-					Args:        f.Args,
-					CachedField: f.CachedField,
-					Partition:   f.Partition,
-					WhenExpr:    f.WhenExpr,
-				})
+	for _, field := range s.Fields {
+		var calls []*ast.Call
+		ast.WalkFunc(&field, func(n ast.Node) bool {
+			switch f := n.(type) {
+			case *ast.Call:
+				if function.IsAnalyticFunc(f.Name) {
+					f.CachedField = fmt.Sprintf("%s_%s_%d", function.AnalyticPrefix, f.Name, f.FuncId)
+					f.Cached = true
+					calls = append([]*ast.Call{
+						{
+							Name:        f.Name,
+							FuncId:      f.FuncId,
+							FuncType:    f.FuncType,
+							Args:        f.Args,
+							CachedField: f.CachedField,
+							Partition:   f.Partition,
+							WhenExpr:    f.WhenExpr,
+						},
+					}, calls...)
+				}
 			}
-		}
-		return true
-	})
+			return true
+		})
+		analyticFieldFuncs = append(analyticFieldFuncs, calls...)
+	}
 	if walkErr != nil {
-		return nil, nil, walkErr
+		return nil, nil, nil, walkErr
 	}
-	return streamStmts, analyticFuncs, walkErr
+	return streamStmts, analyticFuncs, analyticFieldFuncs, walkErr
 }
 
 type aliasTopoDegree struct {

+ 5 - 4
internal/topo/planner/planner.go

@@ -130,7 +130,7 @@ func buildOps(lp LogicalPlan, tp *topo.Topo, options *api.RuleOption, sources []
 	case *WatermarkPlan:
 		op = node.NewWatermarkOp(fmt.Sprintf("%d_watermark", newIndex), t.SendWatermark, t.Emitters, options)
 	case *AnalyticFuncsPlan:
-		op = Transform(&operator.AnalyticFuncsOp{Funcs: t.funcs}, fmt.Sprintf("%d_analytic", newIndex), options)
+		op = Transform(&operator.AnalyticFuncsOp{Funcs: t.funcs, FieldFuncs: t.fieldFuncs}, fmt.Sprintf("%d_analytic", newIndex), options)
 	case *WindowPlan:
 		if t.condition != nil {
 			wfilterOp := Transform(&operator.FilterOp{Condition: t.condition}, fmt.Sprintf("%d_windowFilter", newIndex), options)
@@ -303,7 +303,7 @@ func createLogicalPlan(stmt *ast.SelectStatement, opt *api.RuleOption, store kv.
 		ds                  ast.Dimensions
 	)
 
-	streamStmts, analyticFuncs, err := decorateStmt(stmt, store)
+	streamStmts, analyticFuncs, analyticFieldFuncs, err := decorateStmt(stmt, store)
 	if err != nil {
 		return nil, err
 	}
@@ -341,9 +341,10 @@ func createLogicalPlan(stmt *ast.SelectStatement, opt *api.RuleOption, store kv.
 		p.SetChildren(children)
 		children = []LogicalPlan{p}
 	}
-	if len(analyticFuncs) > 0 {
+	if len(analyticFuncs) > 0 || len(analyticFieldFuncs) > 0 {
 		p = AnalyticFuncsPlan{
-			funcs: analyticFuncs,
+			funcs:      analyticFuncs,
+			fieldFuncs: analyticFieldFuncs,
 		}.Init()
 		p.SetChildren(children)
 		children = []LogicalPlan{p}

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 27 - 20
internal/topo/planner/planner_test.go


+ 78 - 2
internal/topo/topotest/rule_test.go

@@ -295,6 +295,80 @@ func TestSingleSQL(t *testing.T) {
 	// Data setup
 	tests := []RuleTest{
 		{
+			Name: `TestAnalyzeFuncAlias1`,
+			Sql:  `SELECT lag(size,1,0) + 1 as b, lag(b,1,0),size FROM demo Group by COUNTWINDOW(5)`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"b":    float64(1),
+						"lag":  float64(0),
+						"size": float64(3),
+					},
+					{
+						"b":    float64(4),
+						"lag":  float64(1),
+						"size": float64(6),
+					},
+					{
+						"b":    float64(7),
+						"lag":  float64(4),
+						"size": float64(2),
+					},
+					{
+						"b":    float64(3),
+						"lag":  float64(7),
+						"size": float64(4),
+					},
+					{
+						"b":    float64(5),
+						"lag":  float64(3),
+						"size": float64(1),
+					},
+				},
+			},
+		},
+		{
+			Name: `TestAnalyzeFuncAlias2`,
+			Sql:  `SELECT lag(size,1,0) + 1 as b, lag(b,1,0),size FROM demo`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"b":    float64(1),
+						"lag":  float64(0),
+						"size": float64(3),
+					},
+				},
+				{
+					{
+						"b":    float64(4),
+						"lag":  float64(1),
+						"size": float64(6),
+					},
+				},
+				{
+					{
+						"b":    float64(7),
+						"lag":  float64(4),
+						"size": float64(2),
+					},
+				},
+				{
+					{
+						"b":    float64(3),
+						"lag":  float64(7),
+						"size": float64(4),
+					},
+				},
+				{
+					{
+						"b":    float64(5),
+						"lag":  float64(3),
+						"size": float64(1),
+					},
+				},
+			},
+		},
+		{
 			Name: `TestSingleSQLRule0`,
 			Sql:  `SELECT arr[x:y+1] as col1 FROM demoArr where x=1`,
 			R: [][]map[string]interface{}{
@@ -971,12 +1045,14 @@ func TestSingleSQL(t *testing.T) {
 		{
 			BufferLength: 100,
 			SendError:    true,
-		}, {
+		},
+		{
 			BufferLength:       100,
 			SendError:          true,
 			Qos:                api.AtLeastOnce,
 			CheckpointInterval: 5000,
-		}, {
+		},
+		{
 			BufferLength:       100,
 			SendError:          true,
 			Qos:                api.ExactlyOnce,