Skip to content

Commit

Permalink
Add webhook based auth provider for demo
Browse files Browse the repository at this point in the history
  • Loading branch information
popduke committed Apr 24, 2024
1 parent 76e5fdf commit 0df16fa
Show file tree
Hide file tree
Showing 6 changed files with 446 additions and 0 deletions.
5 changes: 5 additions & 0 deletions build/build-bifromq-starters/conf/standalone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
# if the node is responsible for cluster bootstrap. NOTE: there must be EXACTLY ONE bootstrap node in cluster deployment
bootstrap: true


# Enabling custom auth provider, specify the Fully Qualified Name of the auth provider class from your plugin.
# or experiment with the built-in auth provider which is a simple webhook based implementation.
# authProviderFQN: "com.baidu.demo.plugin.DemoAuthProvider"

# Enabling runtime throttling at tenant-level, specify the Fully Qualified Name of the resource throttler class from your plugin.
# or experiment with the built-in resource throttler which is a simple webhook based implementation.
# resourceThrottlerFQN: "com.baidu.demo.plugin.DemoResourceThrottler"
Expand Down
9 changes: 9 additions & 0 deletions build/build-plugin-demo/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
<plugin.class>com.baidu.demo.plugin.DemoPlugin</plugin.class>
</properties>
<dependencies>
<dependency>
<groupId>com.baidu.bifromq</groupId>
<artifactId>bifromq-plugin-auth-provider</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.baidu.bifromq</groupId>
<artifactId>bifromq-plugin-setting-provider</artifactId>
Expand All @@ -48,6 +53,10 @@
<groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
</dependency>
<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-registry-prometheus</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/

package com.baidu.demo.plugin;

import com.baidu.bifromq.plugin.authprovider.IAuthProvider;
import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthData;
import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthResult;
import com.baidu.bifromq.plugin.authprovider.type.MQTTAction;
import com.baidu.bifromq.plugin.authprovider.type.Reject;
import com.baidu.bifromq.type.ClientInfo;
import java.net.URI;
import java.util.concurrent.CompletableFuture;
import lombok.extern.slf4j.Slf4j;
import org.pf4j.Extension;

@Slf4j
@Extension
public class DemoAuthProvider implements IAuthProvider {
private static final String PLUGIN_AUTHPROVIDER_URL = "plugin.authprovider.url";
private final IAuthProvider delegate;

public DemoAuthProvider() {
IAuthProvider delegate1;
String webhookUrl = System.getProperty(PLUGIN_AUTHPROVIDER_URL);
if (webhookUrl == null) {
log.info("No webhook url specified, fallback to no auth.");
delegate1 = new FallbackAuthProvider();
} else {
try {
URI webhookURI = URI.create(webhookUrl);
delegate1 = new WebHookBasedAuthProvider(webhookURI);
log.info("Resource will be throttled at runtime by consulting: {}", webhookUrl);
} catch (Throwable e) {
delegate1 = new FallbackAuthProvider();
}
}
delegate = delegate1;
}

@Override
public CompletableFuture<MQTT3AuthResult> auth(MQTT3AuthData authData) {
return delegate.auth(authData);
}

@Override
public CompletableFuture<Boolean> check(ClientInfo client, MQTTAction action) {
return delegate.check(client, action);
}

static class FallbackAuthProvider implements IAuthProvider {
@Override
public CompletableFuture<MQTT3AuthResult> auth(MQTT3AuthData authData) {
return CompletableFuture.completedFuture(
MQTT3AuthResult.newBuilder().setReject(Reject.newBuilder()
.setCode(Reject.Code.Error)
.setReason("No webhook url for auth provider configured")
.build())
.build());
}

@Override
public CompletableFuture<Boolean> check(ClientInfo client, MQTTAction action) {
return CompletableFuture.completedFuture(false);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/

package com.baidu.demo.plugin;

import static com.baidu.bifromq.type.MQTTClientInfoConstants.MQTT_USER_ID_KEY;

import com.baidu.bifromq.plugin.authprovider.IAuthProvider;
import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthData;
import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthResult;
import com.baidu.bifromq.plugin.authprovider.type.MQTTAction;
import com.baidu.bifromq.plugin.authprovider.type.Reject;
import com.baidu.bifromq.type.ClientInfo;
import com.google.protobuf.util.JsonFormat;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;

class WebHookBasedAuthProvider implements IAuthProvider {
private final URI webhookURI;
private final HttpClient httpClient;

WebHookBasedAuthProvider(URI webhookURI) {
this.webhookURI = webhookURI;
this.httpClient = HttpClient.newBuilder()
.version(HttpClient.Version.HTTP_1_1)
.followRedirects(HttpClient.Redirect.NORMAL)
.build();
}

@Override
public CompletableFuture<MQTT3AuthResult> auth(MQTT3AuthData authData) {
try {
HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create(webhookURI + "/auth"))
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))
.timeout(Duration.ofSeconds(5))
.build();
return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())
.thenApply(response -> {
if (response.statusCode() == 200) {
try {
MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();
JsonFormat.parser()
.ignoringUnknownFields()
.merge(response.body(), resultBuilder);
return resultBuilder.build();
} catch (Throwable e) {
return MQTT3AuthResult.newBuilder()
.setReject(Reject.newBuilder()
.setCode(Reject.Code.Error)
.setReason(e.getMessage())
.build())
.build();
}
} else {
return MQTT3AuthResult.newBuilder()
.setReject(Reject.newBuilder()
.setCode(Reject.Code.Error)
.setReason("Authenticate failed")
.build())
.build();
}
})
.exceptionally(e -> {
System.out.println("Failed to call webhook: " + e.getMessage());
return null;
});
} catch (Throwable e) {
return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()
.setReject(Reject.newBuilder()
.setCode(Reject.Code.Error)
.setReason(e.getMessage())
.build())
.build());
}
}

