labelImage.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. // Copyright 2021 EMQ Technologies Co., Ltd.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //go:build tflite
  15. // +build tflite
  16. package main
  17. import (
  18. "bufio"
  19. "bytes"
  20. "fmt"
  21. "github.com/lf-edge/ekuiper/pkg/api"
  22. tflite "github.com/mattn/go-tflite"
  23. "github.com/nfnt/resize"
  24. "image"
  25. _ "image/jpeg"
  26. _ "image/png"
  27. "os"
  28. "path"
  29. "sort"
  30. "sync"
  31. )
  32. type labelImage struct {
  33. modelPath string
  34. labelPath string
  35. once sync.Once
  36. interpreter *tflite.Interpreter
  37. labels []string
  38. }
  39. func (f *labelImage) Validate(args []interface{}) error {
  40. if len(args) != 1 {
  41. return fmt.Errorf("labelImage function only supports 1 parameter but got %d", len(args))
  42. }
  43. return nil
  44. }
  45. func (f *labelImage) Exec(args []interface{}, ctx api.FunctionContext) (interface{}, bool) {
  46. arg0, ok := args[0].([]byte)
  47. if !ok {
  48. return fmt.Errorf("labelImage function parameter must be a bytea, but got %[1]T(%[1]v)", args[0]), false
  49. }
  50. img, _, err := image.Decode(bytes.NewReader(arg0))
  51. if err != nil {
  52. return err, false
  53. }
  54. var outerErr error
  55. f.once.Do(func() {
  56. ploc := path.Join(ctx.GetRootPath(), "etc", "functions")
  57. f.labels, err = loadLabels(path.Join(ploc, f.labelPath))
  58. if err != nil {
  59. outerErr = fmt.Errorf("fail to load labels: %s", err)
  60. return
  61. }
  62. model := tflite.NewModelFromFile(path.Join(ploc, f.modelPath))
  63. if model == nil {
  64. outerErr = fmt.Errorf("fail to load model: %s", err)
  65. return
  66. }
  67. defer model.Delete()
  68. options := tflite.NewInterpreterOptions()
  69. options.SetNumThread(4)
  70. options.SetErrorReporter(func(msg string, user_data interface{}) {
  71. fmt.Println(msg)
  72. }, nil)
  73. defer options.Delete()
  74. interpreter := tflite.NewInterpreter(model, options)
  75. if interpreter == nil {
  76. outerErr = fmt.Errorf("cannot create interpreter")
  77. return
  78. }
  79. status := interpreter.AllocateTensors()
  80. if status != tflite.OK {
  81. outerErr = fmt.Errorf("allocate failed")
  82. interpreter.Delete()
  83. return
  84. }
  85. f.interpreter = interpreter
  86. // TODO If created, the interpreter will be kept through the whole life of kuiper. Refactor this later.
  87. //defer interpreter.Delete()
  88. })
  89. if f.interpreter == nil {
  90. return fmt.Errorf("fail to load model %s %s", f.modelPath, outerErr), false
  91. }
  92. input := f.interpreter.GetInputTensor(0)
  93. wantedHeight := input.Dim(1)
  94. wantedWidth := input.Dim(2)
  95. wantedChannels := input.Dim(3)
  96. wantedType := input.Type()
  97. resized := resize.Resize(uint(wantedWidth), uint(wantedHeight), img, resize.NearestNeighbor)
  98. bounds := resized.Bounds()
  99. dx, dy := bounds.Dx(), bounds.Dy()
  100. if wantedType == tflite.UInt8 {
  101. bb := make([]byte, dx*dy*wantedChannels)
  102. for y := 0; y < dy; y++ {
  103. for x := 0; x < dx; x++ {
  104. col := resized.At(x, y)
  105. r, g, b, _ := col.RGBA()
  106. bb[(y*dx+x)*3+0] = byte(float64(r) / 255.0)
  107. bb[(y*dx+x)*3+1] = byte(float64(g) / 255.0)
  108. bb[(y*dx+x)*3+2] = byte(float64(b) / 255.0)
  109. }
  110. }
  111. input.CopyFromBuffer(bb)
  112. } else {
  113. return fmt.Errorf("is not wanted type"), false
  114. }
  115. status := f.interpreter.Invoke()
  116. if status != tflite.OK {
  117. return fmt.Errorf("invoke failed"), false
  118. }
  119. output := f.interpreter.GetOutputTensor(0)
  120. outputSize := output.Dim(output.NumDims() - 1)
  121. b := make([]byte, outputSize)
  122. type result struct {
  123. score float64
  124. index int
  125. }
  126. status = output.CopyToBuffer(&b[0])
  127. if status != tflite.OK {
  128. return fmt.Errorf("output failed"), false
  129. }
  130. var results []result
  131. for i := 0; i < outputSize; i++ {
  132. score := float64(b[i]) / 255.0
  133. if score < 0.2 {
  134. continue
  135. }
  136. results = append(results, result{score: score, index: i})
  137. }
  138. sort.Slice(results, func(i, j int) bool {
  139. return results[i].score > results[j].score
  140. })
  141. // output is the biggest score labelImage
  142. if len(results) > 0 {
  143. return f.labels[results[0].index], true
  144. } else {
  145. return "", true
  146. }
  147. }
  148. func (f *labelImage) IsAggregate() bool {
  149. return false
  150. }
  151. func loadLabels(filename string) ([]string, error) {
  152. labels := []string{}
  153. f, err := os.Open(filename)
  154. if err != nil {
  155. return nil, err
  156. }
  157. defer f.Close()
  158. scanner := bufio.NewScanner(f)
  159. for scanner.Scan() {
  160. labels = append(labels, scanner.Text())
  161. }
  162. return labels, nil
  163. }
  164. var LabelImage = labelImage{
  165. modelPath: "labelImage/mobilenet_quant_v1_224.tflite",
  166. labelPath: "labelImage/labels.txt",
  167. }