Skip to content

Commit

Permalink
feat: add basic UPDATE .. RETURNING * for Postgres
Browse files Browse the repository at this point in the history
With bitemyapp#44 as a starting point, and more ideas in mind.
  • Loading branch information
ulidtko committed May 1, 2024
1 parent 30a5e80 commit d9777b4
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
41 changes: 37 additions & 4 deletions src/Database/Esqueleto/Internal/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,10 @@ locking kind = putLocking $ LegacyLockingClause kind
putLocking :: LockingClause -> SqlQuery ()
putLocking clause = Q $ W.tell mempty { sdLockingClause = clause }

-- | (Internal) Remember a @RETURNING@ clause in a query
tellReturning :: ReturningClause -> SqlQuery ()
tellReturning clause = Q $ W.tell mempty { sdReturningClause = clause }

{-#
DEPRECATED
sub_select
Expand Down Expand Up @@ -1835,14 +1839,15 @@ data SideData = SideData
, sdLimitClause :: !LimitClause
, sdLockingClause :: !LockingClause
, sdCteClause :: ![CommonTableExpressionClause]
, sdReturningClause :: !ReturningClause
}

instance Semigroup SideData where
SideData d f s w g h o l k c <> SideData d' f' s' w' g' h' o' l' k' c' =
SideData (d <> d') (f <> f') (s <> s') (w <> w') (g <> g') (h <> h') (o <> o') (l <> l') (k <> k') (c <> c')
SideData d f s w g h o l k c r <> SideData d' f' s' w' g' h' o' l' k' c' r' =
SideData (d <> d') (f <> f') (s <> s') (w <> w') (g <> g') (h <> h') (o <> o') (l <> l') (k <> k') (c <> c') (r <> r')

instance Monoid SideData where
mempty = SideData mempty mempty mempty mempty mempty mempty mempty mempty mempty mempty
mempty = SideData mempty mempty mempty mempty mempty mempty mempty mempty mempty mempty mempty
mappend = (<>)

-- | The @DISTINCT@ "clause".
Expand Down Expand Up @@ -1879,6 +1884,12 @@ data CommonTableExpressionKind
data CommonTableExpressionClause =
CommonTableExpressionClause CommonTableExpressionKind Ident (IdentInfo -> (TLB.Builder, [PersistValue]))

data ReturningClause
= ReturningNothing -- ^ The default, absent clause.
| ReturningStar -- ^ @RETURNING *@
-- | ReturningExprs (NonEmpty (SqlExpr Returning))
-- ^ @output_expression [ [ AS ] output_name ] [, ...]@

data SubQueryType
= NormalSubQuery
| LateralSubQuery
Expand Down Expand Up @@ -2117,6 +2128,16 @@ instance Monoid LockingClause where
mempty = NoLockingClause
mappend = (<>)

instance Semigroup ReturningClause where
(<>) ReturningNothing x = x
(<>) x ReturningNothing = x
(<>) ReturningStar ReturningStar = ReturningStar
-- (<>) _ _ = error "instance Semigroup FIXME"

instance Monoid ReturningClause where
mempty = ReturningNothing
mappend = (<>)

----------------------------------------------------------------------

-- | Identifier used for table names.
Expand Down Expand Up @@ -2981,7 +3002,8 @@ toRawSql mode (conn, firstIdentState) query =
orderByClauses
limitClause
lockingClause
cteClause = sd
cteClause
returningClause = sd
-- Pass the finalIdentState (containing all identifiers
-- that were used) to the subsequent calls. This ensures
-- that no name clashes will occur on subqueries that may
Expand All @@ -2999,6 +3021,7 @@ toRawSql mode (conn, firstIdentState) query =
, makeOrderBy info orderByClauses
, makeLimit info limitClause
, makeLocking info lockingClause
, makeReturning info returningClause ret
]


Expand Down Expand Up @@ -3073,6 +3096,7 @@ data Mode
| DELETE
| UPDATE
| INSERT_INTO
| UPDATE_RETSTAR

