Skip to content

Commit

Permalink
Ensure profile neg works for RPC
Browse files Browse the repository at this point in the history
  • Loading branch information
steve-chavez committed Mar 26, 2020
1 parent 5773c38 commit a249cff
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 23 deletions.
7 changes: 5 additions & 2 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ postgrest conf refDbStructure pool getTime worker =
Nothing -> respond . errorResponseFor $ ConnectionLostError
Just dbStructure -> do
response <- do
-- Need to parse ?columns early because findProc needs it to solve overloaded functions
-- Need to parse ?columns early because findProc needs it to solve overloaded functions.
-- TODO: move this logic to the app function
let apiReq = userApiRequest (configSchemas conf) (configRootSpec conf) req body
apiReqCols = (,) <$> apiReq <*> (pRequestColumns =<< iColumns <$> apiReq)
case apiReqCols of
Expand Down Expand Up @@ -305,7 +306,9 @@ app dbStructure proc cols conf apiRequest =
Left _ -> return . errorResponseFor $ GucHeadersError
Right ghdrs -> do
let (status, contentRange) = rangeStatusHeader topLevelRange queryTotal tableTotal
headers = addHeadersIfNotIncluded [toHeader contentType, contentRange] (unwrapGucHeader <$> ghdrs)
headers = addHeadersIfNotIncluded
(catMaybes [Just $ toHeader contentType, Just contentRange, profileH])
(unwrapGucHeader <$> ghdrs)
rBody = if invMethod == InvHead then mempty else toS body
if contentType == CTSingularJSON && queryTotal /= 1
then do
Expand Down
16 changes: 9 additions & 7 deletions src/PostgREST/DbStructure.hs
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,14 @@ sourceColumnFromRow allCols (s1,t1,c1,s2,t2,c2) = (,) <$> col1 <*> col2
col2 = findCol s2 t2 c2
findCol s t c = find (\col -> (tableSchema . colTable) col == s && (tableName . colTable) col == t && colName col == c) allCols

decodeProcs :: HD.Result (M.HashMap Text [ProcDescription])
decodeProcs :: HD.Result ProcsMap
decodeProcs =
-- Duplicate rows for a function means they're overloaded, order these by least args according to ProcDescription Ord instance
map sort . M.fromListWith (++) . map ((\(x,y) -> (x, [y])) . addName) <$> HD.rowList tblRow
map sort . M.fromListWith (++) . map ((\(x,y) -> (x, [y])) . addKey) <$> HD.rowList procRow
where
tblRow = ProcDescription
procRow = ProcDescription
<$> column HD.text
<*> column HD.text
<*> nullableColumn HD.text
<*> (parseArgs <$> column HD.text)
<*> (parseRetType
Expand All @@ -142,8 +143,8 @@ decodeProcs =
<*> column HD.char)
<*> (parseVolatility <$> column HD.char)

addName :: ProcDescription -> (Text, ProcDescription)
addName pd = (pdName pd, pd)
addKey :: ProcDescription -> (QualifiedIdentifier, ProcDescription)
addKey pd = (QualifiedIdentifier (pdSchema pd) (pdName pd), pd)

parseArgs :: Text -> [PgArg]
parseArgs = mapMaybe parseArg . filter (not . isPrefixOf "OUT" . toS) . map strip . split (==',')
Expand Down Expand Up @@ -176,19 +177,20 @@ decodeProcs =
| v == 's' = Stable
| otherwise = Volatile -- only 'v' can happen here

allProcs :: H.Statement [Schema] (M.HashMap Text [ProcDescription])
allProcs :: H.Statement [Schema] ProcsMap
allProcs = H.Statement (toS sql) (arrayParam HE.text) decodeProcs True
where
sql = procsSqlQuery <> " WHERE pn.nspname = ANY($1)"

accessibleProcs :: H.Statement Schema (M.HashMap Text [ProcDescription])
accessibleProcs :: H.Statement Schema ProcsMap
accessibleProcs = H.Statement (toS sql) (param HE.text) decodeProcs True
where
sql = procsSqlQuery <> " WHERE pn.nspname = $1 AND has_function_privilege(p.oid, 'execute')"

procsSqlQuery :: SqlQuery
procsSqlQuery = [q|
SELECT
pn.nspname as "proc_schema",
p.proname as "proc_name",
d.description as "proc_description",
pg_get_function_arguments(p.oid) as "args",
Expand Down
28 changes: 17 additions & 11 deletions src/PostgREST/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Module : PostgREST.Types
Description : PostgREST common types and functions used by the rest of the modules
-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}

module PostgREST.Types where
Expand Down Expand Up @@ -105,8 +106,7 @@ data DbStructure = DbStructure {
, dbColumns :: [Column]
, dbRelations :: [Relation]
, dbPrimaryKeys :: [PrimaryKey]
-- ProcDescription is a list because a function can be overloaded
, dbProcs :: M.HashMap Text [ProcDescription]
, dbProcs :: ProcsMap
, pgVersion :: PgVersion
} deriving (Show, Eq)

Expand All @@ -132,7 +132,8 @@ data ProcVolatility = Volatile | Stable | Immutable
deriving (Eq, Show, Ord)

