Browse Source

refactor(compression): make compression plugable module

Signed-off-by: Jiyong Huang <huangjy@emqx.io>
Jiyong Huang 1 year ago
parent
commit
d37e78c22e

+ 18 - 17
internal/compressor/compressor.go

@@ -16,28 +16,29 @@ package compressor
 
 import (
 	"fmt"
+	"io"
 
 	"github.com/lf-edge/ekuiper/pkg/message"
 )
 
-const (
-	ZLIB  = "zlib"
-	GZIP  = "gzip"
-	FLATE = "flate"
-	ZSTD  = "zstd"
-)
+type CompressorInstantiator func(name string) (message.Compressor, error)
+
+var compressors = map[string]CompressorInstantiator{}
 
 func GetCompressor(name string) (message.Compressor, error) {
-	switch name {
-	case ZLIB:
-		return newZlibCompressor()
-	case GZIP:
-		return newGzipCompressor()
-	case FLATE:
-		return newFlateCompressor()
-	case ZSTD:
-		return newZstdCompressor()
-	default:
-		return nil, fmt.Errorf("unsupported compressor: %s", name)
+	if instantiator, ok := compressors[name]; ok {
+		return instantiator(name)
+	}
+	return nil, fmt.Errorf("unsupported compressor: %s", name)
+}
+
+type CompressWriterIns func(reader io.Writer) (io.Writer, error)
+
+var compressWriters = map[string]CompressWriterIns{}
+
+func GetCompressWriter(name string, writer io.Writer) (io.Writer, error) {
+	if instantiator, ok := compressWriters[name]; ok {
+		return instantiator(writer)
 	}
+	return nil, fmt.Errorf("unsupported compressor for file: %s", name)
 }

+ 19 - 11
internal/compressor/decompressor.go

@@ -16,21 +16,29 @@ package compressor
 
 import (
 	"fmt"
+	"io"
 
 	"github.com/lf-edge/ekuiper/pkg/message"
 )
 
+type DecompressorInstantiator func(name string) (message.Decompressor, error)
+
+var decompressors = map[string]DecompressorInstantiator{}
+
 func GetDecompressor(name string) (message.Decompressor, error) {
-	switch name {
-	case ZLIB:
-		return newZlibDecompressor()
-	case GZIP:
-		return newGzipDecompressor()
-	case FLATE:
-		return newFlateDecompressor()
-	case ZSTD:
-		return newzstdDecompressor()
-	default:
-		return nil, fmt.Errorf("unsupported decompressor: %s", name)
+	if instantiator, ok := decompressors[name]; ok {
+		return instantiator(name)
+	}
+	return nil, fmt.Errorf("unsupported decompressor: %s", name)
+}
+
+type DecompressReaderIns func(reader io.Reader) (io.ReadCloser, error)
+
+var decompressReaders = map[string]DecompressReaderIns{}
+
+func GetDecompressReader(name string, reader io.Reader) (io.ReadCloser, error) {
+	if instantiator, ok := decompressReaders[name]; ok {
+		return instantiator(reader)
 	}
+	return nil, fmt.Errorf("unsupported decompressor for file: %s", name)
 }

+ 50 - 0
internal/compressor/ext_compressor.go

@@ -0,0 +1,50 @@
+// Copyright 2023 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build compression || !core
+
+package compressor
+
+import (
+	"github.com/lf-edge/ekuiper/internal/compressor/flate"
+	"github.com/lf-edge/ekuiper/internal/compressor/gzip"
+	"github.com/lf-edge/ekuiper/internal/compressor/zlib"
+	"github.com/lf-edge/ekuiper/internal/compressor/zstd"
+	"github.com/lf-edge/ekuiper/pkg/message"
+)
+
+const (
+	ZLIB  = "zlib"
+	GZIP  = "gzip"
+	FLATE = "flate"
+	ZSTD  = "zstd"
+)
+
+func init() {
+	compressors[ZLIB] = func(name string) (message.Compressor, error) {
+		return zlib.NewZlibCompressor()
+	}
+	compressors[GZIP] = func(name string) (message.Compressor, error) {
+		return gzip.NewGzipCompressor()
+	}
+	compressors[FLATE] = func(name string) (message.Compressor, error) {
+		return flate.NewFlateCompressor()
+	}
+	compressors[ZSTD] = func(name string) (message.Compressor, error) {
+		return zstd.NewZstdCompressor()
+	}
+
+	compressWriters[GZIP] = gzip.NewWriter
+	compressWriters[ZSTD] = zstd.NewWriter
+}

+ 43 - 0
internal/compressor/ext_decompressor.go

@@ -0,0 +1,43 @@
+// Copyright 2023 EMQ Technologies Co., Ltd.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build compression || !core
+
+package compressor
+
+import (
+	"github.com/lf-edge/ekuiper/internal/compressor/flate"
+	"github.com/lf-edge/ekuiper/internal/compressor/gzip"
+	"github.com/lf-edge/ekuiper/internal/compressor/zlib"
+	"github.com/lf-edge/ekuiper/internal/compressor/zstd"
+	"github.com/lf-edge/ekuiper/pkg/message"
+)
+
+func init() {
+	decompressors[ZLIB] = func(name string) (message.Decompressor, error) {
+		return zlib.NewZlibDecompressor()
+	}
+	decompressors[GZIP] = func(name string) (message.Decompressor, error) {
+		return gzip.NewGzipDecompressor()
+	}
+	decompressors[FLATE] = func(name string) (message.Decompressor, error) {
+		return flate.NewFlateDecompressor()
+	}
+	decompressors[ZSTD] = func(name string) (message.Decompressor, error) {
+		return zstd.NewzstdDecompressor()
+	}
+
+	decompressReaders[GZIP] = gzip.NewReader
+	decompressReaders[ZSTD] = zstd.NewReader
+}

+ 4 - 4
internal/compressor/flate.go

@@ -1,4 +1,4 @@
-// Copyright 2023 carlclone@gmail.com.
+// Copyright 2023-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-package compressor
+package flate
 
 import (
 	"bytes"
@@ -24,7 +24,7 @@ import (
 	"github.com/lf-edge/ekuiper/internal/conf"
 )
 
-func newFlateCompressor() (*flateCompressor, error) {
+func NewFlateCompressor() (*flateCompressor, error) {
 	flateWriter, err := flate.NewWriter(nil, flate.DefaultCompression)
 	if err != nil {
 		return nil, err
@@ -53,7 +53,7 @@ func (g *flateCompressor) Compress(data []byte) ([]byte, error) {
 	return g.buffer.Bytes(), nil
 }
 
-func newFlateDecompressor() (*flateDecompressor, error) {
+func NewFlateDecompressor() (*flateDecompressor, error) {
 	return &flateDecompressor{reader: flate.NewReader(bytes.NewReader(nil))}, nil
 }
 

+ 12 - 4
internal/compressor/gzip.go

@@ -1,4 +1,4 @@
-// Copyright 2023 carlclone@gmail.com.
+// Copyright 2023-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-package compressor
+package gzip
 
 import (
 	"bytes"
@@ -24,7 +24,7 @@ import (
 	"github.com/lf-edge/ekuiper/internal/conf"
 )
 
-func newGzipCompressor() (*gzipCompressor, error) {
+func NewGzipCompressor() (*gzipCompressor, error) {
 	return &gzipCompressor{
 		writer: gzip.NewWriter(nil),
 	}, nil
@@ -49,7 +49,7 @@ func (g *gzipCompressor) Compress(data []byte) ([]byte, error) {
 	return g.buffer.Bytes(), nil
 }
 
-func newGzipDecompressor() (*gzipDecompressor, error) {
+func NewGzipDecompressor() (*gzipDecompressor, error) {
 	return &gzipDecompressor{}, nil
 }
 
@@ -78,3 +78,11 @@ func (z *gzipDecompressor) Decompress(data []byte) ([]byte, error) {
 	}()
 	return io.ReadAll(z.reader)
 }
+
+func NewReader(r io.Reader) (io.ReadCloser, error) {
+	return gzip.NewReader(r)
+}
+
+func NewWriter(w io.Writer) (io.Writer, error) {
+	return gzip.NewWriter(w), nil
+}

+ 4 - 4
internal/compressor/zlib.go

@@ -1,4 +1,4 @@
-// Copyright 2023 carlclone@gmail.com.
+// Copyright 2023-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-package compressor
+package zlib
 
 import (
 	"bytes"
@@ -24,7 +24,7 @@ import (
 	"github.com/lf-edge/ekuiper/internal/conf"
 )
 
-func newZlibCompressor() (*zlibCompressor, error) {
+func NewZlibCompressor() (*zlibCompressor, error) {
 	return &zlibCompressor{
 		writer: zlib.NewWriter(nil),
 	}, nil
@@ -49,7 +49,7 @@ func (z *zlibCompressor) Compress(data []byte) ([]byte, error) {
 	return z.buffer.Bytes(), nil
 }
 
-func newZlibDecompressor() (*zlibDecompressor, error) {
+func NewZlibDecompressor() (*zlibDecompressor, error) {
 	return &zlibDecompressor{}, nil
 }
 

+ 17 - 4
internal/compressor/zstd.go

@@ -1,4 +1,4 @@
-// Copyright 2023 carlclone@gmail.com.
+// Copyright 2023-2023 EMQ Technologies Co., Ltd.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -12,15 +12,16 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-package compressor
+package zstd
 
 import (
 	"bytes"
+	"io"
 
 	"github.com/klauspost/compress/zstd"
 )
 
-func newZstdCompressor() (*zstdCompressor, error) {
+func NewZstdCompressor() (*zstdCompressor, error) {
 	zstdWriter, err := zstd.NewWriter(nil)
 	if err != nil {
 		return nil, err
@@ -49,7 +50,7 @@ func (g *zstdCompressor) Compress(data []byte) ([]byte, error) {
 	return g.buffer.Bytes(), nil
 }
 
-func newzstdDecompressor() (*zstdDecompressor, error) {
+func NewzstdDecompressor() (*zstdDecompressor, error) {
 	r, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(0))
 	if err != nil {
 		return nil, err
@@ -64,3 +65,15 @@ type zstdDecompressor struct {
 func (z *zstdDecompressor) Decompress(data []byte) ([]byte, error) {
 	return z.decoder.DecodeAll(data, nil)
 }
+
+func NewReader(r io.Reader) (io.ReadCloser, error) {
+	result, err := zstd.NewReader(r)
+	if err != nil {
+		return nil, err
+	}
+	return result.IOReadCloser(), nil
+}
+
+func NewWriter(w io.Writer) (io.Writer, error) {
+	return zstd.NewWriter(w)
+}

+ 5 - 15
internal/io/file/file_source.go

@@ -27,9 +27,7 @@ import (
 	"strings"
 	"time"
 
-	"github.com/klauspost/compress/gzip"
-	"github.com/klauspost/compress/zstd"
-
+	"github.com/lf-edge/ekuiper/internal/compressor"
 	"github.com/lf-edge/ekuiper/internal/conf"
 	"github.com/lf-edge/ekuiper/internal/xsql"
 	"github.com/lf-edge/ekuiper/pkg/api"
@@ -356,22 +354,14 @@ func (fs *FileSource) prepareFile(ctx api.StreamContext, file string) (io.Reader
 		ctx.GetLogger().Error(err)
 		return nil, err
 	}
-	var reader io.ReadCloser
 
-	switch fs.config.Decompression {
-	case GZIP:
-		newReader, err := gzip.NewReader(f)
-		if err != nil {
-			return nil, err
-		}
-		reader = newReader
-	case ZSTD:
-		newReader, err := zstd.NewReader(f)
+	var reader io.ReadCloser
+	if fs.config.Decompression != "" {
+		reader, err = compressor.GetDecompressReader(fs.config.Decompression, f)
 		if err != nil {
 			return nil, err
 		}
-		reader = newReader.IOReadCloser()
-	default:
+	} else {
 		reader = f
 	}
 

+ 6 - 14
internal/io/file/file_writer.go

@@ -21,9 +21,7 @@ import (
 	"os"
 	"time"
 
-	"github.com/klauspost/compress/gzip"
-	"github.com/klauspost/compress/zstd"
-
+	"github.com/lf-edge/ekuiper/internal/compressor"
 	"github.com/lf-edge/ekuiper/internal/conf"
 	"github.com/lf-edge/ekuiper/pkg/api"
 )
@@ -73,21 +71,15 @@ func createFileWriter(ctx api.StreamContext, fn string, ft FileType, headers str
 
 	fws.Compress = compressAlgorithm
 
-	switch compressAlgorithm {
-	case GZIP:
-		fws.fileBuffer = bufio.NewWriter(f)
-		fws.Writer = gzip.NewWriter(fws.fileBuffer)
-	case ZSTD:
+	if compressAlgorithm == "" {
+		fws.Writer = bufio.NewWriter(f)
+	} else {
 		fws.fileBuffer = bufio.NewWriter(f)
-		enc, err := zstd.NewWriter(fws.fileBuffer)
+		fws.Writer, err = compressor.GetCompressWriter(compressAlgorithm, fws.fileBuffer)
 		if err != nil {
-			return nil, err
+			return nil, fmt.Errorf("fail to get compress writer for %s: %v", compressAlgorithm, err)
 		}
-		fws.Writer = enc
-	default:
-		fws.Writer = bufio.NewWriter(f)
 	}
-
 	_, err = fws.Writer.Write(fws.Hook.Header())
 	if err != nil {
 		return nil, err