uncommas :: [TLB.Builder] -> TLB.Builder
uncommas = intersperseB ", "
Expand Down Expand Up @@ -3124,6 +3148,7 @@ makeSelect info mode_ distinctClause ret = process mode_
DELETE -> plain "DELETE "
UPDATE -> plain "UPDATE "
INSERT_INTO -> process SELECT
UPDATE_RETSTAR -> plain "UPDATE "
selectKind =
case distinctClause of
DistinctAll -> ("SELECT ", [])
Expand Down Expand Up @@ -3151,6 +3176,7 @@ makeFrom info mode fs = ret
keyword =
case mode of
UPDATE -> id
UPDATE_RETSTAR -> id
_ -> first ("\nFROM " <>)

mk _ (FromStart i def) = base i def
Expand Down Expand Up @@ -3268,6 +3294,13 @@ makeLocking info (PostgresLockingClauses clauses) =
plain v = (v,[])
makeLocking _ NoLockingClause = mempty

makeReturning :: SqlSelect a r
=> IdentInfo -> ReturningClause -> a -> (TLB.Builder, [PersistValue])
makeReturning _ ReturningNothing _ = mempty
makeReturning info ReturningStar ret = ("RETURNING ", []) <> sqlSelectCols info ret
-- makeReturning info (ReturningExprs _) = undefined -- FIXME


parens :: TLB.Builder -> TLB.Builder
parens b = "(" <> (b <> ")")

Expand Down
10 changes: 10 additions & 0 deletions src/Database/Esqueleto/PostgreSQL.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ module Database.Esqueleto.PostgreSQL
, chr
, now_
, random_
, updateReturningAll
, upsert
, upsertBy
, insertSelectWithConflict
Expand All @@ -41,6 +42,7 @@ module Database.Esqueleto.PostgreSQL
#if __GLASGOW_HASKELL__ < 804
import Data.Semigroup
#endif
import Conduit (withAcquire)
import Control.Arrow (first)
import Control.Exception (throw)
import Control.Monad (void)
Expand Down Expand Up @@ -477,3 +479,11 @@ forUpdateOf lockableEntities onLockedBehavior =
forShareOf :: LockableEntity a => a -> OnLockedBehavior -> SqlQuery ()
forShareOf lockableEntities onLockedBehavior =
putLocking $ PostgresLockingClauses [PostgresLockingKind PostgresForShare (Just $ LockingOfClause lockableEntities) onLockedBehavior]

updateReturningAll :: (MonadIO m, PersistEntity ent, SqlBackendCanWrite backend, backend ~ PersistEntityBackend ent)
=> (SqlExpr (Entity ent) -> SqlQuery (SqlExpr (Entity ent)))
-> R.ReaderT backend m [Entity ent]
updateReturningAll block = do
conn <- R.ask
conduit <- rawSelectSource UPDATE_RETSTAR (tellReturning ReturningStar >> from block)
liftIO . withAcquire conduit $ flip R.runReaderT conn . runSource
12 changes: 12 additions & 0 deletions test/PostgreSQL/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,17 @@ testUpsert =
u3e <- EP.upsert u3 [OneUniqueName =. val "fifth"]
liftIO $ entityVal u3e `shouldBe` u1{oneUniqueName="fifth"}

testUpdateDeleteReturning :: SpecDb
testUpdateDeleteReturning =
describe "UPDATE .. RETURNING *" $ do
itDb "Whole updated entity gets returned" $ do
[p1k, p2k, p3k, p4k, p5k] <- mapM insert [p1, p2, p3, p4, p5]
ret <- EP.updateReturningAll $ \p -> do
set p [ PersonFavNum =. val 42 ]
where_ (p ^. PersonFavNum ==. val 4)
return p
asserting $ ret `shouldBe` [Entity p4k p4{ personFavNum = 42 }]

testInsertSelectWithConflict :: SpecDb
testInsertSelectWithConflict =
describe "insertSelectWithConflict test" $ do
Expand Down Expand Up @@ -1629,6 +1640,7 @@ spec = beforeAll mkConnectionPool $ do
testPostgresqlTextFunctions
testInsertUniqueViolation
testUpsert
testUpdateDeleteReturning
testInsertSelectWithConflict
testFilterWhere
testCommonTableExpressions
Expand Down

0 comments on commit d9777b4

Please sign in to comment.