Skip to content

Commit

Permalink
Merge pull request #34 from imandra-ai/matt/consume-body
Browse files Browse the repository at this point in the history
fix(cohttp): ensure response body is always consumed
  • Loading branch information
mattjbray authored Aug 30, 2024
2 parents ff26d20 + a0baf3e commit 5abc06a
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 54 deletions.
45 changes: 29 additions & 16 deletions src/auth.ml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ module Compute_engine = struct

let ping () =
let uri = Uri.of_string metadata_ip_root in
let open Lwt.Infix in
Cohttp_lwt_unix.Client.get uri ~headers:metadata_headers
>>= Util.drain_body

let get_project_id () :
( string,
Expand All @@ -100,9 +102,10 @@ module Compute_engine = struct
Uri.of_string (Printf.sprintf "%s/project/project-id" metadata_root)
in
Cohttp_lwt_unix.Client.get uri ~headers:metadata_headers
>>= Util.consume_body
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
| `OK -> Cohttp_lwt.Body.to_string body |> ok
| `OK -> Lwt_result.return body
| status -> `Bad_GCE_metadata_response status |> Lwt_result.fail
end
end
Expand Down Expand Up @@ -230,22 +233,20 @@ module External_account_credentials = struct
|> to_string

let subject_token_of_response (t : t)
((resp, body) : Cohttp.Response.t * Cohttp_lwt.Body.t) :
((resp, body_str) : Cohttp.Response.t * string) :
(string, [> `Bad_token_response of string ]) result Lwt.t =
let open Lwt.Syntax in
match Cohttp.Response.status resp with
| `OK -> (
match t.credential_source.format.type_ with
| `Json -> (
let* body_str = Cohttp_lwt.Body.to_string body in
try
body_str |> Yojson.Basic.from_string |> subject_token_of_json t
|> Lwt.return_ok
with Yojson.Basic.Util.Type_error (msg, _) ->
let* () = L.debug (fun m -> m "Type_error: %s" msg) in
Lwt.return_error (`Bad_subject_token_response (resp, body_str))))
| _ ->
let* body_str = Cohttp_lwt.Body.to_string body in
let* () = L.err (fun m -> m "response: %s" body_str) in
Lwt.return_error (`Bad_subject_token_response (resp, body_str))
end
Expand Down Expand Up @@ -367,15 +368,12 @@ let credentials_of_file (credentials_file : string) :
lines |> String.concat "\n" |> credentials_of_string |> Lwt.return)

