auth_test.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package middleware
  2. import (
  3. "io"
  4. "net/http"
  5. "net/http/httptest"
  6. "reflect"
  7. "testing"
  8. "github.com/lf-edge/ekuiper/internal/pkg/jwt"
  9. )
  10. func genToken(signKeyName, issuer, aud string) string {
  11. tkStr, _ := jwt.CreateToken(signKeyName, issuer, aud)
  12. return tkStr
  13. }
  14. func Test_AUTH(t *testing.T) {
  15. nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  16. w.WriteHeader(http.StatusOK)
  17. })
  18. handler := Auth(nextHandler)
  19. type args struct {
  20. th string
  21. }
  22. tests := []struct {
  23. name string
  24. args args
  25. req *http.Request
  26. res *httptest.ResponseRecorder
  27. wantCode int
  28. }{
  29. {
  30. name: "token right",
  31. args: args{th: genToken("sample_key", "sample_key.pub", "eKuiper")},
  32. req: httptest.NewRequest(http.MethodGet, "http://127.0.0.1:9081/streams", nil),
  33. res: httptest.NewRecorder(),
  34. wantCode: 200,
  35. },
  36. {
  37. name: "audience not right",
  38. args: args{th: genToken("sample_key", "sample_key.pub", "Neuron")},
  39. req: httptest.NewRequest(http.MethodGet, "http://127.0.0.1:9081/streams", nil),
  40. res: httptest.NewRecorder(),
  41. wantCode: 401,
  42. },
  43. {
  44. name: "no token",
  45. args: args{th: ""},
  46. req: httptest.NewRequest(http.MethodGet, "http://127.0.0.1:9081/streams", nil),
  47. res: httptest.NewRecorder(),
  48. wantCode: 401,
  49. },
  50. {
  51. name: "no need token path",
  52. args: args{th: ""},
  53. req: httptest.NewRequest(http.MethodGet, "http://127.0.0.1:9081/ping", nil),
  54. res: httptest.NewRecorder(),
  55. wantCode: 200,
  56. },
  57. }
  58. for _, tt := range tests {
  59. t.Run(tt.name, func(t *testing.T) {
  60. tt.req.Header.Set("Authorization", tt.args.th)
  61. handler.ServeHTTP(tt.res, tt.req)
  62. res := tt.res.Result()
  63. data, err := io.ReadAll(res.Body)
  64. if err != nil {
  65. t.Errorf("expected error to be nil got %v", err)
  66. }
  67. if !reflect.DeepEqual(tt.wantCode, tt.res.Code) {
  68. t.Errorf("expect %d, actual %d, result %s", tt.wantCode, tt.res.Code, string(data))
  69. }
  70. _ = res.Body.Close()
  71. })
  72. }
  73. }