fieldConverterSingleton.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. // Copyright 2022-2023 EMQ Technologies Co., Ltd.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package protobuf
  15. import (
  16. "fmt"
  17. // TODO: replace with `google.golang.org/protobuf/proto` pkg.
  18. "github.com/golang/protobuf/proto" //nolint:staticcheck
  19. dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
  20. "github.com/jhump/protoreflect/desc"
  21. "github.com/jhump/protoreflect/dynamic"
  22. "github.com/lf-edge/ekuiper/pkg/cast"
  23. )
  24. const (
  25. WrapperBool = "google.protobuf.BoolValue"
  26. WrapperBytes = "google.protobuf.BytesValue"
  27. WrapperDouble = "google.protobuf.DoubleValue"
  28. WrapperFloat = "google.protobuf.FloatValue"
  29. WrapperInt32 = "google.protobuf.Int32Value"
  30. WrapperInt64 = "google.protobuf.Int64Value"
  31. WrapperString = "google.protobuf.StringValue"
  32. WrapperUInt32 = "google.protobuf.UInt32Value"
  33. WrapperUInt64 = "google.protobuf.UInt64Value"
  34. WrapperVoid = "google.protobuf.EMPTY"
  35. )
  36. var WRAPPER_TYPES = map[string]struct{}{
  37. WrapperBool: {},
  38. WrapperBytes: {},
  39. WrapperDouble: {},
  40. WrapperFloat: {},
  41. WrapperInt32: {},
  42. WrapperInt64: {},
  43. WrapperString: {},
  44. WrapperUInt32: {},
  45. WrapperUInt64: {},
  46. }
  47. var (
  48. fieldConverterIns = &FieldConverter{}
  49. mf = dynamic.NewMessageFactoryWithDefaults()
  50. )
  51. type FieldConverter struct{}
  52. func GetFieldConverter() *FieldConverter {
  53. return fieldConverterIns
  54. }
  55. func (fc *FieldConverter) encodeMap(im *desc.MessageDescriptor, i interface{}) (*dynamic.Message, error) {
  56. result := mf.NewDynamicMessage(im)
  57. fields := im.GetFields()
  58. if m, ok := i.(map[string]interface{}); ok {
  59. for _, field := range fields {
  60. v, ok := m[field.GetName()]
  61. if !ok {
  62. if field.IsRequired() {
  63. return nil, fmt.Errorf("field %s not found", field.GetName())
  64. } else {
  65. v = field.GetDefaultValue()
  66. }
  67. }
  68. fv, err := fc.EncodeField(field, v)
  69. if err != nil {
  70. return nil, err
  71. }
  72. result.SetFieldByName(field.GetName(), fv)
  73. }
  74. }
  75. return result, nil
  76. }
  77. func (fc *FieldConverter) EncodeField(field *desc.FieldDescriptor, v interface{}) (interface{}, error) {
  78. fn := field.GetName()
  79. ft := field.GetType()
  80. if field.IsRepeated() {
  81. var (
  82. result interface{}
  83. err error
  84. )
  85. switch ft {
  86. case dpb.FieldDescriptorProto_TYPE_DOUBLE:
  87. result, err = cast.ToFloat64Slice(v, cast.STRICT)
  88. case dpb.FieldDescriptorProto_TYPE_FLOAT:
  89. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  90. r, err := cast.ToFloat32(input, sn)
  91. if err != nil {
  92. return 0, nil
  93. } else {
  94. return r, nil
  95. }
  96. }, "float", cast.STRICT)
  97. case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32:
  98. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  99. r, err := cast.ToInt(input, sn)
  100. if err != nil {
  101. return 0, nil
  102. } else {
  103. return int32(r), nil
  104. }
  105. }, "int", cast.STRICT)
  106. case dpb.FieldDescriptorProto_TYPE_INT64, dpb.FieldDescriptorProto_TYPE_SFIXED64, dpb.FieldDescriptorProto_TYPE_SINT64:
  107. result, err = cast.ToInt64Slice(v, cast.STRICT)
  108. case dpb.FieldDescriptorProto_TYPE_FIXED32, dpb.FieldDescriptorProto_TYPE_UINT32:
  109. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  110. r, err := cast.ToUint64(input, sn)
  111. if err != nil {
  112. return 0, nil
  113. } else {
  114. return uint32(r), nil
  115. }
  116. }, "uint", cast.STRICT)
  117. case dpb.FieldDescriptorProto_TYPE_FIXED64, dpb.FieldDescriptorProto_TYPE_UINT64:
  118. result, err = cast.ToUint64Slice(v, cast.STRICT)
  119. case dpb.FieldDescriptorProto_TYPE_BOOL:
  120. result, err = cast.ToBoolSlice(v, cast.STRICT)
  121. case dpb.FieldDescriptorProto_TYPE_STRING:
  122. result, err = cast.ToStringSlice(v, cast.STRICT)
  123. case dpb.FieldDescriptorProto_TYPE_BYTES:
  124. result, err = cast.ToBytesSlice(v, cast.STRICT)
  125. case dpb.FieldDescriptorProto_TYPE_MESSAGE:
  126. result, err = cast.ToTypedSlice(v, func(input interface{}, sn cast.Strictness) (interface{}, error) {
  127. r, err := cast.ToStringMap(v)
  128. if err == nil {
  129. return fc.encodeMap(field.GetMessageType(), r)
  130. } else {
  131. return nil, fmt.Errorf("invalid type for map type field '%s': %v", fn, err)
  132. }
  133. }, "map", cast.STRICT)
  134. default:
  135. return nil, fmt.Errorf("invalid type for field '%s'", fn)
  136. }
  137. if err != nil {
  138. err = fmt.Errorf("failed to encode field '%s':%v", fn, err)
  139. }
  140. return result, err
  141. } else {
  142. return fc.encodeSingleField(field, v)
  143. }
  144. }
  145. func (fc *FieldConverter) encodeSingleField(field *desc.FieldDescriptor, v interface{}) (interface{}, error) {
  146. fn := field.GetName()
  147. switch field.GetType() {
  148. case dpb.FieldDescriptorProto_TYPE_DOUBLE:
  149. r, err := cast.ToFloat64(v, cast.STRICT)
  150. if err == nil {
  151. return r, nil
  152. } else {
  153. return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
  154. }
  155. case dpb.FieldDescriptorProto_TYPE_FLOAT:
  156. r, err := cast.ToFloat32(v, cast.STRICT)
  157. if err == nil {
  158. return r, nil
  159. } else {
  160. return nil, fmt.Errorf("invalid type for float type field '%s': %v", fn, err)
  161. }
  162. case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32, dpb.FieldDescriptorProto_TYPE_ENUM:
  163. r, err := cast.ToInt(v, cast.STRICT)
  164. if err == nil {
  165. return int32(r), nil
  166. } else {
  167. return nil, fmt.Errorf("invalid type for int type field '%s': %v", fn, err)
  168. }
  169. case dpb.FieldDescriptorProto_TYPE_INT64, dpb.FieldDescriptorProto_TYPE_SFIXED64, dpb.FieldDescriptorProto_TYPE_SINT64:
  170. r, err := cast.ToInt64(v, cast.STRICT)
  171. if err == nil {
  172. return r, nil
  173. } else {
  174. return nil, fmt.Errorf("invalid type for int type field '%s': %v", fn, err)
  175. }
  176. case dpb.FieldDescriptorProto_TYPE_FIXED32, dpb.FieldDescriptorProto_TYPE_UINT32:
  177. r, err := cast.ToUint64(v, cast.STRICT)
  178. if err == nil {
  179. return uint32(r), nil
  180. } else {
  181. return nil, fmt.Errorf("invalid type for uint type field '%s': %v", fn, err)
  182. }
  183. case dpb.FieldDescriptorProto_TYPE_FIXED64, dpb.FieldDescriptorProto_TYPE_UINT64:
  184. r, err := cast.ToUint64(v, cast.STRICT)
  185. if err == nil {
  186. return r, nil
  187. } else {
  188. return nil, fmt.Errorf("invalid type for uint type field '%s': %v", fn, err)
  189. }
  190. case dpb.FieldDescriptorProto_TYPE_BOOL:
  191. r, err := cast.ToBool(v, cast.STRICT)
  192. if err == nil {
  193. return r, nil
  194. } else {
  195. return nil, fmt.Errorf("invalid type for bool type field '%s': %v", fn, err)
  196. }
  197. case dpb.FieldDescriptorProto_TYPE_STRING:
  198. r, err := cast.ToString(v, cast.STRICT)
  199. if err == nil {
  200. return r, nil
  201. } else {
  202. return nil, fmt.Errorf("invalid type for string type field '%s': %v", fn, err)
  203. }
  204. case dpb.FieldDescriptorProto_TYPE_BYTES:
  205. r, err := cast.ToBytes(v, cast.STRICT)
  206. if err == nil {
  207. return r, nil
  208. } else {
  209. return nil, fmt.Errorf("invalid type for bytes type field '%s': %v", fn, err)
  210. }
  211. case dpb.FieldDescriptorProto_TYPE_MESSAGE:
  212. r, err := cast.ToStringMap(v)
  213. if err == nil {
  214. return fc.encodeMap(field.GetMessageType(), r)
  215. } else {
  216. return nil, fmt.Errorf("invalid type for map type field '%s': %v", fn, err)
  217. }
  218. default:
  219. return nil, fmt.Errorf("invalid type for field '%s'", fn)
  220. }
  221. }
  222. func (fc *FieldConverter) DecodeField(src interface{}, field *desc.FieldDescriptor, sn cast.Strictness) (interface{}, error) {
  223. var (
  224. r interface{}
  225. e error
  226. )
  227. fn := field.GetName()
  228. switch field.GetType() {
  229. case dpb.FieldDescriptorProto_TYPE_DOUBLE, dpb.FieldDescriptorProto_TYPE_FLOAT:
  230. if field.IsRepeated() {
  231. r, e = cast.ToFloat64Slice(src, sn)
  232. } else {
  233. r, e = cast.ToFloat64(src, sn)
  234. }
  235. case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32, dpb.FieldDescriptorProto_TYPE_INT64, dpb.FieldDescriptorProto_TYPE_SFIXED64, dpb.FieldDescriptorProto_TYPE_SINT64, dpb.FieldDescriptorProto_TYPE_FIXED32, dpb.FieldDescriptorProto_TYPE_UINT32, dpb.FieldDescriptorProto_TYPE_FIXED64, dpb.FieldDescriptorProto_TYPE_UINT64, dpb.FieldDescriptorProto_TYPE_ENUM:
  236. if field.IsRepeated() {
  237. r, e = cast.ToInt64Slice(src, sn)
  238. } else {
  239. r, e = cast.ToInt64(src, sn)
  240. }
  241. case dpb.FieldDescriptorProto_TYPE_BOOL:
  242. if field.IsRepeated() {
  243. r, e = cast.ToBoolSlice(src, sn)
  244. } else {
  245. r, e = cast.ToBool(src, sn)
  246. }
  247. case dpb.FieldDescriptorProto_TYPE_STRING:
  248. if field.IsRepeated() {
  249. r, e = cast.ToStringSlice(src, sn)
  250. } else {
  251. r, e = cast.ToString(src, sn)
  252. }
  253. case dpb.FieldDescriptorProto_TYPE_BYTES:
  254. if field.IsRepeated() {
  255. r, e = cast.ToBytesSlice(src, sn)
  256. } else {
  257. r, e = cast.ToBytes(src, sn)
  258. }
  259. case dpb.FieldDescriptorProto_TYPE_MESSAGE:
  260. if field.IsRepeated() {
  261. r, e = cast.ToTypedSlice(src, func(input interface{}, ssn cast.Strictness) (interface{}, error) {
  262. return fc.decodeSubMessage(input, field.GetMessageType(), ssn)
  263. }, "map", sn)
  264. } else {
  265. r, e = fc.decodeSubMessage(src, field.GetMessageType(), sn)
  266. }
  267. default:
  268. return nil, fmt.Errorf("unsupported type for %s", fn)
  269. }
  270. if e != nil {
  271. e = fmt.Errorf("invalid type of return value for '%s': %v", fn, e)
  272. }
  273. return r, e
  274. }
  275. func (fc *FieldConverter) decodeSubMessage(input interface{}, ft *desc.MessageDescriptor, sn cast.Strictness) (interface{}, error) {
  276. m := map[string]interface{}{}
  277. switch v := input.(type) {
  278. case map[interface{}]interface{}:
  279. for k, val := range v {
  280. m[cast.ToStringAlways(k)] = val
  281. }
  282. return fc.DecodeMap(m, ft, sn)
  283. case map[string]interface{}:
  284. return fc.DecodeMap(v, ft, sn)
  285. case proto.Message:
  286. message, err := dynamic.AsDynamicMessage(v)
  287. if err != nil {
  288. return nil, err
  289. }
  290. return fc.DecodeMessage(message, ft), nil
  291. default:
  292. return nil, fmt.Errorf("cannot decode %[1]T(%[1]v) to map", input)
  293. }
  294. }
  295. func (fc *FieldConverter) DecodeMap(src map[string]interface{}, ft *desc.MessageDescriptor, sn cast.Strictness) (map[string]interface{}, error) {
  296. result := make(map[string]interface{})
  297. for _, field := range ft.GetFields() {
  298. val, ok := src[field.GetName()]
  299. if !ok {
  300. continue
  301. }
  302. err := fc.decodeMessageField(val, field, result, sn)
  303. if err != nil {
  304. return nil, err
  305. }
  306. }
  307. return result, nil
  308. }
  309. func (fc *FieldConverter) decodeMessageField(src interface{}, field *desc.FieldDescriptor, result map[string]interface{}, sn cast.Strictness) error {
  310. if f, err := fc.DecodeField(src, field, sn); err != nil {
  311. return err
  312. } else {
  313. result[field.GetName()] = f
  314. return nil
  315. }
  316. }
  317. func (fc *FieldConverter) DecodeMessage(message *dynamic.Message, outputType *desc.MessageDescriptor) interface{} {
  318. if _, ok := WRAPPER_TYPES[outputType.GetFullyQualifiedName()]; ok {
  319. return message.GetFieldByNumber(1)
  320. } else if WrapperVoid == outputType.GetFullyQualifiedName() {
  321. return nil
  322. }
  323. result := make(map[string]interface{})
  324. for _, field := range outputType.GetFields() {
  325. fc.decodeMessageField(message.GetField(field), field, result, cast.STRICT)
  326. }
  327. return result
  328. }