let access_token_of_response ?(of_json = access_token_of_json)
((resp, body) : Cohttp.Response.t * Cohttp_lwt.Body.t) :
((resp, body_str) : Cohttp.Response.t * string) :
(Access_token.t, [> `Bad_token_response of string ]) result Lwt.t =
let open Lwt.Syntax in
match Cohttp.Response.status resp with
| `OK ->
let* body_str = Cohttp_lwt.Body.to_string body in
body_str |> Yojson.Basic.from_string |> of_json |> Lwt.return
| `OK -> body_str |> Yojson.Basic.from_string |> of_json |> Lwt.return
| _ ->
let* body_str = Cohttp_lwt.Body.to_string body in
let* () = L.err (fun m -> m "response: %s" body_str) in
Lwt.return_error (`Bad_token_response body_str)

Expand All @@ -397,7 +395,11 @@ let access_token_of_credentials (scopes : string list)
("grant_type", [ "refresh_token" ]);
]
in
let* res = Cohttp_lwt_unix.Client.post_form token_uri ~params |> ok in
let* res =
let open Lwt.Infix in
Cohttp_lwt_unix.Client.post_form token_uri ~params
>>= Util.consume_body |> ok
in
access_token_of_response ~of_json:access_token_of_json res
| Service_account c -> (
let now = Unix.time () in
Expand Down Expand Up @@ -432,8 +434,9 @@ let access_token_of_credentials (scopes : string list)
]
in
let* res =
let open Lwt.Infix in
Cohttp_lwt_unix.Client.post_form (Uri.of_string c.token_uri) ~params
|> ok
>>= Util.consume_body |> ok
in
access_token_of_response ~of_json:access_token_of_json res
| _ -> Lwt_result.fail (`Bad_credentials_priv_key "Not RSA key"))
Expand All @@ -444,9 +447,10 @@ let access_token_of_credentials (scopes : string list)
|> Uri.of_string
in
let* res =
let open Lwt.Infix in
Cohttp_lwt_unix.Client.get uri
~headers:Compute_engine.Metadata.metadata_headers
|> ok
>>= Util.consume_body |> ok
in
access_token_of_response ~of_json:access_token_of_json res
| External_account (c : External_account_credentials.t) -> (
Expand All @@ -463,10 +467,11 @@ let access_token_of_credentials (scopes : string list)
let* subject_token =
let subject_token_uri = Uri.of_string c.credential_source.url in
let* resp =
let open Lwt.Infix in
Cohttp_lwt_unix.Client.get
~headers:(Cohttp.Header.of_list c.credential_source.headers)
subject_token_uri
|> ok
>>= Util.consume_body |> ok
in
External_account_credentials.subject_token_of_response c resp
in
Expand All @@ -486,7 +491,11 @@ let access_token_of_credentials (scopes : string list)
]
in
let body = Cohttp_lwt.Body.of_string (Yojson.Basic.to_string params) in
let* res = Cohttp_lwt_unix.Client.post token_uri ~body |> ok in
let* res =
let open Lwt.Infix in
Cohttp_lwt_unix.Client.post token_uri ~body
>>= Util.consume_body |> ok
in
Lwt_result.return res
in
match c.service_account_impersonation_url with
Expand Down Expand Up @@ -514,7 +523,11 @@ let access_token_of_credentials (scopes : string list)
in
let uri = Uri.of_string sac in
let* () = L.debug (fun m -> m "POST %a" Uri.pp_hum uri) |> ok in
let* res = Cohttp_lwt_unix.Client.post uri ~headers ~body |> ok in
let* res =
let open Lwt.Infix in
Cohttp_lwt_unix.Client.post uri ~headers ~body
>>= Util.consume_body |> ok
in
let access_token_of_json (json : Yojson.Basic.t) :
(Access_token.t, [> error ]) result =
(* has a slightly different format from the access token in the other responses:
Expand Down Expand Up @@ -579,7 +592,7 @@ let discover_credentials_with (discovery_mode : discovery_mode) =
Lwt.catch
(fun () ->
let open Lwt_result.Syntax in
let* resp, _body = Compute_engine.Metadata.ping () |> ok in
let* resp = Compute_engine.Metadata.ping () |> ok in
let* () = L.debug (fun m -> m "Got metadata response") |> ok in
let has_metadata_header =
Compute_engine.Metadata.response_has_metadata_header resp
Expand Down
30 changes: 22 additions & 8 deletions src/big_query.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ let src = Logs.Src.create "gcloud.bigquery"

module L = (val Logs_lwt.src_log src)

let ok = Lwt_result.ok

module Scopes = struct
let bigquery = "https://www.googleapis.com/auth/bigquery"
end
Expand Down Expand Up @@ -85,11 +87,13 @@ module Datasets = struct
]
in
L.debug (fun m -> m "GET %a" Uri.pp_hum uri) |> Lwt_result.ok
>>= fun () -> Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok)
>>= fun () ->
let open Lwt.Infix in
Cohttp_lwt_unix.Client.get uri ~headers >>= Util.consume_body |> ok)
(fun e -> `Network_error e |> Lwt_result.fail)
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
| `OK -> Cohttp_lwt.Body.to_string body |> Lwt_result.ok
| `OK -> Lwt_result.return body
| x -> Error.of_response_status_code_and_body x body

