From a3f3cc17cf80f916302d91c160a5b1d187188eaf Mon Sep 17 00:00:00 2001 From: chang Date: Thu, 17 Apr 2025 01:25:46 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E4=BE=9D=E8=B5=96=E9=A1=B9?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=20OAuth2=20=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=EF=BC=8C=E6=B7=BB=E5=8A=A0=20PKCE=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=BC=BA=20OIDC=20=E5=A4=84=E7=90=86?= =?UTF-8?q?=E5=99=A8=EF=BC=8C=E6=96=B0=E5=A2=9E=E5=AE=A2=E6=88=B7=E7=AB=AF?= =?UTF-8?q?=E6=B3=A8=E5=86=8C=E5=92=8C=E4=BB=A4=E7=89=8C=E7=AE=A1=E7=90=86?= =?UTF-8?q?=E7=AB=AF=E7=82=B9=EF=BC=8C=E6=94=B9=E8=BF=9B=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93=E6=A8=A1=E5=9E=8B=E4=BB=A5=E6=94=AF=E6=8C=81=E6=96=B0?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 15 ++- go.sum | 22 ++++ handlers/oidc.go | 50 +++++---- handlers/registration.go | 77 +++++++++++++ handlers/token.go | 46 ++++++++ main.go | 21 +++- models/authorization_code.go | 17 +-- models/client.go | 27 +++-- oauth.db | 0 services/client.go | 211 +++++++++++++++++++++++++++++++++++ services/keys.go | 75 +++++++++++++ services/oauth.go | 133 +++++++++++++++++----- services/token.go | 114 +++++++++++++++++++ 13 files changed, 738 insertions(+), 70 deletions(-) create mode 100644 handlers/registration.go create mode 100644 handlers/token.go create mode 100644 oauth.db create mode 100644 services/client.go create mode 100644 services/keys.go create mode 100644 services/token.go diff --git a/go.mod b/go.mod index 29b328e..8fd44f0 100644 --- a/go.mod +++ b/go.mod @@ -8,14 +8,19 @@ require ( github.com/golang-jwt/jwt/v4 v4.5.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/sqlite v1.5.7 - gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde + gorm.io/gorm v1.25.11 ) require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect + gorm.io/datatypes v1.2.5 // indirect + gorm.io/driver/mysql v1.5.6 // indirect ) require ( @@ -40,10 +45,10 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/crypto v0.9.0 - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect - golang.org/x/text v0.9.0 // indirect + golang.org/x/crypto v0.22.0 + golang.org/x/net v0.21.0 // indirect + golang.org/x/sys v0.19.0 // indirect + golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect ) diff --git a/go.sum b/go.sum index 764194a..1d76709 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -23,6 +25,9 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= @@ -31,6 +36,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= @@ -86,14 +93,22 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= @@ -105,8 +120,15 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/datatypes v1.2.5 h1:9UogU3jkydFVW1bIVVeoYsTpLRgwDVW3rHfJG6/Ek9I= +gorm.io/datatypes v1.2.5/go.mod h1:I5FUdlKpLb5PMqeMQhm30CQ6jXP8Rj89xkTeCSAaAD4= +gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= +gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I= gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg= gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.11 h1:/Wfyg1B/je1hnDx3sMkX+gAlxrlZpn6X0BXRlwXlvHg= +gorm.io/gorm v1.25.11/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/handlers/oidc.go b/handlers/oidc.go index f72817c..9dccd11 100644 --- a/handlers/oidc.go +++ b/handlers/oidc.go @@ -23,28 +23,34 @@ type OIDCHandler struct { } type OIDCConfig struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - UserinfoEndpoint string `json:"userinfo_endpoint"` - JwksURI string `json:"jwks_uri"` - ResponseTypesSupported []string `json:"response_types_supported"` - SubjectTypesSupported []string `json:"subject_types_supported"` - ScopesSupported []string `json:"scopes_supported"` - ClaimsSupported []string `json:"claims_supported"` + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + JwksURI string `json:"jwks_uri"` + ResponseTypesSupported []string `json:"response_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` + ScopesSupported []string `json:"scopes_supported"` + ClaimsSupported []string `json:"claims_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` } func NewOIDCHandler(issuerURL string, oauthService *services.OAuthService, authService *services.AuthService) *OIDCHandler { config := &OIDCConfig{ - Issuer: issuerURL, - AuthorizationEndpoint: issuerURL + "/authorize", - TokenEndpoint: issuerURL + "/token", - UserinfoEndpoint: issuerURL + "/userinfo", - JwksURI: issuerURL + "/jwks", - ResponseTypesSupported: []string{"code"}, - SubjectTypesSupported: []string{"public"}, - ScopesSupported: []string{"openid", "profile", "email"}, - ClaimsSupported: []string{"sub", "name", "email", "email_verified"}, + Issuer: issuerURL, + AuthorizationEndpoint: issuerURL + "/authorize", + TokenEndpoint: issuerURL + "/token", + UserinfoEndpoint: issuerURL + "/userinfo", + JwksURI: issuerURL + "/jwks", + ResponseTypesSupported: []string{"code"}, + SubjectTypesSupported: []string{"public"}, + IDTokenSigningAlgValuesSupported: []string{"RS256"}, + ScopesSupported: []string{"openid", "profile", "email"}, + ClaimsSupported: []string{"sub", "iss", "aud", "exp", "iat", "auth_time", "nonce", "acr", "name", "email", "email_verified"}, + CodeChallengeMethodsSupported: []string{"plain", "S256"}, + TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, } return &OIDCHandler{ @@ -166,6 +172,10 @@ func (h *OIDCHandler) Userinfo(c *gin.Context) { // JWKS handles /jwks endpoint func (h *OIDCHandler) JWKS(c *gin.Context) { - // TODO: 实现 JWKS 密钥集获取逻辑 - c.JSON(http.StatusNotImplemented, gin.H{"message": "Not implemented yet"}) + jwks, err := h.oauthService.GetJWKS() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get JWKS"}) + return + } + c.JSON(http.StatusOK, jwks) } diff --git a/handlers/registration.go b/handlers/registration.go new file mode 100644 index 0000000..16abc8e --- /dev/null +++ b/handlers/registration.go @@ -0,0 +1,77 @@ +package handlers + +import ( + "net/http" + + "oidc-oauth2-server/services" + + "github.com/gin-gonic/gin" +) + +type RegistrationHandler struct { + clientService *services.ClientService +} + +func NewRegistrationHandler(clientService *services.ClientService) *RegistrationHandler { + return &RegistrationHandler{ + clientService: clientService, + } +} + +// Register 处理动态客户端注册 +func (h *RegistrationHandler) Register(c *gin.Context) { + var req services.ClientRegistrationRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request", "error_description": err.Error()}) + return + } + + client, err := h.clientService.RegisterClient(&req) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request", "error_description": err.Error()}) + return + } + + c.JSON(http.StatusCreated, client) +} + +// GetClient 获取客户端信息 +func (h *RegistrationHandler) GetClient(c *gin.Context) { + clientID := c.Param("client_id") + client, err := h.clientService.GetClient(clientID) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "not_found"}) + return + } + + c.JSON(http.StatusOK, client) +} + +// UpdateClient 更新客户端信息 +func (h *RegistrationHandler) UpdateClient(c *gin.Context) { + clientID := c.Param("client_id") + var req services.ClientRegistrationRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"}) + return + } + + client, err := h.clientService.UpdateClient(clientID, &req) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"}) + return + } + + c.JSON(http.StatusOK, client) +} + +// DeleteClient 删除客户端 +func (h *RegistrationHandler) DeleteClient(c *gin.Context) { + clientID := c.Param("client_id") + if err := h.clientService.DeleteClient(clientID); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "not_found"}) + return + } + + c.Status(http.StatusNoContent) +} diff --git a/handlers/token.go b/handlers/token.go new file mode 100644 index 0000000..78ac5cb --- /dev/null +++ b/handlers/token.go @@ -0,0 +1,46 @@ +package handlers + +import ( + "net/http" + + "oidc-oauth2-server/services" + + "github.com/gin-gonic/gin" +) + +type TokenHandler struct { + tokenService *services.TokenService +} + +func NewTokenHandler(tokenService *services.TokenService) *TokenHandler { + return &TokenHandler{ + tokenService: tokenService, + } +} + +// Revoke 处理令牌撤销请求 +func (h *TokenHandler) Revoke(c *gin.Context) { + token := c.PostForm("token") + tokenTypeHint := c.PostForm("token_type_hint") + + if err := h.tokenService.RevokeToken(token, tokenTypeHint); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"}) + return + } + + c.Status(http.StatusOK) +} + +// Introspect 处理令牌自省请求 +func (h *TokenHandler) Introspect(c *gin.Context) { + token := c.PostForm("token") + tokenTypeHint := c.PostForm("token_type_hint") + + result, err := h.tokenService.IntrospectToken(token, tokenTypeHint) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"}) + return + } + + c.JSON(http.StatusOK, result) +} diff --git a/main.go b/main.go index 995b835..10543cd 100644 --- a/main.go +++ b/main.go @@ -35,7 +35,12 @@ func main() { // 初始化服务 authService := services.NewAuthService(db) - oauthService := services.NewOAuthService(db, []byte(config.GlobalConfig.JWT.SigningKey)) + oauthService, err := services.NewOAuthService(db) + if err != nil { + log.Fatalf("Failed to initialize OAuth service: %v", err) + } + clientService := services.NewClientService(db) + tokenService := services.NewTokenService(db, oauthService.GetKeyManager()) // 设置 Gin 路由 r := gin.Default() @@ -44,7 +49,7 @@ func main() { r.LoadHTMLGlob("templates/*") // 设置 session 中间件 - store := cookie.NewStore([]byte("secret")) + store := cookie.NewStore([]byte(config.GlobalConfig.JWT.SigningKey)) r.Use(sessions.Sessions("oidc_session", store)) // 健康检查 @@ -57,6 +62,8 @@ func main() { // 创建处理器 authHandler := handlers.NewAuthHandler(authService) oidcHandler := handlers.NewOIDCHandler(config.GlobalConfig.OAuth.IssuerURL, oauthService, authService) + registrationHandler := handlers.NewRegistrationHandler(clientService) + tokenHandler := handlers.NewTokenHandler(tokenService) // 认证路由 r.GET("/login", authHandler.ShowLogin) @@ -69,6 +76,16 @@ func main() { r.GET("/userinfo", oidcHandler.Userinfo) r.GET("/jwks", oidcHandler.JWKS) + // 客户端注册端点 + r.POST("/register", registrationHandler.Register) + r.GET("/register/:client_id", registrationHandler.GetClient) + r.PUT("/register/:client_id", registrationHandler.UpdateClient) + r.DELETE("/register/:client_id", registrationHandler.DeleteClient) + + // 令牌管理端点 + r.POST("/revoke", tokenHandler.Revoke) + r.POST("/introspect", tokenHandler.Introspect) + // 启动服务器 addr := fmt.Sprintf("%s:%d", config.GlobalConfig.Server.Host, config.GlobalConfig.Server.Port) log.Printf("Starting server on %s", addr) diff --git a/models/authorization_code.go b/models/authorization_code.go index 4ed7997..5780e68 100644 --- a/models/authorization_code.go +++ b/models/authorization_code.go @@ -8,13 +8,16 @@ import ( type AuthorizationCode struct { gorm.Model - Code string `gorm:"uniqueIndex;not null"` - ClientID string `gorm:"not null"` - UserID uint `gorm:"not null"` - RedirectURI string `gorm:"not null"` - Scope string `gorm:"not null"` - ExpiresAt time.Time `gorm:"not null"` - Used bool `gorm:"default:false"` + Code string `gorm:"uniqueIndex;not null"` + ClientID string `gorm:"not null"` + UserID uint `gorm:"not null"` + RedirectURI string `gorm:"not null"` + Scope string `gorm:"not null"` + ExpiresAt time.Time `gorm:"not null"` + Used bool `gorm:"default:false"` + CodeChallenge string `gorm:"type:varchar(128)"` + CodeChallengeMethod string `gorm:"type:varchar(20)"` + Nonce string `gorm:"type:varchar(255)"` } func (ac *AuthorizationCode) TableName() string { diff --git a/models/client.go b/models/client.go index abea721..f6cf75c 100644 --- a/models/client.go +++ b/models/client.go @@ -3,19 +3,30 @@ package models import ( "time" + "gorm.io/datatypes" "gorm.io/gorm" ) type Client struct { gorm.Model - ClientID string `gorm:"uniqueIndex;not null"` - ClientSecret string `gorm:"not null"` - RedirectURIs []string `gorm:"type:json"` - GrantTypes []string `gorm:"type:json"` - Scopes []string `gorm:"type:json"` - IsActive bool `gorm:"default:true"` - CreatedAt time.Time - UpdatedAt time.Time + ClientID string `gorm:"uniqueIndex;not null"` + ClientSecret string `gorm:"not null"` + RedirectURIs datatypes.JSON `gorm:"type:json"` + TokenEndpointAuthMethod string `gorm:"not null"` + GrantTypes datatypes.JSON `gorm:"type:json"` + ResponseTypes datatypes.JSON `gorm:"type:json"` + ClientName string `gorm:"type:varchar(255)"` + ClientURI string `gorm:"type:varchar(255)"` + LogoURI string `gorm:"type:varchar(255)"` + Scopes datatypes.JSON `gorm:"type:json"` + Contacts datatypes.JSON `gorm:"type:json"` + TosURI string `gorm:"type:varchar(255)"` + PolicyURI string `gorm:"type:varchar(255)"` + SoftwareID string `gorm:"type:varchar(255)"` + SoftwareVersion string `gorm:"type:varchar(255)"` + IsActive bool `gorm:"default:true"` + CreatedAt time.Time + UpdatedAt time.Time } func (c *Client) TableName() string { diff --git a/oauth.db b/oauth.db new file mode 100644 index 0000000..e69de29 diff --git a/services/client.go b/services/client.go new file mode 100644 index 0000000..a8ff806 --- /dev/null +++ b/services/client.go @@ -0,0 +1,211 @@ +package services + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "time" + + "gorm.io/datatypes" + "gorm.io/gorm" + + "oidc-oauth2-server/models" +) + +type ClientService struct { + db *gorm.DB +} + +type ClientResponse struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` + RegistrationAccessToken string `json:"registration_access_token,omitempty"` + RegistrationClientURI string `json:"registration_client_uri,omitempty"` + ClientIDIssuedAt time.Time `json:"client_id_issued_at"` + ClientSecretExpiresAt time.Time `json:"client_secret_expires_at"` + RedirectURIs []string `json:"redirect_uris"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + ClientName string `json:"client_name"` + ClientURI string `json:"client_uri"` + LogoURI string `json:"logo_uri"` + Scope string `json:"scope"` + Contacts []string `json:"contacts"` + TosURI string `json:"tos_uri"` + PolicyURI string `json:"policy_uri"` + SoftwareID string `json:"software_id"` + SoftwareVersion string `json:"software_version"` +} + +type ClientRegistrationRequest struct { + RedirectURIs []string `json:"redirect_uris" binding:"required"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + ClientName string `json:"client_name"` + ClientURI string `json:"client_uri"` + LogoURI string `json:"logo_uri"` + Scope string `json:"scope"` + Contacts []string `json:"contacts"` + TosURI string `json:"tos_uri"` + PolicyURI string `json:"policy_uri"` + SoftwareID string `json:"software_id"` + SoftwareVersion string `json:"software_version"` +} + +func NewClientService(db *gorm.DB) *ClientService { + return &ClientService{db: db} +} + +func (s *ClientService) RegisterClient(req *ClientRegistrationRequest) (*ClientResponse, error) { + // 生成客户端凭证 + clientID := generateSecureToken(32) + clientSecret := generateSecureToken(48) + + // 设置默认值 + if len(req.GrantTypes) == 0 { + req.GrantTypes = []string{"authorization_code"} + } + if len(req.ResponseTypes) == 0 { + req.ResponseTypes = []string{"code"} + } + if req.TokenEndpointAuthMethod == "" { + req.TokenEndpointAuthMethod = "client_secret_basic" + } + + // 将字符串数组转换为 JSON + redirectURIsJSON, _ := json.Marshal(req.RedirectURIs) + grantTypesJSON, _ := json.Marshal(req.GrantTypes) + responseTypesJSON, _ := json.Marshal(req.ResponseTypes) + scopesJSON, _ := json.Marshal([]string{req.Scope}) + contactsJSON, _ := json.Marshal(req.Contacts) + + // 创建客户端记录 + client := &models.Client{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURIs: datatypes.JSON(redirectURIsJSON), + TokenEndpointAuthMethod: req.TokenEndpointAuthMethod, + GrantTypes: datatypes.JSON(grantTypesJSON), + ResponseTypes: datatypes.JSON(responseTypesJSON), + ClientName: req.ClientName, + ClientURI: req.ClientURI, + LogoURI: req.LogoURI, + Scopes: datatypes.JSON(scopesJSON), + Contacts: datatypes.JSON(contactsJSON), + TosURI: req.TosURI, + PolicyURI: req.PolicyURI, + SoftwareID: req.SoftwareID, + SoftwareVersion: req.SoftwareVersion, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := s.db.Create(client).Error; err != nil { + return nil, err + } + + // 构建响应 + return &ClientResponse{ + ClientID: clientID, + ClientSecret: clientSecret, + ClientIDIssuedAt: client.CreatedAt, + ClientSecretExpiresAt: time.Time{}, // 永不过期 + RedirectURIs: req.RedirectURIs, + TokenEndpointAuthMethod: req.TokenEndpointAuthMethod, + GrantTypes: req.GrantTypes, + ResponseTypes: req.ResponseTypes, + ClientName: req.ClientName, + ClientURI: req.ClientURI, + LogoURI: req.LogoURI, + Scope: req.Scope, + Contacts: req.Contacts, + TosURI: req.TosURI, + PolicyURI: req.PolicyURI, + SoftwareID: req.SoftwareID, + SoftwareVersion: req.SoftwareVersion, + }, nil +} + +func (s *ClientService) GetClient(clientID string) (*ClientResponse, error) { + var client models.Client + if err := s.db.Where("client_id = ? AND is_active = ?", clientID, true).First(&client).Error; err != nil { + return nil, err + } + + // 解析 JSON 字段 + var redirectURIs, grantTypes, responseTypes, scopes, contacts []string + json.Unmarshal(client.RedirectURIs, &redirectURIs) + json.Unmarshal(client.GrantTypes, &grantTypes) + json.Unmarshal(client.ResponseTypes, &responseTypes) + json.Unmarshal(client.Scopes, &scopes) + json.Unmarshal(client.Contacts, &contacts) + + return &ClientResponse{ + ClientID: client.ClientID, + ClientIDIssuedAt: client.CreatedAt, + RedirectURIs: redirectURIs, + TokenEndpointAuthMethod: client.TokenEndpointAuthMethod, + GrantTypes: grantTypes, + ResponseTypes: responseTypes, + ClientName: client.ClientName, + ClientURI: client.ClientURI, + LogoURI: client.LogoURI, + Scope: scopes[0], + Contacts: contacts, + TosURI: client.TosURI, + PolicyURI: client.PolicyURI, + SoftwareID: client.SoftwareID, + SoftwareVersion: client.SoftwareVersion, + }, nil +} + +func (s *ClientService) UpdateClient(clientID string, req *ClientRegistrationRequest) (*ClientResponse, error) { + var client models.Client + if err := s.db.Where("client_id = ? AND is_active = ?", clientID, true).First(&client).Error; err != nil { + return nil, err + } + + // 将字符串数组转换为 JSON + redirectURIsJSON, _ := json.Marshal(req.RedirectURIs) + grantTypesJSON, _ := json.Marshal(req.GrantTypes) + responseTypesJSON, _ := json.Marshal(req.ResponseTypes) + scopesJSON, _ := json.Marshal([]string{req.Scope}) + contactsJSON, _ := json.Marshal(req.Contacts) + + // 更新客户端信息 + client.RedirectURIs = datatypes.JSON(redirectURIsJSON) + client.GrantTypes = datatypes.JSON(grantTypesJSON) + client.ResponseTypes = datatypes.JSON(responseTypesJSON) + client.Scopes = datatypes.JSON(scopesJSON) + client.Contacts = datatypes.JSON(contactsJSON) + client.UpdatedAt = time.Now() + + if err := s.db.Save(&client).Error; err != nil { + return nil, err + } + + return s.GetClient(clientID) +} + +func (s *ClientService) DeleteClient(clientID string) error { + result := s.db.Model(&models.Client{}).Where("client_id = ?", clientID).Update("is_active", false) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return errors.New("client not found") + } + return nil +} + +func generateSecureToken(length int) string { + b := make([]byte, length) + if _, err := rand.Read(b); err != nil { + return "" + } + return base64.RawURLEncoding.EncodeToString(b) +} diff --git a/services/keys.go b/services/keys.go new file mode 100644 index 0000000..a353f63 --- /dev/null +++ b/services/keys.go @@ -0,0 +1,75 @@ +package services + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "fmt" + "math/big" +) + +type KeyManager struct { + privateKey *rsa.PrivateKey + publicKey *rsa.PublicKey + kid string +} + +type JSONWebKey struct { + Kty string `json:"kty"` + Kid string `json:"kid"` + Use string `json:"use"` + N string `json:"n"` + E string `json:"e"` + Alg string `json:"alg"` +} + +type JSONWebKeySet struct { + Keys []JSONWebKey `json:"keys"` +} + +func NewKeyManager() (*KeyManager, error) { + // 生成 RSA 密钥对 + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("failed to generate RSA key: %v", err) + } + + // 生成密钥 ID + kid := generateKeyID() + + return &KeyManager{ + privateKey: privateKey, + publicKey: &privateKey.PublicKey, + kid: kid, + }, nil +} + +func (km *KeyManager) GetJWKS() (*JSONWebKeySet, error) { + // 将公钥转换为 JWK 格式 + jwk := JSONWebKey{ + Kty: "RSA", + Kid: km.kid, + Use: "sig", + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(km.publicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(km.publicKey.E)).Bytes()), + } + + return &JSONWebKeySet{ + Keys: []JSONWebKey{jwk}, + }, nil +} + +func (km *KeyManager) GetPrivateKey() *rsa.PrivateKey { + return km.privateKey +} + +func (km *KeyManager) GetKID() string { + return km.kid +} + +func generateKeyID() string { + b := make([]byte, 16) + rand.Read(b) + return base64.RawURLEncoding.EncodeToString(b) +} diff --git a/services/oauth.go b/services/oauth.go index 617abfa..1964c3c 100644 --- a/services/oauth.go +++ b/services/oauth.go @@ -2,8 +2,11 @@ package services import ( "crypto/rand" + "crypto/sha256" "encoding/base64" + "encoding/json" "errors" + "fmt" "time" "github.com/golang-jwt/jwt/v4" @@ -13,17 +16,20 @@ import ( ) type OAuthService struct { - db *gorm.DB - jwtSecret []byte - tokenTTL time.Duration + db *gorm.DB + keyManager *KeyManager + tokenTTL time.Duration } type AuthorizeRequest struct { - ResponseType string - ClientID string - RedirectURI string - Scope string - State string + ResponseType string + ClientID string + RedirectURI string + Scope string + State string + CodeChallenge string + CodeChallengeMethod string + Nonce string } type TokenRequest struct { @@ -32,6 +38,7 @@ type TokenRequest struct { RedirectURI string ClientID string ClientSecret string + CodeVerifier string } type TokenResponse struct { @@ -42,12 +49,17 @@ type TokenResponse struct { RefreshToken string `json:"refresh_token,omitempty"` } -func NewOAuthService(db *gorm.DB, jwtSecret []byte) *OAuthService { - return &OAuthService{ - db: db, - jwtSecret: jwtSecret, - tokenTTL: time.Hour, +func NewOAuthService(db *gorm.DB) (*OAuthService, error) { + keyManager, err := NewKeyManager() + if err != nil { + return nil, err } + + return &OAuthService{ + db: db, + keyManager: keyManager, + tokenTTL: time.Hour, + }, nil } func (s *OAuthService) ValidateAuthorizeRequest(req *AuthorizeRequest) error { @@ -61,8 +73,13 @@ func (s *OAuthService) ValidateAuthorizeRequest(req *AuthorizeRequest) error { } // 验证重定向 URI + var redirectURIs []string + if err := json.Unmarshal(client.RedirectURIs, &redirectURIs); err != nil { + return errors.New("invalid redirect URIs format") + } + validRedirect := false - for _, uri := range client.RedirectURIs { + for _, uri := range redirectURIs { if uri == req.RedirectURI { validRedirect = true break @@ -72,6 +89,16 @@ func (s *OAuthService) ValidateAuthorizeRequest(req *AuthorizeRequest) error { return errors.New("invalid redirect URI") } + // 验证 PKCE + if req.CodeChallenge != "" { + if req.CodeChallengeMethod != "S256" && req.CodeChallengeMethod != "plain" { + return errors.New("invalid code challenge method") + } + if len(req.CodeChallenge) < 43 || len(req.CodeChallenge) > 128 { + return errors.New("invalid code challenge length") + } + } + return nil } @@ -84,13 +111,16 @@ func (s *OAuthService) GenerateAuthorizationCode(userID uint, req *AuthorizeRequ code := base64.RawURLEncoding.EncodeToString(b) authCode := &models.AuthorizationCode{ - Code: code, - ClientID: req.ClientID, - RedirectURI: req.RedirectURI, - UserID: userID, - Scope: req.Scope, - ExpiresAt: time.Now().Add(10 * time.Minute), - Used: false, + Code: code, + ClientID: req.ClientID, + RedirectURI: req.RedirectURI, + UserID: userID, + Scope: req.Scope, + ExpiresAt: time.Now().Add(10 * time.Minute), + Used: false, + CodeChallenge: req.CodeChallenge, + CodeChallengeMethod: req.CodeChallengeMethod, + Nonce: req.Nonce, } // 保存授权码到数据库 @@ -119,6 +149,17 @@ func (s *OAuthService) ExchangeToken(req *TokenRequest) (*TokenResponse, error) return nil, errors.New("redirect URI mismatch") } + // 验证 PKCE + if authCode.CodeChallenge != "" { + if req.CodeVerifier == "" { + return nil, errors.New("code verifier required") + } + + if err := validatePKCE(authCode.CodeChallenge, authCode.CodeChallengeMethod, req.CodeVerifier); err != nil { + return nil, err + } + } + // 验证客户端 client := &models.Client{} if err := s.db.Where("client_id = ? AND client_secret = ?", @@ -139,7 +180,7 @@ func (s *OAuthService) ExchangeToken(req *TokenRequest) (*TokenResponse, error) } // 生成 ID 令牌 - idToken, err := s.generateIDToken(user, client, authCode.Scope) + idToken, err := s.generateIDToken(user, client, authCode.Scope, "") if err != nil { return nil, err } @@ -167,20 +208,56 @@ func (s *OAuthService) generateAccessToken(user *models.User, client *models.Cli } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString(s.jwtSecret) + return token.SignedString(s.keyManager.GetPrivateKey()) } -func (s *OAuthService) generateIDToken(user *models.User, client *models.Client, scope string) (string, error) { +func (s *OAuthService) generateIDToken(user *models.User, client *models.Client, scope string, nonce string) (string, error) { now := time.Now() claims := jwt.MapClaims{ - "sub": user.ID, + "iss": client.ClientID, + "sub": fmt.Sprintf("%d", user.ID), + "aud": client.ClientID, "exp": now.Add(s.tokenTTL).Unix(), "iat": now.Unix(), - "iss": client.ClientID, + "auth_time": now.Unix(), + "nonce": nonce, + "acr": "1", "email": user.Email, "email_verified": true, + "name": user.Username, } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString(s.jwtSecret) + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = s.keyManager.GetKID() + + return token.SignedString(s.keyManager.GetPrivateKey()) +} + +func validatePKCE(challenge, method, verifier string) error { + if len(verifier) < 43 || len(verifier) > 128 { + return errors.New("invalid code verifier length") + } + + var computedChallenge string + if method == "S256" { + h := sha256.New() + h.Write([]byte(verifier)) + computedChallenge = base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + } else { + computedChallenge = verifier + } + + if computedChallenge != challenge { + return errors.New("code verifier does not match challenge") + } + + return nil +} + +func (s *OAuthService) GetJWKS() (*JSONWebKeySet, error) { + return s.keyManager.GetJWKS() +} + +func (s *OAuthService) GetKeyManager() *KeyManager { + return s.keyManager } diff --git a/services/token.go b/services/token.go new file mode 100644 index 0000000..2e5b853 --- /dev/null +++ b/services/token.go @@ -0,0 +1,114 @@ +package services + +import ( + "errors" + "time" + + "github.com/golang-jwt/jwt/v4" + "gorm.io/gorm" +) + +type TokenService struct { + db *gorm.DB + keyManager *KeyManager +} + +type TokenInfo struct { + Active bool `json:"active"` + Scope string `json:"scope,omitempty"` + ClientID string `json:"client_id,omitempty"` + Username string `json:"username,omitempty"` + TokenType string `json:"token_type,omitempty"` + Exp int64 `json:"exp,omitempty"` + Iat int64 `json:"iat,omitempty"` + Nbf int64 `json:"nbf,omitempty"` + Sub string `json:"sub,omitempty"` + Aud string `json:"aud,omitempty"` + Iss string `json:"iss,omitempty"` + Jti string `json:"jti,omitempty"` +} + +type RevokedToken struct { + gorm.Model + Token string `gorm:"uniqueIndex;not null"` + ExpiresAt time.Time `gorm:"not null"` +} + +func NewTokenService(db *gorm.DB, keyManager *KeyManager) *TokenService { + return &TokenService{ + db: db, + keyManager: keyManager, + } +} + +func (s *TokenService) RevokeToken(token, tokenTypeHint string) error { + // 验证令牌 + claims, err := s.parseToken(token) + if err != nil { + return err + } + + // 保存到撤销列表 + revokedToken := &RevokedToken{ + Token: token, + ExpiresAt: time.Unix(claims["exp"].(int64), 0), + } + + return s.db.Create(revokedToken).Error +} + +func (s *TokenService) IntrospectToken(token, tokenTypeHint string) (*TokenInfo, error) { + // 检查令牌是否被撤销 + var revokedToken RevokedToken + if err := s.db.Where("token = ?", token).First(&revokedToken).Error; err == nil { + return &TokenInfo{Active: false}, nil + } + + // 解析令牌 + claims, err := s.parseToken(token) + if err != nil { + return &TokenInfo{Active: false}, nil + } + + // 检查令牌是否过期 + exp := time.Unix(claims["exp"].(int64), 0) + if time.Now().After(exp) { + return &TokenInfo{Active: false}, nil + } + + // 构建令牌信息 + info := &TokenInfo{ + Active: true, + Scope: claims["scope"].(string), + ClientID: claims["iss"].(string), + TokenType: "Bearer", + Exp: claims["exp"].(int64), + Iat: claims["iat"].(int64), + Sub: claims["sub"].(string), + } + + return info, nil +} + +func (s *TokenService) parseToken(token string) (jwt.MapClaims, error) { + parsedToken, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, errors.New("unexpected signing method") + } + return s.keyManager.GetPrivateKey().Public(), nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok && parsedToken.Valid { + return claims, nil + } + + return nil, errors.New("invalid token") +} + +func (s *TokenService) AutoMigrate() error { + return s.db.AutoMigrate(&RevokedToken{}) +}