Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: permission denied for new roles and remove security definer functions #158

Merged
merged 2 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# PG_NET
*A PostgreSQL extension that enables asynchronous (non-blocking) HTTP/HTTPS requests with SQL*.

Requires libcurl >= 7.83.
Requires libcurl >= 7.83. Compatible with PostgreSQL > = 12.

![PostgreSQL version](https://img.shields.io/badge/postgresql-12+-blue.svg)
[![License](https://img.shields.io/pypi/l/markdown-subtemplate.svg)](https://github.com/supabase/pg_net/blob/master/LICENSE)
Expand Down
18 changes: 18 additions & 0 deletions sql/pg_net--0.11.0--0.11.1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
alter function net.http_get(text, jsonb, jsonb, integer) security invoker;

alter function net.http_post(text, jsonb, jsonb, jsonb, integer) security invoker;

alter function net.http_delete ( text, jsonb, jsonb, integer) security invoker;

alter function net._http_collect_response ( bigint, boolean) security invoker;

alter function net.http_collect_response ( bigint, boolean) security invoker;

create or replace function net.worker_restart()
returns bool
language 'c'
as 'pg_net';

grant usage on schema net to PUBLIC;
grant all on all sequences in schema net to PUBLIC;
grant all on all tables in schema net to PUBLIC;
23 changes: 7 additions & 16 deletions sql/pg_net.sql
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ create or replace function net._encode_url_with_params_array(url text, params_ar
immutable
as 'pg_net';

create or replace function net.worker_restart()
returns bool
language 'c'
as 'pg_net';

-- Interface to make an async request
-- API: Public
Expand All @@ -115,7 +119,6 @@ create or replace function net.http_get(
volatile
parallel safe
language plpgsql
security definer
as $$
declare
request_id bigint;
Expand Down Expand Up @@ -159,7 +162,6 @@ create or replace function net.http_post(
volatile
parallel safe
language plpgsql
security definer
as $$
declare
request_id bigint;
Expand Down Expand Up @@ -229,7 +231,6 @@ create or replace function net.http_delete(
volatile
parallel safe
language plpgsql
security definer
as $$
declare
request_id bigint;
Expand Down Expand Up @@ -290,7 +291,6 @@ create or replace function net._http_collect_response(
volatile
parallel safe
language plpgsql
security definer
as $$
declare
rec net._http_response;
Expand Down Expand Up @@ -345,22 +345,13 @@ create or replace function net.http_collect_response(
volatile
parallel safe
language plpgsql
security definer
as $$
begin
raise notice 'The net.http_collect_response function is deprecated.';
select net._http_collect_response(request_id, async);
end;
$$;

create or replace function net.worker_restart() returns bool as $$
select pg_reload_conf();
select pg_terminate_backend(pid)
from pg_stat_activity
where backend_type ilike '%pg_net%';
$$
security definer
language sql;

grant all on schema net to postgres;
grant all on all tables in schema net to postgres;
grant usage on schema net to PUBLIC;
grant all on all sequences in schema net to PUBLIC;
grant all on all tables in schema net to PUBLIC;
41 changes: 34 additions & 7 deletions src/worker.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <string.h>
#include <inttypes.h>

#include <storage/shmem.h>

#include "util.h"
#include "core.h"

Expand All @@ -40,17 +42,25 @@ _Static_assert(LIBCURL_VERSION_NUM >= MIN_LIBCURL_VERSION_NUM, REQUIRED_LIBCURL_

PG_MODULE_MAGIC;

static char *guc_ttl;
static int guc_batch_size;
static char* guc_database_name;
static MemoryContext CurlMemContext = NULL;
static char* guc_ttl;
static int guc_batch_size;
static char* guc_database_name;
static MemoryContext CurlMemContext = NULL;
static shmem_startup_hook_type prev_shmem_startup_hook = NULL;
static long latch_timeout = 1000;
static volatile sig_atomic_t got_sigterm = false;
static volatile sig_atomic_t got_sighup = false;
static bool* restart_worker = NULL;

void _PG_init(void);
PGDLLEXPORT void pg_net_worker(Datum main_arg) pg_attribute_noreturn();

static long latch_timeout = 1000;
static volatile sig_atomic_t got_sigterm = false;
static volatile sig_atomic_t got_sighup = false;
PG_FUNCTION_INFO_V1(worker_restart);
Datum worker_restart(PG_FUNCTION_ARGS) {
bool result = DatumGetBool(DirectFunctionCall1(pg_reload_conf, (Datum) NULL)); // reload the config
*restart_worker = true;
PG_RETURN_BOOL(result && *restart_worker); // TODO is not necessary to return a bool here, but we do it to maintain backward compatibility
}

static void
handle_sigterm(SIGNAL_ARGS)
Expand Down Expand Up @@ -141,6 +151,12 @@ void pg_net_worker(Datum main_arg) {
ProcessConfigFile(PGC_SIGHUP);
}

if (restart_worker && *restart_worker) {
*restart_worker = false;
elog(INFO, "Restarting pg_net worker");
break;
}

delete_expired_responses(guc_ttl, guc_batch_size);

consume_request_queue(lstate.curl_mhandle, guc_batch_size, CurlMemContext);
Expand Down Expand Up @@ -206,6 +222,14 @@ void pg_net_worker(Datum main_arg) {
proc_exit(EXIT_FAILURE);
}

static void net_shmem_startup(void) {
if (prev_shmem_startup_hook)
prev_shmem_startup_hook();

restart_worker = ShmemAlloc(sizeof(bool));
*restart_worker = false;
}

void _PG_init(void) {
if (IsBinaryUpgrade) {
return;
Expand All @@ -226,6 +250,9 @@ void _PG_init(void) {
.bgw_restart_time = 1,
});

prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = net_shmem_startup;

CurlMemContext = AllocSetContextCreate(TopMemoryContext,
"pg_net curl context",
ALLOCSET_DEFAULT_MINSIZE,
Expand Down
80 changes: 80 additions & 0 deletions test/test_privileges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest
from sqlalchemy import text

def test_net_on_postgres_role(sess):
"""Check that the postgres role can use the net schema by default"""

role = sess.execute(text("select current_user;")).fetchone()

assert role[0] == "postgres"

# Create a request
(request_id,) = sess.execute(text(
"""
select net.http_get(
'http://localhost:8080/anything'
);
"""
)).fetchone()

# Commit so background worker can start
sess.commit()

# Confirm that the request was retrievable
response = sess.execute(
text(
"""
select * from net._http_collect_response(:request_id, async:=false);
"""
),
{"request_id": request_id},
).fetchone()
assert response[0] == "SUCCESS"

def test_net_on_another_role(sess):
"""Check that a newly created role can use the net schema"""

sess.execute(text("""
create role another;
"""))

# Create a request
(request_id,) = sess.execute(text(
"""
set local role to another;
select net.http_get(
'http://localhost:8080/anything'
);
"""
)).fetchone()

# Commit so background worker can start
sess.commit()

# Confirm that the request was retrievable
response = sess.execute(
text(
"""
set local role to another;
select * from net._http_collect_response(:request_id, async:=false);
"""
),
{"request_id": request_id},
).fetchone()
assert response[0] == "SUCCESS"

## can use the net.worker_restart function
response = sess.execute(
text(
"""
set local role to another;
select net.worker_restart();
"""
)
).fetchone()
assert response[0] == True

sess.execute(text("""
set local role postgres;
drop role another;
"""))
Comment on lines +34 to +80
Copy link
Member Author

@steve-chavez steve-chavez Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that these changes are tested

2 changes: 2 additions & 0 deletions test/test_worker_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
def test_success_when_worker_is_up(sess):
"""net.check_worker_is_up should not return anything when the worker is running"""

time.sleep(1) # wait if another test did a net.worker_restart()

(result,) = sess.execute(text("""
select net.check_worker_is_up();
""")).fetchone()
Expand Down