Skip to content

Commit 9d69335

Browse files
committed
Update TestOAuthRequestTokenEndpoint to handle CSRF
Signed-off-by: Monis Khan <mkhan@redhat.com>
1 parent 3b75f6b commit 9d69335

1 file changed

Lines changed: 103 additions & 10 deletions

File tree

test/integration/oauth_token_endpoint_test.go

Lines changed: 103 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package integration
22

33
import (
44
"bytes"
5+
"io"
56
"io/ioutil"
67
"net/http"
78
"net/url"
9+
"strings"
810
"testing"
911

1012
"golang.org/x/net/html"
@@ -50,7 +52,7 @@ func TestOAuthRequestTokenEndpoint(t *testing.T) {
5052
}
5153
masterURL.Path = "/oauth/token/request"
5254

53-
_, tokenHeaderLocation := checkNewReqAndRoundTrip(t, transport, masterURL.String(), false, http.StatusFound)
55+
_, tokenHeaderLocation := checkNewReqAndRoundTrip(t, transport, http.MethodGet, masterURL.String(), nil, nil, false, http.StatusFound)
5456

5557
if len(tokenHeaderLocation) == 0 {
5658
t.Fatalf("no Location header")
@@ -61,14 +63,44 @@ func TestOAuthRequestTokenEndpoint(t *testing.T) {
6163
t.Fatalf("unexpected error %v", err)
6264
}
6365

64-
_, authHeaderLocation := checkNewReqAndRoundTrip(t, transport, authRedirect.String(), true, http.StatusFound)
66+
_, authHeaderLocation := checkNewReqAndRoundTrip(t, transport, http.MethodGet, authRedirect.String(), nil, nil, true, http.StatusFound)
6567

6668
if len(authHeaderLocation) == 0 {
6769
t.Fatalf("no Location header")
6870
}
6971

70-
displayResp, _ := checkNewReqAndRoundTrip(t, transport, authHeaderLocation, false, http.StatusOK)
71-
apiToken := getTokenFromDisplay(t, displayResp)
72+
displayRespGet, _ := checkNewReqAndRoundTrip(t, transport, http.MethodGet, authHeaderLocation, nil, nil, false, http.StatusOK)
73+
74+
code, csrf, loc := getCodeCSRFFromDisplay(t, displayRespGet)
75+
masterURL.Path = loc
76+
values := url.Values{}
77+
values.Set("code", code)
78+
values.Set("csrf", csrf)
79+
headers := map[string]string{"content-type": "application/x-www-form-urlencoded"}
80+
81+
// no csrf cookie fails
82+
b1, _ := checkNewReqAndRoundTrip(t, transport, http.MethodPost, masterURL.String(), headers, strings.NewReader(values.Encode()), false, http.StatusBadRequest)
83+
assertBodyContains(t, b1, "Could not check CSRF token. Please try again.")
84+
85+
// empty csrf cookie fails
86+
headers["cookie"] = "csrf="
87+
b2, _ := checkNewReqAndRoundTrip(t, transport, http.MethodPost, masterURL.String(), headers, strings.NewReader(values.Encode()), false, http.StatusBadRequest)
88+
assertBodyContains(t, b2, "Could not check CSRF token. Please try again.")
89+
90+
// wrong csrf cookie fails
91+
headers["cookie"] = "csrf=123"
92+
b3, _ := checkNewReqAndRoundTrip(t, transport, http.MethodPost, masterURL.String(), headers, strings.NewReader(values.Encode()), false, http.StatusBadRequest)
93+
assertBodyContains(t, b3, "Could not check CSRF token. Please try again.")
94+
95+
// technically any matching csrf gets past the check (do not send code as it is one use only)
96+
b4, _ := checkNewReqAndRoundTrip(t, transport, http.MethodPost, masterURL.String(), headers, strings.NewReader(headers["cookie"]), false, http.StatusBadRequest)
97+
assertBodyContains(t, b4, "Error handling auth request: Requested parameter not sent")
98+
99+
// correct csrf cookie works
100+
headers["cookie"] = "csrf=" + csrf
101+
displayRespPost, _ := checkNewReqAndRoundTrip(t, transport, http.MethodPost, masterURL.String(), headers, strings.NewReader(values.Encode()), false, http.StatusOK)
102+
103+
apiToken := getTokenFromDisplay(t, displayRespPost)
72104

73105
// Verify use of the bearer token
74106
userConfig := restclient.AnonymousClientConfig(clientConfig)
@@ -87,6 +119,61 @@ func TestOAuthRequestTokenEndpoint(t *testing.T) {
87119
}
88120
}
89121

122+
func assertBodyContains(t *testing.T, body []byte, sub string) {
123+
t.Helper()
124+
125+
if !bytes.Contains(body, []byte(sub)) {
126+
t.Fatalf("request body %s missing data %s", string(body), sub)
127+
}
128+
}
129+
130+
func getCodeCSRFFromDisplay(t *testing.T, body []byte) (code, csrf, loc string) {
131+
tokenizer := html.NewTokenizer(bytes.NewReader(body))
132+
133+
var codeSeen, csrfSeen, locSeen bool
134+
135+
for tokenType := tokenizer.Next(); tokenType != html.ErrorToken; tokenType = tokenizer.Next() {
136+
if token := tokenizer.Token(); tokenType == html.StartTagToken {
137+
if token.Data == "form" && getVal("method", token.Attr) == "post" {
138+
if locSeen {
139+
t.Fatalf("loc seen more than once in body %s, first loc=%s", string(body), loc)
140+
}
141+
loc = getVal("action", token.Attr)
142+
locSeen = true
143+
}
144+
if token.Data == "input" && getVal("name", token.Attr) == "code" {
145+
if codeSeen {
146+
t.Fatalf("code seen more than once in body %s, first code=%s", string(body), code)
147+
}
148+
code = getVal("value", token.Attr)
149+
codeSeen = true
150+
}
151+
if token.Data == "input" && getVal("name", token.Attr) == "csrf" {
152+
if csrfSeen {
153+
t.Fatalf("csrf seen more than once in body %s, first csrf=%s", string(body), csrf)
154+
}
155+
csrf = getVal("value", token.Attr)
156+
csrfSeen = true
157+
}
158+
}
159+
}
160+
161+
if len(code) == 0 || len(csrf) == 0 || len(loc) == 0 {
162+
t.Fatalf("could not find code or csrf or loc in body %s, saw: code=%s, csrf=%s, loc=%s", string(body), code, csrf, loc)
163+
}
164+
165+
return code, csrf, loc
166+
}
167+
168+
func getVal(key string, attrs []html.Attribute) string {
169+
for _, attr := range attrs {
170+
if attr.Key == key {
171+
return attr.Val
172+
}
173+
}
174+
return ""
175+
}
176+
90177
// Parse the HTML body for the API token contained in the first <code></code> block
91178
func getTokenFromDisplay(t *testing.T, body []byte) string {
92179
tokenizer := html.NewTokenizer(bytes.NewReader(body))
@@ -102,17 +189,22 @@ func getTokenFromDisplay(t *testing.T, body []byte) string {
102189
}
103190
}
104191

105-
t.Fatalf("API Token not found in display")
192+
t.Fatalf("API Token not found in display, body %s", string(body))
106193
return ""
107194
}
108195

