Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions example/client/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log/slog"
"net/http"
"os"
"strconv"
"strings"
"sync/atomic"
"time"
Expand Down Expand Up @@ -34,6 +35,14 @@ func main() {
scopes := strings.Split(os.Getenv("SCOPES"), " ")
responseMode := os.Getenv("RESPONSE_MODE")

var pkce bool
if pkceEnv, ok := os.LookupEnv("PKCE"); ok {
var err error
pkce, err = strconv.ParseBool(pkceEnv)
if err != nil {
logrus.Fatalf("error parsing PKCE %s", err.Error())
}
}
redirectURI := fmt.Sprintf("http://localhost:%v%v", port, callbackPath)
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())

Expand Down Expand Up @@ -64,6 +73,9 @@ func main() {
if keyPath != "" {
options = append(options, rp.WithJWTProfile(rp.SignerFromKeyPath(keyPath)))
}
if pkce {
options = append(options, rp.WithPKCE(cookieHandler))
}

// One can add a logger to the context,
// pre-defining log attributes as required.
Expand Down
4 changes: 2 additions & 2 deletions example/server/exampleop/templates/login.html
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
<button type="submit">Login</button>
</form>
</body>
</html>`
{{- end }}
</html>
{{- end }}
15 changes: 10 additions & 5 deletions example/server/storage/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ const (
// CustomClaim is an example for how to return custom claims with this library
CustomClaim = "custom_claim"

// CustomScopeImpersonatePrefix is an example scope prefix for passing user id to impersonate using token exchage
// CustomScopeImpersonatePrefix is an example scope prefix for passing user id to impersonate using token exchange
CustomScopeImpersonatePrefix = "custom_scope:impersonate:"
)

Expand Down Expand Up @@ -143,6 +143,14 @@ func MaxAgeToInternal(maxAge *uint) *time.Duration {
}

func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthRequest {
var codeChallenge *OIDCCodeChallenge
if authReq.CodeChallenge != "" {
codeChallenge = &OIDCCodeChallenge{
Challenge: authReq.CodeChallenge,
Method: string(authReq.CodeChallengeMethod),
}
}

return &AuthRequest{
CreationDate: time.Now(),
ApplicationID: authReq.ClientID,
Expand All @@ -157,10 +165,7 @@ func authRequestToInternal(authReq *oidc.AuthRequest, userID string) *AuthReques
ResponseType: authReq.ResponseType,
ResponseMode: authReq.ResponseMode,
Nonce: authReq.Nonce,
CodeChallenge: &OIDCCodeChallenge{
Challenge: authReq.CodeChallenge,
Method: string(authReq.CodeChallengeMethod),
},
CodeChallenge: codeChallenge,
}
}

Expand Down
1 change: 1 addition & 0 deletions pkg/op/op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ func TestRoutes(t *testing.T) {
authReq, err := storage.CreateAuthRequest(ctx, oidcAuthReq, "id1")
require.NoError(t, err)
storage.AuthRequestDone(authReq.GetID())
storage.SaveAuthCode(ctx, authReq.GetID(), "123")

accessToken, refreshToken, _, err := op.CreateAccessToken(ctx, authReq, op.AccessTokenTypeBearer, testProvider, client, "")
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/op/server_http_routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestServerRoutes(t *testing.T) {
"client_id": client.GetID(),
"client_secret": "secret",
"redirect_uri": "https://example.com",
"code": "123",
"code": "abc",
},
wantCode: http.StatusBadRequest,
json: `{"error":"invalid_grant", "error_description":"invalid code"}`,
Expand Down
26 changes: 19 additions & 7 deletions pkg/op/token_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
ctx, span := tracer.Start(ctx, "AuthorizeCodeClient")
defer span.End()

request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
if err != nil {
return nil, nil, err
}

codeChallenge := request.GetCodeChallenge()
if codeChallenge != nil {
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, codeChallenge)

if err != nil {
return nil, nil, err
}
}

if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger)
if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() {
Expand All @@ -83,9 +97,9 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
if err != nil {
return nil, nil, err
}
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
return request, client, err
}

client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
if err != nil {
return nil, nil, oidc.ErrInvalidClient().WithParent(err)
Expand All @@ -94,12 +108,10 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
return nil, nil, oidc.ErrInvalidClient().WithDescription("private_key_jwt not allowed for this client")
}
if client.AuthMethod() == oidc.AuthMethodNone {
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)
if err != nil {
return nil, nil, err
if codeChallenge == nil {
return nil, nil, oidc.ErrInvalidRequest().WithDescription("PKCE required")
}
err = AuthorizeCodeChallenge(tokenReq.CodeVerifier, request.GetCodeChallenge())
return request, client, err
return request, client, nil
}
if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
Expand All @@ -108,7 +120,7 @@ func AuthorizeCodeClient(ctx context.Context, tokenReq *oidc.AccessTokenRequest,
if err != nil {
return nil, nil, err
}
request, err = AuthRequestByCode(ctx, exchanger.Storage(), tokenReq.Code)

return request, client, err
}

Expand Down
Loading