client_test.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package cnn
  2. import (
  3. "encoding/json"
  4. "net/http"
  5. "net/http/httptest"
  6. "testing"
  7. )
  8. func TestNewClientDefaults(t *testing.T) {
  9. c := NewClient(" http://localhost:8080/ ", "", 0)
  10. if c.Endpoint != "http://localhost:8080" {
  11. t.Errorf("endpoint not trimmed: %q", c.Endpoint)
  12. }
  13. if c.HTTP.Timeout != DefaultTimeout {
  14. t.Errorf("expected default timeout, got %v", c.HTTP.Timeout)
  15. }
  16. }
  17. func TestAuthHeaderSentWhenTokenSet(t *testing.T) {
  18. var gotAuth string
  19. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  20. gotAuth = r.Header.Get("Authorization")
  21. w.Header().Set("Content-Type", "application/json")
  22. w.Write([]byte(`{"status":"ok","version":"0.1.0"}`))
  23. }))
  24. defer srv.Close()
  25. c := NewClient(srv.URL, "tok123", 0)
  26. if _, err := c.Health(); err != nil {
  27. t.Fatalf("unexpected error: %v", err)
  28. }
  29. if gotAuth != "Bearer tok123" {
  30. t.Errorf("expected Bearer header, got %q", gotAuth)
  31. }
  32. }
  33. func TestAuthHeaderOmittedWhenNoToken(t *testing.T) {
  34. var gotAuth string
  35. called := false
  36. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  37. called = true
  38. gotAuth = r.Header.Get("Authorization")
  39. w.Write([]byte(`{"status":"ok"}`))
  40. }))
  41. defer srv.Close()
  42. c := NewClient(srv.URL, "", 0)
  43. if _, err := c.Health(); err != nil {
  44. t.Fatalf("unexpected error: %v", err)
  45. }
  46. if !called {
  47. t.Fatal("server was not called")
  48. }
  49. if gotAuth != "" {
  50. t.Errorf("expected no Authorization header, got %q", gotAuth)
  51. }
  52. }
  53. func TestErrorEnvelopeDecoded(t *testing.T) {
  54. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  55. w.WriteHeader(http.StatusNotFound)
  56. json.NewEncoder(w).Encode(map[string]any{
  57. "error": map[string]any{
  58. "message": `model "x" is not available`,
  59. "type": "not_found_error",
  60. "code": "model_not_found",
  61. },
  62. })
  63. }))
  64. defer srv.Close()
  65. c := NewClient(srv.URL, "", 0)
  66. _, _, err := c.Detect([]byte{1, 2, 3}, "image/png", RequestOptions{Model: "x"})
  67. if err == nil {
  68. t.Fatal("expected an error")
  69. }
  70. apiErr, ok := err.(*APIError)
  71. if !ok {
  72. t.Fatalf("expected *APIError, got %T: %v", err, err)
  73. }
  74. if apiErr.Status != http.StatusNotFound || apiErr.Code != "model_not_found" {
  75. t.Errorf("unexpected error fields: %+v", apiErr)
  76. }
  77. }
  78. func TestAsyncSubmissionReturnsJob(t *testing.T) {
  79. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  80. w.WriteHeader(http.StatusAccepted)
  81. json.NewEncoder(w).Encode(Job{ID: "job-1", Object: "job", Status: "queued", Created: 1})
  82. }))
  83. defer srv.Close()
  84. c := NewClient(srv.URL, "", 0)
  85. result, job, err := c.Detect([]byte{1, 2, 3}, "image/png", RequestOptions{Async: true})
  86. if err != nil {
  87. t.Fatalf("unexpected error: %v", err)
  88. }
  89. if result != nil {
  90. t.Errorf("expected nil result on async submission, got %+v", result)
  91. }
  92. if job == nil || job.ID != "job-1" || job.Status != "queued" {
  93. t.Fatalf("unexpected job: %+v", job)
  94. }
  95. }
  96. func TestEndpointNotConfigured(t *testing.T) {
  97. c := NewClient("", "", 0)
  98. if _, err := c.Health(); err == nil {
  99. t.Fatal("expected an error when endpoint is empty")
  100. }
  101. }
  102. func TestGetJob(t *testing.T) {
  103. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  104. if r.URL.Path != "/v1/jobs/job-1" {
  105. t.Errorf("unexpected path: %s", r.URL.Path)
  106. }
  107. json.NewEncoder(w).Encode(Job{ID: "job-1", Status: "succeeded", Result: json.RawMessage(`{"object":"image.detection"}`)})
  108. }))
  109. defer srv.Close()
  110. c := NewClient(srv.URL, "", 0)
  111. job, err := c.GetJob("job-1")
  112. if err != nil {
  113. t.Fatalf("unexpected error: %v", err)
  114. }
  115. if job.Status != "succeeded" {
  116. t.Errorf("unexpected status: %s", job.Status)
  117. }
  118. }
  119. func TestHealthAndListModelsAndGetModel(t *testing.T) {
  120. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  121. switch r.URL.Path {
  122. case "/v1/health":
  123. json.NewEncoder(w).Encode(Health{Status: "ok", Version: "0.1.0", ModelsLoaded: 12, UptimeS: 100})
  124. case "/v1/models":
  125. json.NewEncoder(w).Encode(ModelList{Object: "list", Data: []ModelInfo{{ID: "yolo11n", Object: "model", Task: "detection"}}})
  126. case "/v1/models/yolo11n":
  127. json.NewEncoder(w).Encode(ModelInfo{ID: "yolo11n", Object: "model", Task: "detection", Classes: 80, Input: 640})
  128. default:
  129. http.NotFound(w, r)
  130. }
  131. }))
  132. defer srv.Close()
  133. c := NewClient(srv.URL, "", 0)
  134. h, err := c.Health()
  135. if err != nil || h.Status != "ok" || h.ModelsLoaded != 12 {
  136. t.Fatalf("unexpected health: %+v, err=%v", h, err)
  137. }
  138. models, err := c.ListModels()
  139. if err != nil || len(models.Data) != 1 || models.Data[0].ID != "yolo11n" {
  140. t.Fatalf("unexpected models: %+v, err=%v", models, err)
  141. }
  142. model, err := c.GetModel("yolo11n")
  143. if err != nil || model.Classes != 80 {
  144. t.Fatalf("unexpected model: %+v, err=%v", model, err)
  145. }
  146. }