109-
func checkNewReqAndRoundTrip(t *testing.T, rt http.RoundTripper, url string, doBasicAuth bool, expectedCode int) ([]byte, string) {
110-
req, err := http.NewRequest("GET", url, nil)
196+
func checkNewReqAndRoundTrip(t *testing.T, rt http.RoundTripper, method, url string, headers map[string]string, reqBody io.Reader, doBasicAuth bool, expectedCode int) ([]byte, string) {
197+
t.Helper()
198+
199+
req, err := http.NewRequest(method, url, reqBody)
111200
if err != nil {
112201
t.Fatalf("unexpected error %v", err)
113202
}
114203

115204
req.Header.Set("Accept", "text/html; charset=UTF-8")
205+
for k, v := range headers {
206+
req.Header.Set(k, v)
207+
}
116208

117209
if doBasicAuth {
118210
req.SetBasicAuth("foo", "bar")
@@ -123,14 +215,15 @@ func checkNewReqAndRoundTrip(t *testing.T, rt http.RoundTripper, url string, doB
123215
t.Fatalf("unexpected error %v", err)
124216
}
125217
defer resp.Body.Close()
126-
if resp.StatusCode != expectedCode {
127-
t.Fatalf("unexpected response code %v", resp.StatusCode)
128-
}
129218

130219
body, err := ioutil.ReadAll(resp.Body)
131220
if err != nil {
132221
t.Fatalf("unexpected error %v", err)
133222
}
134223

224+
if resp.StatusCode != expectedCode {
225+
t.Fatalf("unexpected response code %v, body=%s", resp.StatusCode, string(body))
226+
}
227+
135228
return body, resp.Header.Get("Location")
136229
}

0 commit comments

Comments
 (0)