let list ?project_id () : (string, [> Error.t ]) Lwt_result.t =
Expand All @@ -110,11 +114,13 @@ module Datasets = struct
]
in
L.debug (fun m -> m "GET %a" Uri.pp_hum uri) |> Lwt_result.ok
>>= fun () -> Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok)
>>= fun () ->
let open Lwt.Infix in
Cohttp_lwt_unix.Client.get uri ~headers >>= Util.consume_body |> ok)
(fun e -> `Network_error e |> Lwt_result.fail)
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
| `OK -> Cohttp_lwt.Body.to_string body |> Lwt_result.ok
| `OK -> Lwt_result.return body
| x -> Error.of_response_status_code_and_body x body

module Tables = struct
Expand Down Expand Up @@ -173,11 +179,13 @@ module Datasets = struct
]
in
L.debug (fun m -> m "GET %a" Uri.pp_hum uri) |> Lwt_result.ok
>>= fun () -> Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok)
>>= fun () ->
let open Lwt.Infix in
Cohttp_lwt_unix.Client.get uri ~headers >>= Util.consume_body |> ok)
(fun e -> `Network_error e |> Lwt_result.fail)
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
| `OK -> Error.parse_body_json resp_of_yojson body
| `OK -> Error.parse_body_json resp_of_yojson body |> Lwt.return
| x -> Error.of_response_status_code_and_body x body
end
end
Expand Down Expand Up @@ -730,12 +738,15 @@ module Jobs = struct
m "Query: %s" q_trimmed)
|> Lwt_result.ok
>>= fun () ->
Cohttp_lwt_unix.Client.post uri ~headers ~body |> Lwt_result.ok)
let open Lwt.Infix in
Cohttp_lwt_unix.Client.post uri ~headers ~body
>>= Util.consume_body |> ok)
(fun e -> `Network_error e |> Lwt_result.fail)
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
| `OK ->
Error.parse_body_json ~gzipped:use_gzip query_response_of_yojson body
|> Lwt.return
>>= fun response ->
L.debug (fun m -> m "%a" pp_query_response response) |> Lwt_result.ok
>>= fun () -> Lwt_result.return response
Expand Down Expand Up @@ -786,12 +797,15 @@ module Jobs = struct
|> add_gzip_headers ~use_gzip
in
Lwt.catch
(fun () -> Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok)
(fun () ->
let open Lwt.Infix in
Cohttp_lwt_unix.Client.get uri ~headers >>= Util.consume_body |> ok)
(fun e -> Lwt_result.fail (`Network_error e))
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
| `OK ->
Error.parse_body_json ~gzipped:use_gzip query_response_of_yojson body
|> Lwt.return
>>= fun response ->
L.debug (fun m -> m "%a" pp_query_response response) |> Lwt_result.ok
>>= fun () -> Lwt_result.return response
Expand Down
14 changes: 9 additions & 5 deletions src/compute.ml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
let ok = Lwt_result.ok

module Scopes = struct
let cloud_platform = "https://www.googleapis.com/auth/cloud-platform"
let compute = "https://www.googleapis.com/auth/cloud-platform"
Expand Down Expand Up @@ -48,13 +50,14 @@ module FirewallRules = struct
]
in
let body_str = rule |> rule_to_yojson |> Yojson.Safe.to_string in
print_endline body_str;
let body = body_str |> Cohttp_lwt.Body.of_string in
Cohttp_lwt_unix.Client.post uri ~body ~headers |> Lwt_result.ok)
let open Lwt.Infix in
Cohttp_lwt_unix.Client.post uri ~body ~headers
>>= Util.consume_body |> ok)
(fun e -> Lwt_result.fail (`Network_error e))
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
| `OK -> Lwt_result.ok (Cohttp_lwt.Body.to_string body)
| `OK -> Lwt_result.return body
| status_code -> Error.of_response_status_code_and_body status_code body

let delete ?project_id ~(name : string) () :
Expand All @@ -78,10 +81,11 @@ module FirewallRules = struct
Printf.sprintf "Bearer %s" token_info.Auth.token.access_token );
]
in
Cohttp_lwt_unix.Client.delete uri ~headers |> Lwt_result.ok)
let open Lwt.Infix in
Cohttp_lwt_unix.Client.delete uri ~headers >>= Util.consume_body |> ok)
(fun e -> Lwt_result.fail (`Network_error e))
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
| `OK -> Lwt_result.ok (Cohttp_lwt.Body.to_string body)
| `OK -> Lwt_result.return body
| status_code -> Error.of_response_status_code_and_body status_code body
end
7 changes: 5 additions & 2 deletions src/container.ml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
let ok = Lwt_result.ok

