Browse Source

feat(func): function and aggregate function error message and handling

ngjaying 5 years atrás
parent
commit
03aecba15e
3 changed files with 152 additions and 90 deletions
  1. 3 1
      xsql/ast.go
  2. 122 62
      xsql/funcs_aggregate.go
  3. 27 27
      xsql/plans/project_test.go

+ 3 - 1
xsql/ast.go

@@ -933,7 +933,7 @@ func (a multiValuer) Call(name string, args []interface{}) (interface{}, bool) {
 			if v, ok := valuer.Call(name, args); ok {
 				return v, true
 			} else {
-				return fmt.Errorf("found error \"%s\" when call func %s", v, name), false
+				return fmt.Errorf("call func %s error: %v", name, v), false
 			}
 		}
 	}
@@ -968,6 +968,8 @@ func (a *multiAggregateValuer) Call(name string, args []interface{}) (interface{
 		if a, ok := valuer.(AggregateCallValuer); ok {
 			if v, ok := a.Call(name, args); ok {
 				return v, true
+			} else {
+				return fmt.Errorf("call func %s error: %v", name, v), false
 			}
 		} else if c, ok := valuer.(CallValuer); ok {
 			if singleArgs == nil {

+ 122 - 62
xsql/funcs_aggregate.go

@@ -24,14 +24,20 @@ func (v AggregateFunctionValuer) Call(name string, args []interface{}) (interfac
 		if len(arg0) > 0 {
 			v := getFirstValidArg(arg0)
 			switch v.(type) {
-			case int:
-				return sliceIntTotal(arg0) / len(arg0), true
-			case int64:
-				return sliceIntTotal(arg0) / len(arg0), true
+			case int, int64:
+				if r, err := sliceIntTotal(arg0); err != nil {
+					return err, false
+				} else {
+					return r / len(arg0), true
+				}
 			case float64:
-				return sliceFloatTotal(arg0) / float64(len(arg0)), true
+				if r, err := sliceFloatTotal(arg0); err != nil {
+					return err, false
+				} else {
+					return r / float64(len(arg0)), true
+				}
 			default:
-				return fmt.Errorf("invalid data type for avg function"), false
+				return fmt.Errorf("run avg function error: found invalid arg %[1]T(%[1]v)", v), false
 			}
 		}
 		return 0, true
@@ -44,49 +50,87 @@ func (v AggregateFunctionValuer) Call(name string, args []interface{}) (interfac
 			v := getFirstValidArg(arg0)
 			switch t := v.(type) {
 			case int:
-				return sliceIntMax(arg0, t), true
+				if r, err := sliceIntMax(arg0, t); err != nil {
+					return err, false
+				} else {
+					return r, true
+				}
 			case int64:
-				return sliceIntMax(arg0, int(t)), true
+				if r, err := sliceIntMax(arg0, int(t)); err != nil {
+					return err, false
+				} else {
+					return r, true
+				}
 			case float64:
-				return sliceFloatMax(arg0, t), true
+				if r, err := sliceFloatMax(arg0, t); err != nil {
+					return err, false
+				} else {
+					return r, true
+				}
 			case string:
-				return sliceStringMax(arg0, t), true
+				if r, err := sliceStringMax(arg0, t); err != nil {
+					return err, false
+				} else {
+					return r, true
+				}
 			default:
-				return fmt.Errorf("unsupported data type for avg function"), false
+				return fmt.Errorf("run max function error: found invalid arg %[1]T(%[1]v)", v), false
 			}
 		}
-		return fmt.Errorf("empty data for max function"), false
+		return fmt.Errorf("run max function error: empty data"), false
 	case "min":
 		arg0 := args[0].([]interface{})
 		if len(arg0) > 0 {
 			v := getFirstValidArg(arg0)
 			switch t := v.(type) {
 			case int:
-				return sliceIntMin(arg0, t), true
+				if r, err := sliceIntMin(arg0, t); err != nil {
+					return err, false
+				} else {
+					return r, true
+				}
 			case int64:
-				return sliceIntMin(arg0, int(t)), true
+				if r, err := sliceIntMin(arg0, int(t)); err != nil {
+					return err, false
+				} else {
+					return r, true
+				}
 			case float64:
-				return sliceFloatMin(arg0, t), true
+				if r, err := sliceFloatMin(arg0, t); err != nil {
+					return err, false
+				} else {
+					return r, true
+				}
 			case string:
-				return sliceStringMin(arg0, t), true
+				if r, err := sliceStringMin(arg0, t); err != nil {
+					return err, false
+				} else {
+					return r, true
+				}
 			default:
-				return fmt.Errorf("unsupported data type for avg function"), false
+				return fmt.Errorf("run min function error: found invalid arg %[1]T(%[1]v)", v), false
 			}
 		}
-		return fmt.Errorf("empty data for max function"), false
+		return fmt.Errorf("run min function error: empty data"), false
 	case "sum":
 		arg0 := args[0].([]interface{})
 		if len(arg0) > 0 {
 			v := getFirstValidArg(arg0)
 			switch v.(type) {
-			case int:
-				return sliceIntTotal(arg0), true
-			case int64:
-				return sliceIntTotal(arg0), true
+			case int, int64:
+				if r, err := sliceIntTotal(arg0); err != nil {
+					return err, false
+				} else {
+					return r, true
+				}
 			case float64:
-				return sliceFloatTotal(arg0), true
+				if r, err := sliceFloatTotal(arg0); err != nil {
+					return err, false
+				} else {
+					return r, true
+				}
 			default:
-				return fmt.Errorf("invalid data type for sum function"), false
+				return fmt.Errorf("run sum function error: found invalid arg %[1]T(%[1]v)", v), false
 			}
 		}
 		return 0, true
@@ -122,84 +166,100 @@ func getFirstValidArg(s []interface{}) interface{} {
 	return nil
 }
 
-func sliceIntTotal(s []interface{}) int {
+func sliceIntTotal(s []interface{}) (int, error) {
 	var total int
 	for _, v := range s {
-		if v, ok := v.(int); ok {
-			total += v
+		if vi, ok := v.(int); ok {
+			total += vi
+		} else {
+			return 0, fmt.Errorf("requires int but found %[1]T(%[1]v)", v)
 		}
 	}
-	return total
+	return total, nil
 }
 
-func sliceFloatTotal(s []interface{}) float64 {
+func sliceFloatTotal(s []interface{}) (float64, error) {
 	var total float64
 	for _, v := range s {
-		if v, ok := v.(float64); ok {
-			total += v
+		if vf, ok := v.(float64); ok {
+			total += vf
+		} else {
+			return 0, fmt.Errorf("requires float64 but found %[1]T(%[1]v)", v)
 		}
 	}
-	return total
+	return total, nil
 }
-func sliceIntMax(s []interface{}, max int) int {
+func sliceIntMax(s []interface{}, max int) (int, error) {
 	for _, v := range s {
-		if v, ok := v.(int); ok {
-			if max < v {
-				max = v
+		if vi, ok := v.(int); ok {
+			if max < vi {
+				max = vi
 			}
+		} else {
+			return 0, fmt.Errorf("requires int but found %[1]T(%[1]v)", v)
 		}
 	}
-	return max
+	return max, nil
 }
-func sliceFloatMax(s []interface{}, max float64) float64 {
+func sliceFloatMax(s []interface{}, max float64) (float64, error) {
 	for _, v := range s {
-		if v, ok := v.(float64); ok {
-			if max < v {
-				max = v
+		if vf, ok := v.(float64); ok {
+			if max < vf {
+				max = vf
 			}
+		} else {
+			return 0, fmt.Errorf("requires float64 but found %[1]T(%[1]v)", v)
 		}
 	}
-	return max
+	return max, nil
 }
 
-func sliceStringMax(s []interface{}, max string) string {
+func sliceStringMax(s []interface{}, max string) (string, error) {
 	for _, v := range s {
-		if v, ok := v.(string); ok {
-			if max < v {
-				max = v
+		if vs, ok := v.(string); ok {
+			if max < vs {
+				max = vs
 			}
+		} else {
+			return "", fmt.Errorf("requires string but found %[1]T(%[1]v)", v)
 		}
 	}
-	return max
+	return max, nil
 }
-func sliceIntMin(s []interface{}, min int) int {
+func sliceIntMin(s []interface{}, min int) (int, error) {
 	for _, v := range s {
-		if v, ok := v.(int); ok {
-			if min > v {
-				min = v
+		if vi, ok := v.(int); ok {
+			if min > vi {
+				min = vi
 			}
+		} else {
+			return 0, fmt.Errorf("requires int but found %[1]T(%[1]v)", v)
 		}
 	}
-	return min
+	return min, nil
 }
-func sliceFloatMin(s []interface{}, min float64) float64 {
+func sliceFloatMin(s []interface{}, min float64) (float64, error) {
 	for _, v := range s {
-		if v, ok := v.(float64); ok {
-			if min > v {
-				min = v
+		if vf, ok := v.(float64); ok {
+			if min > vf {
+				min = vf
 			}
+		} else {
+			return 0, fmt.Errorf("requires float64 but found %[1]T(%[1]v)", v)
 		}
 	}
-	return min
+	return min, nil
 }
 
-func sliceStringMin(s []interface{}, min string) string {
+func sliceStringMin(s []interface{}, min string) (string, error) {
 	for _, v := range s {
-		if v, ok := v.(string); ok {
-			if min < v {
-				min = v
+		if vs, ok := v.(string); ok {
+			if min < vs {
+				min = vs
 			}
+		} else {
+			return "", fmt.Errorf("requires string but found %[1]T(%[1]v)", v)
 		}
 	}
-	return min
+	return min, nil
 }

+ 27 - 27
xsql/plans/project_test.go

@@ -1280,7 +1280,7 @@ func TestProjectPlan_AggFuncs(t *testing.T) {
 					Tuples: []xsql.Tuple{
 						{
 							Emitter: "src1",
-							Message: xsql.Message{"b": 53},
+							Message: xsql.Message{"a": 53},
 						}, {
 							Emitter: "src1",
 							Message: xsql.Message{"a": 27},
@@ -1292,28 +1292,8 @@ func TestProjectPlan_AggFuncs(t *testing.T) {
 				},
 			},
 			result: []map[string]interface{}{{
-				"sum": float64(123150),
+				"sum": float64(123203),
 			}},
-		}, {
-			sql: "SELECT sum(a) as sum FROM test GROUP BY TumblingWindow(ss, 10)",
-			data: xsql.WindowTuplesSet{
-				xsql.WindowTuples{
-					Emitter: "test",
-					Tuples: []xsql.Tuple{
-						{
-							Emitter: "src1",
-							Message: xsql.Message{"a": "nan"},
-						}, {
-							Emitter: "src1",
-							Message: xsql.Message{"a": 27},
-						}, {
-							Emitter: "src1",
-							Message: xsql.Message{"a": 123123},
-						},
-					},
-				},
-			},
-			result: []map[string]interface{}{{}},
 		},
 	}
 
@@ -1342,7 +1322,7 @@ func TestProjectPlan_AggFuncs(t *testing.T) {
 				t.Errorf("%d. %q\n\nresult mismatch:\n\nexp=%#v\n\ngot=%#v\n\n", i, tt.sql, tt.result, mapRes)
 			}
 		} else {
-			t.Errorf("The returned result is not type of []byte\n")
+			t.Errorf("%d. %q\n\nThe returned result is not type of []byte: %#v\n", i, tt.sql, result)
 		}
 	}
 }
@@ -1383,7 +1363,7 @@ func TestProjectPlanError(t *testing.T) {
 					"a": "common string",
 				},
 			},
-			result: errors.New("run Select error: found error \"only float64 & int type are supported\" when call func round"),
+			result: errors.New("run Select error: call func round error: only float64 & int type are supported"),
 		}, {
 			sql: `SELECT round(a) as r FROM test`,
 			data: &xsql.Tuple{
@@ -1392,7 +1372,7 @@ func TestProjectPlanError(t *testing.T) {
 					"abc": "common string",
 				},
 			},
-			result: errors.New("run Select error: found error \"only float64 & int type are supported\" when call func round"),
+			result: errors.New("run Select error: call func round error: only float64 & int type are supported"),
 		}, {
 			sql: "SELECT avg(a) as avg FROM test Inner Join test1 on test.id = test1.id GROUP BY TumblingWindow(ss, 10), test1.color",
 			data: xsql.GroupedTuplesSet{
@@ -1437,7 +1417,27 @@ func TestProjectPlanError(t *testing.T) {
 					},
 				},
 			},
-			result: errors.New("run Select error: found error \"%!s(<nil>)\" when call func avg"),
+			result: errors.New("run Select error: call func avg error: requires float64 but found string(dde)"),
+		}, {
+			sql: "SELECT sum(a) as sum FROM test GROUP BY TumblingWindow(ss, 10)",
+			data: xsql.WindowTuplesSet{
+				xsql.WindowTuples{
+					Emitter: "test",
+					Tuples: []xsql.Tuple{
+						{
+							Emitter: "src1",
+							Message: xsql.Message{"a": 53},
+						}, {
+							Emitter: "src1",
+							Message: xsql.Message{"a": "ddd"},
+						}, {
+							Emitter: "src1",
+							Message: xsql.Message{"a": 123123},
+						},
+					},
+				},
+			},
+			result: errors.New("run Select error: call func sum error: requires int but found string(ddd)"),
 		},
 	}
 	fmt.Printf("The test bucket size is %d.\n\n", len(tests))
@@ -1446,7 +1446,7 @@ func TestProjectPlanError(t *testing.T) {
 	for i, tt := range tests {
 		stmt, _ := xsql.NewParser(strings.NewReader(tt.sql)).Parse()
 
-		pp := &ProjectPlan{Fields: stmt.Fields}
+		pp := &ProjectPlan{Fields: stmt.Fields, IsAggregate: xsql.IsAggStatement(stmt)}
 		pp.isTest = true
 		result := pp.Apply(ctx, tt.data)
 		if !reflect.DeepEqual(tt.result, result) {