package services import ( "crypto/rand" "encoding/base64" "errors" "time" "github.com/golang-jwt/jwt/v4" "gorm.io/gorm" "oidc-oauth2-server/models" ) type OAuthService struct { db *gorm.DB jwtSecret []byte tokenTTL time.Duration } type AuthorizeRequest struct { ResponseType string ClientID string RedirectURI string Scope string State string } type TokenRequest struct { GrantType string Code string RedirectURI string ClientID string ClientSecret string } type TokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` ExpiresIn int `json:"expires_in"` IDToken string `json:"id_token,omitempty"` RefreshToken string `json:"refresh_token,omitempty"` } func NewOAuthService(db *gorm.DB, jwtSecret []byte) *OAuthService { return &OAuthService{ db: db, jwtSecret: jwtSecret, tokenTTL: time.Hour, } } func (s *OAuthService) ValidateAuthorizeRequest(req *AuthorizeRequest) error { if req.ResponseType != "code" { return errors.New("unsupported response type") } client := &models.Client{} if err := s.db.First(client, "client_id = ?", req.ClientID).Error; err != nil { return errors.New("invalid client") } // 验证重定向 URI validRedirect := false for _, uri := range client.RedirectURIs { if uri == req.RedirectURI { validRedirect = true break } } if !validRedirect { return errors.New("invalid redirect URI") } return nil } func (s *OAuthService) GenerateAuthorizationCode(userID uint, req *AuthorizeRequest) (*models.AuthorizationCode, error) { // 生成随机授权码 b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return nil, err } code := base64.RawURLEncoding.EncodeToString(b) authCode := &models.AuthorizationCode{ Code: code, ClientID: req.ClientID, RedirectURI: req.RedirectURI, UserID: userID, Scope: req.Scope, ExpiresAt: time.Now().Add(10 * time.Minute), Used: false, } // 保存授权码到数据库 if err := s.db.Create(authCode).Error; err != nil { return nil, err } return authCode, nil } func (s *OAuthService) ExchangeToken(req *TokenRequest) (*TokenResponse, error) { // 验证授权码 authCode := &models.AuthorizationCode{} if err := s.db.Where("code = ? AND client_id = ? AND used = ?", req.Code, req.ClientID, false).First(authCode).Error; err != nil { return nil, errors.New("invalid authorization code") } // 验证授权码是否过期 if time.Now().After(authCode.ExpiresAt) { return nil, errors.New("authorization code expired") } // 验证重定向 URI if authCode.RedirectURI != req.RedirectURI { return nil, errors.New("redirect URI mismatch") } // 验证客户端 client := &models.Client{} if err := s.db.Where("client_id = ? AND client_secret = ?", req.ClientID, req.ClientSecret).First(client).Error; err != nil { return nil, errors.New("invalid client credentials") } // 获取用户信息 user := &models.User{} if err := s.db.First(user, authCode.UserID).Error; err != nil { return nil, errors.New("user not found") } // 生成访问令牌 accessToken, err := s.generateAccessToken(user, client, authCode.Scope) if err != nil { return nil, err } // 生成 ID 令牌 idToken, err := s.generateIDToken(user, client, authCode.Scope) if err != nil { return nil, err } // 标记授权码为已使用 authCode.Used = true s.db.Save(authCode) return &TokenResponse{ AccessToken: accessToken, TokenType: "Bearer", ExpiresIn: int(s.tokenTTL.Seconds()), IDToken: idToken, }, nil } func (s *OAuthService) generateAccessToken(user *models.User, client *models.Client, scope string) (string, error) { now := time.Now() claims := jwt.MapClaims{ "sub": user.ID, "exp": now.Add(s.tokenTTL).Unix(), "iat": now.Unix(), "iss": client.ClientID, "scope": scope, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(s.jwtSecret) } func (s *OAuthService) generateIDToken(user *models.User, client *models.Client, scope string) (string, error) { now := time.Now() claims := jwt.MapClaims{ "sub": user.ID, "exp": now.Add(s.tokenTTL).Unix(), "iat": now.Unix(), "iss": client.ClientID, "email": user.Email, "email_verified": true, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(s.jwtSecret) }