瀏覽代碼

feat(sink): updatable SQL sink

1. Support updatable for SQL sink
2. Add unit test
3. Fix a bug for writing []map, generated SQL miss table

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
Jiyong Huang 2 年之前
父節點
當前提交
fac7fa714e
共有 2 個文件被更改,包括 481 次插入48 次删除
  1. 171 48
      extensions/sinks/sql/sql.go
  2. 310 0
      extensions/sinks/sql/sql_test.go

+ 171 - 48
extensions/sinks/sql/sql.go

@@ -20,6 +20,7 @@ import (
 	"fmt"
 	"fmt"
 	"github.com/lf-edge/ekuiper/extensions/sqldatabase/driver"
 	"github.com/lf-edge/ekuiper/extensions/sqldatabase/driver"
 	"github.com/lf-edge/ekuiper/pkg/api"
 	"github.com/lf-edge/ekuiper/pkg/api"
+	"github.com/lf-edge/ekuiper/pkg/ast"
 	"github.com/lf-edge/ekuiper/pkg/cast"
 	"github.com/lf-edge/ekuiper/pkg/cast"
 	"github.com/lf-edge/ekuiper/pkg/errorx"
 	"github.com/lf-edge/ekuiper/pkg/errorx"
 	"github.com/xo/dburl"
 	"github.com/xo/dburl"
@@ -33,11 +34,22 @@ type sqlConfig struct {
 	Fields         []string `json:"fields"`
 	Fields         []string `json:"fields"`
 	DataTemplate   string   `json:"dataTemplate"`
 	DataTemplate   string   `json:"dataTemplate"`
 	TableDataField string   `json:"tableDataField"`
 	TableDataField string   `json:"tableDataField"`
+	RowkindField   string   `json:"rowkindField"`
+	KeyField       string   `json:"keyField"`
 }
 }
 
 
-func (t *sqlConfig) buildSql(ctx api.StreamContext, mapData map[string]interface{}) ([]string, string, error) {
+func (t *sqlConfig) buildInsertSql(ctx api.StreamContext, mapData map[string]interface{}) ([]string, string, error) {
+	keys, vals, err := t.getKeyValues(ctx, mapData)
+	if err != nil {
+		return keys, "", err
+	}
+	sqlStr := "(" + strings.Join(vals, ",") + ")"
+	return keys, sqlStr, nil
+}
+
+func (t *sqlConfig) getKeyValues(ctx api.StreamContext, mapData map[string]interface{}) ([]string, []string, error) {
 	if 0 == len(mapData) {
 	if 0 == len(mapData) {
-		return nil, "", fmt.Errorf("data is empty.")
+		return nil, nil, fmt.Errorf("data is empty.")
 	}
 	}
 	logger := ctx.GetLogger()
 	logger := ctx.GetLogger()
 	var keys, vals []string
 	var keys, vals []string
@@ -66,9 +78,7 @@ func (t *sqlConfig) buildSql(ctx api.StreamContext, mapData map[string]interface
 			}
 			}
 		}
 		}
 	}
 	}