@Override
public CompletableFuture<Boolean> check(ClientInfo client, MQTTAction action) {
try {
HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create(webhookURI + "/check"))
.header("Content-Type", "application/json")
.header("tenant_id", client.getTenantId())
.header("user_id", client.getMetadataMap().get(MQTT_USER_ID_KEY))
.POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(action)))
.timeout(Duration.ofSeconds(5))
.build();
return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())
.thenApply(response -> {
if (response.statusCode() == 200) {
try {
return Boolean.parseBoolean(response.body());
} catch (Throwable e) {
return false;
}
} else {
return false;
}
})
.exceptionally(e -> {
System.out.println("Failed to call webhook: " + e.getMessage());
return null;
});
} catch (Throwable e) {
return CompletableFuture.completedFuture(false);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/

package com.baidu.demo.plugin;

import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthData;
import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthResult;
import com.baidu.bifromq.plugin.authprovider.type.MQTTAction;
import com.baidu.bifromq.plugin.authprovider.type.Ok;
import com.baidu.bifromq.plugin.authprovider.type.Reject;
import com.google.protobuf.util.JsonFormat;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpHandler;
import com.sun.net.httpserver.HttpServer;
import java.io.IOException;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.HashSet;
import java.util.Set;
import lombok.SneakyThrows;

public class TestAuthServer {
private static final Set<String> authedUsers = new HashSet<>();
private static final Set<String> permittedPubTopics = new HashSet<>();
private static final Set<String> permittedSubTopicFilters = new HashSet<>();
private HttpServer server;

@SneakyThrows
public TestAuthServer() {
server = HttpServer.create(new InetSocketAddress("127.0.0.1", 0), 0);
server.createContext("/auth", new AuthHandler());
server.createContext("/check", new CheckHandler());
server.setExecutor(null);
}

public void start() {
server.start();
}

public void stop() {
server.stop(0);
}

public URI getURI() {
return URI.create(
"http://" + server.getAddress().getHostName() + ":" + server.getAddress().getPort());
}

public void addAuthedUser(String username) {
authedUsers.add(username);
}

public void addPermittedPubTopic(String topic) {
permittedPubTopics.add(topic);
}

public void addPermittedSubTopicFilter(String topicFilter) {
permittedSubTopicFilters.add(topicFilter);
}

static class AuthHandler implements HttpHandler {
@Override
public void handle(HttpExchange exchange) throws IOException {
if ("POST".equals(exchange.getRequestMethod())) {
// read request body as JSON string
String requestBody = new String(exchange.getRequestBody().readAllBytes());
MQTT3AuthData.Builder authDataBuilder = MQTT3AuthData.newBuilder();
JsonFormat.parser().merge(requestBody, authDataBuilder);
MQTT3AuthData authData = authDataBuilder.build();
if (authedUsers.contains(authData.getUsername())) {
sendJson(exchange, JsonFormat.printer().print(MQTT3AuthResult.newBuilder()
.setOk(Ok.newBuilder()
.setTenantId("TestTenant")
.setUserId(authData.getUsername())
.build())
.build()));
} else {
sendJson(exchange, JsonFormat.printer().print(MQTT3AuthResult.newBuilder()
.setReject(Reject.newBuilder()
.setCode(Reject.Code.NotAuthorized)
.setTenantId("TestTenant")
.setUserId(authData.getUsername())
.build())
.build()));
}
} else {
sendResponse(exchange, "text/plain", "Method Not Allowed", 405);
}
}
}

static class CheckHandler implements HttpHandler {

@Override
public void handle(HttpExchange exchange) throws IOException {
if ("POST".equals(exchange.getRequestMethod())) {
// read request body as JSON string
String requestBody = new String(exchange.getRequestBody().readAllBytes());
MQTTAction.Builder mqttActionBuilder = MQTTAction.newBuilder();
JsonFormat.parser().merge(requestBody, mqttActionBuilder);
MQTTAction mqttAction = mqttActionBuilder.build();
switch (mqttAction.getTypeCase()) {
case PUB -> sendText(exchange,
Boolean.toString(permittedPubTopics.contains(mqttAction.getPub().getTopic())));
case SUB, UNSUB -> sendText(exchange,
Boolean.toString(permittedSubTopicFilters.contains(mqttAction.getPub().getTopic())));
}
} else {
sendResponse(exchange, "text/plain", "Method Not Allowed", 405);
}
}
}

private static void sendJson(HttpExchange exchange, String response) throws IOException {
sendResponse(exchange, "application/json", response, 200);
}

private static void sendText(HttpExchange exchange, String response) throws IOException {
sendResponse(exchange, "text/plain", response, 200);
}

private static void sendResponse(HttpExchange exchange, String contentType, String response, int statusCode)
throws IOException {
exchange.getResponseHeaders().add("Content-Type", contentType);
exchange.sendResponseHeaders(statusCode, response.getBytes().length);
OutputStream os = exchange.getResponseBody();
os.write(response.getBytes());
os.close();
}
}
Loading

0 comments on commit 0df16fa

Please sign in to comment.