Browse Source

fix(sql): sql string use single quote literal (#1921)

Because double quote string is recognized as column by pgsql and many other db

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
ngjaying 1 year ago
parent
commit
ebc9078749

+ 4 - 4
extensions/sinks/sql/sql.go

@@ -15,13 +15,13 @@
 package main
 
 import (
-	"database/sql"
 	"encoding/json"
 	"fmt"
 	"github.com/lf-edge/ekuiper/internal/topo/transform"
 	"reflect"
 	"strings"
 
+	"github.com/lf-edge/ekuiper/extensions/sqldatabase"
 	"github.com/lf-edge/ekuiper/extensions/sqldatabase/driver"
 	"github.com/lf-edge/ekuiper/extensions/util"
 	"github.com/lf-edge/ekuiper/pkg/api"
@@ -92,7 +92,7 @@ func (t *sqlConfig) getKeyValues(ctx api.StreamContext, mapData map[string]inter
 type sqlSink struct {
 	conf *sqlConfig
 	// The db connection instance
-	db *sql.DB
+	db sqldatabase.DB
 }
 
 func (m *sqlSink) Configure(props map[string]interface{}) error {
@@ -322,7 +322,7 @@ func (m *sqlSink) save(ctx api.StreamContext, table string, data map[string]inte
 			sqlStr += fmt.Sprintf("%s=%s", key, vals[i])
 		}
 		if _, ok := keyval.(string); ok {
-			sqlStr += fmt.Sprintf(" WHERE %s = \"%s\";", m.conf.KeyField, keyval)
+			sqlStr += fmt.Sprintf(" WHERE %s = '%s';", m.conf.KeyField, keyval)
 		} else {
 			sqlStr += fmt.Sprintf(" WHERE %s = %v;", m.conf.KeyField, keyval)
 		}
@@ -332,7 +332,7 @@ func (m *sqlSink) save(ctx api.StreamContext, table string, data map[string]inte
 			return fmt.Errorf("field %s does not exist in data %v", m.conf.KeyField, data)
 		}
 		if _, ok := keyval.(string); ok {
-			sqlStr = fmt.Sprintf("DELETE FROM %s WHERE %s = \"%s\";", table, m.conf.KeyField, keyval)
+			sqlStr = fmt.Sprintf("DELETE FROM %s WHERE %s = '%s';", table, m.conf.KeyField, keyval)
 		} else {
 			sqlStr = fmt.Sprintf("DELETE FROM %s WHERE %s = %v;", table, m.conf.KeyField, keyval)
 		}

+ 41 - 0
extensions/sinks/sql/sql_test.go

@@ -18,10 +18,12 @@ import (
 	"database/sql"
 	"database/sql/driver"
 	"fmt"
+	"github.com/stretchr/testify/assert"
 	"os"
 	"reflect"
 	"testing"
 
+	"github.com/lf-edge/ekuiper/extensions/sqldatabase"
 	econf "github.com/lf-edge/ekuiper/internal/conf"
 	"github.com/lf-edge/ekuiper/internal/topo/context"
 )
@@ -256,6 +258,45 @@ func TestUpdate(t *testing.T) {
 	}
 }
 
+func TestSaveSql(t *testing.T) {
+	contextLogger := econf.Log.WithField("rule", "test")
+	ctx := context.WithValue(context.Background(), context.LoggerKey, contextLogger)
+	s := &sqlSink{}
+	mdb := &sqldatabase.MockDB{}
+	s.db = mdb
+	s.conf = &sqlConfig{
+		Fields: []string{ // set fields to make sure order is always testable
+			"id", "name", "address", "mobile",
+		},
+		KeyField:     "id",
+		RowkindField: "action",
+	}
+
+	test := []struct {
+		name string
+		d    map[string]interface{}
+		s    string
+	}{
+		{
+			name: "insert",
+			d:    map[string]interface{}{"id": 1, "name": "John", "address": "343", "mobile": "334433"},
+			s:    "INSERT INTO test (id,name,address,mobile) values (1,'John','343','334433');",
+		},
+		{
+			name: "update",
+			d:    map[string]interface{}{"action": "update", "id": 1, "name": "John", "address": "343", "mobile": "334433"},
+			s:    "UPDATE test SET id=1,name='John',address='343',mobile='334433' WHERE id = 1;",
+		},
+	}
+	for _, tt := range test {
+		t.Run(tt.name, func(t *testing.T) {
+			err := s.save(ctx, "test", tt.d)
+			assert.NoError(t, err)
+			assert.Equal(t, tt.s, mdb.LastSql())
+		})
+	}
+}
+
 func rowsToMap(rows *sql.Rows) ([]map[string]interface{}, error) {
 	cols, _ := rows.Columns()
 

+ 21 - 0
extensions/sqldatabase/db.go

@@ -0,0 +1,21 @@
+// Copyright 2023 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqldatabase
+
+import "database/sql"
+
+type DB interface {
+	Exec(query string, args ...interface{}) (sql.Result, error)
+}

+ 46 - 0
extensions/sqldatabase/mock.go

@@ -0,0 +1,46 @@
+// Copyright 2023 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqldatabase
+
+import "database/sql"
+
+type MockDB struct {
+	sqls []string
+}
+
+func (m *MockDB) Exec(query string, _ ...interface{}) (sql.Result, error) {
+	m.sqls = append(m.sqls, query)
+	return &MockResult{rowsAffected: 1}, nil
+}
+
+func (m *MockDB) LastSql() string {
+	if len(m.sqls) == 0 {
+		return ""
+	} else {
+		return m.sqls[len(m.sqls)-1]
+	}
+}
+
+type MockResult struct {
+	rowsAffected int64
+}
+
+func (m *MockResult) LastInsertId() (int64, error) {
+	return 1, nil
+}
+
+func (m *MockResult) RowsAffected() (int64, error) {
+	return m.rowsAffected, nil
+}