Skip to content

Commit

Permalink
add new catalog and cache APIs
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Feb 10, 2024
1 parent 12842b6 commit c0cb502
Show file tree
Hide file tree
Showing 14 changed files with 1,040 additions and 430 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pgrx-pg-sys/include/pg12.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "catalog/pg_foreign_data_wrapper.h"
#include "catalog/pg_foreign_server.h"
#include "catalog/pg_foreign_table.h"
#include "catalog/pg_index.h"
#include "catalog/pg_operator.h"
#include "catalog/pg_opclass.h"
#include "catalog/pg_opfamily.h"
Expand Down
1 change: 1 addition & 0 deletions pgrx-pg-sys/include/pg13.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "catalog/pg_foreign_data_wrapper.h"
#include "catalog/pg_foreign_server.h"
#include "catalog/pg_foreign_table.h"
#include "catalog/pg_index.h"
#include "catalog/pg_operator.h"
#include "catalog/pg_opclass.h"
#include "catalog/pg_opfamily.h"
Expand Down
1 change: 1 addition & 0 deletions pgrx-pg-sys/include/pg14.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "catalog/pg_foreign_data_wrapper.h"
#include "catalog/pg_foreign_server.h"
#include "catalog/pg_foreign_table.h"
#include "catalog/pg_index.h"
#include "catalog/pg_operator.h"
#include "catalog/pg_opclass.h"
#include "catalog/pg_opfamily.h"
Expand Down
1 change: 1 addition & 0 deletions pgrx-pg-sys/include/pg15.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include "catalog/pg_foreign_data_wrapper.h"
#include "catalog/pg_foreign_server.h"
#include "catalog/pg_foreign_table.h"
#include "catalog/pg_index.h"
#include "catalog/pg_operator.h"
#include "catalog/pg_opclass.h"
#include "catalog/pg_opfamily.h"
Expand Down
1 change: 1 addition & 0 deletions pgrx-pg-sys/include/pg16.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "catalog/pg_foreign_data_wrapper.h"
#include "catalog/pg_foreign_server.h"
#include "catalog/pg_foreign_table.h"
#include "catalog/pg_index.h"
#include "catalog/pg_operator.h"
#include "catalog/pg_opclass.h"
#include "catalog/pg_opfamily.h"
Expand Down
1 change: 1 addition & 0 deletions pgrx-tests/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ mod memcxt_tests;
mod name_tests;
mod numeric_tests;
mod pg_cast_tests;
mod pg_catalog_tests;
mod pg_extern_tests;
mod pg_guard_tests;
mod pg_operator_tests;
Expand Down
131 changes: 131 additions & 0 deletions pgrx-tests/src/tests/pg_catalog_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use pgrx::prelude::*;

