@@ -8,8 +8,10 @@ import (
88 "fmt"
99 "io"
1010 "log"
11+ "net"
1112 "net/http"
1213 "net/url"
14+ "time"
1315
1416 "github.com/aws/aws-lambda-go/events"
1517 "github.com/aws/aws-sdk-go/aws"
@@ -23,6 +25,7 @@ import (
2325type lambdaBackend struct {
2426 functionName string
2527 lambda * lambda.Lambda
28+ isPayloadV2 bool // https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html
2629}
2730
2831func New (opts erconfig.BackendOptsAwsLambda , logger * log.Logger ) (http.Handler , error ) {
@@ -36,31 +39,115 @@ func New(opts erconfig.BackendOptsAwsLambda, logger *log.Logger) (http.Handler,
3639 return nil , err
3740 }
3841
42+ isPayloadV2 , err := validatePayloadFormatVersion (opts .PayloadFormatVersion )
43+ if err != nil {
44+ return nil , err
45+ }
46+
3947 handler := & lambdaBackend {
4048 functionName : opts .FunctionName ,
4149 lambda : lambda .New (
4250 awsSession ,
4351 aws .NewConfig ().WithCredentials (creds ).WithRegion (opts .RegionId )),
52+ isPayloadV2 : isPayloadV2 ,
4453 }
4554
4655 return turbocharger .WrapWithMiddlewareIfConfigAvailable (handler , logger )
4756}
4857
4958func (b * lambdaBackend ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
59+ if b .isPayloadV2 {
60+ // https://pkg.go.dev/github.com/aws/aws-lambda-go/events#APIGatewayV2HTTPRequest
61+ b .serveHTTPModel (w , r )
62+ } else {
63+ // https://pkg.go.dev/github.com/aws/aws-lambda-go/events#APIGatewayProxyRequest
64+ // https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format
65+ b .serveRESTModel (w , r )
66+ }
67+ }
68+
69+ func (b * lambdaBackend ) serveHTTPModel (w http.ResponseWriter , r * http.Request ) {
5070 defer r .Body .Close ()
5171
52- requestBodyBase64 := & bytes.Buffer {}
53- requestBodyBase64Encoder := base64 .NewEncoder (base64 .StdEncoding , requestBodyBase64 )
54- if _ , err := io .Copy (requestBodyBase64Encoder , r .Body ); err != nil {
72+ bodyBase64 , err := encodeToBase64RawStd (r .Body )
73+ if err != nil {
74+ http .Error (w , err .Error (), http .StatusBadRequest )
75+ return
76+ }
77+
78+ headers , err := copyHeaders (r )
79+ if err != nil {
5580 http .Error (w , err .Error (), http .StatusBadRequest )
5681 return
5782 }
5883
59- if err := requestBodyBase64Encoder .Close (); err != nil {
84+ sourceIP , _ , _ := net .SplitHostPort (r .RemoteAddr )
85+
86+ now := time .Now ().UTC ()
87+
88+ const routeKey = "$default"
89+
90+ proxyRequestJson , err := json .Marshal (events.APIGatewayV2HTTPRequest {
91+ Version : "2.0" ,
92+ RouteKey : routeKey ,
93+ RawPath : r .URL .Path ,
94+ RawQueryString : r .URL .RawQuery ,
95+ // Cookies: []string{},
96+ Headers : headers ,
97+ QueryStringParameters : queryParametersToSimpleMap (r .URL .Query ()),
98+ RequestContext : events.APIGatewayV2HTTPRequestContext {
99+ RouteKey : routeKey ,
100+ Stage : "$default" ,
101+ // RequestID: "",
102+ DomainName : r .URL .Host ,
103+ Time : now .Format ("02/Jan/2006:15:04:05 -0700" ), // what a dumbass format
104+ TimeEpoch : now .UnixMilli (),
105+ HTTP : events.APIGatewayV2HTTPRequestContextHTTPDescription {
106+ Method : r .Method ,
107+ Path : r .URL .Path ,
108+ Protocol : r .Proto ,
109+ SourceIP : sourceIP ,
110+ UserAgent : r .UserAgent (),
111+ },
112+ },
113+ Body : bodyBase64 ,
114+ IsBase64Encoded : true ,
115+ })
116+ if err != nil {
117+ http .Error (w , err .Error (), http .StatusInternalServerError )
118+ return
119+ }
120+
121+ lambdaResponse , err := b .lambda .InvokeWithContext (r .Context (), & lambda.InvokeInput {
122+ FunctionName : aws .String (b .functionName ),
123+ Payload : proxyRequestJson ,
124+ })
125+ if err != nil {
60126 http .Error (w , err .Error (), http .StatusInternalServerError )
61127 return
62128 }
63129
130+ payloadResponse := & events.APIGatewayV2HTTPResponse {}
131+ if err := json .Unmarshal (lambdaResponse .Payload , payloadResponse ); err != nil {
132+ http .Error (w , err .Error (), http .StatusBadGateway )
133+ return
134+ }
135+
136+ if err := proxyApiGatewayResponse (payloadResponse , w ); err != nil {
137+ // TODO: if we already wrote headers, this will not succeed
138+ http .Error (w , err .Error (), http .StatusBadGateway )
139+ }
140+ }
141+
142+ func (b * lambdaBackend ) serveRESTModel (w http.ResponseWriter , r * http.Request ) {
143+ defer r .Body .Close ()
144+
145+ requestBodyBase64 , err := encodeToBase64RawStd (r .Body )
146+ if err != nil {
147+ http .Error (w , err .Error (), http .StatusBadRequest )
148+ return
149+ }
150+
64151 headers , err := copyHeaders (r )
65152 if err != nil {
66153 http .Error (w , err .Error (), http .StatusBadRequest )
@@ -78,8 +165,8 @@ func (b *lambdaBackend) ServeHTTP(w http.ResponseWriter, r *http.Request) {
78165 },
79166 }
80167
81- if requestBodyBase64 . Len ( ) > 0 {
82- proxyRequest .Body = requestBodyBase64 . String ()
168+ if len ( requestBodyBase64 ) > 0 {
169+ proxyRequest .Body = requestBodyBase64
83170 proxyRequest .IsBase64Encoded = true
84171 }
85172
@@ -89,7 +176,7 @@ func (b *lambdaBackend) ServeHTTP(w http.ResponseWriter, r *http.Request) {
89176 return
90177 }
91178
92- lambdaResponse , err := b .lambda .Invoke ( & lambda.InvokeInput {
179+ lambdaResponse , err := b .lambda .InvokeWithContext ( r . Context (), & lambda.InvokeInput {
93180 FunctionName : aws .String (b .functionName ),
94181 Payload : proxyRequestJson ,
95182 })
@@ -110,13 +197,20 @@ func (b *lambdaBackend) ServeHTTP(w http.ResponseWriter, r *http.Request) {
110197 return
111198 }
112199
113- if err := proxyApiGatewayResponse (payloadResponse , w ); err != nil {
200+ if err := proxyApiGatewayResponse (& events.APIGatewayV2HTTPResponse {
201+ StatusCode : payloadResponse .StatusCode ,
202+ Headers : payloadResponse .Headers ,
203+ MultiValueHeaders : payloadResponse .MultiValueHeaders ,
204+ Body : payloadResponse .Body ,
205+ IsBase64Encoded : payloadResponse .IsBase64Encoded ,
206+ // only field missing from old struct: `Cookies`
207+ }, w ); err != nil {
114208 // TODO: if we already wrote headers, this will not succeed
115209 http .Error (w , err .Error (), http .StatusBadGateway )
116210 }
117211}
118212
119- func proxyApiGatewayResponse (payloadResponse * events.APIGatewayProxyResponse , w http.ResponseWriter ) error {
213+ func proxyApiGatewayResponse (payloadResponse * events.APIGatewayV2HTTPResponse , w http.ResponseWriter ) error {
120214 responseHeaders := w .Header ()
121215
122216 for key , val := range payloadResponse .Headers {
@@ -165,3 +259,26 @@ func copyHeaders(r *http.Request) (map[string]string, error) {
165259
166260 return headers , nil
167261}
262+
263+ func encodeToBase64RawStd (content io.Reader ) (string , error ) {
264+ bodyBuffered := & bytes.Buffer {}
265+ encodeBase64 := base64 .NewEncoder (base64 .RawStdEncoding , bodyBuffered )
266+ if _ , err := io .Copy (encodeBase64 , content ); err != nil {
267+ return "" , err
268+ }
269+ if err := encodeBase64 .Close (); err != nil {
270+ return "" , err
271+ }
272+ return bodyBuffered .String (), nil
273+ }
274+
275+ func validatePayloadFormatVersion (version string ) (bool , error ) {
276+ switch version {
277+ case "1.0" , "" :
278+ return false , nil
279+ case "2.0" :
280+ return true , nil
281+ default :
282+ return false , fmt .Errorf ("unrecognized PayloadFormatVersion: %s" , version )
283+ }
284+ }
0 commit comments