Browse Source

feat(plugin): support limit connections for sql(#1761)

* support limit connections for sql sink

Signed-off-by: yisaer <disxiaofei@163.com>

* use config

Signed-off-by: yisaer <disxiaofei@163.com>

* use config

Signed-off-by: yisaer <disxiaofei@163.com>

* revise test

Signed-off-by: yisaer <disxiaofei@163.com>

* revise test

Signed-off-by: yisaer <disxiaofei@163.com>

* revise test

Signed-off-by: yisaer <disxiaofei@163.com>

* address the comment

Signed-off-by: yisaer <disxiaofei@163.com>

* add config

Signed-off-by: yisaer <disxiaofei@163.com>

---------

Signed-off-by: yisaer <disxiaofei@163.com>
Song Gao 2 years ago
parent
commit
95a1dd5825

+ 4 - 0
etc/kuiper.yaml

@@ -29,6 +29,10 @@ basic:
   pluginHosts: https://packages.emqx.net
   # Whether to ignore case in SQL processing. Note that, the name of customized function by plugins are case-sensitive.
   ignoreCase: true
+  sql:
+    # maxConnections indicates the max connections for the certain database instance group by driver and dsn sharing between the sources/sinks
+    # 0 indicates unlimited
+    maxConnections: 0
 
 # The default options for all rules. Each rule can override this setting by defining its own option
 rule:

+ 11 - 3
extensions/sinks/sql/sql.go

@@ -86,6 +86,9 @@ type sqlSink struct {
 	conf *sqlConfig
 	//The db connection instance
 	db *sql.DB
+
+	driver string
+	dsn    string
 }
 
 func (m *sqlSink) Configure(props map[string]interface{}) error {
@@ -104,14 +107,19 @@ func (m *sqlSink) Configure(props map[string]interface{}) error {
 		return fmt.Errorf("keyField is required when rowkindField is set")
 	}
 	m.conf = cfg
+	sqlDriver, dsn, err := util.ParseDBUrl(m.conf.Url)
+	if err != nil {
+		return err
+	}
+	m.driver = sqlDriver
+	m.dsn = dsn
 	return nil
 }
 
 func (m *sqlSink) Open(ctx api.StreamContext) (err error) {
 	logger := ctx.GetLogger()
 	logger.Debugf("Opening sql sink")
-
-	db, err := util.Open(m.conf.Url)
+	db, err := util.FetchDBToOneNode(util.GlobalPool, m.driver, m.dsn)
 	if err != nil {
 		logger.Errorf("support build tags are %v", driver.KnownBuildTags())
 		return err
@@ -264,7 +272,7 @@ func (m *sqlSink) Collect(ctx api.StreamContext, item interface{}) error {
 
 func (m *sqlSink) Close(_ api.StreamContext) error {
 	if m.db != nil {
-		return m.db.Close()
+		return util.ReturnDBFromOneNode(util.GlobalPool, m.driver, m.dsn)
 	}
 	return nil
 }

+ 11 - 2
extensions/sources/sql/sql.go

@@ -37,6 +37,9 @@ type sqlsource struct {
 	Query sqlgen.SqlQueryGenerator
 	//The db connection instance
 	db *sql.DB
+
+	driver string
+	dsn    string
 }
 
 func (m *sqlsource) Configure(_ string, props map[string]interface{}) error {
@@ -66,7 +69,13 @@ func (m *sqlsource) Configure(_ string, props map[string]interface{}) error {
 	m.Query = generator
 	m.conf = cfg
 
-	db, err := util.Open(m.conf.Url)
+	driver, dsn, err := util.ParseDBUrl(m.conf.Url)
+	if err != nil {
+		return err
+	}
+	m.driver = driver
+	m.dsn = dsn
+	db, err := util.FetchDBToOneNode(util.GlobalPool, driver, dsn)
 	if err != nil {
 		return fmt.Errorf("connection to %s Open with error %v, support build tags are %v", m.conf.Url, err, driver2.KnownBuildTags())
 	}
@@ -137,7 +146,7 @@ func (m *sqlsource) Close(ctx api.StreamContext) error {
 	logger := ctx.GetLogger()
 	logger.Debugf("Closing sql stream to %v", m.conf)
 	if m.db != nil {
-		_ = m.db.Close()
+		return util.ReturnDBFromOneNode(util.GlobalPool, m.driver, m.dsn)
 	}
 
 	return nil

+ 11 - 2
extensions/sources/sql/sqlLookup.go

@@ -32,12 +32,15 @@ type sqlLookupSource struct {
 	url   string
 	table string
 	db    *sql.DB
+
+	driver string
+	dsn    string
 }
 
 // Open establish a connection to the database
 func (s *sqlLookupSource) Open(ctx api.StreamContext) error {
 	ctx.GetLogger().Debugf("Opening sql lookup source")
-	db, err := util.Open(s.url)
+	db, err := util.FetchDBToOneNode(util.GlobalPool, s.driver, s.dsn)
 	if err != nil {
 		return fmt.Errorf("connection to %s Open with error %v", s.url, err)
 	}
@@ -60,6 +63,12 @@ func (s *sqlLookupSource) Configure(datasource string, props map[string]interfac
 	}
 	s.url = cfg.Url
 	s.table = datasource
+	driver, dsn, err := util.ParseDBUrl(s.url)
+	if err != nil {
+		return err
+	}
+	s.driver = driver
+	s.dsn = dsn
 	return nil
 }
 
@@ -119,7 +128,7 @@ func (s *sqlLookupSource) Close(ctx api.StreamContext) error {
 	ctx.GetLogger().Debugf("Closing sql lookup source")
 	defer func() { s.db = nil }()
 	if s.db != nil {
-		return s.db.Close()
+		return util.ReturnDBFromOneNode(util.GlobalPool, s.driver, s.dsn)
 	}
 	return nil
 }

+ 176 - 0
extensions/util/pool.go

@@ -0,0 +1,176 @@
+// Copyright 2023 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package util
+
+import (
+	"database/sql"
+	"strings"
+	"sync"
+
+	"github.com/lf-edge/ekuiper/internal/conf"
+	"github.com/xo/dburl"
+)
+
+var GlobalPool *driverPool
+
+func init() {
+	// GlobalPool maintained the *sql.DB group by the driver and DSN.
+	// Multiple sql sources/sinks can directly fetch the `*sql.DB` from the GlobalPool and return it back when they don't need it anymore.
+	// As multiple sql sources/sinks share the same `*sql.DB`, we can directly control the total count of connections by using `SetMaxOpenConns`
+	GlobalPool = newDriverPool()
+}
+
+type driverPool struct {
+	isTesting bool
+
+	sync.RWMutex
+	pool map[string]*dbPool
+}
+
+func newDriverPool() *driverPool {
+	return &driverPool{
+		pool: map[string]*dbPool{},
+	}
+}
+
+type dbPool struct {
+	isTesting bool
+	driver    string
+
+	sync.RWMutex
+	pool        map[string]*sql.DB
+	connections map[string]int
+}
+
+func (dp *dbPool) getDBConnCount(dsn string) int {
+	dp.RLock()
+	defer dp.RUnlock()
+	count, ok := dp.connections[dsn]
+	if ok {
+		return count
+	}
+	return 0
+}
+
+func (dp *dbPool) getOrCreate(dsn string) (*sql.DB, error) {
+	dp.Lock()
+	defer dp.Unlock()
+	db, ok := dp.pool[dsn]
+	if ok {
+		dp.connections[dsn] = dp.connections[dsn] + 1
+		return db, nil
+	}
+	newDb, err := openDB(dp.driver, dsn, dp.isTesting)
+	if err != nil {
+		return nil, err
+	}
+	conf.Log.Debugf("create new database instance: %v", dsn)
+	dp.pool[dsn] = newDb
+	dp.connections[dsn] = 1
+	return newDb, nil
+}
+
+func openDB(driver, dsn string, isTesting bool) (*sql.DB, error) {
+	if isTesting {
+		return nil, nil
+	}
+	db, err := sql.Open(driver, dsn)
+	if err != nil {
+		return nil, err
+	}
+	c := conf.Config
+	if c != nil && c.Basic.SQLConf != nil && c.Basic.SQLConf.MaxConnections > 0 {
+		db.SetMaxOpenConns(c.Basic.SQLConf.MaxConnections)
+	}
+	return db, nil
+}
+
+func (dp *dbPool) closeOneConn(dsn string) error {
+	dp.Lock()
+	defer dp.Unlock()
+	connCount, ok := dp.connections[dsn]
+	if !ok {
+		return nil
+	}
+	connCount--
+	if connCount > 0 {
+		dp.connections[dsn] = connCount
+		return nil
+	}
+	conf.Log.Debugf("drop database instance: %v", dsn)
+	db := dp.pool[dsn]
+	// remove db instance from map in order to avoid memory leak
+	delete(dp.pool, dsn)
+	delete(dp.connections, dsn)
+	if dp.isTesting {
+		return nil
+	}
+	return db.Close()
+}
+
+func (dp *driverPool) getOrCreate(driver string) *dbPool {
+	dp.Lock()
+	defer dp.Unlock()
+	db, ok := dp.pool[driver]
+	if ok {
+		return db
+	}
+	newDB := &dbPool{
+		isTesting:   dp.isTesting,
+		driver:      driver,
+		pool:        map[string]*sql.DB{},
+		connections: map[string]int{},
+	}
+	dp.pool[driver] = newDB
+	return newDB
+}
+
+func (dp *driverPool) get(driver string) (*dbPool, bool) {
+	dp.RLock()
+	defer dp.RUnlock()
+	dbPool, ok := dp.pool[driver]
+	return dbPool, ok
+}
+
+func ParseDBUrl(urlstr string) (string, string, error) {
+	u, err := dburl.Parse(urlstr)
+	if err != nil {
+		return "", "", err
+	}
+	// Open returns *sql.DB from urlstr
+	// As we use modernc.org/sqlite with `sqlite` as driver name and dburl use `sqlite3` as driver name, we need to fix it before open sql.DB
+	if strings.ToLower(u.Driver) == "sqlite3" {
+		u.Driver = "sqlite"
+	}
+	return u.Driver, u.DSN, nil
+}
+
+func FetchDBToOneNode(driverPool *driverPool, driver, dsn string) (*sql.DB, error) {
+	dbPool := driverPool.getOrCreate(driver)
+	return dbPool.getOrCreate(dsn)
+}
+
+func ReturnDBFromOneNode(driverPool *driverPool, driver, dsn string) error {
+	dbPool, ok := driverPool.get(driver)
+	if !ok {
+		return nil
+	}
+	return dbPool.closeOneConn(dsn)
+}
+
+func getDBConnCount(driverPool *driverPool, driver, dsn string) int {
+	dbPool := driverPool.getOrCreate(driver)
+	return dbPool.getDBConnCount(dsn)
+}

+ 65 - 0
extensions/util/pool_test.go

@@ -0,0 +1,65 @@
+// Copyright 2023 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package util
+
+import (
+	"sync"
+	"testing"
+)
+
+func TestDriverPool(t *testing.T) {
+	driver := "mysql"
+	dsn := "root@127.0.0.1:4000/mock"
+	testPool := newDriverPool()
+	testPool.isTesting = true
+
+	expCount := 3
+	wg := sync.WaitGroup{}
+	wg.Add(expCount)
+	for i := 0; i < expCount; i++ {
+		go func() {
+			defer func() {
+				wg.Done()
+			}()
+			_, err := FetchDBToOneNode(testPool, driver, dsn)
+			if err != nil {
+				t.Errorf("meet unexpected err:%v", err)
+			}
+		}()
+	}
+	wg.Wait()
+	count := getDBConnCount(testPool, driver, dsn)
+	if expCount != count {
+		t.Errorf("expect conn count:%v, got:%v", expCount, count)
+	}
+
+	wg.Add(expCount)
+	for i := 0; i < expCount; i++ {
+		go func() {
+			defer func() {
+				wg.Done()
+			}()
+			err := ReturnDBFromOneNode(testPool, driver, dsn)
+			if err != nil {
+				t.Errorf("meet unexpected err:%v", err)
+			}
+		}()
+	}
+	wg.Wait()
+	count = getDBConnCount(testPool, driver, dsn)
+	if count != 0 {
+		t.Errorf("expect conn count:%v, got:%v", 0, count)
+	}
+}

+ 5 - 0
internal/conf/conf.go

@@ -112,6 +112,10 @@ func (sc *SourceConf) Validate() error {
 	return e
 }
 
+type SQLConf struct {
+	MaxConnections int `yaml:"maxConnections"`
+}
+
 type KuiperConf struct {
 	Basic struct {
 		Debug          bool     `yaml:"debug"`
@@ -129,6 +133,7 @@ type KuiperConf struct {
 		PluginHosts    string   `yaml:"pluginHosts"`
 		Authentication bool     `yaml:"authentication"`
 		IgnoreCase     bool     `yaml:"ignoreCase"`
+		SQLConf        *SQLConf `yaml:"sql"`
 	}
 	Rule   api.RuleOption
 	Sink   *SinkConf