Browse Source

fix: fix oracle query with limit (#1988)

* support oracle limit

Signed-off-by: yisaer <disxiaofei@163.com>

* support oracle limit

Signed-off-by: yisaer <disxiaofei@163.com>

* support oracle limit

Signed-off-by: yisaer <disxiaofei@163.com>

* fix lint

Signed-off-by: yisaer <disxiaofei@163.com>

---------

Signed-off-by: yisaer <disxiaofei@163.com>
Song Gao 1 year ago
parent
commit
04bc98cbab

+ 29 - 1
extensions/sqldatabase/sqlgen/commonSqlDialect.go

@@ -1,4 +1,4 @@
-// Copyright 2022 EMQ Technologies Co., Ltd.
+// Copyright 2022-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -93,3 +93,31 @@ func (q *CommonQueryGenerator) UpdateMaxIndexValue(row map[string]interface{}) {
 		q.IndexValue = v
 	}
 }
+
+type OracleQueryGenerate struct {
+	*CommonQueryGenerator
+}
+
+func NewOracleQueryGenerate(cfg *InternalSqlQueryCfg) SqlQueryGenerator {
+	return &OracleQueryGenerate{
+		CommonQueryGenerator: &CommonQueryGenerator{
+			InternalSqlQueryCfg: cfg,
+		},
+	}
+}
+
+func (q *OracleQueryGenerate) SqlQueryStatement() (string, error) {
+	con, err := q.getCondition()
+	if err != nil {
+		return "", err
+	}
+	query := q.getSelect() + con + q.getOrderby()
+	if q.Limit != 0 {
+		return fmt.Sprintf("select * from (%s) where rownum <= %v", query, q.Limit), nil
+	}
+	return query, nil
+}
+
+func (q *OracleQueryGenerate) UpdateMaxIndexValue(row map[string]interface{}) {
+	q.CommonQueryGenerator.UpdateMaxIndexValue(row)
+}

+ 13 - 1
extensions/sqldatabase/sqlgen/sqlServerDialect_test.go

@@ -1,4 +1,4 @@
-// Copyright 2022 EMQ Technologies Co., Ltd.
+// Copyright 2022-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -17,6 +17,8 @@ package sqlgen
 import (
 	"reflect"
 	"testing"
+
+	"github.com/stretchr/testify/require"
 )
 
 func TestQueryGenerator_SqlQueryStatement(t *testing.T) {
@@ -80,6 +82,16 @@ func TestQueryGenerator_SqlQueryStatement(t *testing.T) {
 	}
 }
 
+func TestOracleQuery(t *testing.T) {
+	s := NewOracleQueryGenerate(&InternalSqlQueryCfg{
+		Table: "t",
+		Limit: 1,
+	})
+	query, err := s.SqlQueryStatement()
+	require.NoError(t, err)
+	require.Equal(t, query, "select * from (select * from t ) where rownum <= 1")
+}
+
 func TestInternalQuery(t *testing.T) {
 	s := NewSqlServerQuery(&InternalSqlQueryCfg{
 		Table:      "table",

+ 2 - 0
extensions/sqldatabase/sqlgen/sqlgen.go

@@ -121,6 +121,8 @@ func GetQueryGenerator(driver string, props map[string]interface{}) (SqlQueryGen
 	switch driver {
 	case "sqlserver":
 		return NewSqlServerQuery(cfg.InternalSqlQueryCfg), nil
+	case "godror", "oracle":
+		return NewOracleQueryGenerate(cfg.InternalSqlQueryCfg), nil
 	default:
 		return NewCommonSqlQuery(cfg.InternalSqlQueryCfg), nil
 	}