diff --git a/go.mod b/go.mod index 374dfd4..53452ac 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/gin-contrib/sessions v0.0.5 github.com/gin-gonic/gin v1.9.1 github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/stretchr/testify v1.10.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/datatypes v1.2.5 gorm.io/driver/sqlite v1.5.7 @@ -16,12 +17,14 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect gorm.io/driver/mysql v1.5.6 // indirect ) diff --git a/go.sum b/go.sum index e897cf2..a36d3f3 100644 --- a/go.sum +++ b/go.sum @@ -96,8 +96,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= diff --git a/handlers/oidc.go b/handlers/oidc.go index 9dccd11..d35f719 100644 --- a/handlers/oidc.go +++ b/handlers/oidc.go @@ -2,7 +2,6 @@ package handlers import ( "crypto/rsa" - "fmt" "net/http" "net/url" "strings" @@ -154,7 +153,7 @@ func (h *OIDCHandler) Userinfo(c *gin.Context) { // 准备返回的声明 claims := gin.H{ - "sub": fmt.Sprintf("%d", user.ID), + "sub": user.Username, } // 根据授权范围添加相应的声明 @@ -164,6 +163,7 @@ func (h *OIDCHandler) Userinfo(c *gin.Context) { claims["name"] = user.Username case "email": claims["email"] = user.Email + claims["email_verified"] = true } } diff --git a/main.go b/main.go index 47cbd33..1c10a4f 100644 --- a/main.go +++ b/main.go @@ -87,7 +87,7 @@ func main() { r.GET("/.well-known/openid-configuration", oidcHandler.OpenIDConfiguration) r.GET("/authorize", oidcHandler.Authorize) r.POST("/token", oidcHandler.Token) - r.GET("/userinfo", oidcHandler.Userinfo) + r.GET("/userinfo", middleware.BearerAuth(oauthService.GetKeyManager(), db), oidcHandler.Userinfo) r.GET("/jwks", oidcHandler.JWKS) // 客户端注册端点 diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..cafbeb3 --- /dev/null +++ b/main_test.go @@ -0,0 +1,425 @@ +package main + +import ( + "bytes" + "encoding/json" + "html/template" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/cookie" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" + "gorm.io/datatypes" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "oidc-oauth2-server/handlers" + "oidc-oauth2-server/middleware" + "oidc-oauth2-server/models" + "oidc-oauth2-server/services" +) + +func setupTestRouter() (*gin.Engine, *gorm.DB, error) { + // 设置测试模式 + gin.SetMode(gin.TestMode) + + // 初始化测试数据库 + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + return nil, nil, err + } + + // 运行数据库迁移 + if err := models.AutoMigrate(db); err != nil { + return nil, nil, err + } + + // 初始化服务 + authService := services.NewAuthService(db) + oauthService, err := services.NewOAuthService(db) + if err != nil { + return nil, nil, err + } + clientService := services.NewClientService(db) + tokenService := services.NewTokenService(db, oauthService.GetKeyManager()) + + // 设置路由 + r := gin.New() + r.Use(gin.Recovery()) + + // 设置模板 + r.SetFuncMap(template.FuncMap{ + "subtract": func(a, b int) int { return a - b }, + "add": func(a, b int) int { return a + b }, + }) + r.LoadHTMLGlob("templates/*") + + // 设置会话存储 + store := cookie.NewStore([]byte("test_secret")) + r.Use(sessions.Sessions("oidc_session", store)) + + // 设置路由处理器 + r.GET("/health", func(c *gin.Context) { + c.JSON(200, gin.H{"status": "ok"}) + }) + + // 设置认证相关路由 + authHandler := handlers.NewAuthHandler(authService) + r.GET("/login", authHandler.ShowLogin) + r.POST("/login", authHandler.HandleLogin) + r.GET("/signup", authHandler.ShowSignup) + r.POST("/signup", authHandler.HandleSignup) + + // 设置 OIDC 相关路由 + oidcHandler := handlers.NewOIDCHandler("http://localhost:8080", oauthService, authService) + r.GET("/.well-known/openid-configuration", oidcHandler.OpenIDConfiguration) + r.GET("/authorize", oidcHandler.Authorize) + r.POST("/token", oidcHandler.Token) + r.GET("/userinfo", middleware.BearerAuth(oauthService.GetKeyManager(), db), oidcHandler.Userinfo) + r.GET("/jwks", oidcHandler.JWKS) + + // 设置客户端注册路由 + registrationHandler := handlers.NewRegistrationHandler(clientService) + r.POST("/register", registrationHandler.Register) + + // 设置令牌管理路由 + tokenHandler := handlers.NewTokenHandler(tokenService) + r.POST("/revoke", tokenHandler.Revoke) + r.POST("/introspect", tokenHandler.Introspect) + + return r, db, nil +} + +func TestHealthCheck(t *testing.T) { + r, _, err := setupTestRouter() + if err != nil { + t.Fatalf("Failed to setup test router: %v", err) + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/health", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + + var response map[string]string + err = json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "ok", response["status"]) +} + +func TestLoginPage(t *testing.T) { + r, _, err := setupTestRouter() + if err != nil { + t.Fatalf("Failed to setup test router: %v", err) + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/login", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) +} + +func TestLoginAuthentication(t *testing.T) { + r, db, err := setupTestRouter() + if err != nil { + t.Fatalf("Failed to setup test router: %v", err) + } + + // 创建测试用户 + password := "testpass" + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + assert.NoError(t, err) + + testUser := &models.User{ + Username: "testuser", + Password: string(hashedPassword), + Email: "test@example.com", + } + result := db.Create(testUser) + assert.NoError(t, result.Error) + + // 测试登录请求 - 使用表单数据 + form := url.Values{} + form.Add("username", "testuser") + form.Add("password", "wrongpass") + form.Add("client_id", "test_client") + form.Add("response_type", "code") + form.Add("scope", "openid profile") + form.Add("redirect_uri", "http://localhost:8080/callback") + form.Add("state", "test_state") + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.ServeHTTP(w, req) + + // 检查是否返回错误页面 + assert.Equal(t, 200, w.Code) + assert.Contains(t, w.Body.String(), "用户名或密码错误") + + // 测试成功登录 + form.Set("password", password) // 使用正确的密码 + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/login", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.ServeHTTP(w, req) + + // 检查重定向 + assert.Equal(t, 302, w.Code) + location := w.Header().Get("Location") + assert.Contains(t, location, "/authorize") +} + +func TestOIDCConfiguration(t *testing.T) { + r, _, err := setupTestRouter() + if err != nil { + t.Fatalf("Failed to setup test router: %v", err) + } + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/.well-known/openid-configuration", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + + var response map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Contains(t, response, "issuer") +} + +func TestUserRegistration(t *testing.T) { + r, db, err := setupTestRouter() + if err != nil { + t.Fatalf("Failed to setup test router: %v", err) + } + + // 测试注册请求 + form := url.Values{} + form.Add("username", "newuser") + form.Add("password", "newpass") + form.Add("email", "newuser@example.com") + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/signup", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.ServeHTTP(w, req) + + // 检查重定向到登录页面 + assert.Equal(t, 302, w.Code) + assert.Contains(t, w.Header().Get("Location"), "/login") + + // 验证用户是否已创建 + var user models.User + err = db.Where("username = ?", "newuser").First(&user).Error + assert.NoError(t, err) + assert.Equal(t, "newuser@example.com", user.Email) +} + +func TestTokenEndpoint(t *testing.T) { + r, db, err := setupTestRouter() + if err != nil { + t.Fatalf("Failed to setup test router: %v", err) + } + + // 创建测试客户端 + redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"}) + grantTypes, _ := json.Marshal([]string{"authorization_code"}) + responseTypes, _ := json.Marshal([]string{"code"}) + client := &models.Client{ + ClientID: "test_client", + ClientSecret: "test_secret", + RedirectURIs: datatypes.JSON(redirectURIs), + GrantTypes: datatypes.JSON(grantTypes), + ResponseTypes: datatypes.JSON(responseTypes), + TokenEndpointAuthMethod: "client_secret_post", + } + result := db.Create(client) + assert.NoError(t, result.Error) + + // 创建测试用户 + user := &models.User{ + Username: "testuser", + Email: "test@example.com", + } + result = db.Create(user) + assert.NoError(t, result.Error) + + // 创建测试授权码 + code := &models.AuthorizationCode{ + Code: "test_code", + ClientID: client.ClientID, + UserID: user.ID, + Scope: "openid profile email", + RedirectURI: "http://localhost:8080/callback", + ExpiresAt: time.Now().Add(10 * time.Minute), + Used: false, + } + result = db.Create(code) + assert.NoError(t, result.Error) + + // 测试令牌请求 + form := url.Values{} + form.Add("grant_type", "authorization_code") + form.Add("code", code.Code) + form.Add("client_id", client.ClientID) + form.Add("client_secret", client.ClientSecret) + form.Add("redirect_uri", code.RedirectURI) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Logf("Response body: %s", w.Body.String()) + } + assert.Equal(t, 200, w.Code) + + var response map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Contains(t, response, "access_token") + assert.Contains(t, response, "id_token") + + // 使用获取到的访问令牌测试用户信息端点 + accessToken := response["access_token"].(string) + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+accessToken) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Logf("Userinfo response body: %s", w.Body.String()) + } + assert.Equal(t, 200, w.Code) + + var userinfoResponse map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &userinfoResponse) + assert.NoError(t, err) + assert.Equal(t, "testuser", userinfoResponse["sub"]) + assert.Equal(t, "test@example.com", userinfoResponse["email"]) +} + +func TestUserinfoEndpoint(t *testing.T) { + r, db, err := setupTestRouter() + if err != nil { + t.Fatalf("Failed to setup test router: %v", err) + } + + // 创建测试用户 + user := &models.User{ + Username: "testuser", + Email: "test@example.com", + } + result := db.Create(user) + assert.NoError(t, result.Error) + + // 创建测试客户端 + redirectURIs, _ := json.Marshal([]string{"http://localhost:8080/callback"}) + grantTypes, _ := json.Marshal([]string{"authorization_code"}) + responseTypes, _ := json.Marshal([]string{"code"}) + client := &models.Client{ + ClientID: "test_client", + ClientSecret: "test_secret", + RedirectURIs: datatypes.JSON(redirectURIs), + GrantTypes: datatypes.JSON(grantTypes), + ResponseTypes: datatypes.JSON(responseTypes), + TokenEndpointAuthMethod: "client_secret_post", + } + result = db.Create(client) + assert.NoError(t, result.Error) + + // 创建测试授权码 + code := &models.AuthorizationCode{ + Code: "test_code", + ClientID: client.ClientID, + UserID: user.ID, + Scope: "openid profile email", + RedirectURI: "http://localhost:8080/callback", + ExpiresAt: time.Now().Add(10 * time.Minute), + Used: false, + } + result = db.Create(code) + assert.NoError(t, result.Error) + + // 测试令牌请求 + form := url.Values{} + form.Add("grant_type", "authorization_code") + form.Add("code", code.Code) + form.Add("client_id", client.ClientID) + form.Add("client_secret", client.ClientSecret) + form.Add("redirect_uri", code.RedirectURI) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Logf("Response body: %s", w.Body.String()) + } + assert.Equal(t, 200, w.Code) + + var tokenResponse map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &tokenResponse) + assert.NoError(t, err) + assert.Contains(t, tokenResponse, "access_token") + + // 测试用户信息请求 + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+tokenResponse["access_token"].(string)) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Logf("Response body: %s", w.Body.String()) + } + assert.Equal(t, 200, w.Code) + + var response map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "testuser", response["sub"]) + assert.Equal(t, "test@example.com", response["email"]) +} + +func TestClientRegistration(t *testing.T) { + r, _, err := setupTestRouter() + if err != nil { + t.Fatalf("Failed to setup test router: %v", err) + } + + // 测试客户端注册请求 + registrationRequest := map[string]interface{}{ + "client_name": "Test Client", + "redirect_uris": []string{"http://localhost:8080/callback"}, + "response_types": []string{"code"}, + "grant_types": []string{"authorization_code"}, + "token_endpoint_auth_method": "client_secret_basic", + } + + jsonData, _ := json.Marshal(registrationRequest) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/register", bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + + assert.Equal(t, 201, w.Code) + + var response map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Contains(t, response, "client_id") + assert.Contains(t, response, "client_secret") +} diff --git a/middleware/auth.go b/middleware/auth.go index 6eddd37..10adf6f 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -2,13 +2,18 @@ package middleware import ( "net/http" + "strconv" "strings" + "oidc-oauth2-server/models" + "oidc-oauth2-server/services" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v4" + "gorm.io/gorm" ) -func BearerAuth(jwtSecret []byte) gin.HandlerFunc { +func BearerAuth(keyManager *services.KeyManager, db *gorm.DB) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { @@ -30,10 +35,10 @@ func BearerAuth(jwtSecret []byte) gin.HandlerFunc { // 解析和验证令牌 token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { // 验证签名算法 - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, jwt.ErrSignatureInvalid } - return jwtSecret, nil + return keyManager.GetPrivateKey().Public(), nil }) if err != nil { @@ -50,8 +55,18 @@ func BearerAuth(jwtSecret []byte) gin.HandlerFunc { // 将令牌中的声明存储在上下文中 if claims, ok := token.Claims.(jwt.MapClaims); ok { - c.Set("user_id", uint(claims["sub"].(float64))) - c.Set("scope", claims["scope"].(string)) + if sub, ok := claims["sub"].(string); ok { + if userID, err := strconv.ParseUint(sub, 10, 32); err == nil { + var user models.User + if err := db.Where("id = ?", uint(userID)).First(&user).Error; err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) + c.Abort() + return + } + c.Set("user_id", user.ID) + c.Set("scope", claims["scope"].(string)) + } + } } c.Next() diff --git a/models/access_token.go b/models/access_token.go new file mode 100644 index 0000000..0163879 --- /dev/null +++ b/models/access_token.go @@ -0,0 +1,22 @@ +package models + +import ( + "time" + + "gorm.io/gorm" +) + +// AccessToken 表示 OAuth2 访问令牌 +type AccessToken struct { + gorm.Model + Token string `gorm:"uniqueIndex;not null"` + UserID uint `gorm:"not null"` + ClientID string `gorm:"not null"` + Scope string `gorm:"type:text"` + ExpiresAt time.Time `gorm:"not null"` + IsRevoked bool `gorm:"default:false"` +} + +func (t *AccessToken) TableName() string { + return "oauth_access_tokens" +} diff --git a/models/migration.go b/models/migration.go index 9e14ca9..08a7515 100644 --- a/models/migration.go +++ b/models/migration.go @@ -11,5 +11,6 @@ func AutoMigrate(db *gorm.DB) error { &Client{}, &AuthorizationCode{}, &Admin{}, + &AccessToken{}, ) } diff --git a/oauth.db b/oauth.db deleted file mode 100644 index e69de29..0000000 diff --git a/services/oauth.go b/services/oauth.go index 1964c3c..8c19dfa 100644 --- a/services/oauth.go +++ b/services/oauth.go @@ -174,7 +174,7 @@ func (s *OAuthService) ExchangeToken(req *TokenRequest) (*TokenResponse, error) } // 生成访问令牌 - accessToken, err := s.generateAccessToken(user, client, authCode.Scope) + accessToken, err := s.GenerateAccessToken(user, client, authCode.Scope) if err != nil { return nil, err } @@ -197,17 +197,18 @@ func (s *OAuthService) ExchangeToken(req *TokenRequest) (*TokenResponse, error) }, nil } -func (s *OAuthService) generateAccessToken(user *models.User, client *models.Client, scope string) (string, error) { +func (s *OAuthService) GenerateAccessToken(user *models.User, client *models.Client, scope string) (string, error) { now := time.Now() claims := jwt.MapClaims{ - "sub": user.ID, + "sub": fmt.Sprintf("%d", user.ID), + "iss": client.ClientID, + "aud": client.ClientID, "exp": now.Add(s.tokenTTL).Unix(), "iat": now.Unix(), - "iss": client.ClientID, "scope": scope, } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) return token.SignedString(s.keyManager.GetPrivateKey()) }