Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
gateways: allow authorization via http header (#93)
Browse files Browse the repository at this point in the history
Added flag to the gateway inside `authentication` named
`http-credentials-source` defaults to `query`. Possible values are
`query` and `header`.

If using an http endpoint, the api gateway will check this gateway flag
to gather the credentials. This apply both to `produce` and `service`
endpoints.

CLI has been updated to decide the auth mode based on the gateway config
(only for produce with --protocol http)
  • Loading branch information
nicoloboschi authored Jul 3, 2024
1 parent 99de449 commit af4b4c6
Show file tree
Hide file tree
Showing 12 changed files with 551 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
@Slf4j
public class GatewayRequestHandler {

public static final String AUTH_HTTP_CREDENTIALS_HEADER = "Authorization";
public static final String AUTH_HTTP_TEST_CREDENTIALS_HEADER =
"X-LangStream-Test-Authorization";

public static class AuthFailedException extends Exception {
public AuthFailedException(String message) {
super(message);
Expand Down Expand Up @@ -74,23 +78,92 @@ public GatewayRequestHandler(
}
}

public GatewayRequestContext validateRequest(
public GatewayRequestContext validateHttpRequest(
String tenant,
String applicationId,
String gatewayId,
Gateway.GatewayType expectedGatewayType,
Map<String, String> queryString,
Map<String, String> httpHeaders,
GatewayRequestValidator validator) {
return validateRequest(
tenant,
applicationId,
gatewayId,
expectedGatewayType,
queryString,
httpHeaders,
validator,
true);
}

public GatewayRequestContext validateWebSocketRequest(
String tenant,
String applicationId,
String gatewayId,
Gateway.GatewayType expectedGatewayType,
Map<String, String> queryString,
Map<String, String> httpHeaders,
GatewayRequestValidator validator) {
return validateRequest(
tenant,
applicationId,
gatewayId,
expectedGatewayType,
queryString,
httpHeaders,
validator,
false);
}

private GatewayRequestContext validateRequest(
String tenant,
String applicationId,
String gatewayId,
Gateway.GatewayType expectedGatewayType,
Map<String, String> queryString,
Map<String, String> httpHeaders,
GatewayRequestValidator validator,
boolean isHttp) {

final Application application = getResolvedApplication(tenant, applicationId);
final Gateway gateway = extractGateway(gatewayId, application, expectedGatewayType);

final Map<String, String> options = new HashMap<>();
final Map<String, String> userParameters = new HashMap<>();

final String credentials = queryString.remove("credentials");
final String testCredentials = queryString.remove("test-credentials");
final String credentials;
final String testCredentials;
if (isHttp
&& gateway.getAuthentication() != null
&& gateway.getAuthentication().getHttpAuthenticationSource()
== Gateway.Authentication.HttpCredentialsSource.header) {
if (queryString.containsKey("credentials")) {
throw new IllegalArgumentException(
"credentials must be passed in the HTTP '%s' header for this gateway"
.formatted(AUTH_HTTP_CREDENTIALS_HEADER));
}
if (queryString.containsKey("test-credentials")) {
throw new IllegalArgumentException(
"test-credentials must be passed in the HTTP '%s' header for this gateway"
.formatted(AUTH_HTTP_TEST_CREDENTIALS_HEADER));
}
credentials = httpHeaders.get("Authorization");
testCredentials = httpHeaders.get("X-LangStream-Test-Authorization");
} else {
if (httpHeaders.containsKey(AUTH_HTTP_CREDENTIALS_HEADER)) {
throw new IllegalArgumentException(
AUTH_HTTP_CREDENTIALS_HEADER + " header is not allowed for this gateway");
}
if (httpHeaders.containsKey(AUTH_HTTP_TEST_CREDENTIALS_HEADER)) {
throw new IllegalArgumentException(
AUTH_HTTP_TEST_CREDENTIALS_HEADER
+ " header is not allowed for this gateway");
}
credentials = queryString.remove("credentials");
testCredentials = queryString.remove("test-credentials");
}

final boolean checkOptions;

if (expectedGatewayType == Gateway.GatewayType.service
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,7 @@
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -102,7 +98,7 @@ ProduceResponse produce(
final Map<String, String> queryString = computeQueryString(request);
final Map<String, String> headers = computeHeaders(request);
final GatewayRequestContext context =
gatewayRequestHandler.validateRequest(
gatewayRequestHandler.validateHttpRequest(
tenant,
application,
gateway,
Expand Down Expand Up @@ -149,10 +145,28 @@ private ProduceRequest parseProduceRequest(WebRequest request, String payload)
}

private Map<String, String> computeHeaders(WebRequest request) {
final Map<String, String> headers = new HashMap<>();
final Map<String, String> headers =
new HashMap<>() {
@Override
public String get(Object key) {
return super.get(key != null ? key.toString().toLowerCase() : key);
}

@Override
public String getOrDefault(Object key, String defaultValue) {
return super.getOrDefault(
key != null ? key.toString().toLowerCase() : key, defaultValue);
}

@Override
public boolean containsKey(Object key) {
return super.containsKey(key != null ? key.toString().toLowerCase() : key);
}
};

request.getHeaderNames()
.forEachRemaining(name -> headers.put(name, request.getHeader(name)));
return headers;
return Collections.unmodifiableMap(headers);
}

@PostMapping(value = GATEWAY_SERVICE_PATH)
Expand Down Expand Up @@ -209,7 +223,7 @@ private CompletableFuture<ResponseEntity> handleServiceCall(
final Map<String, String> queryString = computeQueryString(request);
final Map<String, String> headers = computeHeaders(request);
final GatewayRequestContext context =
gatewayRequestHandler.validateRequest(
gatewayRequestHandler.validateHttpRequest(
tenant,
application,
gateway,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public boolean beforeHandshake(
final Map<String, String> vars =
antPathMatcher.extractUriTemplateVariables(handler.path(), path);
final GatewayRequestContext gatewayRequestContext =
gatewayRequestHandler.validateRequest(
gatewayRequestHandler.validateWebSocketRequest(
handler.tenantFromPath(vars, querystring),
handler.applicationIdFromPath(vars, querystring),
handler.gatewayFromPath(vars, querystring),
Expand Down
Loading

0 comments on commit af4b4c6

Please sign in to comment.