Browse Source

refactor: refactor dbPool

Signed-off-by: yisaer <disxiaofei@163.com>
Song Gao 2 years ago
parent
commit
dd3d07d85e

+ 2 - 11
extensions/sinks/sql/sql.go

@@ -86,9 +86,6 @@ type sqlSink struct {
 	conf *sqlConfig
 	//The db connection instance
 	db *sql.DB
-
-	driver string
-	dsn    string
 }
 
 func (m *sqlSink) Configure(props map[string]interface{}) error {
@@ -107,19 +104,13 @@ func (m *sqlSink) Configure(props map[string]interface{}) error {
 		return fmt.Errorf("keyField is required when rowkindField is set")
 	}
 	m.conf = cfg
-	sqlDriver, dsn, err := util.ParseDBUrl(m.conf.Url)
-	if err != nil {
-		return err
-	}
-	m.driver = sqlDriver
-	m.dsn = dsn
 	return nil
 }
 
 func (m *sqlSink) Open(ctx api.StreamContext) (err error) {
 	logger := ctx.GetLogger()
 	logger.Debugf("Opening sql sink")
-	db, err := util.FetchDBToOneNode(util.GlobalPool, m.driver, m.dsn)
+	db, err := util.FetchDBToOneNode(util.GlobalPool, m.conf.Url)
 	if err != nil {
 		logger.Errorf("support build tags are %v", driver.KnownBuildTags())
 		return err
@@ -272,7 +263,7 @@ func (m *sqlSink) Collect(ctx api.StreamContext, item interface{}) error {
 
 func (m *sqlSink) Close(_ api.StreamContext) error {
 	if m.db != nil {
-		return util.ReturnDBFromOneNode(util.GlobalPool, m.driver, m.dsn)
+		return util.ReturnDBFromOneNode(util.GlobalPool, m.conf.Url)
 	}
 	return nil
 }

+ 2 - 12
extensions/sources/sql/sql.go

@@ -37,9 +37,6 @@ type sqlsource struct {
 	Query sqlgen.SqlQueryGenerator
 	//The db connection instance
 	db *sql.DB
-
-	driver string
-	dsn    string
 }
 
 func (m *sqlsource) Configure(_ string, props map[string]interface{}) error {
@@ -68,14 +65,7 @@ func (m *sqlsource) Configure(_ string, props map[string]interface{}) error {
 
 	m.Query = generator
 	m.conf = cfg
-
-	driver, dsn, err := util.ParseDBUrl(m.conf.Url)
-	if err != nil {
-		return err
-	}
-	m.driver = driver
-	m.dsn = dsn
-	db, err := util.FetchDBToOneNode(util.GlobalPool, driver, dsn)
+	db, err := util.FetchDBToOneNode(util.GlobalPool, m.conf.Url)
 	if err != nil {
 		return fmt.Errorf("connection to %s Open with error %v, support build tags are %v", m.conf.Url, err, driver2.KnownBuildTags())
 	}
@@ -146,7 +136,7 @@ func (m *sqlsource) Close(ctx api.StreamContext) error {
 	logger := ctx.GetLogger()
 	logger.Debugf("Closing sql stream to %v", m.conf)
 	if m.db != nil {
-		return util.ReturnDBFromOneNode(util.GlobalPool, m.driver, m.dsn)
+		return util.ReturnDBFromOneNode(util.GlobalPool, m.conf.Url)
 	}
 
 	return nil

+ 2 - 11
extensions/sources/sql/sqlLookup.go

@@ -32,15 +32,12 @@ type sqlLookupSource struct {
 	url   string
 	table string
 	db    *sql.DB
-
-	driver string
-	dsn    string
 }
 
 // Open establish a connection to the database
 func (s *sqlLookupSource) Open(ctx api.StreamContext) error {
 	ctx.GetLogger().Debugf("Opening sql lookup source")
-	db, err := util.FetchDBToOneNode(util.GlobalPool, s.driver, s.dsn)
+	db, err := util.FetchDBToOneNode(util.GlobalPool, s.url)
 	if err != nil {
 		return fmt.Errorf("connection to %s Open with error %v", s.url, err)
 	}
@@ -63,12 +60,6 @@ func (s *sqlLookupSource) Configure(datasource string, props map[string]interfac
 	}
 	s.url = cfg.Url
 	s.table = datasource
-	driver, dsn, err := util.ParseDBUrl(s.url)
-	if err != nil {
-		return err
-	}
-	s.driver = driver
-	s.dsn = dsn
 	return nil
 }
 
@@ -128,7 +119,7 @@ func (s *sqlLookupSource) Close(ctx api.StreamContext) error {
 	ctx.GetLogger().Debugf("Closing sql lookup source")
 	defer func() { s.db = nil }()
 	if s.db != nil {
-		return util.ReturnDBFromOneNode(util.GlobalPool, s.driver, s.dsn)
+		return util.ReturnDBFromOneNode(util.GlobalPool, s.url)
 	}
 	return nil
 }

+ 0 - 35
extensions/util/dburl.go

@@ -1,35 +0,0 @@
-// Copyright 2023 EMQ Technologies Co., Ltd.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package util
-
-import (
-	"database/sql"
-	"strings"
-
-	"github.com/xo/dburl"
-)
-
-// Open returns *sql.DB from urlstr
-// As we use modernc.org/sqlite with `sqlite` as driver name and dburl use `sqlite3` as driver name, we need to fix it before open sql.DB
-func Open(urlstr string) (*sql.DB, error) {
-	u, err := dburl.Parse(urlstr)
-	if err != nil {
-		return nil, err
-	}
-	if strings.ToLower(u.Driver) == "sqlite3" {
-		u.Driver = "sqlite"
-	}
-	return sql.Open(u.Driver, u.DSN)
-}

+ 38 - 69
extensions/util/pool.go

@@ -23,69 +23,68 @@ import (
 	"github.com/xo/dburl"
 )
 
-var GlobalPool *driverPool
+var GlobalPool *dbPool
 
 func init() {
 	// GlobalPool maintained the *sql.DB group by the driver and DSN.
 	// Multiple sql sources/sinks can directly fetch the `*sql.DB` from the GlobalPool and return it back when they don't need it anymore.
 	// As multiple sql sources/sinks share the same `*sql.DB`, we can directly control the total count of connections by using `SetMaxOpenConns`
-	GlobalPool = newDriverPool()
+	GlobalPool = newDBPool()
 }
 
-type driverPool struct {
+type dbPool struct {
 	isTesting bool
 
 	sync.RWMutex
-	pool map[string]*dbPool
+	// url -> *sql.DB
+	pool map[string]*sql.DB
+	// url -> connection count
+	connections map[string]int
 }
 
-func newDriverPool() *driverPool {
-	return &driverPool{
-		pool: map[string]*dbPool{},
+func newDBPool() *dbPool {
+	return &dbPool{
+		pool:        map[string]*sql.DB{},
+		connections: map[string]int{},
 	}
 }
 
-type dbPool struct {
-	isTesting bool
-	driver    string
-
-	sync.RWMutex
-	pool        map[string]*sql.DB
-	connections map[string]int
-}
-
-func (dp *dbPool) getDBConnCount(dsn string) int {
+func (dp *dbPool) getDBConnCount(url string) int {
 	dp.RLock()
 	defer dp.RUnlock()
-	count, ok := dp.connections[dsn]
+	count, ok := dp.connections[url]
 	if ok {
 		return count
 	}
 	return 0
 }
 
-func (dp *dbPool) getOrCreate(dsn string) (*sql.DB, error) {
+func (dp *dbPool) getOrCreate(url string) (*sql.DB, error) {
 	dp.Lock()
 	defer dp.Unlock()
-	db, ok := dp.pool[dsn]
+	db, ok := dp.pool[url]
 	if ok {
-		dp.connections[dsn] = dp.connections[dsn] + 1
+		dp.connections[url] = dp.connections[url] + 1
 		return db, nil
 	}
-	newDb, err := openDB(dp.driver, dsn, dp.isTesting)
+	newDb, err := openDB(url, dp.isTesting)
 	if err != nil {
 		return nil, err
 	}
-	conf.Log.Debugf("create new database instance: %v", dsn)
-	dp.pool[dsn] = newDb
-	dp.connections[dsn] = 1
+	conf.Log.Debugf("create new database instance: %v", url)
+	dp.pool[url] = newDb
+	dp.connections[url] = 1
 	return newDb, nil
 }
 
-func openDB(driver, dsn string, isTesting bool) (*sql.DB, error) {
+func openDB(url string, isTesting bool) (*sql.DB, error) {
 	if isTesting {
 		return nil, nil
 	}
+	driver, dsn, err := ParseDBUrl(url)
+	if err != nil {
+		return nil, err
+	}
 	db, err := sql.Open(driver, dsn)
 	if err != nil {
 		return nil, err
@@ -97,53 +96,29 @@ func openDB(driver, dsn string, isTesting bool) (*sql.DB, error) {
 	return db, nil
 }
 
-func (dp *dbPool) closeOneConn(dsn string) error {
+func (dp *dbPool) closeOneConn(url string) error {
 	dp.Lock()
 	defer dp.Unlock()
-	connCount, ok := dp.connections[dsn]
+	connCount, ok := dp.connections[url]
 	if !ok {
 		return nil
 	}
 	connCount--
 	if connCount > 0 {
-		dp.connections[dsn] = connCount
+		dp.connections[url] = connCount
 		return nil
 	}
-	conf.Log.Debugf("drop database instance: %v", dsn)
-	db := dp.pool[dsn]
+	conf.Log.Debugf("drop database instance: %v", url)
+	db := dp.pool[url]
 	// remove db instance from map in order to avoid memory leak
-	delete(dp.pool, dsn)
-	delete(dp.connections, dsn)
+	delete(dp.pool, url)
+	delete(dp.connections, url)
 	if dp.isTesting {
 		return nil
 	}
 	return db.Close()
 }
 
-func (dp *driverPool) getOrCreate(driver string) *dbPool {
-	dp.Lock()
-	defer dp.Unlock()
-	db, ok := dp.pool[driver]
-	if ok {
-		return db
-	}
-	newDB := &dbPool{
-		isTesting:   dp.isTesting,
-		driver:      driver,
-		pool:        map[string]*sql.DB{},
-		connections: map[string]int{},
-	}
-	dp.pool[driver] = newDB
-	return newDB
-}
-
-func (dp *driverPool) get(driver string) (*dbPool, bool) {
-	dp.RLock()
-	defer dp.RUnlock()
-	dbPool, ok := dp.pool[driver]
-	return dbPool, ok
-}
-
 func ParseDBUrl(urlstr string) (string, string, error) {
 	u, err := dburl.Parse(urlstr)
 	if err != nil {
@@ -157,20 +132,14 @@ func ParseDBUrl(urlstr string) (string, string, error) {
 	return u.Driver, u.DSN, nil
 }
 
-func FetchDBToOneNode(driverPool *driverPool, driver, dsn string) (*sql.DB, error) {
-	dbPool := driverPool.getOrCreate(driver)
-	return dbPool.getOrCreate(dsn)
+func FetchDBToOneNode(pool *dbPool, url string) (*sql.DB, error) {
+	return pool.getOrCreate(url)
 }
 
-func ReturnDBFromOneNode(driverPool *driverPool, driver, dsn string) error {
-	dbPool, ok := driverPool.get(driver)
-	if !ok {
-		return nil
-	}
-	return dbPool.closeOneConn(dsn)
+func ReturnDBFromOneNode(pool *dbPool, url string) error {
+	return pool.closeOneConn(url)
 }
 
-func getDBConnCount(driverPool *driverPool, driver, dsn string) int {
-	dbPool := driverPool.getOrCreate(driver)
-	return dbPool.getDBConnCount(dsn)
+func getDBConnCount(pool *dbPool, url string) int {
+	return pool.getDBConnCount(url)
 }

+ 6 - 7
extensions/util/pool_test.go

@@ -20,9 +20,8 @@ import (
 )
 
 func TestDriverPool(t *testing.T) {
-	driver := "mysql"
-	dsn := "root@127.0.0.1:4000/mock"
-	testPool := newDriverPool()
+	url := "mock"
+	testPool := newDBPool()
 	testPool.isTesting = true
 
 	expCount := 3
@@ -33,14 +32,14 @@ func TestDriverPool(t *testing.T) {
 			defer func() {
 				wg.Done()
 			}()
-			_, err := FetchDBToOneNode(testPool, driver, dsn)
+			_, err := FetchDBToOneNode(testPool, url)
 			if err != nil {
 				t.Errorf("meet unexpected err:%v", err)
 			}
 		}()
 	}
 	wg.Wait()
-	count := getDBConnCount(testPool, driver, dsn)
+	count := getDBConnCount(testPool, url)
 	if expCount != count {
 		t.Errorf("expect conn count:%v, got:%v", expCount, count)
 	}
@@ -51,14 +50,14 @@ func TestDriverPool(t *testing.T) {
 			defer func() {
 				wg.Done()
 			}()
-			err := ReturnDBFromOneNode(testPool, driver, dsn)
+			err := ReturnDBFromOneNode(testPool, url)
 			if err != nil {
 				t.Errorf("meet unexpected err:%v", err)
 			}
 		}()
 	}
 	wg.Wait()
-	count = getDBConnCount(testPool, driver, dsn)
+	count = getDBConnCount(testPool, url)
 	if count != 0 {
 		t.Errorf("expect conn count:%v, got:%v", 0, count)
 	}