#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
mod tests {
use pgrx::pg_sys::Oid;
use std::ffi::CString;

#[allow(unused_imports)]
use crate as pgrx_tests;
use pgrx::prelude::*;

#[pg_test]
fn test_pg_catalog_pg_proc_boolin() {
use pgrx::pg_catalog::*;
let proname = CString::new("boolin").unwrap();
let proargtypes = /* cstring */ [pgrx::wrappers::regtypein("cstring")];
let pronamespace = /* pg_catalog */ Oid::from(11);
// search
let pg_proc = PgProc::search_procnameargsnsp(&proname, &proargtypes, pronamespace).unwrap();
let pg_proc = pg_proc.get().unwrap();
// getstruct, name
assert_eq!(pg_proc.proname(), proname.as_c_str());
// getstruct, primitive types
assert_eq!(pg_proc.pronamespace(), pronamespace);
assert_eq!(pg_proc.procost(), 1.0);
assert_eq!(pg_proc.prorows(), 0.0);
assert_eq!(pg_proc.provariadic(), Oid::INVALID);
// getstruct, regproc
assert_eq!(pg_proc.prosupport(), Oid::INVALID);
// getstruct, char
assert_eq!(pg_proc.prokind(), PgProcProkind::Function);
assert_eq!(pg_proc.prosecdef(), false);
assert_eq!(pg_proc.proleakproof(), false);
assert_eq!(pg_proc.proisstrict(), true);
assert_eq!(pg_proc.proretset(), false);
assert_eq!(pg_proc.provolatile(), PgProcProvolatile::Immutable);
assert_eq!(pg_proc.proparallel(), PgProcProparallel::Safe);
assert_eq!(pg_proc.pronargs(), 1);
assert_eq!(pg_proc.pronargdefaults(), 0);
assert_eq!(pg_proc.prorettype(), pgrx::pg_sys::BOOLOID);
// getstruct, oidvector
assert_eq!(pg_proc.proargtypes(), &proargtypes);
// getattr, null
assert!(pg_proc.proallargtypes().is_none());
assert!(pg_proc.proargmodes().is_none());
assert!(pg_proc.proargnames().is_none());
assert!(pg_proc.protrftypes().is_none());
// getattr, text
assert_eq!(pg_proc.prosrc(), "boolin");
assert!(pg_proc.probin().is_none());
assert!(pg_proc.proconfig().is_none());
}

#[pg_test]
fn test_pg_catalog_pg_proc_num_nulls() {
use pgrx::pg_catalog::*;
let proname = CString::new("num_nulls").unwrap();
let proargtypes = [pgrx::pg_sys::ANYOID];
let pronamespace = /* pg_catalog */ pgrx::pg_sys::Oid::from(11);
let pg_proc = PgProc::search_procnameargsnsp(&proname, &proargtypes, pronamespace).unwrap();
let pg_proc = pg_proc.get().unwrap();
assert_eq!(pg_proc.proname(), proname.as_c_str());
assert_eq!(pg_proc.pronamespace(), pronamespace);
assert_eq!(pg_proc.procost(), 1.0);
assert_eq!(pg_proc.prorows(), 0.0);
assert_eq!(pg_proc.provariadic(), pgrx::pg_sys::ANYOID);
assert_eq!(pg_proc.prosupport(), Oid::INVALID);
assert_eq!(pg_proc.prokind(), PgProcProkind::Function);
assert_eq!(pg_proc.prosecdef(), false);
assert_eq!(pg_proc.proleakproof(), false);
assert_eq!(pg_proc.proisstrict(), false);
assert_eq!(pg_proc.proretset(), false);
assert_eq!(pg_proc.provolatile(), PgProcProvolatile::Immutable);
assert_eq!(pg_proc.proparallel(), PgProcProparallel::Safe);
assert_eq!(pg_proc.pronargs(), 1);
assert_eq!(pg_proc.pronargdefaults(), 0);
assert_eq!(pg_proc.prorettype(), pgrx::pg_sys::INT4OID);
assert_eq!(pg_proc.proargtypes(), &proargtypes);
// getattr, oid[]
assert_eq!(
pg_proc.proallargtypes().map(|v| v.iter().collect()),
Some(vec![Some(pgrx::pg_sys::ANYOID)])
);
// getattr, char[]
assert_eq!(
pg_proc.proargmodes().map(|v| v.iter().collect()),
Some(vec![Some(PgProcProargmodes::Variadic)])
);
assert!(pg_proc.proargnames().is_none());
assert!(pg_proc.protrftypes().is_none());
assert_eq!(pg_proc.prosrc(), "pg_num_nulls");
assert!(pg_proc.probin().is_none());
assert!(pg_proc.proconfig().is_none());
}

#[pg_test]
fn test_pg_catalog_pg_proc_gcd() {
use pgrx::pg_catalog::*;
let proname = CString::new("gcd").unwrap();
// search_list
let pg_proc = PgProc::search_list_procnameargsnsp_1(&proname).unwrap();
let mut int4gcd = false;
let mut int8gcd = false;
for i in 0..pg_proc.len() {
let pg_proc = pg_proc.get(i).unwrap();
if pg_proc.prosrc() == "int4gcd" {
int4gcd = true;
}
if pg_proc.prosrc() == "int8gcd" {
int8gcd = true;
}
}
assert!(int4gcd);
assert!(int8gcd);
}

#[pg_test]
fn test_pg_catalog_pg_class_pg_stats() {
use pgrx::pg_catalog::*;
let relname = CString::new("pg_stats").unwrap();
let relnamespace = /* pg_catalog */ pgrx::pg_sys::Oid::from(11);
let pg_class = PgClass::search_relnamensp(&relname, relnamespace).unwrap();
let pg_class = pg_class.get().unwrap();
// getattr, text[]
assert_eq!(
pg_class.reloptions().map(|v| v.iter().collect()),
Some(vec![Some("security_barrier=true".to_string())])
);
}
}
1 change: 1 addition & 0 deletions pgrx/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ seahash = "4.1.0" # derive(PostgresHash)
serde = { version = "1.0", features = [ "derive" ] } # impls on pub types
serde_cbor = "0.11.2" # derive(PostgresType)
serde_json = "1.0" # everything JSON
paste = "1.0.14"
75 changes: 17 additions & 58 deletions pgrx/src/enum_helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,26 @@
//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
//! Helper functions for working with Postgres `enum` types
use crate::pg_sys::GETSTRUCT;
use crate::pg_catalog::PgEnum;
use crate::{ereport, pg_sys, PgLogLevel, PgSqlErrorCode};