module Scopes = struct
let cloud_platform = "https://www.googleapis.com/auth/cloud-platform"
end
Expand Down Expand Up @@ -54,11 +56,12 @@ module Projects = struct
token_info.Auth.token.access_token );
]
in
Cohttp_lwt_unix.Client.get uri ~headers |> Lwt_result.ok)
let open Lwt.Infix in
Cohttp_lwt_unix.Client.get uri ~headers >>= Util.consume_body |> ok)
(fun e -> Lwt_result.fail (`Network_error e))
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
| `OK -> Error.parse_body_json of_yojson body
| `OK -> Error.parse_body_json of_yojson body |> Lwt.return
| status_code -> Error.of_response_status_code_and_body status_code body
end
end
Expand Down
13 changes: 4 additions & 9 deletions src/error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ let pp fmt (error : t) =
| `Msg s -> Format.fprintf fmt "Msg: %s" s

let parse_body_json ?(gzipped = false)
(transform : Yojson.Safe.t -> ('a, string) result)
(body : Cohttp_lwt.Body.t) : ('a, [> t ]) Lwt_result.t =
let open Lwt.Infix in
Cohttp_lwt.Body.to_string body >>= fun body_str ->
(transform : Yojson.Safe.t -> ('a, string) result) (body_str : string) :
('a, [> t ]) result =
let body =
if gzipped then
Ezgzip.decompress body_str
Expand All @@ -72,18 +70,15 @@ let parse_body_json ?(gzipped = false)
| Yojson.Json_error msg -> Error (`Json_parse_error (msg, body_str))
| e -> Error (`Json_parse_error (Printexc.to_string e, body_str))
in

parse_result
|> CCResult.flat_map (fun json ->
transform json
|> CCResult.map_err (fun e -> `Json_transform_error (e, json)))
|> Lwt.return

let of_response_status_code_and_body ?gzipped
(status_code : Cohttp.Code.status_code) (body : Cohttp_lwt.Body.t) :
(status_code : Cohttp.Code.status_code) (body_str : string) :
('a, [> t ]) Lwt_result.t =
let open Lwt.Infix in
parse_body_json ?gzipped api_json_error_of_yojson body >>= function
match parse_body_json ?gzipped api_json_error_of_yojson body_str with
| Ok parsed_error ->
Lwt_result.fail (`Gcloud_api_error (status_code, Json parsed_error))
| Error (`Json_parse_error (_, body_str)) ->
Expand Down
7 changes: 6 additions & 1 deletion src/kms.ml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
let ok = Lwt_result.ok

module Scopes = struct
let cloudkms = "https://www.googleapis.com/auth/cloudkms"
end
Expand Down Expand Up @@ -36,7 +38,9 @@ module V1 = struct
token_info.Auth.token.access_token );
]
in
Cohttp_lwt_unix.Client.post uri ~headers ~body |> Lwt_result.ok)
let open Lwt.Infix in
Cohttp_lwt_unix.Client.post uri ~headers ~body
>>= Util.consume_body |> ok)
(fun e -> `Network_error e |> Lwt_result.fail)
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
Expand All @@ -52,6 +56,7 @@ module V1 = struct
Error "Could not base64-decode the plaintext")
| _ -> Error "Expected an object with field 'plaintext'")
body
|> Lwt.return
| x -> Error.of_response_status_code_and_body x body
end
end
Expand Down
12 changes: 9 additions & 3 deletions src/pub_sub.ml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
let ok = Lwt_result.ok

module Scopes = struct
let pubsub = "https://www.googleapis.com/auth/pubsub"
end
Expand Down Expand Up @@ -36,7 +38,9 @@ module Subscriptions = struct
in
Logs_lwt.debug (fun m -> m "POST %a" Uri.pp_hum uri) |> Lwt_result.ok
>>= fun () ->
Cohttp_lwt_unix.Client.post uri ~headers ~body |> Lwt_result.ok)
let open Lwt.Infix in
Cohttp_lwt_unix.Client.post uri ~headers ~body
>>= Util.consume_body |> ok)
(fun e -> `Network_error e |> Lwt_result.fail)
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
Expand Down Expand Up @@ -94,12 +98,14 @@ module Subscriptions = struct
Logs_lwt.debug ~src:log_src_pull (fun m -> m "POST %a" Uri.pp_hum uri)
|> Lwt_result.ok
>>= fun () ->
Cohttp_lwt_unix.Client.post uri ~headers ~body |> Lwt_result.ok)
let open Lwt.Infix in
Cohttp_lwt_unix.Client.post uri ~headers ~body
>>= Util.consume_body |> ok)
(fun e -> `Network_error e |> Lwt_result.fail)
>>= fun (resp, body) ->
match Cohttp.Response.status resp with
| `OK ->
Error.parse_body_json received_messages_of_yojson body
Error.parse_body_json received_messages_of_yojson body |> Lwt.return
>|= fun { received_messages } ->
let received_messages =
received_messages
Expand Down
Loading

0 comments on commit 5abc06a

Please sign in to comment.