Преглед изворни кода

feat: add bytea to CAST function (#2128)

* add bytea for cast

Signed-off-by: Rui-Gan <1171530954@qq.com>

* fix ut

Signed-off-by: Rui-Gan <1171530954@qq.com>

* fix ut

Signed-off-by: Rui-Gan <1171530954@qq.com>

---------

Signed-off-by: Rui-Gan <1171530954@qq.com>
Regina пре 1 година
родитељ
комит
b0552e47f4

+ 2 - 2
internal/binder/function/funcs_misc.go

@@ -55,8 +55,8 @@ func registerMiscFunc() {
 				return ProduceErrInfo(0, "string")
 			}
 			if av, ok := a.(*ast.StringLiteral); ok {
-				if !(av.Val == "bigint" || av.Val == "float" || av.Val == "string" || av.Val == "boolean" || av.Val == "datetime") {
-					return fmt.Errorf("Expect one of following value for the 2nd parameter: bigint, float, string, boolean, datetime.")
+				if !(av.Val == "bigint" || av.Val == "float" || av.Val == "string" || av.Val == "boolean" || av.Val == "datetime" || av.Val == "bytea") {
+					return fmt.Errorf("Expect one of following value for the 2nd parameter: bigint, float, string, boolean, datetime, bytea.")
 				}
 			}
 			return nil

+ 112 - 0
internal/binder/function/funcs_misc_test.go

@@ -584,3 +584,115 @@ func TestMiscFuncNil(t *testing.T) {
 		}
 	}
 }
+
+func TestCast(t *testing.T) {
+	f, ok := builtins["cast"]
+	if !ok {
+		t.Fatal("builtin not found")
+	}
+	contextLogger := conf.Log.WithField("rule", "testExec")
+	ctx := kctx.WithValue(kctx.Background(), kctx.LoggerKey, contextLogger)
+	tempStore, _ := state.CreateStore("mockRule0", api.AtMostOnce)
+	fctx := kctx.NewDefaultFuncContext(ctx.WithMeta("mockRule0", "test", tempStore), 2)
+
+	tests := []struct {
+		args   []interface{}
+		result interface{}
+	}{
+		{ // 0
+			args: []interface{}{
+				"Ynl0ZWE=",
+				"bytea",
+			},
+			result: []byte("bytea"),
+		},
+		{ // 1
+			args: []interface{}{
+				[]byte("bytea"),
+				"bytea",
+			},
+			result: []byte("bytea"),
+		},
+		{ // 2
+			args: []interface{}{
+				1,
+				"bytea",
+			},
+			result: fmt.Errorf("cannot convert int(1) to bytea"),
+		},
+		{ // 3
+			args: []interface{}{
+				101.5,
+				"bigint",
+			},
+			result: 101,
+		},
+		{ // 4
+			args: []interface{}{
+				1,
+				"boolean",
+			},
+			result: true,
+		},
+		{ // 5
+			args: []interface{}{
+				1,
+				"float",
+			},
+			result: float64(1),
+		},
+		{ // 6
+			args: []interface{}{
+				1,
+				"string",
+			},
+			result: "1",
+		},
+	}
+	for _, tt := range tests {
+		result, _ := f.exec(fctx, tt.args)
+		assert.Equal(t, tt.result, result)
+	}
+
+	vtests := []struct {
+		args    []ast.Expr
+		wantErr bool
+	}{
+		{
+			[]ast.Expr{&ast.FieldRef{Name: "foo"}, &ast.StringLiteral{Val: "bytea"}},
+			false,
+		},
+		{
+			[]ast.Expr{&ast.FieldRef{Name: "foo"}},
+			true,
+		},
+		{
+			[]ast.Expr{&ast.FieldRef{Name: "foo"}, &ast.StringLiteral{Val: "bigint"}},
+			false,
+		},
+		{
+			[]ast.Expr{&ast.FieldRef{Name: "foo"}, &ast.StringLiteral{Val: "float"}},
+			false,
+		},
+		{
+			[]ast.Expr{&ast.FieldRef{Name: "foo"}, &ast.StringLiteral{Val: "string"}},
+			false,
+		},
+		{
+			[]ast.Expr{&ast.FieldRef{Name: "foo"}, &ast.StringLiteral{Val: "boolean"}},
+			false,
+		},
+		{
+			[]ast.Expr{&ast.FieldRef{Name: "foo"}, &ast.StringLiteral{Val: "test"}},
+			true,
+		},
+	}
+	for _, vtt := range vtests {
+		err := f.val(fctx, vtt.args)
+		if vtt.wantErr {
+			assert.Error(t, err)
+		} else {
+			assert.NoError(t, err)
+		}
+	}
+}

+ 2 - 2
internal/xsql/funcs_validator_test.go

@@ -1,4 +1,4 @@
-// Copyright 2022 EMQ Technologies Co., Ltd.
+// Copyright 2022-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -361,7 +361,7 @@ func TestFuncValidator(t *testing.T) {
 		{
 			s:    `SELECT cast("12", "bool") FROM tbl`,
 			stmt: nil,
-			err:  "Expect one of following value for the 2nd parameter: bigint, float, string, boolean, datetime.",
+			err:  "Expect one of following value for the 2nd parameter: bigint, float, string, boolean, datetime, bytea.",
 		},
 
 		///

+ 9 - 2
pkg/cast/cast.go

@@ -825,7 +825,7 @@ func ToByteA(input interface{}, _ Strictness) ([]byte, error) {
 		}
 		return r, nil
 	}
-	return nil, fmt.Errorf("cannot convert %[1]T(%[1]v) to bytes", input)
+	return nil, fmt.Errorf("cannot convert %[1]T(%[1]v) to bytea", input)
 }
 
 func ToStringMap(input interface{}) (map[string]interface{}, error) {
@@ -1112,7 +1112,7 @@ func ConvertSlice(v interface{}) []interface{} {
 }
 
 // ToType cast value into newType type
-// newType support bigint, float, string, boolean, datetime
+// newType support bigint, float, string, boolean, datetime, bytea
 func ToType(value interface{}, newType interface{}) (interface{}, bool) {
 	if v, ok := newType.(string); ok {
 		switch v {
@@ -1151,6 +1151,13 @@ func ToType(value interface{}, newType interface{}) (interface{}, bool) {
 			} else {
 				return dt, true
 			}
+		case "bytea":
+			r, e := ToByteA(value, CONVERT_ALL)
+			if e != nil {
+				return e, false
+			} else {
+				return r, true
+			}
 		default:
 			return fmt.Errorf("unknow type, only support bigint, float, string, boolean and datetime"), false
 		}

+ 1 - 1
pkg/cast/cast_test.go

@@ -892,7 +892,7 @@ func TestToByteA(t *testing.T) {
 		}, {
 			input:  1,
 			output: nil,
-			err:    "cannot convert int(1) to bytes",
+			err:    "cannot convert int(1) to bytea",
 		}, {
 			input:  "c29tZSBkYXRhIHdpdGggACBhbmQg77u/",
 			output: bytea,