Bladeren bron

refactor(parser): parse built-in function name to lowercase

Avoid dynamic lowername comparison in runtime which has a lot of overhead

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
Jiyong Huang 3 jaren geleden
bovenliggende
commit
2bd4b8e344

+ 3 - 0
internal/binder/factory.go

@@ -29,6 +29,9 @@ type FuncFactory interface {
 	// HasFunctionSet Some functions are bundled together into a plugin which shares the same json file.
 	// This function can return if the function set name exists.
 	HasFunctionSet(funcName string) bool
+	// ConvName Convert the name of the function usually to lowercase.
+	// This is only be used when parsing the SQL statement.
+	ConvName(funcName string) (string, bool)
 }
 
 type FactoryEntry struct {

+ 10 - 0
internal/binder/function/binder.go

@@ -72,6 +72,16 @@ func HasFunctionSet(name string) bool {
 	return false
 }
 
+func ConvName(name string) (string, bool) {
+	for _, sf := range funcFactories {
+		r, ok := sf.ConvName(name)
+		if ok {
+			return r, ok
+		}
+	}
+	return name, false
+}
+
 type multiAggFunc interface {
 	IsAggregateWithName(name string) bool
 }

+ 1 - 3
internal/binder/function/funcs_agg.go

@@ -16,12 +16,10 @@ package function
 
 import (
 	"fmt"
-	"strings"
 )
 
 func aggCall(name string, args []interface{}) (interface{}, bool) {
-	lowerName := strings.ToLower(name)
-	switch lowerName {
+	switch name {
 	case "avg":
 		arg0 := args[0].([]interface{})
 		c := getCount(arg0)

+ 47 - 48
internal/binder/function/funcs_ast_validator.go

@@ -25,41 +25,40 @@ type AllowTypes struct {
 }
 
 func validateFuncs(funcName string, args []ast.Expr) error {
-	lowerName := strings.ToLower(funcName)
-	switch getFuncType(lowerName) {
+	switch getFuncType(funcName) {
 	case AggFunc:
-		return validateAggFunc(lowerName, args)
+		return validateAggFunc(funcName, args)
 	case MathFunc:
-		return validateMathFunc(lowerName, args)
+		return validateMathFunc(funcName, args)
 	case ConvFunc:
-		return validateConvFunc(lowerName, args)
+		return validateConvFunc(funcName, args)
 	case StrFunc:
-		return validateStrFunc(lowerName, args)
+		return validateStrFunc(funcName, args)
 	case HashFunc:
-		return validateHashFunc(lowerName, args)
+		return validateHashFunc(funcName, args)
 	case JsonFunc:
-		return validateJsonFunc(lowerName, args)
+		return validateJsonFunc(funcName, args)
 	case OtherFunc:
-		return validateOtherFunc(lowerName, args)
+		return validateOtherFunc(funcName, args)
 	default:
 		// should not happen
-		return fmt.Errorf("unkndow function %s", lowerName)
+		return fmt.Errorf("unkndow function %s", funcName)
 	}
 }
 
 func validateMathFunc(name string, args []ast.Expr) error {
-	len := len(args)
+	l := len(args)
 	switch name {
 	case "abs", "acos", "asin", "atan", "ceil", "cos", "cosh", "exp", "ln", "log", "round", "sign", "sin", "sinh",
 		"sqrt", "tan", "tanh":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 		if ast.IsStringArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
 			return ast.ProduceErrInfo(name, 0, "number - float or int")
 		}
 	case "bitand", "bitor", "bitxor":
-		if err := ast.ValidateLen(name, 2, len); err != nil {
+		if err := ast.ValidateLen(name, 2, l); err != nil {
 			return err
 		}
 		if ast.IsFloatArg(args[0]) || ast.IsStringArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
@@ -70,7 +69,7 @@ func validateMathFunc(name string, args []ast.Expr) error {
 		}
 
 	case "bitnot":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 		if ast.IsFloatArg(args[0]) || ast.IsStringArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
@@ -78,7 +77,7 @@ func validateMathFunc(name string, args []ast.Expr) error {
 		}
 
 	case "atan2", "mod", "power":
-		if err := ast.ValidateLen(name, 2, len); err != nil {
+		if err := ast.ValidateLen(name, 2, l); err != nil {
 			return err
 		}
 		if ast.IsStringArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
@@ -89,7 +88,7 @@ func validateMathFunc(name string, args []ast.Expr) error {
 		}
 
 	case "rand":
-		if err := ast.ValidateLen(name, 0, len); err != nil {
+		if err := ast.ValidateLen(name, 0, l); err != nil {
 			return err
 		}
 	}
@@ -97,10 +96,10 @@ func validateMathFunc(name string, args []ast.Expr) error {
 }
 
 func validateStrFunc(name string, args []ast.Expr) error {
-	len := len(args)
+	l := len(args)
 	switch name {
 	case "concat":
-		if len == 0 {
+		if l == 0 {
 			return fmt.Errorf("The arguments for %s should be at least one.\n", name)
 		}
 		for i, a := range args {
@@ -109,7 +108,7 @@ func validateStrFunc(name string, args []ast.Expr) error {
 			}
 		}
 	case "endswith", "indexof", "regexp_matches", "startswith":
-		if err := ast.ValidateLen(name, 2, len); err != nil {
+		if err := ast.ValidateLen(name, 2, l); err != nil {
 			return err
 		}
 		for i := 0; i < 2; i++ {
@@ -118,7 +117,7 @@ func validateStrFunc(name string, args []ast.Expr) error {
 			}
 		}
 	case "format_time":
-		if err := ast.ValidateLen(name, 2, len); err != nil {
+		if err := ast.ValidateLen(name, 2, l); err != nil {
 			return err
 		}
 
@@ -130,7 +129,7 @@ func validateStrFunc(name string, args []ast.Expr) error {
 		}
 
 	case "regexp_replace":
-		if err := ast.ValidateLen(name, 3, len); err != nil {
+		if err := ast.ValidateLen(name, 3, l); err != nil {
 			return err
 		}
 		for i := 0; i < 3; i++ {
@@ -139,14 +138,14 @@ func validateStrFunc(name string, args []ast.Expr) error {
 			}
 		}
 	case "length", "lower", "ltrim", "numbytes", "rtrim", "trim", "upper":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 		if ast.IsNumericArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
 			return ast.ProduceErrInfo(name, 0, "string")
 		}
 	case "lpad", "rpad":
-		if err := ast.ValidateLen(name, 2, len); err != nil {
+		if err := ast.ValidateLen(name, 2, l); err != nil {
 			return err
 		}
 		if ast.IsNumericArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
@@ -156,13 +155,13 @@ func validateStrFunc(name string, args []ast.Expr) error {
 			return ast.ProduceErrInfo(name, 1, "int")
 		}
 	case "substring":
-		if len != 2 && len != 3 {
+		if l != 2 && l != 3 {
 			return fmt.Errorf("the arguments for substring should be 2 or 3")
 		}
 		if ast.IsNumericArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
 			return ast.ProduceErrInfo(name, 0, "string")
 		}
-		for i := 1; i < len; i++ {
+		for i := 1; i < l; i++ {
 			if ast.IsFloatArg(args[i]) || ast.IsTimeArg(args[i]) || ast.IsBooleanArg(args[i]) || ast.IsStringArg(args[i]) {
 				return ast.ProduceErrInfo(name, i, "int")
 			}
@@ -173,7 +172,7 @@ func validateStrFunc(name string, args []ast.Expr) error {
 			if sv < 0 {
 				return fmt.Errorf("The start index should not be a nagtive integer.")
 			}
-			if len == 3 {
+			if l == 3 {
 				if e, ok1 := args[2].(*ast.IntegerLiteral); ok1 {
 					ev := e.Val
 					if ev < sv {
@@ -183,7 +182,7 @@ func validateStrFunc(name string, args []ast.Expr) error {
 			}
 		}
 	case "split_value":
-		if len != 3 {
+		if l != 3 {
 			return fmt.Errorf("the arguments for split_value should be 3")
 		}
 		if ast.IsNumericArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
@@ -205,10 +204,10 @@ func validateStrFunc(name string, args []ast.Expr) error {
 }
 
 func validateConvFunc(name string, args []ast.Expr) error {
-	len := len(args)
+	l := len(args)
 	switch name {
 	case "cast":
-		if err := ast.ValidateLen(name, 2, len); err != nil {
+		if err := ast.ValidateLen(name, 2, l); err != nil {
 			return err
 		}
 		a := args[1]
@@ -221,14 +220,14 @@ func validateConvFunc(name string, args []ast.Expr) error {
 			}
 		}
 	case "chr":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 		if ast.IsFloatArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
 			return ast.ProduceErrInfo(name, 0, "int")
 		}
 	case "encode":
-		if err := ast.ValidateLen(name, 2, len); err != nil {
+		if err := ast.ValidateLen(name, 2, l); err != nil {
 			return err
 		}
 
@@ -246,7 +245,7 @@ func validateConvFunc(name string, args []ast.Expr) error {
 			}
 		}
 	case "trunc":
-		if err := ast.ValidateLen(name, 2, len); err != nil {
+		if err := ast.ValidateLen(name, 2, l); err != nil {
 			return err
 		}
 
@@ -262,10 +261,10 @@ func validateConvFunc(name string, args []ast.Expr) error {
 }
 
 func validateHashFunc(name string, args []ast.Expr) error {
-	len := len(args)
+	l := len(args)
 	switch name {
 	case "md5", "sha1", "sha224", "sha256", "sha384", "sha512":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 
@@ -277,29 +276,29 @@ func validateHashFunc(name string, args []ast.Expr) error {
 }
 
 func validateOtherFunc(name string, args []ast.Expr) error {
-	len := len(args)
+	l := len(args)
 	switch name {
 	case "isNull":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 	case "cardinality":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 	case "nanvl":
-		if err := ast.ValidateLen(name, 2, len); err != nil {
+		if err := ast.ValidateLen(name, 2, l); err != nil {
 			return err
 		}
 		if ast.IsIntegerArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) || ast.IsStringArg(args[0]) {
 			return ast.ProduceErrInfo(name, 1, "float")
 		}
 	case "newuuid":
-		if err := ast.ValidateLen(name, 0, len); err != nil {
+		if err := ast.ValidateLen(name, 0, l); err != nil {
 			return err
 		}
 	case "mqtt":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 		if ast.IsIntegerArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) || ast.IsStringArg(args[0]) || ast.IsFloatArg(args[0]) {
@@ -312,7 +311,7 @@ func validateOtherFunc(name string, args []ast.Expr) error {
 			}
 		}
 	case "meta":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 		if _, ok := args[0].(*ast.MetaRef); ok {
@@ -335,8 +334,8 @@ func validateOtherFunc(name string, args []ast.Expr) error {
 }
 
 func validateJsonFunc(name string, args []ast.Expr) error {
-	len := len(args)
-	if err := ast.ValidateLen(name, 2, len); err != nil {
+	l := len(args)
+	if err := ast.ValidateLen(name, 2, l); err != nil {
 		return err
 	}
 	if !ast.IsStringArg(args[1]) {
@@ -346,25 +345,25 @@ func validateJsonFunc(name string, args []ast.Expr) error {
 }
 
 func validateAggFunc(name string, args []ast.Expr) error {
-	len := len(args)
+	l := len(args)
 	switch name {
 	case "avg", "max", "min", "sum":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 		if ast.IsStringArg(args[0]) || ast.IsTimeArg(args[0]) || ast.IsBooleanArg(args[0]) {
 			return ast.ProduceErrInfo(name, 0, "number - float or int")
 		}
 	case "count":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 	case "collect":
-		if err := ast.ValidateLen(name, 1, len); err != nil {
+		if err := ast.ValidateLen(name, 1, l); err != nil {
 			return err
 		}
 	case "deduplicate":
-		if err := ast.ValidateLen(name, 2, len); err != nil {
+		if err := ast.ValidateLen(name, 2, l); err != nil {
 			return err
 		}
 		if !ast.IsBooleanArg(args[1]) {

+ 1 - 3
internal/binder/function/funcs_misc.go

@@ -37,7 +37,6 @@ func convCall(name string, args []interface{}) (interface{}, bool) {
 	switch name {
 	case "cast":
 		if v, ok := args[1].(string); ok {
-			v = strings.ToLower(v)
 			switch v {
 			case "bigint":
 				if v1, ok1 := args[0].(int); ok1 {
@@ -149,8 +148,7 @@ func convCall(name string, args []interface{}) (interface{}, bool) {
 		}
 	case "encode":
 		if v, ok := args[1].(string); ok {
-			v = strings.ToLower(v)
-			if v == "base64" {
+			if strings.EqualFold(v, "base64") {
 				if v1, ok1 := args[0].(string); ok1 {
 					return b64.StdEncoding.EncodeToString([]byte(v1)), true
 				} else {

+ 19 - 12
internal/binder/function/function.go

@@ -96,7 +96,7 @@ var otherFuncMap = map[string]string{"isnull": "",
 
 func getFuncType(name string) funcType {
 	for i, m := range maps {
-		if _, ok := m[strings.ToLower(name)]; ok {
+		if _, ok := m[name]; ok {
 			return funcType(i)
 		}
 	}
@@ -127,22 +127,21 @@ func (f *funcExecutor) Exec(_ []interface{}, _ api.FunctionContext) (interface{}
 }
 
 func (f *funcExecutor) ExecWithName(args []interface{}, ctx api.FunctionContext, name string) (interface{}, bool) {
-	lowerName := strings.ToLower(name)
-	switch getFuncType(lowerName) {
+	switch getFuncType(name) {
 	case AggFunc:
-		return aggCall(lowerName, args)
+		return aggCall(name, args)
 	case MathFunc:
-		return mathCall(lowerName, args)
+		return mathCall(name, args)
 	case ConvFunc:
-		return convCall(lowerName, args)
+		return convCall(name, args)
 	case StrFunc:
-		return strCall(lowerName, args)
+		return strCall(name, args)
 	case HashFunc:
-		return hashCall(lowerName, args)
+		return hashCall(name, args)
 	case JsonFunc:
-		return jsonCall(ctx, lowerName, args)
+		return jsonCall(ctx, name, args)
 	case OtherFunc:
-		return otherCall(lowerName, args)
+		return otherCall(name, args)
 	}
 	return fmt.Errorf("unknow name"), false
 }
@@ -152,8 +151,7 @@ func (f *funcExecutor) IsAggregate() bool {
 }
 
 func (f *funcExecutor) IsAggregateWithName(name string) bool {
-	lowerName := strings.ToLower(name)
-	return getFuncType(lowerName) == AggFunc
+	return getFuncType(name) == AggFunc
 }
 
 var staticFuncExecutor = &funcExecutor{}
@@ -172,6 +170,15 @@ func (m *Manager) HasFunctionSet(name string) bool {
 	return name == "internal"
 }
 
+func (m *Manager) ConvName(n string) (string, bool) {
+	name := strings.ToLower(n)
+	ft := getFuncType(name)
+	if ft != NotFoundFunc {
+		return name, true
+	}
+	return name, false
+}
+
 var m = &Manager{}
 
 func GetManager() *Manager {

+ 4 - 0
internal/binder/mock/mock_factory.go

@@ -51,6 +51,10 @@ func (f *MockFactory) Function(name string) (api.Function, error) {
 	}
 }
 
+func (f *MockFactory) ConvName(name string) (string, bool) {
+	return name, true
+}
+
 func (f *MockFactory) HasFunctionSet(funcName string) bool {
 	if strings.HasPrefix(funcName, "mock") {
 		return true

+ 8 - 0
internal/plugin/native/manager.go

@@ -562,6 +562,14 @@ func (rr *Manager) HasFunctionSet(name string) bool {
 	return ok
 }
 
+func (rr *Manager) ConvName(name string) (string, bool) {
+	_, err := rr.Function(name)
+	if err == nil {
+		return name, true
+	}
+	return name, false
+}
+
 // If not found, return nil,nil; Other errors return nil, err
 func (rr *Manager) loadRuntime(t plugin2.PluginType, name string) (plugin.Symbol, error) {
 	ut := ucFirst(name)

+ 5 - 0
internal/plugin/portable/factory.go

@@ -63,6 +63,11 @@ func (m *Manager) HasFunctionSet(funcName string) bool {
 	return ok
 }
 
+func (m *Manager) ConvName(funcName string) (string, bool) {
+	_, ok := m.GetPluginMeta(plugin.FUNCTION, funcName)
+	return funcName, ok
+}
+
 // Clean up function map
 func (m *Manager) Clean() {
 	funcInsMap.Range(func(_, ins interface{}) bool {

+ 6 - 1
internal/service/manager.go

@@ -182,7 +182,7 @@ func (m *Manager) initFile(baseName string) error {
 
 // Start Implement FunctionFactory
 
-func (m *Manager) HasFunctionSet(name string) bool {
+func (m *Manager) HasFunctionSet(_ string) bool {
 	return false
 }
 
@@ -207,6 +207,11 @@ func (m *Manager) Function(name string) (api.Function, error) {
 	return &ExternalFunc{exe: e, methodName: f.MethodName}, nil
 }
 
+func (m *Manager) ConvName(funcName string) (string, bool) {
+	_, ok := m.getFunction(funcName)
+	return funcName, ok
+}
+
 // End Implement FunctionFactory
 
 func (m *Manager) HasService(name string) bool {

+ 2 - 2
internal/topo/planner/planner_test.go

@@ -404,7 +404,7 @@ func Test_createLogicalPlan(t *testing.T) {
 								},
 							},
 							condition: &ast.BinaryExpr{
-								LHS: &ast.Call{Name: "COUNT", Args: []ast.Expr{&ast.Wildcard{
+								LHS: &ast.Call{Name: "count", Args: []ast.Expr{&ast.Wildcard{
 									Token: ast.ASTERISK,
 								}}},
 								OP:  ast.GT,
@@ -1434,7 +1434,7 @@ func Test_createLogicalPlanSchemaless(t *testing.T) {
 								},
 							},
 							condition: &ast.BinaryExpr{
-								LHS: &ast.Call{Name: "COUNT", Args: []ast.Expr{&ast.Wildcard{
+								LHS: &ast.Call{Name: "count", Args: []ast.Expr{&ast.Wildcard{
 									Token: ast.ASTERISK,
 								}}},
 								OP:  ast.GT,

+ 26 - 4
internal/xsql/parser.go

@@ -17,6 +17,7 @@ package xsql
 import (
 	"fmt"
 	"github.com/golang-collections/collections/stack"
+	"github.com/lf-edge/ekuiper/internal/binder/function"
 	"github.com/lf-edge/ekuiper/pkg/ast"
 	"github.com/lf-edge/ekuiper/pkg/message"
 	"io"
@@ -643,8 +644,30 @@ func (p *Parser) parseAs(f *ast.Field) (*ast.Field, error) {
 	return f, nil
 }
 
-func (p *Parser) parseCall(name string) (ast.Expr, error) {
-	if strings.ToLower(name) == "meta" || strings.ToLower(name) == "mqtt" {
+var WindowFuncs = map[string]struct{}{
+	"tumblingwindow": {},
+	"hoppingwindow":  {},
+	"sessionwindow":  {},
+	"slidingwindow":  {},
+	"countwindow":    {},
+}
+
+func convFuncName(n string) (string, bool) {
+	lname := strings.ToLower(n)
+	if _, ok := WindowFuncs[lname]; ok {
+		return lname, ok
+	} else {
+		return function.ConvName(n)
+	}
+}
+
+func (p *Parser) parseCall(n string) (ast.Expr, error) {
+	// Check if n function exists and convert it to lowercase for built-in func
+	name, ok := convFuncName(n)
+	if !ok {
+		return nil, fmt.Errorf("function %s not found", n)
+	}
+	if name == "meta" || name == "mqtt" {
 		p.inmeta = true
 		defer func() {
 			p.inmeta = false
@@ -779,8 +802,7 @@ loop:
 	return c, nil
 }
 
-func validateWindows(name string, args []ast.Expr) (ast.WindowType, error) {
-	fname := strings.ToLower(name)
+func validateWindows(fname string, args []ast.Expr) (ast.WindowType, error) {
 	switch fname {
 	case "tumblingwindow":
 		if err := validateWindow(fname, 2, args); err != nil {

+ 1 - 17
internal/xsql/parser_test.go

@@ -238,7 +238,7 @@ func TestParser_ParseStatement(t *testing.T) {
 		},
 
 		{
-			s: `SELECT length("test") FROM tbl`,
+			s: `SELECT LenGth("test") FROM tbl`,
 			stmt: &ast.SelectStatement{
 				Fields: []ast.Field{
 					{
@@ -1389,22 +1389,6 @@ func TestParser_ParseStatement(t *testing.T) {
 			stmt: nil,
 			err:  "invalid CASE expression, WHEN expression must be a bool condition",
 		}, {
-			s: `SELECT echo(*) FROM tbl`,
-			stmt: &ast.SelectStatement{
-				Fields: []ast.Field{
-					{
-						AName: "",
-						Name:  "echo",
-						Expr: &ast.Call{
-							Name: "echo",
-							Args: []ast.Expr{&ast.Wildcard{Token: ast.ASTERISK}},
-						},
-					},
-				},
-				Sources: []ast.Source{&ast.Table{Name: "tbl"}},
-			},
-		},
-		{
 			s: `SELECT count(*)-10 FROM demo`,
 			stmt: &ast.SelectStatement{
 				Fields: []ast.Field{