package main import ( "bytes" "encoding/json" "html/template" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "golang.org/x/crypto/bcrypt" "gorm.io/datatypes" "gorm.io/driver/sqlite" "gorm.io/gorm" "oidc-oauth2-server/handlers" "oidc-oauth2-server/middleware" "oidc-oauth2-server/models" "oidc-oauth2-server/services" ) func setupTestRouter() (*gin.Engine, *gorm.DB, error) { // 设置测试模式 gin.SetMode(gin.TestMode) // 初始化测试数据库 db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) if err != nil { return nil, nil, err } // 运行数据库迁移 if err := models.AutoMigrate(db); err != nil { return nil, nil, err } // 初始化服务 authService := services.NewAuthService(db) oauthService, err := services.NewOAuthService(db) if err != nil { return nil, nil, err } clientService := services.NewClientService(db) tokenService := services.NewTokenService(db, oauthService.GetKeyManager()) // 设置路由 r := gin.New() r.Use(gin.Recovery()) // 设置模板 r.SetFuncMap(template.FuncMap{ "subtract": func(a, b int) int { return a - b }, "add": func(a, b int) int { return a + b }, }) r.LoadHTMLGlob("templates/*") // 设置会话存储 store := cookie.NewStore([]byte("test_secret")) r.Use(sessions.Sessions("oidc_session", store)) // 设置路由处理器 r.GET("/health", func(c *gin.Context) { c.JSON(200, gin.H{"status": "ok"}) }) // 设置认证相关路由 authHandler := handlers.NewAuthHandler(authService) r.GET("/login", authHandler.ShowLogin) r.POST("/login", authHandler.HandleLogin) r.GET("/signup", authHandler.ShowSignup) r.POST("/signup", authHandler.HandleSignup) // 设置 OIDC 相关路由 oidcHandler := handlers.NewOIDCHandler("http://localhost:8080", oauthService, authService) r.GET("/.well-known/openid-configuration", oidcHandler.OpenIDConfiguration) r.GET("/authorize", oidcHandler.Authorize) r.POST("/token", oidcHandler.Token) r.GET("/userinfo", middleware.BearerAuth(oauthService.GetKeyManager(), db), oidcHandler.Userinfo) r.GET("/jwks", oidcHandler.JWKS) // 设置客户端注册路由 registrationHandler := handlers.NewRegistrationHandler(clientService) r.POST("/register", registrationHandler.Register) // 设置令牌管理路由 tokenHandler := handlers.NewTokenHandler(tokenService) r.POST("/revoke", tokenHandler.Revoke) r.POST("/introspect", tokenHandler.Introspect) return r, db, nil } func TestHealthCheck(t *testing.T) { r, _, err := setupTestRouter() if err != nil { t.Fatalf("Failed to setup test router: %v", err) } w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/health", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) var response map[string]string err = json.Unmarshal(w.Body.Bytes(), &response) assert.NoError(t, err) assert.Equal(t, "ok", response["status"]) } func TestLoginPage(t *testing.T) { r, _, err := setupTestRouter() if err != nil { t.Fatalf("Failed to setup test router: %v", err) } w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/login", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) } func TestLoginAuthentication(t *testing.T) { r, db, err := setupTestRouter() if err != nil { t.Fatalf("Failed to setup test router: %v", err) } // 创建测试用户 password := "testpass" hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) assert.NoError(t, err) testUser := &models.User{ Username: "testuser", Password: string(hashedPassword), Email: "test@example.com", } result := db.Create(testUser) assert.NoError(t, result.Error) // 测试登录请求 - 使用表单数据 form := url.Values{} form.Add("username", "testuser") form.Add("password", "wrongpass") form.Add("client_id", "test_client") form.Add("response_type", "code") form.Add("scope", "openid profile") form.Add("redirect_uri", "http://localhost:8080/callback") form.Add("state", "test_state") w := httptest.NewRecorder() req, _ := http.NewRequest("POST", "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.ServeHTTP(w, req) // 检查是否返回错误页面 assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "用户名或密码错误") // 测试成功登录 form.Set("password", password) // 使用正确的密码 w = httptest.NewRecorder() req, _ = http.NewRequest("POST", "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.ServeHTTP(w, req) // 检查重定向 assert.Equal(t, 302, w.Code) location := w.Header().Get("Location") assert.Contains(t, location, "/authorize") } func TestOIDCConfiguration(t *testing.T) { r, _, err := setupTestRouter() if err != nil { t.Fatalf("Failed to setup test router: %v", err) } w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/.well-known/openid-configuration", nil) r.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) var response map[string]interface{} err = json.Unmarshal(w.Body.Bytes(), &response) assert.NoError(t, err) assert.Contains(t, response, "issuer") } func TestUserRegistration(t *testing.T) { r, db, err := setupTestRouter() if err != nil { t.Fatalf("Failed to setup test router: %v", err) } // 测试注册请求 form := url.Values{} form.Add("username", "newuser") form.Add("password", "newpass") form.Add("email", "newuser@example.com") w := httptest.NewRecorder() req, _ := http.NewRequest("POST", "/signup", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.ServeHTTP(w, req) // 检查重定向到登录页面 assert.Equal(t, 302, w.Code) assert.Contains(t, w.Header().Get("Location"), "/login") // 验证用户是否已创建 var user models.User err = db.Where("username = ?", "newuser").First(&user).Error assert.NoError(t, err) assert.Equal(t, "newuser@example.com", user.Email) } func TestTokenEndpoint(t *testing.T) { r, db, err := setupTestRouter() if err != nil { t.Fatalf("Failed to setup test router: %v", err) } // 创建测试客户端 redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"}) grantTypes, _ := json.Marshal([]string{"authorization_code"}) responseTypes, _ := json.Marshal([]string{"code"}) client := &models.Client{ ClientID: "test_client", ClientSecret: "test_secret", RedirectURIs: datatypes.JSON(redirectURIs), GrantTypes: datatypes.JSON(grantTypes), ResponseTypes: datatypes.JSON(responseTypes), TokenEndpointAuthMethod: "client_secret_post", } result := db.Create(client) assert.NoError(t, result.Error) // 创建测试用户 user := &models.User{ Username: "testuser", Email: "test@example.com", } result = db.Create(user) assert.NoError(t, result.Error) // 创建测试授权码 code := &models.AuthorizationCode{ Code: "test_code", ClientID: client.ClientID, UserID: user.ID, Scope: "openid profile email", RedirectURI: "http://localhost:8080/callback", ExpiresAt: time.Now().Add(10 * time.Minute), Used: false, } result = db.Create(code) assert.NoError(t, result.Error) // 测试令牌请求 form := url.Values{} form.Add("grant_type", "authorization_code") form.Add("code", code.Code) form.Add("client_id", client.ClientID) form.Add("client_secret", client.ClientSecret) form.Add("redirect_uri", code.RedirectURI) w := httptest.NewRecorder() req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.ServeHTTP(w, req) if w.Code != 200 { t.Logf("Response body: %s", w.Body.String()) } assert.Equal(t, 200, w.Code) var response map[string]interface{} err = json.Unmarshal(w.Body.Bytes(), &response) assert.NoError(t, err) assert.Contains(t, response, "access_token") assert.Contains(t, response, "id_token") // 使用获取到的访问令牌测试用户信息端点 accessToken := response["access_token"].(string) w = httptest.NewRecorder() req, _ = http.NewRequest("GET", "/userinfo", nil) req.Header.Set("Authorization", "Bearer "+accessToken) r.ServeHTTP(w, req) if w.Code != 200 { t.Logf("Userinfo response body: %s", w.Body.String()) } assert.Equal(t, 200, w.Code) var userinfoResponse map[string]interface{} err = json.Unmarshal(w.Body.Bytes(), &userinfoResponse) assert.NoError(t, err) assert.Equal(t, "testuser", userinfoResponse["sub"]) assert.Equal(t, "test@example.com", userinfoResponse["email"]) } func TestUserinfoEndpoint(t *testing.T) { r, db, err := setupTestRouter() if err != nil { t.Fatalf("Failed to setup test router: %v", err) } // 创建测试用户 user := &models.User{ Username: "testuser", Email: "test@example.com", } result := db.Create(user) assert.NoError(t, result.Error) // 创建测试客户端 redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"}) grantTypes, _ := json.Marshal([]string{"authorization_code"}) responseTypes, _ := json.Marshal([]string{"code"}) client := &models.Client{ ClientID: "test_client", ClientSecret: "test_secret", RedirectURIs: datatypes.JSON(redirectURIs), GrantTypes: datatypes.JSON(grantTypes), ResponseTypes: datatypes.JSON(responseTypes), TokenEndpointAuthMethod: "client_secret_post", } result = db.Create(client) assert.NoError(t, result.Error) // 创建测试授权码 code := &models.AuthorizationCode{ Code: "test_code", ClientID: client.ClientID, UserID: user.ID, Scope: "openid profile email", RedirectURI: "http://localhost:8080/callback", ExpiresAt: time.Now().Add(10 * time.Minute), Used: false, } result = db.Create(code) assert.NoError(t, result.Error) // 测试令牌请求 form := url.Values{} form.Add("grant_type", "authorization_code") form.Add("code", code.Code) form.Add("client_id", client.ClientID) form.Add("client_secret", client.ClientSecret) form.Add("redirect_uri", code.RedirectURI) w := httptest.NewRecorder() req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.ServeHTTP(w, req) if w.Code != 200 { t.Logf("Response body: %s", w.Body.String()) } assert.Equal(t, 200, w.Code) var tokenResponse map[string]interface{} err = json.Unmarshal(w.Body.Bytes(), &tokenResponse) assert.NoError(t, err) assert.Contains(t, tokenResponse, "access_token") // 测试用户信息请求 w = httptest.NewRecorder() req, _ = http.NewRequest("GET", "/userinfo", nil) req.Header.Set("Authorization", "Bearer "+tokenResponse["access_token"].(string)) r.ServeHTTP(w, req) if w.Code != 200 { t.Logf("Response body: %s", w.Body.String()) } assert.Equal(t, 200, w.Code) var response map[string]interface{} err = json.Unmarshal(w.Body.Bytes(), &response) assert.NoError(t, err) assert.Equal(t, "testuser", response["sub"]) assert.Equal(t, "test@example.com", response["email"]) } func TestClientRegistration(t *testing.T) { r, _, err := setupTestRouter() if err != nil { t.Fatalf("Failed to setup test router: %v", err) } // 测试客户端注册请求 registrationRequest := map[string]interface{}{ "client_name": "Test Client", "redirect_uris": []string{"http://localhost:8080/callback"}, "response_types": []string{"code"}, "grant_types": []string{"authorization_code"}, "token_endpoint_auth_method": "client_secret_basic", } jsonData, _ := json.Marshal(registrationRequest) w := httptest.NewRecorder() req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(jsonData)) req.Header.Set("Content-Type", "application/json") r.ServeHTTP(w, req) assert.Equal(t, 201, w.Code) var response map[string]interface{} err = json.Unmarshal(w.Body.Bytes(), &response) assert.NoError(t, err) assert.Contains(t, response, "client_id") assert.Contains(t, response, "client_secret") }