Browse Source

feat: support ref alias in select (#2101)

* support alias

Signed-off-by: yisaer <disxiaofei@163.com>
Song Gao 1 year ago
parent
commit
9087cae9e4

+ 125 - 7
internal/topo/planner/analyzer.go

@@ -16,6 +16,7 @@ package planner
 
 import (
 	"fmt"
+	"sort"
 	"strings"
 
 	"github.com/lf-edge/ekuiper/internal/binder/function"
@@ -50,7 +51,9 @@ func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*
 			isSchemaless = true
 		}
 	}
-
+	if !isSchemaless {
+		aliasFieldTopoSort(s, streamStmts)
+	}
 	dsn := ast.DefaultStream
 	if len(streamsFromStmt) == 1 {
 		dsn = streamStmts[0].stmt.Name
@@ -83,6 +86,7 @@ func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*
 		}
 		if f.AName != "" {
 			aliasFields = append(aliasFields, &s.Fields[i])
+			fieldsMap.bindAlias(f.AName)
 		}
 	}
 	// bind alias field expressions
@@ -97,6 +101,19 @@ func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*
 				AliasRef:   ar,
 			}
 			walkErr = fieldsMap.save(f.AName, ast.AliasStream, ar)
+			for _, subF := range aliasFields {
+				ast.WalkFunc(subF, func(node ast.Node) bool {
+					switch fr := node.(type) {
+					case *ast.FieldRef:
+						if fr.Name == f.AName {
+							fr.StreamName = ast.AliasStream
+							fr.AliasRef = ar
+						}
+						return false
+					}
+					return true
+				})
+			}
 		}
 	}
 	// Bind field ref for alias AND set StreamName for all field ref
@@ -179,6 +196,99 @@ func decorateStmt(s *ast.SelectStatement, store kv.KeyValue) ([]*streamInfo, []*
 	return streamStmts, analyticFuncs, walkErr
 }
 
