package services import ( "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/json" "errors" "fmt" "time" "github.com/golang-jwt/jwt/v4" "gorm.io/gorm" "oidc-oauth2-server/models" ) type OAuthService struct { db *gorm.DB keyManager *KeyManager tokenTTL time.Duration } type AuthorizeRequest struct { ResponseType string ClientID string RedirectURI string Scope string State string CodeChallenge string CodeChallengeMethod string Nonce string } type TokenRequest struct { GrantType string Code string RedirectURI string ClientID string ClientSecret string CodeVerifier string } type TokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` IDToken string `json:"id_token,omitempty"` RefreshToken string `json:"refresh_token,omitempty"` } func NewOAuthService(db *gorm.DB) (*OAuthService, error) { keyManager, err := NewKeyManager() if err != nil { return nil, err } return &OAuthService{ db: db, keyManager: keyManager, tokenTTL: time.Hour, }, nil } func (s *OAuthService) ValidateAuthorizeRequest(req *AuthorizeRequest) error { if req.ResponseType != "code" { return errors.New("unsupported response type") } client := &models.Client{} if err := s.db.First(client, "client_id = ?", req.ClientID).Error; err != nil { return errors.New("invalid client") } // 验证重定向 URI var redirectURIs []string if err := json.Unmarshal(client.RedirectURIs, &redirectURIs); err != nil { return errors.New("invalid redirect URIs format") } validRedirect := false for _, uri := range redirectURIs { if uri == req.RedirectURI { validRedirect = true break } } if !validRedirect { return errors.New("invalid redirect URI") } // 验证 PKCE if req.CodeChallenge != "" { if req.CodeChallengeMethod != "S256" && req.CodeChallengeMethod != "plain" { return errors.New("invalid code challenge method") } if len(req.CodeChallenge) < 43 || len(req.CodeChallenge) > 128 { return errors.New("invalid code challenge length") } } return nil } func (s *OAuthService) GenerateAuthorizationCode(userID uint, req *AuthorizeRequest) (*models.AuthorizationCode, error) { // 生成随机授权码 b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return nil, err } code := base64.RawURLEncoding.EncodeToString(b) authCode := &models.AuthorizationCode{ Code: code, ClientID: req.ClientID, RedirectURI: req.RedirectURI, UserID: userID, Scope: req.Scope, ExpiresAt: time.Now().Add(10 * time.Minute), Used: false, CodeChallenge: req.CodeChallenge, CodeChallengeMethod: req.CodeChallengeMethod, Nonce: req.Nonce, } // 保存授权码到数据库 if err := s.db.Create(authCode).Error; err != nil { return nil, err } return authCode, nil } func (s *OAuthService) ExchangeToken(req *TokenRequest) (*TokenResponse, error) { // 验证授权码 authCode := &models.AuthorizationCode{} if err := s.db.Where("code = ? AND client_id = ? AND used = ?", req.Code, req.ClientID, false).First(authCode).Error; err != nil { return nil, errors.New("invalid authorization code") } // 验证授权码是否过期 if time.Now().After(authCode.ExpiresAt) { return nil, errors.New("authorization code expired") } // 验证重定向 URI if authCode.RedirectURI != req.RedirectURI { return nil, errors.New("redirect URI mismatch") } // 验证 PKCE if authCode.CodeChallenge != "" { if req.CodeVerifier == "" { return nil, errors.New("code verifier required") } if err := validatePKCE(authCode.CodeChallenge, authCode.CodeChallengeMethod, req.CodeVerifier); err != nil { return nil, err } } // 验证客户端 client := &models.Client{} if err := s.db.Where("client_id = ? AND client_secret = ?", req.ClientID, req.ClientSecret).First(client).Error; err != nil { return nil, errors.New("invalid client credentials") } // 获取用户信息 user := &models.User{} if err := s.db.First(user, authCode.UserID).Error; err != nil { return nil, errors.New("user not found") } // 生成访问令牌 accessToken, err := s.generateAccessToken(user, client, authCode.Scope) if err != nil { return nil, err } // 生成 ID 令牌 idToken, err := s.generateIDToken(user, client, authCode.Scope, "") if err != nil { return nil, err } // 标记授权码为已使用 authCode.Used = true s.db.Save(authCode) return &TokenResponse{ AccessToken: accessToken, TokenType: "Bearer", ExpiresIn: int(s.tokenTTL.Seconds()), IDToken: idToken, }, nil } func (s *OAuthService) generateAccessToken(user *models.User, client *models.Client, scope string) (string, error) { now := time.Now() claims := jwt.MapClaims{ "sub": user.ID, "exp": now.Add(s.tokenTTL).Unix(), "iat": now.Unix(), "iss": client.ClientID, "scope": scope, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(s.keyManager.GetPrivateKey()) } func (s *OAuthService) generateIDToken(user *models.User, client *models.Client, scope string, nonce string) (string, error) { now := time.Now() claims := jwt.MapClaims{ "iss": client.ClientID, "sub": fmt.Sprintf("%d", user.ID), "aud": client.ClientID, "exp": now.Add(s.tokenTTL).Unix(), "iat": now.Unix(), "auth_time": now.Unix(), "nonce": nonce, "acr": "1", "email": user.Email, "email_verified": true, "name": user.Username, } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) token.Header["kid"] = s.keyManager.GetKID() return token.SignedString(s.keyManager.GetPrivateKey()) } func validatePKCE(challenge, method, verifier string) error { if len(verifier) < 43 || len(verifier) > 128 { return errors.New("invalid code verifier length") } var computedChallenge string if method == "S256" { h := sha256.New() h.Write([]byte(verifier)) computedChallenge = base64.RawURLEncoding.EncodeToString(h.Sum(nil)) } else { computedChallenge = verifier } if computedChallenge != challenge { return errors.New("code verifier does not match challenge") } return nil } func (s *OAuthService) GetJWKS() (*JSONWebKeySet, error) { return s.keyManager.GetJWKS() } func (s *OAuthService) GetKeyManager() *KeyManager { return s.keyManager }