172 lines
4.6 KiB
Go
172 lines
4.6 KiB
Go
package handlers
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/gin-contrib/sessions"
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"oidc-oauth2-server/models"
|
|
"oidc-oauth2-server/services"
|
|
)
|
|
|
|
type OIDCHandler struct {
|
|
privateKey *rsa.PrivateKey
|
|
publicKey *rsa.PublicKey
|
|
config *OIDCConfig
|
|
oauthService *services.OAuthService
|
|
authService *services.AuthService
|
|
}
|
|
|
|
type OIDCConfig struct {
|
|
Issuer string `json:"issuer"`
|
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
|
TokenEndpoint string `json:"token_endpoint"`
|
|
UserinfoEndpoint string `json:"userinfo_endpoint"`
|
|
JwksURI string `json:"jwks_uri"`
|
|
ResponseTypesSupported []string `json:"response_types_supported"`
|
|
SubjectTypesSupported []string `json:"subject_types_supported"`
|
|
ScopesSupported []string `json:"scopes_supported"`
|
|
ClaimsSupported []string `json:"claims_supported"`
|
|
}
|
|
|
|
func NewOIDCHandler(issuerURL string, oauthService *services.OAuthService, authService *services.AuthService) *OIDCHandler {
|
|
config := &OIDCConfig{
|
|
Issuer: issuerURL,
|
|
AuthorizationEndpoint: issuerURL + "/authorize",
|
|
TokenEndpoint: issuerURL + "/token",
|
|
UserinfoEndpoint: issuerURL + "/userinfo",
|
|
JwksURI: issuerURL + "/jwks",
|
|
ResponseTypesSupported: []string{"code"},
|
|
SubjectTypesSupported: []string{"public"},
|
|
ScopesSupported: []string{"openid", "profile", "email"},
|
|
ClaimsSupported: []string{"sub", "name", "email", "email_verified"},
|
|
}
|
|
|
|
return &OIDCHandler{
|
|
config: config,
|
|
oauthService: oauthService,
|
|
authService: authService,
|
|
}
|
|
}
|
|
|
|
// OpenIDConfiguration handles /.well-known/openid-configuration endpoint
|
|
func (h *OIDCHandler) OpenIDConfiguration(c *gin.Context) {
|
|
c.JSON(http.StatusOK, h.config)
|
|
}
|
|
|
|
// Authorize handles /authorize endpoint
|
|
func (h *OIDCHandler) Authorize(c *gin.Context) {
|
|
// 检查用户是否已登录
|
|
session := sessions.Default(c)
|
|
userID := session.Get("user_id")
|
|
if userID == nil {
|
|
// 用户未登录,重定向到登录页面
|
|
query := c.Request.URL.Query()
|
|
c.Redirect(http.StatusFound, "/login?"+query.Encode())
|
|
return
|
|
}
|
|
|
|
req := &services.AuthorizeRequest{
|
|
ResponseType: c.Query("response_type"),
|
|
ClientID: c.Query("client_id"),
|
|
RedirectURI: c.Query("redirect_uri"),
|
|
Scope: c.Query("scope"),
|
|
State: c.Query("state"),
|
|
}
|
|
|
|
if err := h.oauthService.ValidateAuthorizeRequest(req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
authCode, err := h.oauthService.GenerateAuthorizationCode(userID.(uint), req)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization code"})
|
|
return
|
|
}
|
|
|
|
// 构建重定向 URL
|
|
redirectURL, _ := url.Parse(req.RedirectURI)
|
|
q := redirectURL.Query()
|
|
q.Set("code", authCode.Code)
|
|
if req.State != "" {
|
|
q.Set("state", req.State)
|
|
}
|
|
redirectURL.RawQuery = q.Encode()
|
|
|
|
c.Redirect(http.StatusFound, redirectURL.String())
|
|
}
|
|
|
|
// Token handles /token endpoint
|
|
func (h *OIDCHandler) Token(c *gin.Context) {
|
|
req := &services.TokenRequest{
|
|
GrantType: c.PostForm("grant_type"),
|
|
Code: c.PostForm("code"),
|
|
RedirectURI: c.PostForm("redirect_uri"),
|
|
ClientID: c.PostForm("client_id"),
|
|
ClientSecret: c.PostForm("client_secret"),
|
|
}
|
|
|
|
if req.GrantType != "authorization_code" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_grant_type"})
|
|
return
|
|
}
|
|
|
|
tokenResponse, err := h.oauthService.ExchangeToken(req)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, tokenResponse)
|
|
}
|
|
|
|
// Userinfo handles /userinfo endpoint
|
|
func (h *OIDCHandler) Userinfo(c *gin.Context) {
|
|
// 从 token 中获取用户 ID
|
|
userID := c.GetUint("user_id")
|
|
if userID == 0 {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_token"})
|
|
return
|
|
}
|
|
|
|
// 获取用户信息
|
|
var user models.User
|
|
if err := h.authService.GetUserByID(userID, &user); err != nil {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "user_not_found"})
|
|
return
|
|
}
|
|
|
|
// 获取授权范围
|
|
scope := c.GetString("scope")
|
|
scopes := strings.Split(scope, " ")
|
|
|
|
// 准备返回的声明
|
|
claims := gin.H{
|
|
"sub": fmt.Sprintf("%d", user.ID),
|
|
}
|
|
|
|
// 根据授权范围添加相应的声明
|
|
for _, s := range scopes {
|
|
switch s {
|
|
case "profile":
|
|
claims["name"] = user.Username
|
|
case "email":
|
|
claims["email"] = user.Email
|
|
}
|
|
}
|
|
|
|
c.JSON(http.StatusOK, claims)
|
|
}
|
|
|
|
// JWKS handles /jwks endpoint
|
|
func (h *OIDCHandler) JWKS(c *gin.Context) {
|
|
// TODO: 实现 JWKS 密钥集获取逻辑
|
|
c.JSON(http.StatusNotImplemented, gin.H{"message": "Not implemented yet"})
|
|
}
|