oauth2_handler_test.go 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. package oauth2
  2. import (
  3. "context"
  4. "encoding/json"
  5. "net/http"
  6. "net/http/httptest"
  7. "net/url"
  8. "os"
  9. "strings"
  10. "testing"
  11. db "imuslab.com/arozos/mod/database"
  12. syncdb "imuslab.com/arozos/mod/auth/oauth2/syncdb"
  13. )
  14. // ── Test infrastructure ───────────────────────────────────────────────────────
  15. func newTestDB(t *testing.T) (*db.Database, func()) {
  16. t.Helper()
  17. dir, err := os.MkdirTemp("", "arozos-oauth-test-*")
  18. if err != nil {
  19. t.Fatalf("MkdirTemp: %v", err)
  20. }
  21. database, err := db.NewDatabase(dir+"/test.db", false)
  22. if err != nil {
  23. os.RemoveAll(dir)
  24. t.Fatalf("NewDatabase: %v", err)
  25. }
  26. return database, func() { os.RemoveAll(dir) }
  27. }
  28. // minimalOauthHandler returns a handler with only a live database; ag and reg
  29. // are nil because the config/discover handlers under test never touch them.
  30. func minimalOauthHandler(coredb *db.Database) *OauthHandler {
  31. _ = coredb.NewTable("oauth") // ignore "already exists"
  32. return &OauthHandler{coredb: coredb}
  33. }
  34. func postForm(t *testing.T, h http.HandlerFunc, values url.Values) *httptest.ResponseRecorder {
  35. t.Helper()
  36. req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(values.Encode()))
  37. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  38. w := httptest.NewRecorder()
  39. h(w, req)
  40. return w
  41. }
  42. func getReq(t *testing.T, h http.HandlerFunc) *httptest.ResponseRecorder {
  43. t.Helper()
  44. req := httptest.NewRequest(http.MethodGet, "/", nil)
  45. w := httptest.NewRecorder()
  46. h(w, req)
  47. return w
  48. }
  49. func getReqWithParams(t *testing.T, h http.HandlerFunc, params url.Values) *httptest.ResponseRecorder {
  50. t.Helper()
  51. req := httptest.NewRequest(http.MethodGet, "/?"+params.Encode(), nil)
  52. w := httptest.NewRecorder()
  53. h(w, req)
  54. return w
  55. }
  56. // ── ReadConfig ────────────────────────────────────────────────────────────────
  57. func TestReadConfig_DefaultsToDisabled(t *testing.T) {
  58. coredb, cleanup := newTestDB(t)
  59. defer cleanup()
  60. oh := minimalOauthHandler(coredb)
  61. w := getReq(t, oh.ReadConfig)
  62. if w.Code != http.StatusOK {
  63. t.Fatalf("ReadConfig returned %d, want 200", w.Code)
  64. }
  65. var cfg Config
  66. if err := json.Unmarshal(w.Body.Bytes(), &cfg); err != nil {
  67. t.Fatalf("response is not valid JSON: %v; body: %s", err, w.Body)
  68. }
  69. if cfg.Enabled {
  70. t.Error("expected Enabled=false for fresh DB")
  71. }
  72. }
  73. func TestReadConfig_AllFieldsRoundTrip(t *testing.T) {
  74. coredb, cleanup := newTestDB(t)
  75. defer cleanup()
  76. oh := minimalOauthHandler(coredb)
  77. // Seed values
  78. coredb.Write("oauth", "issuerurl", "https://idp.example.com")
  79. coredb.Write("oauth", "authendpoint", "https://idp.example.com/auth")
  80. coredb.Write("oauth", "tokenendpoint", "https://idp.example.com/token")
  81. coredb.Write("oauth", "userinfoendpoint", "https://idp.example.com/userinfo")
  82. coredb.Write("oauth", "usernamefield", "preferred_username")
  83. coredb.Write("oauth", "scope", "openid email")
  84. w := getReq(t, oh.ReadConfig)
  85. var cfg Config
  86. if err := json.Unmarshal(w.Body.Bytes(), &cfg); err != nil {
  87. t.Fatalf("JSON parse: %v", err)
  88. }
  89. checks := []struct{ f, got, want string }{
  90. {"IssuerURL", cfg.IssuerURL, "https://idp.example.com"},
  91. {"AuthEndpoint", cfg.AuthEndpoint, "https://idp.example.com/auth"},
  92. {"TokenEndpoint", cfg.TokenEndpoint, "https://idp.example.com/token"},
  93. {"UserInfoEndpoint", cfg.UserInfoEndpoint, "https://idp.example.com/userinfo"},
  94. {"UsernameField", cfg.UsernameField, "preferred_username"},
  95. {"Scope", cfg.Scope, "openid email"},
  96. }
  97. for _, c := range checks {
  98. if c.got != c.want {
  99. t.Errorf("%s: got %q, want %q", c.f, c.got, c.want)
  100. }
  101. }
  102. }
  103. // ── WriteConfig ───────────────────────────────────────────────────────────────
  104. func TestWriteConfig_MissingEnabledField(t *testing.T) {
  105. coredb, cleanup := newTestDB(t)
  106. defer cleanup()
  107. oh := minimalOauthHandler(coredb)
  108. w := postForm(t, oh.WriteConfig, url.Values{"clientid": {"x"}})
  109. if !strings.Contains(w.Body.String(), "error") {
  110. t.Errorf("expected error without enabled field, got %q", w.Body)
  111. }
  112. }
  113. func TestWriteConfig_DisabledAllowsEmptyFields(t *testing.T) {
  114. coredb, cleanup := newTestDB(t)
  115. defer cleanup()
  116. oh := minimalOauthHandler(coredb)
  117. w := postForm(t, oh.WriteConfig, url.Values{
  118. "enabled": {"false"}, "autoredirect": {"false"},
  119. })
  120. if strings.Contains(w.Body.String(), `"error"`) {
  121. t.Errorf("unexpected error when disabling: %q", w.Body)
  122. }
  123. }
  124. func TestWriteConfig_EnabledRequiresCredentials(t *testing.T) {
  125. coredb, cleanup := newTestDB(t)
  126. defer cleanup()
  127. oh := minimalOauthHandler(coredb)
  128. // enabled=true but clientid missing
  129. w := postForm(t, oh.WriteConfig, url.Values{
  130. "enabled": {"true"},
  131. "autoredirect": {"false"},
  132. "clientsecret": {"s"},
  133. "redirecturl": {"https://aroz.example.com"},
  134. "authendpoint": {"https://idp/auth"},
  135. "tokenendpoint": {"https://idp/token"},
  136. "userinfoendpoint": {"https://idp/userinfo"},
  137. })
  138. if !strings.Contains(w.Body.String(), "error") {
  139. t.Errorf("expected error for missing clientid: %q", w.Body)
  140. }
  141. }
  142. func TestWriteConfig_EnabledRequiresEndpoints(t *testing.T) {
  143. coredb, cleanup := newTestDB(t)
  144. defer cleanup()
  145. oh := minimalOauthHandler(coredb)
  146. // enabled=true but endpoints missing
  147. w := postForm(t, oh.WriteConfig, url.Values{
  148. "enabled": {"true"},
  149. "autoredirect": {"false"},
  150. "clientid": {"id"},
  151. "clientsecret": {"s"},
  152. "redirecturl": {"https://aroz.example.com"},
  153. // authendpoint / tokenendpoint / userinfoendpoint all missing
  154. })
  155. if !strings.Contains(w.Body.String(), "error") {
  156. t.Errorf("expected error for missing endpoints: %q", w.Body)
  157. }
  158. }
  159. func TestWriteConfig_FullRoundTrip(t *testing.T) {
  160. coredb, cleanup := newTestDB(t)
  161. defer cleanup()
  162. oh := minimalOauthHandler(coredb)
  163. in := url.Values{
  164. "enabled": {"false"},
  165. "autoredirect": {"false"},
  166. "issuerurl": {"https://idp.example.com"},
  167. "clientid": {"client-abc"},
  168. "clientsecret": {"secret-xyz"},
  169. "redirecturl": {"https://aroz.example.com"},
  170. "scope": {"openid email profile"},
  171. "usernamefield": {"preferred_username"},
  172. "authendpoint": {"https://idp.example.com/auth"},
  173. "tokenendpoint": {"https://idp.example.com/token"},
  174. "userinfoendpoint": {"https://idp.example.com/userinfo"},
  175. }
  176. wWrite := postForm(t, oh.WriteConfig, in)
  177. if strings.Contains(wWrite.Body.String(), `"error"`) {
  178. t.Fatalf("WriteConfig error: %s", wWrite.Body)
  179. }
  180. wRead := getReq(t, oh.ReadConfig)
  181. var cfg Config
  182. if err := json.Unmarshal(wRead.Body.Bytes(), &cfg); err != nil {
  183. t.Fatalf("ReadConfig JSON parse: %v", err)
  184. }
  185. checks := []struct{ f, got, want string }{
  186. {"IssuerURL", cfg.IssuerURL, "https://idp.example.com"},
  187. {"ClientID", cfg.ClientID, "client-abc"},
  188. {"ClientSecret", cfg.ClientSecret, "secret-xyz"},
  189. {"RedirectURL", cfg.RedirectURL, "https://aroz.example.com"},
  190. {"Scope", cfg.Scope, "openid email profile"},
  191. {"UsernameField", cfg.UsernameField, "preferred_username"},
  192. {"AuthEndpoint", cfg.AuthEndpoint, "https://idp.example.com/auth"},
  193. {"TokenEndpoint", cfg.TokenEndpoint, "https://idp.example.com/token"},
  194. {"UserInfoEndpoint", cfg.UserInfoEndpoint, "https://idp.example.com/userinfo"},
  195. }
  196. for _, c := range checks {
  197. if c.got != c.want {
  198. t.Errorf("%s: got %q, want %q", c.f, c.got, c.want)
  199. }
  200. }
  201. if cfg.Enabled {
  202. t.Error("Enabled: got true, want false")
  203. }
  204. }
  205. func TestWriteConfig_OverwritesPreviousValues(t *testing.T) {
  206. coredb, cleanup := newTestDB(t)
  207. defer cleanup()
  208. oh := minimalOauthHandler(coredb)
  209. postForm(t, oh.WriteConfig, url.Values{
  210. "enabled": {"false"}, "autoredirect": {"false"},
  211. "clientid": {"old-id"},
  212. })
  213. postForm(t, oh.WriteConfig, url.Values{
  214. "enabled": {"false"}, "autoredirect": {"false"},
  215. "clientid": {"new-id"},
  216. })
  217. wRead := getReq(t, oh.ReadConfig)
  218. var cfg Config
  219. json.Unmarshal(wRead.Body.Bytes(), &cfg) //nolint:errcheck
  220. if cfg.ClientID != "new-id" {
  221. t.Errorf("ClientID: got %q, want %q", cfg.ClientID, "new-id")
  222. }
  223. }
  224. // ── HandleDiscover ────────────────────────────────────────────────────────────
  225. func TestHandleDiscover_Success(t *testing.T) {
  226. // Set up a mock OIDC provider. Declare first so the handler closure can
  227. // reference providerSrv.URL by the time it is actually invoked.
  228. var providerSrv *httptest.Server
  229. providerSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  230. w.Header().Set("Content-Type", "application/json")
  231. w.Write(minimalDiscoveryDoc(providerSrv.URL))
  232. }))
  233. defer providerSrv.Close()
  234. defer withMockClient(providerSrv)()
  235. coredb, cleanup := newTestDB(t)
  236. defer cleanup()
  237. oh := minimalOauthHandler(coredb)
  238. w := getReqWithParams(t, oh.HandleDiscover, url.Values{"issuerurl": {providerSrv.URL}})
  239. if w.Code != http.StatusOK {
  240. t.Fatalf("HandleDiscover returned %d; body: %s", w.Code, w.Body)
  241. }
  242. var result DiscoveryResult
  243. if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
  244. t.Fatalf("response is not valid JSON: %v; body: %s", err, w.Body)
  245. }
  246. if result.AuthEndpoint == "" {
  247. t.Error("AuthEndpoint is empty in discovery result")
  248. }
  249. if result.TokenEndpoint == "" {
  250. t.Error("TokenEndpoint is empty in discovery result")
  251. }
  252. if result.UserInfoEndpoint == "" {
  253. t.Error("UserInfoEndpoint is empty in discovery result")
  254. }
  255. if len(result.ScopesSupported) == 0 {
  256. t.Error("ScopesSupported is empty in discovery result")
  257. }
  258. }
  259. func TestHandleDiscover_MissingIssuerURL(t *testing.T) {
  260. coredb, cleanup := newTestDB(t)
  261. defer cleanup()
  262. oh := minimalOauthHandler(coredb)
  263. w := getReq(t, oh.HandleDiscover)
  264. if w.Code != http.StatusOK {
  265. t.Fatalf("unexpected status %d", w.Code)
  266. }
  267. if !strings.Contains(w.Body.String(), "error") {
  268. t.Errorf("expected error for missing issuerurl, got %q", w.Body)
  269. }
  270. }
  271. func TestHandleDiscover_ProviderReturns404(t *testing.T) {
  272. providerSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  273. http.NotFound(w, r)
  274. }))
  275. defer providerSrv.Close()
  276. defer withMockClient(providerSrv)()
  277. coredb, cleanup := newTestDB(t)
  278. defer cleanup()
  279. oh := minimalOauthHandler(coredb)
  280. w := getReqWithParams(t, oh.HandleDiscover, url.Values{"issuerurl": {providerSrv.URL}})
  281. if !strings.Contains(w.Body.String(), "error") {
  282. t.Errorf("expected error for 404 provider, got %q", w.Body)
  283. }
  284. }
  285. func TestHandleDiscover_ScopesSuggested(t *testing.T) {
  286. var providerSrv *httptest.Server
  287. providerSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  288. w.Header().Set("Content-Type", "application/json")
  289. w.Write(minimalDiscoveryDoc(providerSrv.URL))
  290. }))
  291. defer providerSrv.Close()
  292. defer withMockClient(providerSrv)()
  293. coredb, cleanup := newTestDB(t)
  294. defer cleanup()
  295. oh := minimalOauthHandler(coredb)
  296. w := getReqWithParams(t, oh.HandleDiscover, url.Values{"issuerurl": {providerSrv.URL}})
  297. var result DiscoveryResult
  298. json.Unmarshal(w.Body.Bytes(), &result) //nolint:errcheck
  299. if len(result.ScopesSupported) == 0 {
  300. t.Error("ScopesSupported should not be empty after discovery")
  301. }
  302. }
  303. func TestHandleDiscover_ClaimsReturned(t *testing.T) {
  304. var providerSrv *httptest.Server
  305. providerSrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  306. w.Header().Set("Content-Type", "application/json")
  307. w.Write(minimalDiscoveryDoc(providerSrv.URL))
  308. }))
  309. defer providerSrv.Close()
  310. defer withMockClient(providerSrv)()
  311. coredb, cleanup := newTestDB(t)
  312. defer cleanup()
  313. oh := minimalOauthHandler(coredb)
  314. w := getReqWithParams(t, oh.HandleDiscover, url.Values{"issuerurl": {providerSrv.URL}})
  315. var result DiscoveryResult
  316. json.Unmarshal(w.Body.Bytes(), &result) //nolint:errcheck
  317. if len(result.ClaimsSupported) == 0 {
  318. t.Error("ClaimsSupported should not be empty after discovery")
  319. }
  320. }
  321. // ── CheckOAuth ────────────────────────────────────────────────────────────────
  322. func TestCheckOAuth_DisabledByDefault(t *testing.T) {
  323. coredb, cleanup := newTestDB(t)
  324. defer cleanup()
  325. oh := minimalOauthHandler(coredb)
  326. w := getReq(t, oh.CheckOAuth)
  327. var result struct {
  328. Enabled bool `json:"enabled"`
  329. AutoRedirect bool `json:"auto_redirect"`
  330. }
  331. if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
  332. t.Fatalf("JSON parse: %v", err)
  333. }
  334. if result.Enabled {
  335. t.Error("expected Enabled=false by default")
  336. }
  337. }
  338. func TestCheckOAuth_ReflectsStoredValues(t *testing.T) {
  339. coredb, cleanup := newTestDB(t)
  340. defer cleanup()
  341. oh := minimalOauthHandler(coredb)
  342. coredb.Write("oauth", "enabled", "true")
  343. coredb.Write("oauth", "autoredirect", "true")
  344. w := getReq(t, oh.CheckOAuth)
  345. var result struct {
  346. Enabled bool `json:"enabled"`
  347. AutoRedirect bool `json:"auto_redirect"`
  348. }
  349. json.Unmarshal(w.Body.Bytes(), &result) //nolint:errcheck
  350. if !result.Enabled {
  351. t.Error("expected Enabled=true")
  352. }
  353. if !result.AutoRedirect {
  354. t.Error("expected AutoRedirect=true")
  355. }
  356. }
  357. // ── HandleLogin guards ────────────────────────────────────────────────────────
  358. func TestHandleLogin_DisabledReturnsText(t *testing.T) {
  359. coredb, cleanup := newTestDB(t)
  360. defer cleanup()
  361. oh := minimalOauthHandler(coredb)
  362. // "enabled" not set → disabled
  363. req := httptest.NewRequest(http.MethodGet, "/", nil)
  364. w := httptest.NewRecorder()
  365. oh.HandleLogin(w, req)
  366. body := w.Body.String()
  367. if !strings.Contains(strings.ToLower(body), "disabled") {
  368. t.Errorf("expected 'disabled' in response, got %q", body)
  369. }
  370. }
  371. func TestHandleLogin_MisconfiguredNoEndpoints(t *testing.T) {
  372. coredb, cleanup := newTestDB(t)
  373. defer cleanup()
  374. oh := minimalOauthHandler(coredb)
  375. coredb.Write("oauth", "enabled", "true")
  376. // no authendpoint / tokenendpoint / clientid
  377. req := httptest.NewRequest(http.MethodGet, "/", nil)
  378. w := httptest.NewRecorder()
  379. oh.HandleLogin(w, req)
  380. body := w.Body.String()
  381. if strings.Contains(body, "302") || w.Code == http.StatusTemporaryRedirect {
  382. t.Errorf("should not redirect when misconfigured; got code %d, body %q", w.Code, body)
  383. }
  384. }
  385. // ── HandleAuthorize guards ────────────────────────────────────────────────────
  386. func TestHandleAuthorize_DisabledReturnsText(t *testing.T) {
  387. coredb, cleanup := newTestDB(t)
  388. defer cleanup()
  389. oh := minimalOauthHandler(coredb)
  390. req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("state=x&code=y"))
  391. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  392. w := httptest.NewRecorder()
  393. oh.HandleAuthorize(w, req)
  394. if !strings.Contains(strings.ToLower(w.Body.String()), "disabled") {
  395. t.Errorf("expected disabled message, got %q", w.Body)
  396. }
  397. }
  398. func TestHandleAuthorize_MissingCookie(t *testing.T) {
  399. coredb, cleanup := newTestDB(t)
  400. defer cleanup()
  401. oh := minimalOauthHandler(coredb)
  402. coredb.Write("oauth", "enabled", "true")
  403. req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("state=x&code=y"))
  404. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  405. w := httptest.NewRecorder()
  406. oh.HandleAuthorize(w, req)
  407. if !strings.Contains(w.Body.String(), "Invalid redirect URI") {
  408. t.Errorf("expected 'Invalid redirect URI', got %q", w.Body)
  409. }
  410. }
  411. func TestHandleAuthorize_StateMismatch(t *testing.T) {
  412. coredb, cleanup := newTestDB(t)
  413. defer cleanup()
  414. oh := minimalOauthHandler(coredb)
  415. coredb.Write("oauth", "enabled", "true")
  416. oh.syncDb = syncdb.NewSyncDB()
  417. uuid := oh.syncDb.Store("/")
  418. req := httptest.NewRequest(http.MethodPost, "/",
  419. strings.NewReader("state=WRONG_STATE&code=x"))
  420. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  421. req.AddCookie(&http.Cookie{Name: "uuid_login", Value: uuid})
  422. w := httptest.NewRecorder()
  423. oh.HandleAuthorize(w, req)
  424. if !strings.Contains(w.Body.String(), "Invalid oauth state") {
  425. t.Errorf("expected 'Invalid oauth state', got %q", w.Body)
  426. }
  427. }
  428. // ── exchangeCodeForUsername (connectivity) ────────────────────────────────────
  429. // buildMockOIDCStack creates:
  430. // - a mock token endpoint server that accepts any code and returns accessToken
  431. // - a mock userinfo server that verifies the Bearer token and returns claims
  432. //
  433. // Both servers are plain HTTP so the default transport can reach them.
  434. // The package-level httpClient is replaced for the userinfo call and is
  435. // restored by the returned closeFn.
  436. func buildMockOIDCStack(
  437. t *testing.T,
  438. accessToken string,
  439. claims map[string]interface{},
  440. ) (tokenURL, userinfoURL string, closeFn func()) {
  441. t.Helper()
  442. tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  443. w.Header().Set("Content-Type", "application/json")
  444. json.NewEncoder(w).Encode(map[string]interface{}{
  445. "access_token": accessToken,
  446. "token_type": "Bearer",
  447. "expires_in": 3600,
  448. })
  449. }))
  450. userinfoSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  451. auth := r.Header.Get("Authorization")
  452. if auth != "Bearer "+accessToken {
  453. w.WriteHeader(http.StatusUnauthorized)
  454. return
  455. }
  456. w.Header().Set("Content-Type", "application/json")
  457. json.NewEncoder(w).Encode(claims)
  458. }))
  459. // Both test servers are plain HTTP; a standard http.Client can reach both.
  460. // We replace httpClient so getUserInfoFromEndpoint uses the same plain transport.
  461. origClient := httpClient
  462. httpClient = &http.Client{}
  463. closeFn = func() {
  464. tokenSrv.Close()
  465. userinfoSrv.Close()
  466. httpClient = origClient
  467. }
  468. return tokenSrv.URL, userinfoSrv.URL, closeFn
  469. }
  470. // TestExchangeCodeForUsername_Success runs the token exchange → userinfo fetch
  471. // pipeline against real mock HTTP servers.
  472. func TestExchangeCodeForUsername_Success(t *testing.T) {
  473. const fakeToken = "exchange-tok-abc123"
  474. tokenURL, userinfoURL, closeFn := buildMockOIDCStack(t, fakeToken, map[string]interface{}{
  475. "sub": "uid-999",
  476. "email": "testuser@example.com",
  477. })
  478. defer closeFn()
  479. coredb, cleanup := newTestDB(t)
  480. defer cleanup()
  481. oh := minimalOauthHandler(coredb)
  482. coredb.Write("oauth", "authendpoint", "https://example.com/auth") // not called
  483. coredb.Write("oauth", "tokenendpoint", tokenURL)
  484. coredb.Write("oauth", "userinfoendpoint", userinfoURL)
  485. coredb.Write("oauth", "clientid", "test-client")
  486. coredb.Write("oauth", "clientsecret", "test-secret")
  487. coredb.Write("oauth", "redirecturl", "https://aroz.example.com")
  488. coredb.Write("oauth", "usernamefield", "email")
  489. username, err := oh.exchangeCodeForUsername(context.Background(), "some-auth-code")
  490. if err != nil {
  491. t.Fatalf("exchangeCodeForUsername returned error: %v", err)
  492. }
  493. if username != "testuser@example.com" {
  494. t.Errorf("username: got %q, want %q", username, "testuser@example.com")
  495. }
  496. }
  497. func TestExchangeCodeForUsername_PreferredUsername(t *testing.T) {
  498. const fakeToken = "pref-tok"
  499. tokenURL, userinfoURL, closeFn := buildMockOIDCStack(t, fakeToken, map[string]interface{}{
  500. "sub": "uid-123",
  501. "preferred_username": "alice",
  502. "email": "alice@corp.example",
  503. })
  504. defer closeFn()
  505. coredb, cleanup := newTestDB(t)
  506. defer cleanup()
  507. oh := minimalOauthHandler(coredb)
  508. coredb.Write("oauth", "authendpoint", "https://x/auth")
  509. coredb.Write("oauth", "tokenendpoint", tokenURL)
  510. coredb.Write("oauth", "userinfoendpoint", userinfoURL)
  511. coredb.Write("oauth", "clientid", "cid")
  512. coredb.Write("oauth", "clientsecret", "cs")
  513. coredb.Write("oauth", "redirecturl", "https://aroz.example.com")
  514. coredb.Write("oauth", "usernamefield", "preferred_username")
  515. username, err := oh.exchangeCodeForUsername(context.Background(), "code")
  516. if err != nil {
  517. t.Fatalf("unexpected error: %v", err)
  518. }
  519. if username != "alice" {
  520. t.Errorf("username: got %q, want %q", username, "alice")
  521. }
  522. }
  523. func TestExchangeCodeForUsername_TokenEndpointError(t *testing.T) {
  524. // Token server that always returns 400.
  525. tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  526. w.WriteHeader(http.StatusBadRequest)
  527. w.Write([]byte(`{"error":"invalid_grant"}`))
  528. }))
  529. defer tokenSrv.Close()
  530. coredb, cleanup := newTestDB(t)
  531. defer cleanup()
  532. oh := minimalOauthHandler(coredb)
  533. coredb.Write("oauth", "authendpoint", "https://x/auth")
  534. coredb.Write("oauth", "tokenendpoint", tokenSrv.URL)
  535. coredb.Write("oauth", "userinfoendpoint", "https://x/userinfo")
  536. coredb.Write("oauth", "clientid", "cid")
  537. coredb.Write("oauth", "clientsecret", "cs")
  538. coredb.Write("oauth", "redirecturl", "https://aroz.example.com")
  539. _, err := oh.exchangeCodeForUsername(context.Background(), "bad-code")
  540. if err == nil {
  541. t.Fatal("expected error from failing token endpoint, got nil")
  542. }
  543. if !strings.Contains(err.Error(), "token exchange failed") {
  544. t.Errorf("expected 'token exchange failed' in error, got: %v", err)
  545. }
  546. }
  547. func TestExchangeCodeForUsername_UserInfoError(t *testing.T) {
  548. const fakeToken = "good-tok"
  549. // Token server succeeds; userinfo server fails.
  550. tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  551. w.Header().Set("Content-Type", "application/json")
  552. json.NewEncoder(w).Encode(map[string]interface{}{
  553. "access_token": fakeToken, "token_type": "Bearer", "expires_in": 3600,
  554. })
  555. }))
  556. defer tokenSrv.Close()
  557. userinfoSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  558. w.WriteHeader(http.StatusInternalServerError)
  559. }))
  560. defer userinfoSrv.Close()
  561. // Replace httpClient so getUserInfoFromEndpoint uses the same plain transport.
  562. origClient := httpClient
  563. httpClient = &http.Client{}
  564. defer func() { httpClient = origClient }()
  565. coredb, cleanup := newTestDB(t)
  566. defer cleanup()
  567. oh := minimalOauthHandler(coredb)
  568. coredb.Write("oauth", "authendpoint", "https://x/auth")
  569. coredb.Write("oauth", "tokenendpoint", tokenSrv.URL)
  570. coredb.Write("oauth", "userinfoendpoint", userinfoSrv.URL)
  571. coredb.Write("oauth", "clientid", "cid")
  572. coredb.Write("oauth", "clientsecret", "cs")
  573. coredb.Write("oauth", "redirecturl", "https://aroz.example.com")
  574. coredb.Write("oauth", "usernamefield", "email")
  575. _, err := oh.exchangeCodeForUsername(context.Background(), "code")
  576. if err == nil {
  577. t.Fatal("expected error from failing userinfo endpoint, got nil")
  578. }
  579. }
  580. func TestExchangeCodeForUsername_MisconfiguredNoEndpoints(t *testing.T) {
  581. coredb, cleanup := newTestDB(t)
  582. defer cleanup()
  583. oh := minimalOauthHandler(coredb)
  584. // No endpoints configured
  585. _, err := oh.exchangeCodeForUsername(context.Background(), "code")
  586. if err == nil {
  587. t.Fatal("expected error for unconfigured handler, got nil")
  588. }
  589. }
  590. // ── buildOAuthConfig ─────────────────────────────────────────────────────────
  591. func TestBuildOAuthConfig_NilWhenMissing(t *testing.T) {
  592. coredb, cleanup := newTestDB(t)
  593. defer cleanup()
  594. oh := minimalOauthHandler(coredb)
  595. if oh.buildOAuthConfig() != nil {
  596. t.Error("expected nil config when no endpoints are set")
  597. }
  598. }
  599. func TestBuildOAuthConfig_ScopeDefaults(t *testing.T) {
  600. coredb, cleanup := newTestDB(t)
  601. defer cleanup()
  602. oh := minimalOauthHandler(coredb)
  603. coredb.Write("oauth", "authendpoint", "https://x/auth")
  604. coredb.Write("oauth", "tokenendpoint", "https://x/token")
  605. coredb.Write("oauth", "clientid", "cid")
  606. // scope intentionally not set
  607. cfg := oh.buildOAuthConfig()
  608. if cfg == nil {
  609. t.Fatal("buildOAuthConfig returned nil")
  610. }
  611. if len(cfg.Scopes) == 0 {
  612. t.Fatal("Scopes should not be empty when scope is not set (should use default)")
  613. }
  614. defaultScopes := strings.Join(cfg.Scopes, " ")
  615. if !strings.Contains(defaultScopes, "openid") {
  616. t.Errorf("default scope should contain 'openid', got: %q", defaultScopes)
  617. }
  618. }
  619. func TestBuildOAuthConfig_ScopeFromDB(t *testing.T) {
  620. coredb, cleanup := newTestDB(t)
  621. defer cleanup()
  622. oh := minimalOauthHandler(coredb)
  623. coredb.Write("oauth", "authendpoint", "https://x/auth")
  624. coredb.Write("oauth", "tokenendpoint", "https://x/token")
  625. coredb.Write("oauth", "clientid", "cid")
  626. coredb.Write("oauth", "scope", "openid email custom-scope")
  627. cfg := oh.buildOAuthConfig()
  628. if cfg == nil {
  629. t.Fatal("buildOAuthConfig returned nil")
  630. }
  631. if len(cfg.Scopes) != 3 {
  632. t.Errorf("expected 3 scopes, got %d: %v", len(cfg.Scopes), cfg.Scopes)
  633. }
  634. }
  635. func TestBuildOAuthConfig_CallbackURL(t *testing.T) {
  636. coredb, cleanup := newTestDB(t)
  637. defer cleanup()
  638. oh := minimalOauthHandler(coredb)
  639. coredb.Write("oauth", "authendpoint", "https://x/auth")
  640. coredb.Write("oauth", "tokenendpoint", "https://x/token")
  641. coredb.Write("oauth", "clientid", "cid")
  642. coredb.Write("oauth", "redirecturl", "https://aroz.my.domain")
  643. cfg := oh.buildOAuthConfig()
  644. if cfg == nil {
  645. t.Fatal("buildOAuthConfig returned nil")
  646. }
  647. if !strings.HasSuffix(cfg.RedirectURL, "/system/auth/oauth/authorize") {
  648. t.Errorf("RedirectURL should end with /system/auth/oauth/authorize, got: %q", cfg.RedirectURL)
  649. }
  650. if !strings.HasPrefix(cfg.RedirectURL, "https://aroz.my.domain") {
  651. t.Errorf("RedirectURL should start with stored base URL, got: %q", cfg.RedirectURL)
  652. }
  653. }