agi.aimodel_test.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package agi
  2. import (
  3. "encoding/json"
  4. "io"
  5. "net/http"
  6. "net/http/httptest"
  7. "net/url"
  8. "path/filepath"
  9. "strings"
  10. "testing"
  11. "github.com/robertkrimen/otto"
  12. "imuslab.com/arozos/mod/agi/static"
  13. database "imuslab.com/arozos/mod/database"
  14. user "imuslab.com/arozos/mod/user"
  15. )
  16. // dbGateway returns a Gateway backed by a throwaway bolt database so the
  17. // config / pricing / metrics persistence paths can be exercised in tests.
  18. func dbGateway(t *testing.T) *Gateway {
  19. t.Helper()
  20. dbfile := filepath.Join(t.TempDir(), "test.db")
  21. sysdb, err := database.NewDatabase(dbfile, false)
  22. if err != nil {
  23. t.Fatalf("failed to create test database: %v", err)
  24. }
  25. t.Cleanup(func() { sysdb.Close() })
  26. uh, err := user.NewUserHandler(sysdb, nil, nil, nil, nil)
  27. if err != nil {
  28. t.Fatalf("failed to create user handler: %v", err)
  29. }
  30. g := minimalGateway()
  31. g.Option.UserHandler = uh
  32. sysdb.NewTable(aiModelDBTable)
  33. return g
  34. }
  35. // ─── pure helpers ─────────────────────────────────────────────────────────────
  36. func TestParseAIModelOptions(t *testing.T) {
  37. if opt := parseAIModelOptions(""); opt.Model != "" {
  38. t.Errorf("empty string should yield zero options")
  39. }
  40. if opt := parseAIModelOptions("undefined"); opt.Model != "" {
  41. t.Errorf("'undefined' should yield zero options")
  42. }
  43. if opt := parseAIModelOptions("null"); opt.Model != "" {
  44. t.Errorf("'null' should yield zero options")
  45. }
  46. opt := parseAIModelOptions(`{"model":"gpt-4o","system":"be brief","temperature":0.5,"max_tokens":42}`)
  47. if opt.Model != "gpt-4o" || opt.System != "be brief" {
  48. t.Errorf("unexpected parse: %+v", opt)
  49. }
  50. if opt.Temperature == nil || *opt.Temperature != 0.5 {
  51. t.Errorf("temperature not parsed")
  52. }
  53. if opt.MaxTokens == nil || *opt.MaxTokens != 42 {
  54. t.Errorf("max_tokens not parsed")
  55. }
  56. }
  57. func TestAIModelMaskKey(t *testing.T) {
  58. cases := map[string]string{
  59. "": "",
  60. "abc": "•••",
  61. "sk-1234567890": "••••7890",
  62. }
  63. for in, want := range cases {
  64. if got := aiModelMaskKey(in); got != want {
  65. t.Errorf("maskKey(%q) = %q, want %q", in, got, want)
  66. }
  67. }
  68. }
  69. func TestAIModelExtClassification(t *testing.T) {
  70. if !aiModelIsImageExt(".png") || !aiModelIsImageExt(".jpeg") {
  71. t.Error("expected image extensions to be detected")
  72. }
  73. if aiModelIsImageExt(".txt") {
  74. t.Error(".txt should not be an image")
  75. }
  76. if !aiModelIsTextExt(".md") || !aiModelIsTextExt(".go") {
  77. t.Error("expected text extensions to be detected")
  78. }
  79. if aiModelIsTextExt(".png") {
  80. t.Error(".png should not be classified as text")
  81. }
  82. }
  83. // ─── persistence ──────────────────────────────────────────────────────────────
  84. func TestRecordAIModelUsageAccumulatesAndCosts(t *testing.T) {
  85. g := dbGateway(t)
  86. sysdb := g.Option.UserHandler.GetDatabase()
  87. //Pricing: $2.50 / 1M input, $10.00 / 1M output
  88. sysdb.Write(aiModelDBTable, "pricing", map[string]AIModelPricing{
  89. "test-model": {InputPrice: 2.5, OutputPrice: 10.0},
  90. })
  91. g.recordAIModelUsage("test-model", 1000, 500)
  92. g.recordAIModelUsage("test-model", 1000, 500)
  93. m := g.getAIModelMetrics()
  94. if m.TotalRequests != 2 {
  95. t.Errorf("expected 2 requests, got %d", m.TotalRequests)
  96. }
  97. if m.TotalPromptTokens != 2000 || m.TotalCompletionTokens != 1000 || m.TotalTokens != 3000 {
  98. t.Errorf("unexpected token totals: %+v", m)
  99. }
  100. //Each call: 1000/1e6*2.5 + 500/1e6*10 = 0.0075 ; two calls => 0.015
  101. if got := m.TotalCost; got < 0.01499 || got > 0.01501 {
  102. t.Errorf("expected total cost ~0.015, got %v", got)
  103. }
  104. rec := m.PerModel["test-model"]
  105. if rec == nil || rec.Requests != 2 || rec.TotalTokens != 3000 {
  106. t.Errorf("per-model record incorrect: %+v", rec)
  107. }
  108. }
  109. func TestGetAIModelConfigDefaultsCurrency(t *testing.T) {
  110. g := dbGateway(t)
  111. cfg := g.getAIModelConfig()
  112. if cfg.Currency != "USD" {
  113. t.Errorf("expected default currency USD, got %q", cfg.Currency)
  114. }
  115. }
  116. // ─── full request flow against a mock OpenAI-compatible server ──────────────────
  117. func TestAIModelDoRequestFlow(t *testing.T) {
  118. var gotPath, gotAuth, gotModel string
  119. var sawUserMessage bool
  120. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  121. gotPath = r.URL.Path
  122. gotAuth = r.Header.Get("Authorization")
  123. body, _ := io.ReadAll(r.Body)
  124. var req aiChatRequest
  125. json.Unmarshal(body, &req)
  126. gotModel = req.Model
  127. for _, msg := range req.Messages {
  128. if msg.Role == "user" {
  129. sawUserMessage = true
  130. }
  131. }
  132. w.Header().Set("Content-Type", "application/json")
  133. io.WriteString(w, `{"model":"test-model",
  134. "choices":[{"index":0,"message":{"role":"assistant","content":"Hello from mock"},"finish_reason":"stop"}],
  135. "usage":{"prompt_tokens":1000,"completion_tokens":500,"total_tokens":1500}}`)
  136. }))
  137. defer srv.Close()
  138. g := dbGateway(t)
  139. sysdb := g.Option.UserHandler.GetDatabase()
  140. sysdb.Write(aiModelDBTable, "config", AIModelConfig{
  141. Endpoint: srv.URL,
  142. APIKey: "test-key",
  143. DefaultModel: "test-model",
  144. Currency: "USD",
  145. })
  146. sysdb.Write(aiModelDBTable, "pricing", map[string]AIModelPricing{
  147. "test-model": {InputPrice: 2.5, OutputPrice: 10.0},
  148. })
  149. resp, err := g.aiModelDoRequest("", []aiChatMessage{{Role: "user", Content: "hi"}}, aiChatOptions{})
  150. if err != nil {
  151. t.Fatalf("aiModelDoRequest returned error: %v", err)
  152. }
  153. if content := aiModelExtractContent(resp); content != "Hello from mock" {
  154. t.Errorf("unexpected content: %q", content)
  155. }
  156. if gotPath != "/chat/completions" {
  157. t.Errorf("expected /chat/completions, got %q", gotPath)
  158. }
  159. if gotAuth != "Bearer test-key" {
  160. t.Errorf("expected bearer auth header, got %q", gotAuth)
  161. }
  162. if gotModel != "test-model" {
  163. t.Errorf("expected default model to be used, got %q", gotModel)
  164. }
  165. if !sawUserMessage {
  166. t.Error("server did not receive a user message")
  167. }
  168. //Metrics should have been recorded from the usage block
  169. m := g.getAIModelMetrics()
  170. if m.TotalRequests != 1 || m.TotalTokens != 1500 {
  171. t.Errorf("metrics not recorded after request: %+v", m)
  172. }
  173. }
  174. func TestAIModelDoRequestNoEndpoint(t *testing.T) {
  175. g := dbGateway(t)
  176. _, err := g.aiModelDoRequest("m", []aiChatMessage{{Role: "user", Content: "hi"}}, aiChatOptions{})
  177. if err == nil {
  178. t.Error("expected error when endpoint is not configured")
  179. }
  180. }
  181. // ─── config handler masking ─────────────────────────────────────────────────────
  182. func TestHandleAIModelConfigMaskingAndKeyRetention(t *testing.T) {
  183. g := dbGateway(t)
  184. sysdb := g.Option.UserHandler.GetDatabase()
  185. sysdb.Write(aiModelDBTable, "config", AIModelConfig{
  186. Endpoint: "https://api.example.com/v1", APIKey: "sk-supersecret9999", DefaultModel: "m", Currency: "USD",
  187. })
  188. //GET should mask the key
  189. rec := httptest.NewRecorder()
  190. g.HandleAIModelConfig(rec, httptest.NewRequest("GET", "/system/aimodel/config", nil))
  191. var got map[string]interface{}
  192. json.Unmarshal(rec.Body.Bytes(), &got)
  193. if got["hasKey"] != true {
  194. t.Errorf("expected hasKey true, got %v", got["hasKey"])
  195. }
  196. if hint, _ := got["keyHint"].(string); !strings.HasSuffix(hint, "9999") || strings.Contains(hint, "supersecret") {
  197. t.Errorf("key not properly masked: %v", got["keyHint"])
  198. }
  199. //POST without apikey should retain the saved key, but update endpoint
  200. form := url.Values{}
  201. form.Set("endpoint", "https://new.example.com/v1")
  202. form.Set("defaultModel", "m2")
  203. form.Set("currency", "EUR")
  204. req := httptest.NewRequest("POST", "/system/aimodel/config", strings.NewReader(form.Encode()))
  205. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  206. g.HandleAIModelConfig(httptest.NewRecorder(), req)
  207. cfg := g.getAIModelConfig()
  208. if cfg.APIKey != "sk-supersecret9999" {
  209. t.Errorf("API key should have been retained, got %q", cfg.APIKey)
  210. }
  211. if cfg.Endpoint != "https://new.example.com/v1" || cfg.DefaultModel != "m2" || cfg.Currency != "EUR" {
  212. t.Errorf("config not updated correctly: %+v", cfg)
  213. }
  214. }
  215. // ─── JS object exposure ─────────────────────────────────────────────────────────
  216. func TestInjectAIModelLib_JSObjectExposed(t *testing.T) {
  217. g := minimalGateway()
  218. vm := otto.New()
  219. payload := &static.AgiLibInjectionPayload{VM: vm, User: &user.User{Username: "alice"}}
  220. g.injectAIModelFunctions(payload)
  221. for _, method := range []string{"chat", "chatWithFile", "request", "usage", "models"} {
  222. val, err := vm.Run(`typeof aimodel.` + method)
  223. if err != nil {
  224. t.Fatalf("evaluating aimodel.%s: %v", method, err)
  225. }
  226. s, _ := val.ToString()
  227. if s != "function" {
  228. t.Errorf("aimodel.%s should be a function, got %q", method, s)
  229. }
  230. }
  231. }