pub fn lookup_enum_by_oid(enumval: pg_sys::Oid) -> (String, pg_sys::Oid, f32) {
let tup = unsafe {
pg_sys::SearchSysCache(
pg_sys::SysCacheIdentifier_ENUMOID as i32,
pg_sys::Datum::from(enumval),
pg_sys::Datum::from(0),
pg_sys::Datum::from(0),
pg_sys::Datum::from(0),
)
};
if tup.is_null() {
let pg_enum = PgEnum::search_enumoid(enumval).unwrap();

let Some(pg_enum) = pg_enum.get() else {
ereport!(
PgLogLevel::ERROR,
PgSqlErrorCode::ERRCODE_INVALID_BINARY_REPRESENTATION,
format!("invalid internal value for enum: {enumval:?}")
);
}

let en = unsafe { GETSTRUCT(tup) } as pg_sys::Form_pg_enum;
let en = unsafe { en.as_ref() }.unwrap();
let result = (
unsafe {
core::ffi::CStr::from_ptr(en.enumlabel.data.as_ptr() as *const std::os::raw::c_char)
}
.to_str()
.unwrap()
.to_string(),
en.enumtypid,
en.enumsortorder as f32,
);

unsafe {
pg_sys::ReleaseSysCache(tup);
}
unreachable!()
};

result
(
pg_enum.enumlabel().to_str().unwrap().to_string(),
pg_enum.enumtypid(),
pg_enum.enumsortorder() as f32,
)
}

pub fn lookup_enum_by_label(typname: &str, label: &str) -> pg_sys::Datum {
Expand All @@ -57,35 +38,13 @@ pub fn lookup_enum_by_label(typname: &str, label: &str) -> pg_sys::Datum {
panic!("could not locate type oid for type: {typname}");
}

let tup = unsafe {
let label =
alloc::ffi::CString::new(label).expect("failed to convert enum typname to a CString");
pg_sys::SearchSysCache(
pg_sys::SysCacheIdentifier_ENUMTYPOIDNAME as i32,
pg_sys::Datum::from(enumtypoid),
pg_sys::Datum::from(label.as_ptr()),
pg_sys::Datum::from(0usize),
pg_sys::Datum::from(0usize),
)
};

if tup.is_null() {
panic!("could not find heap tuple for enum: {typname}.{label}, typoid={enumtypoid:?}");
}
let label = std::ffi::CString::new(label).expect("failed to convert enum typname to a CString");

// SAFETY: we know that `tup` is valid because we just got it from Postgres above
unsafe {
let oid = extract_enum_oid(tup);
pg_sys::ReleaseSysCache(tup);
pg_sys::Datum::from(oid)
}
}
let pg_enum = PgEnum::search_enumtypoidname(enumtypoid, &label).unwrap();