-
-	sqlStr := "(" + strings.Join(vals, ",") + ")"
-	return keys, sqlStr, nil
+	return keys, vals, nil
 }
 }
 
 
 type sqlSink struct {
 type sqlSink struct {
@@ -89,6 +99,9 @@ func (m *sqlSink) Configure(props map[string]interface{}) error {
 	if cfg.Table == "" {
 	if cfg.Table == "" {
 		return fmt.Errorf("property Table is required")
 		return fmt.Errorf("property Table is required")
 	}
 	}
+	if cfg.RowkindField != "" && cfg.KeyField == "" {
+		return fmt.Errorf("keyField is required when rowkindField is set")
+	}
 	m.conf = cfg
 	m.conf = cfg
 	return nil
 	return nil
 }
 }
@@ -108,11 +121,16 @@ func (m *sqlSink) Open(ctx api.StreamContext) (err error) {
 
 
 func (m *sqlSink) writeToDB(ctx api.StreamContext, sqlStr *string) error {
 func (m *sqlSink) writeToDB(ctx api.StreamContext, sqlStr *string) error {
 	ctx.GetLogger().Debugf(*sqlStr)
 	ctx.GetLogger().Debugf(*sqlStr)
-	rows, err := m.db.Query(*sqlStr)
+	r, err := m.db.Exec(*sqlStr)
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("%s: %s", errorx.IOErr, err.Error())
 		return fmt.Errorf("%s: %s", errorx.IOErr, err.Error())
 	}
 	}
-	return rows.Close()
+	d, err := r.RowsAffected()
+	if err != nil {
+		ctx.GetLogger().Errorf("get rows affected error: %s", err.Error())
+	}
+	ctx.GetLogger().Debugf("Rows affected: %d", d)
+	return nil
 }
 }
 
 
 func (m *sqlSink) Collect(ctx api.StreamContext, item interface{}) error {
 func (m *sqlSink) Collect(ctx api.StreamContext, item interface{}) error {
@@ -130,10 +148,12 @@ func (m *sqlSink) Collect(ctx api.StreamContext, item interface{}) error {
 		item = tm
 		item = tm
 	}
 	}
 
 
-	var table string
-	var err error
-	v, ok := item.(map[string]interface{})
-	if ok {
+	var (
+		table string
+		err   error
+	)
+	switch v := item.(type) {
+	case map[string]interface{}:
 		table, err = ctx.ParseTemplate(m.conf.Table, v)
 		table, err = ctx.ParseTemplate(m.conf.Table, v)
 		if err != nil {
 		if err != nil {
 			ctx.GetLogger().Errorf("parse template for table %s error: %v", m.conf.Table, err)
 			ctx.GetLogger().Errorf("parse template for table %s error: %v", m.conf.Table, err)
@@ -142,60 +162,102 @@ func (m *sqlSink) Collect(ctx api.StreamContext, item interface{}) error {
 		if m.conf.TableDataField != "" {
 		if m.conf.TableDataField != "" {
 			item = v[m.conf.TableDataField]
 			item = v[m.conf.TableDataField]
 		}
 		}
+	case []map[string]interface{}:
+		if len(v) == 0 {
+			ctx.GetLogger().Warnf("empty data array")
+			return nil
+		}
+		table, err = ctx.ParseTemplate(m.conf.Table, v[0])
+		if err != nil {
+			ctx.GetLogger().Errorf("parse template for table %s error: %v", m.conf.Table, err)
+			return err
+		}
 	}
 	}
 
 
 	var keys []string = nil
 	var keys []string = nil
 	var values []string = nil
 	var values []string = nil
 	var vars string
 	var vars string
 
 
-	switch v := item.(type) {
-	case []map[string]interface{}:
-		for _, mapData := range v {
-			keys, vars, err = m.conf.buildSql(ctx, mapData)
+	if m.conf.RowkindField == "" {
+		switch v := item.(type) {
+		case []map[string]interface{}:
+			for _, mapData := range v {
+				keys, vars, err = m.conf.buildInsertSql(ctx, mapData)
+				if err != nil {
+					return err
+				}
+				values = append(values, vars)
+			}
+			if keys != nil {
+				sqlStr := fmt.Sprintf("INSERT INTO %s (%s) values ", table, strings.Join(keys, ",")) + strings.Join(values, ",") + ";"
+				return m.writeToDB(ctx, &sqlStr)
+			}
+			return nil
+		case map[string]interface{}:
+			keys, vars, err = m.conf.buildInsertSql(ctx, v)
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
 			values = append(values, vars)
 			values = append(values, vars)
-		}
-		if keys != nil {
-			sqlStr := fmt.Sprintf("INSERT INTO %s (%s) values ", table, strings.Join(keys, ",")) + strings.Join(values, ",") + ";"
-			return m.writeToDB(ctx, &sqlStr)
-		}
-		return nil
-	case map[string]interface{}:
-		keys, vars, err = m.conf.buildSql(ctx, v)
-		if err != nil {
-			return err
-		}
-		values = append(values, vars)
-		if keys != nil {
-			sqlStr := fmt.Sprintf("INSERT INTO %s (%s) values ", table, strings.Join(keys, ",")) + strings.Join(values, ",") + ";"
-			return m.writeToDB(ctx, &sqlStr)
-		}
-		return nil
-	case []interface{}:
-		for _, data := range v {
-			mapData, ok := data.(map[string]interface{})
-			if !ok {
-				ctx.GetLogger().Errorf("unsupported type: %T", data)
-				return fmt.Errorf("unsupported type: %T", data)
+			if keys != nil {
+				sqlStr := fmt.Sprintf("INSERT INTO %s (%s) values ", table, strings.Join(keys, ",")) + strings.Join(values, ",") + ";"
+				return m.writeToDB(ctx, &sqlStr)
 			}
 			}
+			return nil
+		case []interface{}:
+			for _, data := range v {
+				mapData, ok := data.(map[string]interface{})
+				if !ok {
+					ctx.GetLogger().Errorf("unsupported type: %T", data)
+					return fmt.Errorf("unsupported type: %T", data)
+				}
 
 
-			keys, vars, err = m.conf.buildSql(ctx, mapData)
+				keys, vars, err = m.conf.buildInsertSql(ctx, mapData)
+				if err != nil {
+					ctx.GetLogger().Errorf("sql sink build sql error %v for data", err, mapData)
+					return err
+				}
+				values = append(values, vars)
+			}
+
+			if keys != nil {
+				sqlStr := fmt.Sprintf("INSERT INTO %s (%s) values ", table, strings.Join(keys, ",")) + strings.Join(values, ",") + ";"
+				return m.writeToDB(ctx, &sqlStr)
+			}
+			return nil
+		default: // never happen
+			return fmt.Errorf("unsupported type: %T", item)
+		}
+	} else {
+		switch d := item.(type) {
+		case []map[string]interface{}:
+			for _, el := range d {
+				err := m.save(ctx, table, el)
+				if err != nil {
+					ctx.GetLogger().Error(err)
+				}
+			}
+		case map[string]interface{}:
+			err := m.save(ctx, table, d)
 			if err != nil {
 			if err != nil {
-				ctx.GetLogger().Errorf("sql sink build sql error %v for data", err, mapData)
 				return err
 				return err
 			}
 			}
-			values = append(values, vars)
-		}
-
-		if keys != nil {
-			sqlStr := fmt.Sprintf("INSERT INTO %s (%s) values ", table, strings.Join(keys, ",")) + strings.Join(values, ",") + ";"
-			return m.writeToDB(ctx, &sqlStr)
+		case []interface{}:
+			for _, vv := range d {
+				el, ok := vv.(map[string]interface{})
+				if !ok {
+					ctx.GetLogger().Errorf("unsupported type: %T", vv)
+					return fmt.Errorf("unsupported type: %T", vv)
+				}
+				err := m.save(ctx, table, el)
+				if err != nil {
+					ctx.GetLogger().Error(err)
+				}
+			}
+		default:
+			return fmt.Errorf("unrecognized format of %s", item)
 		}
 		}
 		return nil
 		return nil
-	default: // never happen
-		return fmt.Errorf("unsupported type: %T", item)
 	}
 	}
 }
 }
 
 
@@ -206,6 +268,67 @@ func (m *sqlSink) Close(_ api.StreamContext) error {
 	return nil
 	return nil
 }
 }
 
 
+// save save updatable data only to db
+func (m *sqlSink) save(ctx api.StreamContext, table string, data map[string]interface{}) error {
+	rowkind := ast.RowkindInsert
+	c, ok := data[m.conf.RowkindField]
+	if ok {
+		rowkind, ok = c.(string)
+		if !ok {
+			return fmt.Errorf("rowkind field %s is not a string in data %v", m.conf.RowkindField, data)
+		}
+		if rowkind != ast.RowkindInsert && rowkind != ast.RowkindUpdate && rowkind != ast.RowkindDelete {
+			return fmt.Errorf("invalid rowkind %s", rowkind)
+		}
+	}
+	var sqlStr string
+	switch rowkind {
+	case ast.RowkindInsert:
+		keys, vars, err := m.conf.buildInsertSql(ctx, data)
+		if err != nil {
+			return err
+		}
+		values := []string{vars}
+		if keys != nil {
+			sqlStr = fmt.Sprintf("INSERT INTO %s (%s) values ", table, strings.Join(keys, ",")) + strings.Join(values, ",") + ";"
+		}
+	case ast.RowkindUpdate:
+		keyval, ok := data[m.conf.KeyField]
+		if !ok {
+			return fmt.Errorf("field %s does not exist in data %v", m.conf.KeyField, data)
+		}
+		keys, vals, err := m.conf.getKeyValues(ctx, data)
+		if err != nil {
+			return err
+		}
+		sqlStr = fmt.Sprintf("UPDATE %s SET ", table)
+		for i, key := range keys {
+			if i != 0 {
+				sqlStr += ","
+			}
+			sqlStr += fmt.Sprintf("%s=%s", key, vals[i])
+		}
+		if _, ok := keyval.(string); ok {
+			sqlStr += fmt.Sprintf(" WHERE %s = \"%s\";", m.conf.KeyField, keyval)
+		} else {
+			sqlStr += fmt.Sprintf(" WHERE %s = %v;", m.conf.KeyField, keyval)
+		}
+	case ast.RowkindDelete:
+		keyval, ok := data[m.conf.KeyField]
+		if !ok {
+			return fmt.Errorf("field %s does not exist in data %v", m.conf.KeyField, data)
+		}
+		if _, ok := keyval.(string); ok {
+			sqlStr = fmt.Sprintf("DELETE FROM %s WHERE %s = \"%s\";", table, m.conf.KeyField, keyval)
+		} else {
+			sqlStr = fmt.Sprintf("DELETE FROM %s WHERE %s = %v;", table, m.conf.KeyField, keyval)
+		}
+	default:
+		return fmt.Errorf("invalid rowkind %s", rowkind)
+	}
+	return m.writeToDB(ctx, &sqlStr)
+}
+
 func Sql() api.Sink {
 func Sql() api.Sink {
 	return &sqlSink{}
 	return &sqlSink{}
 }
 }

+ 310 - 0
extensions/sinks/sql/sql_test.go

@@ -0,0 +1,310 @@
+// Copyright 2022 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+	"database/sql"
+	"database/sql/driver"
+	"fmt"
+	econf "github.com/lf-edge/ekuiper/internal/conf"
+	"github.com/lf-edge/ekuiper/internal/topo/context"
+	"os"
+	"reflect"
+	"testing"
+)
+
+func TestSingle(t *testing.T) {
+	db, err := sql.Open("sqlite3", "file:test.db")
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	contextLogger := econf.Log.WithField("rule", "test")
+	ctx := context.WithValue(context.Background(), context.LoggerKey, contextLogger)
+	s := &sqlSink{}
+	defer func() {
+		db.Close()
+		s.Close(ctx)
+		err := os.Remove("test.db")
+		if err != nil {
+			fmt.Println(err)
+		}
+	}()
+	_, err = db.Exec("CREATE TABLE IF NOT EXISTS single (id BIGINT PRIMARY KEY, name TEXT NOT NULL, address varchar(20), mobile varchar(20))")
+	if err != nil {
+		panic(err)
+	}
+	err = s.Configure(map[string]interface{}{
+		"url":   "sqlite://test.db",
+		"table": "single",
+	})
+	if err != nil {
+		t.Error(err)
+		return
+	}
+
+	err = s.Open(ctx)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	var data = []map[string]interface{}{
+		{"id": 1, "name": "John", "address": "343", "mobile": "334433"},
+		{"id": 2, "name": "Susan", "address": "34", "mobile": "334433"},
+		{"id": 3, "name": "Susan", "address": "34", "mobile": "334433"},
+	}
+	for _, d := range data {
+		err = s.Collect(ctx, d)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+	}
+	s.Close(ctx)
+	rows, err := db.Query("SELECT * FROM single")
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	act, _ := rowsToMap(rows)
+	exp := []map[string]interface{}{
+		{"id": int64(1), "name": "John", "address": "343", "mobile": "334433"},
+		{"id": int64(2), "name": "Susan", "address": "34", "mobile": "334433"},
+		{"id": int64(3), "name": "Susan", "address": "34", "mobile": "334433"},
+	}
+	if !reflect.DeepEqual(act, exp) {
+		t.Errorf("Expect %v but got %v", exp, act)
+	}
+}
+
+func TestBatch(t *testing.T) {
+	db, err := sql.Open("sqlite3", "file:test.db")
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	contextLogger := econf.Log.WithField("rule", "test")
+	ctx := context.WithValue(context.Background(), context.LoggerKey, contextLogger)
+	s := &sqlSink{}
+	defer func() {
+		db.Close()
+		s.Close(ctx)
+		err := os.Remove("test.db")
+		if err != nil {
+			fmt.Println(err)
+		}
+	}()
+	_, err = db.Exec("CREATE TABLE IF NOT EXISTS batch (id BIGINT PRIMARY KEY, name TEXT NOT NULL)")
+	if err != nil {
+		panic(err)
+	}
+	err = s.Configure(map[string]interface{}{
+		"url":    "sqlite://test.db",
+		"table":  "batch",
+		"fields": []string{"id", "name"},
+	})
+	if err != nil {
+		t.Error(err)
+		return
+	}
+
+	err = s.Open(ctx)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	var data = []map[string]interface{}{
+		{"id": 1, "name": "John", "address": "343", "mobile": "334433"},
+		{"id": 2, "name": "Susan", "address": "34", "mobile": "334433"},
+		{"id": 3, "name": "Susan", "address": "34", "mobile": "334433"},
+	}
+	err = s.Collect(ctx, data)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	s.Close(ctx)
+	rows, err := db.Query("SELECT * FROM batch")
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	act, _ := rowsToMap(rows)
+	exp := []map[string]interface{}{
+		{"id": int64(1), "name": "John"},
+		{"id": int64(2), "name": "Susan"},
+		{"id": int64(3), "name": "Susan"},
+	}
+	if !reflect.DeepEqual(act, exp) {
+		t.Errorf("Expect %v but got %v", exp, act)
+	}
+}
+
+func TestUpdate(t *testing.T) {
+	db, err := sql.Open("sqlite3", "file:test.db")
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	contextLogger := econf.Log.WithField("rule", "test")
+	ctx := context.WithValue(context.Background(), context.LoggerKey, contextLogger)
+	s := &sqlSink{}
+	defer func() {
+		db.Close()
+		s.Close(ctx)
+		err := os.Remove("test.db")
+		if err != nil {
+			fmt.Println(err)
+		}
+	}()
+	_, err = db.Exec("CREATE TABLE IF NOT EXISTS updateTable (id BIGINT PRIMARY KEY, name TEXT NOT NULL)")
+	if err != nil {
+		panic(err)
+	}
+	err = s.Configure(map[string]interface{}{
+		"url":          "sqlite://test.db",
+		"table":        "updateTable",
+		"rowkindField": "action",
+		"keyField":     "id",
+		"fields":       []string{"id", "name"},
+	})
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	err = s.Open(ctx)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	var test = []struct {
+		d []map[string]interface{}
+		b bool
+		r []map[string]interface{}
+	}{
+		{
+			d: []map[string]interface{}{
+				{"id": 1, "name": "John", "address": "343", "mobile": "334433"},
+				{"action": "insert", "id": 2, "name": "Susan", "address": "34", "mobile": "334433"},
+				{"action": "update", "id": 2, "name": "Diana"},
+			},
+			b: true,
+			r: []map[string]interface{}{
+				{"id": int64(1), "name": "John"},
+				{"id": int64(2), "name": "Diana"},
+			},
+		}, {
+			d: []map[string]interface{}{
+				{"id": 4, "name": "Charles", "address": "343", "mobile": "334433"},
+				{"action": "delete", "id": 2},
+				{"action": "update", "id": 1, "name": "Lizz"},
+			},
+			b: false,
+			r: []map[string]interface{}{
+				{"id": int64(1), "name": "Lizz"},
+				{"id": int64(4), "name": "Charles"},
+			},
+		}, {
+			d: []map[string]interface{}{
+				{"action": "upsert", "id": 4, "name": "Charles", "address": "343", "mobile": "334433"},
+				{"action": "update", "id": 3, "name": "Lizz"},
+				{"action": "update", "id": 1, "name": "Philips"},
+			},
+			b: true,
+			r: []map[string]interface{}{
+				{"id": int64(1), "name": "Philips"},
+				{"id": int64(4), "name": "Charles"},
+			},
+		},
+	}
+	for i, tt := range test {
+		if tt.b {
+			err = s.Collect(ctx, tt.d)
+			if err != nil {
+				fmt.Println(err)
+			}
+		} else {
+			for _, d := range tt.d {
+				err = s.Collect(ctx, d)
+				if err != nil {
+					fmt.Println(err)
+				}
+			}
+		}
+		rows, err := db.Query("SELECT * FROM updateTable")
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		act, _ := rowsToMap(rows)
+		if !reflect.DeepEqual(act, tt.r) {
+			t.Errorf("Case %d Expect %v but got %v", i, tt.r, act)
+		}
+	}
+}
+
+func rowsToMap(rows *sql.Rows) ([]map[string]interface{}, error) {
+	cols, _ := rows.Columns()
+
+	types, err := rows.ColumnTypes()
+	if err != nil {
+		return nil, err
+	}
+	var result []map[string]interface{}
+	for rows.Next() {
+		data := make(map[string]interface{})
+		columns := make([]interface{}, len(cols))
+		prepareValues(columns, types, cols)
+
+		err := rows.Scan(columns...)
+		if err != nil {
+			return nil, err
+		}
+		scanIntoMap(data, columns, cols)
+		result = append(result, data)
+	}
+	return result, nil
+}
+
+func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
+	for idx, column := range columns {
+		if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
+			mapValue[column] = reflectValue.Interface()
+			if valuer, ok := mapValue[column].(driver.Valuer); ok {
+				mapValue[column], _ = valuer.Value()
+			} else if b, ok := mapValue[column].(sql.RawBytes); ok {
+				mapValue[column] = string(b)
+			}
+		} else {
+			mapValue[column] = nil
+		}
+	}
+}
+
+func prepareValues(values []interface{}, columnTypes []*sql.ColumnType, columns []string) {
+	if len(columnTypes) > 0 {
+		for idx, columnType := range columnTypes {
+			if columnType.ScanType() != nil {
+				values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
+			} else {
+				values[idx] = new(interface{})
+			}
+		}
+	} else {
+		for idx := range columns {
+			values[idx] = new(interface{})
+		}
+	}
+}