package handlers import ( "crypto/rsa" "fmt" "net/http" "net/url" "strings" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "oidc-oauth2-server/models" "oidc-oauth2-server/services" ) type OIDCHandler struct { privateKey *rsa.PrivateKey publicKey *rsa.PublicKey config *OIDCConfig oauthService *services.OAuthService authService *services.AuthService } type OIDCConfig struct { Issuer string `json:"issuer"` AuthorizationEndpoint string `json:"authorization_endpoint"` TokenEndpoint string `json:"token_endpoint"` UserinfoEndpoint string `json:"userinfo_endpoint"` JwksURI string `json:"jwks_uri"` ResponseTypesSupported []string `json:"response_types_supported"` SubjectTypesSupported []string `json:"subject_types_supported"` ScopesSupported []string `json:"scopes_supported"` ClaimsSupported []string `json:"claims_supported"` } func NewOIDCHandler(issuerURL string, oauthService *services.OAuthService, authService *services.AuthService) *OIDCHandler { config := &OIDCConfig{ Issuer: issuerURL, AuthorizationEndpoint: issuerURL + "/authorize", TokenEndpoint: issuerURL + "/token", UserinfoEndpoint: issuerURL + "/userinfo", JwksURI: issuerURL + "/jwks", ResponseTypesSupported: []string{"code"}, SubjectTypesSupported: []string{"public"}, ScopesSupported: []string{"openid", "profile", "email"}, ClaimsSupported: []string{"sub", "name", "email", "email_verified"}, } return &OIDCHandler{ config: config, oauthService: oauthService, authService: authService, } } // OpenIDConfiguration handles /.well-known/openid-configuration endpoint func (h *OIDCHandler) OpenIDConfiguration(c *gin.Context) { c.JSON(http.StatusOK, h.config) } // Authorize handles /authorize endpoint func (h *OIDCHandler) Authorize(c *gin.Context) { // 检查用户是否已登录 session := sessions.Default(c) userID := session.Get("user_id") if userID == nil { // 用户未登录,重定向到登录页面 query := c.Request.URL.Query() c.Redirect(http.StatusFound, "/login?"+query.Encode()) return } req := &services.AuthorizeRequest{ ResponseType: c.Query("response_type"), ClientID: c.Query("client_id"), RedirectURI: c.Query("redirect_uri"), Scope: c.Query("scope"), State: c.Query("state"), } if err := h.oauthService.ValidateAuthorizeRequest(req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } authCode, err := h.oauthService.GenerateAuthorizationCode(userID.(uint), req) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization code"}) return } // 构建重定向 URL redirectURL, _ := url.Parse(req.RedirectURI) q := redirectURL.Query() q.Set("code", authCode.Code) if req.State != "" { q.Set("state", req.State) } redirectURL.RawQuery = q.Encode() c.Redirect(http.StatusFound, redirectURL.String()) } // Token handles /token endpoint func (h *OIDCHandler) Token(c *gin.Context) { req := &services.TokenRequest{ GrantType: c.PostForm("grant_type"), Code: c.PostForm("code"), RedirectURI: c.PostForm("redirect_uri"), ClientID: c.PostForm("client_id"), ClientSecret: c.PostForm("client_secret"), } if req.GrantType != "authorization_code" { c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_grant_type"}) return } tokenResponse, err := h.oauthService.ExchangeToken(req) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, tokenResponse) } // Userinfo handles /userinfo endpoint func (h *OIDCHandler) Userinfo(c *gin.Context) { // 从 token 中获取用户 ID userID := c.GetUint("user_id") if userID == 0 { c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token"}) return } // 获取用户信息 var user models.User if err := h.authService.GetUserByID(userID, &user); err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "user_not_found"}) return } // 获取授权范围 scope := c.GetString("scope") scopes := strings.Split(scope, " ") // 准备返回的声明 claims := gin.H{ "sub": fmt.Sprintf("%d", user.ID), } // 根据授权范围添加相应的声明 for _, s := range scopes { switch s { case "profile": claims["name"] = user.Username case "email": claims["email"] = user.Email } } c.JSON(http.StatusOK, claims) } // JWKS handles /jwks endpoint func (h *OIDCHandler) JWKS(c *gin.Context) { // TODO: 实现 JWKS 密钥集获取逻辑 c.JSON(http.StatusNotImplemented, gin.H{"message": "Not implemented yet"}) }