@@ -2,9 +2,11 @@ package integration
22
33import (
44 "bytes"
5+ "io"
56 "io/ioutil"
67 "net/http"
78 "net/url"
9+ "strings"
810 "testing"
911
1012 "golang.org/x/net/html"
@@ -50,7 +52,7 @@ func TestOAuthRequestTokenEndpoint(t *testing.T) {
5052 }
5153 masterURL .Path = "/oauth/token/request"
5254
53- _ , tokenHeaderLocation := checkNewReqAndRoundTrip (t , transport , masterURL .String (), false , http .StatusFound )
55+ _ , tokenHeaderLocation := checkNewReqAndRoundTrip (t , transport , http . MethodGet , masterURL .String (), nil , nil , false , http .StatusFound )
5456
5557 if len (tokenHeaderLocation ) == 0 {
5658 t .Fatalf ("no Location header" )
@@ -61,14 +63,44 @@ func TestOAuthRequestTokenEndpoint(t *testing.T) {
6163 t .Fatalf ("unexpected error %v" , err )
6264 }
6365
64- _ , authHeaderLocation := checkNewReqAndRoundTrip (t , transport , authRedirect .String (), true , http .StatusFound )
66+ _ , authHeaderLocation := checkNewReqAndRoundTrip (t , transport , http . MethodGet , authRedirect .String (), nil , nil , true , http .StatusFound )
6567
6668 if len (authHeaderLocation ) == 0 {
6769 t .Fatalf ("no Location header" )
6870 }
6971
70- displayResp , _ := checkNewReqAndRoundTrip (t , transport , authHeaderLocation , false , http .StatusOK )
71- apiToken := getTokenFromDisplay (t , displayResp )
72+ displayRespGet , _ := checkNewReqAndRoundTrip (t , transport , http .MethodGet , authHeaderLocation , nil , nil , false , http .StatusOK )
73+
74+ code , csrf , loc := getCodeCSRFFromDisplay (t , displayRespGet )
75+ masterURL .Path = loc
76+ values := url.Values {}
77+ values .Set ("code" , code )
78+ values .Set ("csrf" , csrf )
79+ headers := map [string ]string {"content-type" : "application/x-www-form-urlencoded" }
80+
81+ // no csrf cookie fails
82+ b1 , _ := checkNewReqAndRoundTrip (t , transport , http .MethodPost , masterURL .String (), headers , strings .NewReader (values .Encode ()), false , http .StatusBadRequest )
83+ assertBodyContains (t , b1 , "Could not check CSRF token. Please try again." )
84+
85+ // empty csrf cookie fails
86+ headers ["cookie" ] = "csrf="
87+ b2 , _ := checkNewReqAndRoundTrip (t , transport , http .MethodPost , masterURL .String (), headers , strings .NewReader (values .Encode ()), false , http .StatusBadRequest )
88+ assertBodyContains (t , b2 , "Could not check CSRF token. Please try again." )
89+
90+ // wrong csrf cookie fails
91+ headers ["cookie" ] = "csrf=123"
92+ b3 , _ := checkNewReqAndRoundTrip (t , transport , http .MethodPost , masterURL .String (), headers , strings .NewReader (values .Encode ()), false , http .StatusBadRequest )
93+ assertBodyContains (t , b3 , "Could not check CSRF token. Please try again." )
94+
95+ // technically any matching csrf gets past the check (do not send code as it is one use only)
96+ b4 , _ := checkNewReqAndRoundTrip (t , transport , http .MethodPost , masterURL .String (), headers , strings .NewReader (headers ["cookie" ]), false , http .StatusBadRequest )
97+ assertBodyContains (t , b4 , "Error handling auth request: Requested parameter not sent" )
98+
99+ // correct csrf cookie works
100+ headers ["cookie" ] = "csrf=" + csrf
101+ displayRespPost , _ := checkNewReqAndRoundTrip (t , transport , http .MethodPost , masterURL .String (), headers , strings .NewReader (values .Encode ()), false , http .StatusOK )
102+
103+ apiToken := getTokenFromDisplay (t , displayRespPost )
72104
73105 // Verify use of the bearer token
74106 userConfig := restclient .AnonymousClientConfig (clientConfig )
@@ -87,6 +119,61 @@ func TestOAuthRequestTokenEndpoint(t *testing.T) {
87119 }
88120}
89121
122+ func assertBodyContains (t * testing.T , body []byte , sub string ) {
123+ t .Helper ()
124+
125+ if ! bytes .Contains (body , []byte (sub )) {
126+ t .Fatalf ("request body %s missing data %s" , string (body ), sub )
127+ }
128+ }
129+
130+ func getCodeCSRFFromDisplay (t * testing.T , body []byte ) (code , csrf , loc string ) {
131+ tokenizer := html .NewTokenizer (bytes .NewReader (body ))
132+
133+ var codeSeen , csrfSeen , locSeen bool
134+
135+ for tokenType := tokenizer .Next (); tokenType != html .ErrorToken ; tokenType = tokenizer .Next () {
136+ if token := tokenizer .Token (); tokenType == html .StartTagToken {
137+ if token .Data == "form" && getVal ("method" , token .Attr ) == "post" {
138+ if locSeen {
139+ t .Fatalf ("loc seen more than once in body %s, first loc=%s" , string (body ), loc )
140+ }
141+ loc = getVal ("action" , token .Attr )
142+ locSeen = true
143+ }
144+ if token .Data == "input" && getVal ("name" , token .Attr ) == "code" {
145+ if codeSeen {
146+ t .Fatalf ("code seen more than once in body %s, first code=%s" , string (body ), code )
147+ }
148+ code = getVal ("value" , token .Attr )
149+ codeSeen = true
150+ }
151+ if token .Data == "input" && getVal ("name" , token .Attr ) == "csrf" {
152+ if csrfSeen {
153+ t .Fatalf ("csrf seen more than once in body %s, first csrf=%s" , string (body ), csrf )
154+ }
155+ csrf = getVal ("value" , token .Attr )
156+ csrfSeen = true
157+ }
158+ }
159+ }
160+
161+ if len (code ) == 0 || len (csrf ) == 0 || len (loc ) == 0 {
162+ t .Fatalf ("could not find code or csrf or loc in body %s, saw: code=%s, csrf=%s, loc=%s" , string (body ), code , csrf , loc )
163+ }
164+
165+ return code , csrf , loc
166+ }
167+
168+ func getVal (key string , attrs []html.Attribute ) string {
169+ for _ , attr := range attrs {
170+ if attr .Key == key {
171+ return attr .Val
172+ }
173+ }
174+ return ""
175+ }
176+
90177// Parse the HTML body for the API token contained in the first <code></code> block
91178func getTokenFromDisplay (t * testing.T , body []byte ) string {
92179 tokenizer := html .NewTokenizer (bytes .NewReader (body ))
@@ -102,17 +189,22 @@ func getTokenFromDisplay(t *testing.T, body []byte) string {
102189 }
103190 }
104191
105- t .Fatalf ("API Token not found in display" )
192+ t .Fatalf ("API Token not found in display, body %s" , string ( body ) )
106193 return ""
107194}
108195
109- func checkNewReqAndRoundTrip (t * testing.T , rt http.RoundTripper , url string , doBasicAuth bool , expectedCode int ) ([]byte , string ) {
110- req , err := http .NewRequest ("GET" , url , nil )
196+ func checkNewReqAndRoundTrip (t * testing.T , rt http.RoundTripper , method , url string , headers map [string ]string , reqBody io.Reader , doBasicAuth bool , expectedCode int ) ([]byte , string ) {
197+ t .Helper ()
198+
199+ req , err := http .NewRequest (method , url , reqBody )
111200 if err != nil {
112201 t .Fatalf ("unexpected error %v" , err )
113202 }
114203
115204 req .Header .Set ("Accept" , "text/html; charset=UTF-8" )
205+ for k , v := range headers {
206+ req .Header .Set (k , v )
207+ }
116208
117209 if doBasicAuth {
118210 req .SetBasicAuth ("foo" , "bar" )
@@ -123,14 +215,15 @@ func checkNewReqAndRoundTrip(t *testing.T, rt http.RoundTripper, url string, doB
123215 t .Fatalf ("unexpected error %v" , err )
124216 }
125217 defer resp .Body .Close ()
126- if resp .StatusCode != expectedCode {
127- t .Fatalf ("unexpected response code %v" , resp .StatusCode )
128- }
129218
130219 body , err := ioutil .ReadAll (resp .Body )
131220 if err != nil {
132221 t .Fatalf ("unexpected error %v" , err )
133222 }
134223
224+ if resp .StatusCode != expectedCode {
225+ t .Fatalf ("unexpected response code %v, body=%s" , resp .StatusCode , string (body ))
226+ }
227+
135228 return body , resp .Header .Get ("Location" )
136229}
0 commit comments