data ProcDescription = ProcDescription {
pdName :: Text
pdSchema :: Schema
, pdName :: Text
, pdDescription :: Maybe Text
, pdArgs :: [PgArg]
, pdReturnType :: RetType
Expand All @@ -141,18 +142,23 @@ data ProcDescription = ProcDescription {

-- Order by least number of args in the case of overloaded functions
instance Ord ProcDescription where
ProcDescription name1 des1 args1 rt1 vol1 `compare` ProcDescription name2 des2 args2 rt2 vol2
| name1 == name2 && length args1 < length args2 = LT
| name1 == name2 && length args1 > length args2 = GT
| otherwise = (name1, des1, args1, rt1, vol1) `compare` (name2, des2, args2, rt2, vol2)
ProcDescription schema1 name1 des1 args1 rt1 vol1 `compare` ProcDescription schema2 name2 des2 args2 rt2 vol2
| schema1 == schema2 && name1 == name2 && length args1 < length args2 = LT
| schema2 == schema2 && name1 == name2 && length args1 > length args2 = GT
| otherwise = (schema1, name1, des1, args1, rt1, vol1) `compare` (schema2, name2, des2, args2, rt2, vol2)

-- | A map of all procs, all of which can be overloaded(one entry will have more than one ProcDescription).
-- | It uses a HashMap for a faster lookup.
type ProcsMap = M.HashMap QualifiedIdentifier [ProcDescription]

{-|
Search a pg procedure by its parameters. Since a function can be overloaded, the name is not enough to find it.
An overloaded function can have a different volatility or even a different return type.
Ideally, handling overloaded functions should be left to pg itself. But we need to know certain proc attributes in advance.
-}
findProc :: QualifiedIdentifier -> S.Set Text -> Bool -> M.HashMap Text [ProcDescription] -> Maybe ProcDescription
findProc :: QualifiedIdentifier -> S.Set Text -> Bool -> ProcsMap -> Maybe ProcDescription
findProc qi payloadKeys paramsAsSingleObject allProcs =
case M.lookup (qiName qi) allProcs of
case M.lookup qi allProcs of
Nothing -> Nothing
Just [proc] -> Just proc -- if it's not an overloaded function then immediately get the ProcDescription
Just procs -> find matches procs -- Handle overloaded functions case
Expand Down Expand Up @@ -254,8 +260,8 @@ data OrderTerm = OrderTerm {
data QualifiedIdentifier = QualifiedIdentifier {
qiSchema :: Schema
, qiName :: TableName
} deriving (Show, Eq, Ord)

} deriving (Show, Eq, Ord, Generic)
instance Hashable QualifiedIdentifier

-- | The relationship [cardinality](https://en.wikipedia.org/wiki/Cardinality_(data_modeling)).
-- | TODO: missing one-to-one
Expand Down
38 changes: 35 additions & 3 deletions test/Feature/MultipleSchemaSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,46 @@ spec =
matchStatus = 406
}

context "calling procs on different schemas" $ do
it "succeeds in calling the default schema proc" $
request methodGet "/rpc/get_parents_below?id=6" [] ""
`shouldRespondWith`
[json|[{"id":1,"name":"parent v1-1"}, {"id":2,"name":"parent v1-2"}]|]
{
matchStatus = 200
, matchHeaders = [matchContentTypeJson, "Content-Profile" <:> "v1"]
}

it "succeeds in calling the v1 schema proc and embedding" $
request methodGet "/rpc/get_parents_below?id=6&select=id,name,childs(id,name)" [("Accept-Profile", "v1")] ""
`shouldRespondWith`
[json| [
{"id":1,"name":"parent v1-1","childs":[{"id":1,"name":"child 1"}]},
{"id":2,"name":"parent v1-2","childs":[{"id":2,"name":"child 2"}]}] |]
{
matchStatus = 200
, matchHeaders = [matchContentTypeJson, "Content-Profile" <:> "v1"]
}

it "succeeds in calling the v2 schema proc and embedding" $
request methodGet "/rpc/get_parents_below?id=6&select=id,name,childs(id,name)" [("Accept-Profile", "v2")] ""
`shouldRespondWith`
[json| [
{"id":3,"name":"parent v2-3","childs":[{"id":1,"name":"child 3"}]},
{"id":4,"name":"parent v2-4","childs":[]}] |]
{
matchStatus = 200
, matchHeaders = [matchContentTypeJson, "Content-Profile" <:> "v2"]
}

context "OpenAPI output" $ do
it "succeeds in reading table definition from default schema v1 if no schema is selected via header" $ do
r <- request methodGet "/" [] ""
req <- request methodGet "/" [] ""

liftIO $ do
simpleHeaders r `shouldSatisfy` matchHeader "Content-Profile" "v1"
simpleHeaders req `shouldSatisfy` matchHeader "Content-Profile" "v1"

let def = simpleBody r ^? key "definitions" . key "parents"
let def = simpleBody req ^? key "definitions" . key "parents"

def `shouldBe` Just
[aesonQQ|
Expand Down
10 changes: 10 additions & 0 deletions test/fixtures/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,11 @@ create table v1.childs (
, table_id int references v1.parents(id)
);

create function v1.get_parents_below(id int)
returns setof v1.parents as $$
select * from v1.parents where id < $1;
$$ language sql;

create table v2.parents (
id int primary key
, name text
Expand All @@ -1723,3 +1728,8 @@ create table v2.another_table (
id int primary key
, another_value text
);

create function v2.get_parents_below(id int)
returns setof v2.parents as $$
select * from v2.parents where id < $1;
$$ language sql;

0 comments on commit a249cff

Please sign in to comment.