+type aliasTopoDegree struct {
+	alias  string
+	degree int
+	field  ast.Field
+}
+
+type aliasTopoDegrees []*aliasTopoDegree
+
+func (a aliasTopoDegrees) Len() int {
+	return len(a)
+}
+
+func (a aliasTopoDegrees) Less(i, j int) bool {
+	return a[i].degree < a[j].degree
+}
+
+func (a aliasTopoDegrees) Swap(i, j int) {
+	a[i], a[j] = a[j], a[i]
+}
+
+func aliasFieldTopoSort(s *ast.SelectStatement, streamStmts []*streamInfo) {
+	nonAliasFields := make([]ast.Field, 0)
+	aliasDegreeMap := make(map[string]*aliasTopoDegree)
+	for _, field := range s.Fields {
+		if field.AName != "" {
+			aliasDegreeMap[field.AName] = &aliasTopoDegree{
+				alias:  field.AName,
+				degree: -1,
+				field:  field,
+			}
+		} else {
+			nonAliasFields = append(nonAliasFields, field)
+		}
+	}
+	for !isAliasFieldTopoSortFinish(aliasDegreeMap) {
+		for _, field := range s.Fields {
+			if field.AName != "" && aliasDegreeMap[field.AName].degree < 0 {
+				skip := false
+				degree := 0
+				ast.WalkFunc(field.Expr, func(node ast.Node) bool {
+					switch f := node.(type) {
+					case *ast.FieldRef:
+						if fDegree, ok := aliasDegreeMap[f.Name]; ok && fDegree.degree >= 0 {
+							if degree < fDegree.degree+1 {
+								degree = fDegree.degree + 1
+							}
+							return true
+						}
+						if !isFieldRefNameExists(f.Name, streamStmts) {
+							skip = true
+							return false
+						}
+					}
+					return true
+				})
+				if !skip {
+					aliasDegreeMap[field.AName].degree = degree
+				}
+			}
+		}
+	}
+	as := make(aliasTopoDegrees, 0)
+	for _, degree := range aliasDegreeMap {
+		as = append(as, degree)
+	}
+	sort.Sort(as)
+	s.Fields = make([]ast.Field, 0)
+	for _, d := range as {
+		s.Fields = append(s.Fields, d.field)
+	}
+	s.Fields = append(s.Fields, nonAliasFields...)
+}
+
+func isFieldRefNameExists(name string, streamStmts []*streamInfo) bool {
+	for _, streamStmt := range streamStmts {
+		for _, col := range streamStmt.schema {
+			if col.Name == name {
+				return true
+			}
+		}
+	}
+	return false
+}
+
+func isAliasFieldTopoSortFinish(aliasDegrees map[string]*aliasTopoDegree) bool {
+	for _, aliasDegree := range aliasDegrees {
+		if aliasDegree.degree < 0 {
+			return false
+		}
+	}
+	return true
+}
+
 func validate(s *ast.SelectStatement) (err error) {
 	isAggStmt := false
 	if xsql.IsAggregate(s.Condition) {
@@ -265,12 +375,13 @@ func convertStreamInfo(streamStmt *ast.StreamStmt) (*streamInfo, error) {
 
 type fieldsMap struct {
 	content       map[string]streamFieldStore
+	aliasNames    map[string]struct{}
 	isSchemaless  bool
 	defaultStream ast.StreamName
 }
 
 func newFieldsMap(isSchemaless bool, defaultStream ast.StreamName) *fieldsMap {
-	return &fieldsMap{content: make(map[string]streamFieldStore), isSchemaless: isSchemaless, defaultStream: defaultStream}
+	return &fieldsMap{content: make(map[string]streamFieldStore), aliasNames: map[string]struct{}{}, isSchemaless: isSchemaless, defaultStream: defaultStream}
 }
 
 func (f *fieldsMap) reserve(fieldName string, streamName ast.StreamName) {
@@ -302,10 +413,15 @@ func (f *fieldsMap) save(fieldName string, streamName ast.StreamName, field *ast
 	return nil
 }
 
+func (f *fieldsMap) bindAlias(aliasName string) {
+	f.aliasNames[aliasName] = struct{}{}
+}
+
 func (f *fieldsMap) bind(fr *ast.FieldRef) error {
 	lname := strings.ToLower(fr.Name)
-	fm, ok := f.content[lname]
-	if !ok {
+	fm, ok1 := f.content[lname]
+	_, ok2 := f.aliasNames[lname]
+	if !ok1 && !ok2 {
 		if f.isSchemaless && fr.Name != "" {
 			fm = newStreamFieldStore(f.isSchemaless, f.defaultStream)
 			f.content[lname] = fm
@@ -313,9 +429,11 @@ func (f *fieldsMap) bind(fr *ast.FieldRef) error {
 			return fmt.Errorf("unknown field %s", fr.Name)
 		}
 	}
-	err := fm.bindRef(fr)
-	if err != nil {
-		return fmt.Errorf("%s%s", err, fr.Name)
+	if fm != nil {
+		err := fm.bindRef(fr)
+		if err != nil {
+			return fmt.Errorf("%s%s", err, fr.Name)
+		}
 	}
 	return nil
 }

+ 4 - 4
internal/topo/planner/analyzer_test.go

@@ -103,11 +103,11 @@ var tests = []struct {
 		r:   newErrorStruct(""),
 	},
 	{ // 10
-		sql: `SELECT sum(temp) as temp, count(temp) as temp FROM src1`,
-		r:   newErrorStruct("duplicate alias temp"),
+		sql: `SELECT sum(temp) as temp1, count(temp) as temp FROM src1`,
+		r:   newErrorStruct("invalid argument for func count: aggregate argument is not allowed"),
 	},
 	{ // 11
-		sql: `SELECT sum(temp) as temp, count(temp) as ct FROM src1`,
+		sql: `SELECT sum(temp) as temp1, count(temp) as ct FROM src1`,
 		r:   newErrorStruct(""),
 	},
 	{ // 12
@@ -116,7 +116,7 @@ var tests = []struct {
 	},
 	{ // 13
 		sql: `SELECT sin(temp) as temp1, cos(temp1) FROM src1`,
-		r:   newErrorStructWithS("unknown field temp1", ""),
+		r:   newErrorStruct(""),
 	},
 	{ // 14
 		sql: `SELECT collect(*)[-1] as current FROM src1 GROUP BY COUNTWINDOW(2, 1) HAVING isNull(current->name) = false`,

+ 172 - 0
internal/topo/planner/planner_alias_test.go

@@ -0,0 +1,172 @@
+// Copyright 2023 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package planner
+
+import (
+	"encoding/json"
+	"reflect"
+	"strings"
+	"testing"
+
+	"github.com/gdexlab/go-render/render"
+
+	"github.com/lf-edge/ekuiper/internal/pkg/store"
+	"github.com/lf-edge/ekuiper/internal/xsql"
+	"github.com/lf-edge/ekuiper/pkg/api"
+	"github.com/lf-edge/ekuiper/pkg/ast"
+)
+
+func TestPlannerAlias(t *testing.T) {
+	kv, err := store.GetKV("stream")
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	streamSqls := map[string]string{
+		"src1": `CREATE STREAM src1 (
+				) WITH (DATASOURCE="src1", FORMAT="json", KEY="ts");`,
+		"src2": `CREATE STREAM src2 (
+				) WITH (DATASOURCE="src2", FORMAT="json", KEY="ts");`,
+		"tableInPlanner": `CREATE TABLE tableInPlanner (
+					id BIGINT,
+					name STRING,
+					value STRING,
+					hum BIGINT
+				) WITH (TYPE="file");`,
+	}
+	types := map[string]ast.StreamType{
+		"src1":           ast.TypeStream,
+		"src2":           ast.TypeStream,
+		"tableInPlanner": ast.TypeTable,
+	}
+	for name, sql := range streamSqls {
+		s, err := json.Marshal(&xsql.StreamInfo{
+			StreamType: types[name],
+			Statement:  sql,
+		})
+		if err != nil {
+			t.Error(err)
+			t.Fail()
+		}
+		err = kv.Set(name, string(s))
+		if err != nil {
+			t.Error(err)
+			t.Fail()
+		}
+	}
+	streams := make(map[string]*ast.StreamStmt)
+	for n := range streamSqls {
+		streamStmt, err := xsql.GetDataSource(kv, n)
+		if err != nil {
+			t.Errorf("fail to get stream %s, please check if stream is created", n)
+			return
+		}
+		streams[n] = streamStmt
+	}
+	aliasRef1 := &ast.AliasRef{
+		Expression: &ast.BinaryExpr{
+			OP: ast.ADD,
+			LHS: &ast.FieldRef{
+				StreamName: "src1",
+				Name:       "a",
+			},
+			RHS: &ast.FieldRef{
+				StreamName: "src1",
+				Name:       "b",
+			},
+		},
+	}
+	aliasRef1.SetRefSource([]string{"src1"})
+	aliasRef2 := &ast.AliasRef{
+		Expression: &ast.BinaryExpr{
+			OP: ast.ADD,
+			LHS: &ast.FieldRef{
+				StreamName: ast.AliasStream,
+				Name:       "sum",
+				AliasRef:   aliasRef1,
+			},
+			RHS: &ast.IntegerLiteral{
+				Val: 1,
+			},
+		},
+	}
+	aliasRef2.SetRefSource([]string{"src1"})
+
+	testcases := []struct {
+		sql string
+		p   LogicalPlan
+		err string
+	}{
+		{
+			sql: "select a + b as sum, sum + 1 as sum2 from src1",
+			p: ProjectPlan{
+				baseLogicalPlan: baseLogicalPlan{
+					children: []LogicalPlan{
+						DataSourcePlan{
+							baseLogicalPlan: baseLogicalPlan{},
+							name:            "src1",
+							streamFields: map[string]*ast.JsonStreamField{
+								"a": nil,
+								"b": nil,
+							},
+							streamStmt:   streams["src1"],
+							pruneFields:  []string{},
+							isSchemaless: true,
+							metaFields:   []string{},
+						}.Init(),
+					},
+				},
+				fields: []ast.Field{
+					{
+						AName: "sum",
+						Expr: &ast.FieldRef{
+							StreamName: ast.AliasStream,
+							Name:       "sum",
+							AliasRef:   aliasRef1,
+						},
+					},
+					{
+						AName: "sum2",
+						Expr: &ast.FieldRef{
+							StreamName: ast.AliasStream,
+							Name:       "sum2",
+							AliasRef:   aliasRef2,
+						},
+					},
+				},
+			}.Init(),
+		},
+	}
+	for i, tt := range testcases {
+		stmt, err := xsql.NewParser(strings.NewReader(tt.sql)).Parse()
+		if err != nil {
+			t.Errorf("%d. %q: error compile sql: %s\n", i, tt.sql, err)
+			continue
+		}
+		p, _ := createLogicalPlan(stmt, &api.RuleOption{
+			IsEventTime:        false,
+			LateTol:            0,
+			Concurrency:        0,
+			BufferLength:       0,
+			SendMetaToSink:     false,
+			Qos:                0,
+			CheckpointInterval: 0,
+			SendError:          true,
+		}, kv)
+		if !reflect.DeepEqual(tt.p, p) {
+			t.Errorf("%d. %q\n\nstmt mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, tt.sql, render.AsCode(tt.p), render.AsCode(p))
+		}
+	}
+}

+ 12 - 9
internal/topo/planner/planner_test.go

@@ -1097,10 +1097,6 @@ func Test_createLogicalPlan(t *testing.T) {
 				},
 				fields: []ast.Field{
 					{
-						Expr:  &ast.FieldRef{Name: "temp", StreamName: "src1"},
-						Name:  "temp",
-						AName: "",
-					}, {
 						Expr: &ast.FieldRef{Name: "eid", StreamName: ast.AliasStream, AliasRef: ast.MockAliasRef(
 							&ast.Call{Name: "meta", FuncId: 0, Args: []ast.Expr{&ast.MetaRef{
 								Name:       "id",
@@ -1111,7 +1107,8 @@ func Test_createLogicalPlan(t *testing.T) {
 						)},
 						Name:  "meta",
 						AName: "eid",
-					}, {
+					},
+					{
 						Expr: &ast.FieldRef{Name: "hdevice", StreamName: ast.AliasStream, AliasRef: ast.MockAliasRef(
 							&ast.Call{Name: "meta", FuncId: 1, Args: []ast.Expr{
 								&ast.BinaryExpr{
@@ -1126,6 +1123,11 @@ func Test_createLogicalPlan(t *testing.T) {
 						Name:  "meta",
 						AName: "hdevice",
 					},
+					{
+						Expr:  &ast.FieldRef{Name: "temp", StreamName: "src1"},
+						Name:  "temp",
+						AName: "",
+					},
 				},
 				isAggregate: false,
 				sendMeta:    false,
@@ -1287,10 +1289,6 @@ func Test_createLogicalPlan(t *testing.T) {
 				},
 				fields: []ast.Field{
 					{
-						Expr:  &ast.FieldRef{Name: "temp", StreamName: "src1"},
-						Name:  "temp",
-						AName: "",
-					}, {
 						Expr: &ast.FieldRef{Name: "m", StreamName: ast.AliasStream, AliasRef: ast.MockAliasRef(
 							&ast.Call{Name: "meta", FuncId: 0, Args: []ast.Expr{&ast.MetaRef{
 								Name:       "*",
@@ -1302,6 +1300,11 @@ func Test_createLogicalPlan(t *testing.T) {
 						Name:  "meta",
 						AName: "m",
 					},
+					{
+						Expr:  &ast.FieldRef{Name: "temp", StreamName: "src1"},
+						Name:  "temp",
+						AName: "",
+					},
 				},
 				isAggregate: false,
 				sendMeta:    false,

+ 98 - 0
internal/topo/topotest/rule_test.go

@@ -1558,3 +1558,101 @@ func TestWindowSQL(t *testing.T) {
 		DoRuleTest(t, tests, j, opt, 0)
 	}
 }
+
+func TestAliasSQL(t *testing.T) {
+	streamList := []string{"demo"}
+	HandleStream(false, streamList, t)
+	tests := []RuleTest{
+		{
+			Name: "TestAliasSQL1",
+			Sql:  `select size as a, a + 1 as b from demo`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"a": float64(3),
+						"b": float64(4),
+					},
+				},
+				{
+					{
+						"a": float64(6),
+						"b": float64(7),
+					},
+				},
+				{
+					{
+						"a": float64(2),
+						"b": float64(3),
+					},
+				},
+				{
+					{
+						"a": float64(4),
+						"b": float64(5),
+					},
+				},
+				{
+					{
+						"a": float64(1),
+						"b": float64(2),
+					},
+				},
+			},
+		},
+		{
+			Name: "TestAliasSQL2",
+			Sql:  `select a + 1 as b, size as a from demo`,
+			R: [][]map[string]interface{}{
+				{
+					{
+						"a": float64(3),
+						"b": float64(4),
+					},
+				},
+				{
+					{
+						"a": float64(6),
+						"b": float64(7),
+					},
+				},
+				{
+					{
+						"a": float64(2),
+						"b": float64(3),
+					},
+				},
+				{
+					{
+						"a": float64(4),
+						"b": float64(5),
+					},
+				},
+				{
+					{
+						"a": float64(1),
+						"b": float64(2),
+					},
+				},
+			},
+		},
+	}
+	// Data setup
+	HandleStream(true, streamList, t)
+	options := []*api.RuleOption{
+		{
+			BufferLength:       100,
+			SendError:          true,
+			Qos:                api.AtLeastOnce,
+			CheckpointInterval: 5000,
+		},
+		{
+			BufferLength:       100,
+			SendError:          true,
+			Qos:                api.ExactlyOnce,
+			CheckpointInterval: 5000,
+		},
+	}
+	for j, opt := range options {
+		DoRuleTest(t, tests, j, opt, 0)
+	}
+}

+ 9 - 9
internal/topo/topotest/window_rule_test.go

@@ -340,52 +340,52 @@ func TestWindow(t *testing.T) {
 		},
 		{
 			Name: `TestWindowRule6`,
-			Sql:  `SELECT window_end(), event_time(), sum(temp) as temp, count(color) as c, window_start() FROM demo INNER JOIN demo1 ON demo.ts = demo1.ts GROUP BY SlidingWindow(ss, 1)`,
+			Sql:  `SELECT window_end(), event_time(), sum(temp) as temp1, count(color) as c, window_start() FROM demo INNER JOIN demo1 ON demo.ts = demo1.ts GROUP BY SlidingWindow(ss, 1)`,
 			R: [][]map[string]interface{}{
 				{{
-					"temp":         25.5,
+					"temp1":        25.5,
 					"c":            float64(1),
 					"window_start": float64(1541152485115),
 					"window_end":   float64(1541152486115),
 					"event_time":   float64(1541152486115),
 				}}, {{
-					"temp":         25.5,
+					"temp1":        25.5,
 					"c":            float64(1),
 					"window_start": float64(1541152485822),
 					"window_end":   float64(1541152486822),
 					"event_time":   float64(1541152486822),
 				}}, {{
-					"temp":         25.5,
+					"temp1":        25.5,
 					"c":            float64(1),
 					"window_start": float64(1541152485903),
 					"window_end":   float64(1541152486903),
 					"event_time":   float64(1541152486903),
 				}}, {{
-					"temp":         28.1,
+					"temp1":        28.1,
 					"c":            float64(1),
 					"window_start": float64(1541152486702),
 					"window_end":   float64(1541152487702),
 					"event_time":   float64(1541152487702),
 				}}, {{
-					"temp":         28.1,
+					"temp1":        28.1,
 					"c":            float64(1),
 					"window_start": float64(1541152487442),
 					"window_end":   float64(1541152488442),
 					"event_time":   float64(1541152488442),
 				}}, {{
-					"temp":         55.5,
+					"temp1":        55.5,
 					"c":            float64(2),
 					"window_start": float64(1541152487605),
 					"window_end":   float64(1541152488605),
 					"event_time":   float64(1541152488605),
 				}}, {{
-					"temp":         27.4,
+					"temp1":        27.4,
 					"c":            float64(1),
 					"window_start": float64(1541152488252),
 					"window_end":   float64(1541152489252),
 					"event_time":   float64(1541152489252),
 				}}, {{
-					"temp":         52.9,
+					"temp1":        52.9,
 					"c":            float64(2),
 					"window_start": float64(1541152488305),
 					"window_end":   float64(1541152489305),

+ 11 - 3
pkg/ast/expr_ref.go

@@ -15,7 +15,6 @@
 package ast
 
 import (
-	"fmt"
 	"regexp"
 	"strings"
 )
@@ -115,6 +114,14 @@ type AliasRef struct {
 	IsAggregate *bool
 }
 
+// SetRefSource only used for unit test
+func (a *AliasRef) SetRefSource(names []string) {
+	a.refSources = make([]StreamName, 0)
+	for _, name := range names {
+		a.refSources = append(a.refSources, StreamName(name))
+	}
+}
+
 func NewAliasRef(e Expr) (*AliasRef, error) {
 	r := make(map[StreamName]bool)
 	var walkErr error
@@ -123,8 +130,9 @@ func NewAliasRef(e Expr) (*AliasRef, error) {
 		case *FieldRef:
 			switch f.StreamName {
 			case AliasStream:
-				walkErr = fmt.Errorf("cannot use alias %s inside another alias %v", f.Name, e)
-				return false
+				for _, name := range f.AliasRef.refSources {
+					r[name] = true
+				}
 			default:
 				r[f.StreamName] = true
 			}