unsafe fn extract_enum_oid(tup: *mut pg_sys::HeapTupleData) -> pg_sys::Oid {
let en = {
// SAFETY: the caller has assured us that `tup` is a valid HeapTupleData pointer
GETSTRUCT(tup) as pg_sys::Form_pg_enum
let Some(pg_enum) = pg_enum.get() else {
panic!("could not find heap tuple for enum: {typname}.{label:?}, typoid={enumtypoid:?}");
};
let en = en.as_ref().unwrap();
en.oid

pg_sys::Datum::from(pg_enum.oid())
}
19 changes: 12 additions & 7 deletions pgrx/src/fn_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use pgrx_pg_sys::PgTryBuilder;
use std::panic::AssertUnwindSafe;

use crate::memcx;
use crate::pg_catalog::pg_proc::{PgProc, ProArgMode, ProKind};
use crate::pg_catalog::PgProc;
use crate::pg_catalog::{PgProcProargmodes, PgProcProkind};
use crate::seal::Sealed;
use crate::{
direct_function_call, is_a, list::List, pg_sys, pg_sys::AsPgCStr, Array, FromDatum, IntoDatum,
Expand Down Expand Up @@ -205,18 +206,22 @@ pub fn fn_call_with_collation<R: FromDatum + IntoDatum>(
let func_oid = lookup_fn(fname, args)?;

// lookup the function's pg_proc entry and do some validation
let pg_proc = PgProc::new(func_oid).ok_or(FnCallError::UndefinedFunction)?;
let pg_proc = PgProc::search_procoid(func_oid).ok_or(FnCallError::UndefinedFunction)?;
let pg_proc = pg_proc.get().ok_or(FnCallError::UndefinedFunction)?;
let retoid = pg_proc.prorettype();

//
// do some validation to catch the cases we don't/can't directly call
//

if !matches!(pg_proc.prokind(), ProKind::Function) {
if !matches!(pg_proc.prokind(), PgProcProkind::Function) {
// It only makes sense to directly call regular functions. Calling aggregate or window
// functions is nonsensical
return Err(FnCallError::UnsupportedFunctionType);
} else if pg_proc.proargmodes().iter().any(|mode| *mode != ProArgMode::In) {
} else if pg_proc
.proargmodes()
.map_or(false, |x| x.iter_deny_null().any(|mode| mode != PgProcProargmodes::In))
{
// Right now we only know how to support arguments with the IN mode. Perhaps in the
// future we can support IN_OUT and TABLE return types
return Err(FnCallError::UnsupportedArgumentModes);
Expand All @@ -240,7 +245,7 @@ pub fn fn_call_with_collation<R: FromDatum + IntoDatum>(
.iter()
.enumerate()
.map(|(i, a)| a.as_datum(&pg_proc, i))
.chain((args.len()..pg_proc.pronargs()).map(|i| create_default_value(&pg_proc, i)))
.chain((args.len()..pg_proc.pronargs() as usize).map(|i| create_default_value(&pg_proc, i)))
.map(|datum| {
null |= matches!(datum, Ok(None));
datum
Expand Down Expand Up @@ -276,7 +281,7 @@ pub fn fn_call_with_collation<R: FromDatum + IntoDatum>(
//
// SAFETY: we allocate enough zeroed space for the base FunctionCallInfoBaseData *plus* the number of arguments
// we have, and we've asserted that we have the correct number of arguments
assert_eq!(nargs, pg_proc.pronargs());
assert_eq!(nargs, pg_proc.pronargs() as usize);
let fcinfo = pg_sys::palloc0(
std::mem::size_of::<pg_sys::FunctionCallInfoBaseData>()
+ std::mem::size_of::<pg_sys::NullableDatum>() * nargs,
Expand Down Expand Up @@ -433,7 +438,7 @@ fn parse_sql_ident(ident: &str) -> Result<Array<&str>> {
/// - [`FnCallError::NotDefaultArgument`] if the specified `argnum` does not have a `DEFAULT` clause
/// - [`FnCallError::DefaultNotConstantExpression`] if the `DEFAULT` clause is one we cannot evaluate
fn create_default_value(pg_proc: &PgProc, argnum: usize) -> Result<Option<pg_sys::Datum>> {
let non_default_args_cnt = pg_proc.pronargs() - pg_proc.pronargdefaults();
let non_default_args_cnt = (pg_proc.pronargs() - pg_proc.pronargdefaults()) as usize;
if argnum < non_default_args_cnt {
return Err(FnCallError::NotDefaultArgument(argnum));
}
Expand Down
Loading

0 comments on commit c0cb502

Please sign in to comment.