123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- // Copyright 2022 EMQ Technologies Co., Ltd.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package main
- import (
- "fmt"
- "path/filepath"
- "sync"
- "github.com/mattn/go-tflite"
- "github.com/lf-edge/ekuiper/internal/conf"
- )
- var ipManager *interpreterManager
- func init() {
- path, err := conf.GetDataLoc()
- if err != nil {
- panic(err)
- }
- ipManager = &interpreterManager{
- registry: make(map[string]*tflite.Interpreter),
- path: filepath.Join(path, "uploads"),
- }
- }
- type interpreterManager struct {
- sync.Mutex
- registry map[string]*tflite.Interpreter
- path string
- }
- func (m *interpreterManager) GetOrCreate(name string) (*tflite.Interpreter, error) {
- m.Lock()
- defer m.Unlock()
- ip, ok := m.registry[name]
- if !ok {
- mf := filepath.Join(m.path, name+".tflite")
- model := tflite.NewModelFromFile(mf)
- if model == nil {
- return nil, fmt.Errorf("fail to load model: %s", mf)
- }
- defer model.Delete()
- options := tflite.NewInterpreterOptions()
- options.SetNumThread(4)
- options.SetErrorReporter(func(msg string, user_data interface{}) {
- fmt.Println(msg)
- }, nil)
- defer options.Delete()
- ip = tflite.NewInterpreter(model, options)
- status := ip.AllocateTensors()
- if status != tflite.OK {
- ip.Delete()
- return nil, fmt.Errorf("allocate failed: %v", status)
- }
- m.registry[name] = ip
- }
- return ip, nil
- }
|