diff --git a/geonode/base/apps.py b/geonode/base/apps.py index 934f87f7ed0..c9e58d5fcc1 100644 --- a/geonode/base/apps.py +++ b/geonode/base/apps.py @@ -16,11 +16,39 @@ # along with this program. If not, see . # ######################################################################### +import logging + from django.apps import AppConfig from django.utils.translation import gettext_noop as _ from geonode.notifications_helper import NotificationsAppConfigBase +logger = logging.getLogger(__name__) + + +def create_geonode_db_schema(sender, using, **kwargs): + """Create the configured PostgreSQL schema before migrations run.""" + import re + from django.conf import settings + from django.db import connections + from geonode.utils import get_db_schema + + db_config = settings.DATABASES.get(using, {}) + schema = get_db_schema(db_config) + + if not schema or schema == "public": + return + + if not re.match(r"^[A-Za-z_][A-Za-z0-9_$]*$", schema): + logger.warning(f"Skipping schema creation for '{schema}' on database '{using}': invalid schema name") + return + + try: + with connections[using].cursor() as cursor: + cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{schema}"') + except Exception as e: + logger.warning(f"Could not create schema '{schema}' on database '{using}': {e}") + class BaseAppConfig(NotificationsAppConfigBase, AppConfig): name = "geonode.base" @@ -45,8 +73,10 @@ class BaseAppConfig(NotificationsAppConfigBase, AppConfig): def ready(self): """Finalize setup""" + from django.db.models import signals from geonode.base.signals import connect_signals connect_signals() + signals.pre_migrate.connect(create_geonode_db_schema) super(BaseAppConfig, self).ready() diff --git a/geonode/br/management/commands/backup.py b/geonode/br/management/commands/backup.py index c54c040cbc0..f7251ce665f 100644 --- a/geonode/br/management/commands/backup.py +++ b/geonode/br/management/commands/backup.py @@ -370,6 +370,7 @@ def dump_geoserver_vector_data(self, config, settings, target_folder): datastore["HOST"], datastore["PASSWORD"], gs_data_folder, + utils.get_db_schema(datastore), ) def dump_geoserver_externals(self, config, settings, target_folder): diff --git a/geonode/br/management/commands/restore.py b/geonode/br/management/commands/restore.py index 012122ac83a..20065e02c4d 100755 --- a/geonode/br/management/commands/restore.py +++ b/geonode/br/management/commands/restore.py @@ -352,7 +352,10 @@ def execute_restore(self, **options): db_host = settings.DATABASES["default"]["HOST"] db_passwd = settings.DATABASES["default"]["PASSWORD"] - utils.truncate_tables(db_name, db_user, db_port, db_host, db_passwd) + utils.truncate_tables( + db_name, db_user, db_port, db_host, db_passwd, + utils.get_db_schema(settings.DATABASES["default"]), + ) except Exception: logger.info("Error while truncating tables, trying external task") @@ -730,7 +733,10 @@ def restore_geoserver_vector_data(self, config, settings, target_folder, soft_re ogc_db_port = datastore["PORT"] if not soft_reset: - utils.remove_existing_tables(ogc_db_name, ogc_db_user, ogc_db_port, ogc_db_host, ogc_db_passwd) + utils.remove_existing_tables( + ogc_db_name, ogc_db_user, ogc_db_port, ogc_db_host, ogc_db_passwd, + utils.get_db_schema(datastore), + ) utils.restore_db( config, diff --git a/geonode/br/management/commands/utils/utils.py b/geonode/br/management/commands/utils/utils.py index 65b82f03ca1..51300bd090e 100644 --- a/geonode/br/management/commands/utils/utils.py +++ b/geonode/br/management/commands/utils/utils.py @@ -164,6 +164,13 @@ def load_settings(config): sys.path.append(os.path.join(os.path.dirname(__file__), "lib")) +def get_db_schema(db_settings): + """Extract the configured schema name from a Django database settings dict.""" + from geonode.utils import get_db_schema as _get_db_schema + + return _get_db_schema(db_settings) + + def get_db_conn(db_name, db_user, db_port, db_host, db_passwd): """Get db conn (GeoNode)""" db_host = db_host if db_host is not None else "localhost" @@ -174,8 +181,8 @@ def get_db_conn(db_name, db_user, db_port, db_host, db_passwd): return conn -def get_tables(db_user, db_passwd, db_name, db_host="localhost", db_port=5432): - select = f"SELECT tablename FROM pg_tables WHERE tableowner = '{db_user}' and schemaname = 'public'" +def get_tables(db_user, db_passwd, db_name, db_host="localhost", db_port=5432, db_schema="public"): + select = f"SELECT tablename FROM pg_tables WHERE tableowner = '{db_user}' and schemaname = '{db_schema}'" logger.info(f"Retrieving table list from DB {db_name}@{db_host}: {select}") try: @@ -194,13 +201,13 @@ def get_tables(db_user, db_passwd, db_name, db_host="localhost", db_port=5432): conn.close() -def truncate_tables(db_name, db_user, db_port, db_host, db_passwd): +def truncate_tables(db_name, db_user, db_port, db_host, db_passwd, db_schema="public"): """HARD Truncate all DB Tables""" db_host = db_host if db_host is not None else "localhost" db_port = db_port if db_port is not None else 5432 logger.info(f"Truncating the tables in DB {db_name} @{db_host}:{db_port} for user {db_user}") - pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port) + pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port, db_schema) logger.info(f"Tables found: {pg_tables}") conn = get_db_conn(db_name, db_user, db_port, db_host, db_passwd) @@ -227,13 +234,13 @@ def truncate_tables(db_name, db_user, db_port, db_host, db_passwd): conn.close() -def dump_db(config, db_name, db_user, db_port, db_host, db_passwd, target_folder): +def dump_db(config, db_name, db_user, db_port, db_host, db_passwd, target_folder, db_schema="public"): """Dump Full DB into target folder""" db_host = db_host if db_host is not None else "localhost" db_port = db_port if db_port is not None else 5432 logger.info("Dumping data tables") - pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port) + pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port, db_schema) logger.info(f"Tables found: {pg_tables}") include_filter = config.gs_data_datasetname_filter @@ -313,9 +320,9 @@ def restore_db(config, db_name, db_user, db_port, db_host, db_passwd, source_fol logger.error(f"ERR:: {cproc.stderr}") -def remove_existing_tables(db_name, db_user, db_port, db_host, db_passwd): +def remove_existing_tables(db_name, db_user, db_port, db_host, db_passwd, db_schema="public"): logger.info("Dropping existing GeoServer vector data from DB") - pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port) + pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port, db_schema) bad_tables = [] conn = get_db_conn(db_name, db_user, db_port, db_host, db_passwd) diff --git a/geonode/br/tests/test_restore_helpers.py b/geonode/br/tests/test_restore_helpers.py index b950dfab967..64b1a0b2c35 100644 --- a/geonode/br/tests/test_restore_helpers.py +++ b/geonode/br/tests/test_restore_helpers.py @@ -203,3 +203,34 @@ def test_backup_hash_success(self): finally: # remove temporary hash file os.remove(tmp_hash_file) + + +class GetDbSchemaTests(GeoNodeBaseTestSupport): + """Tests for the get_db_schema utility function.""" + + def setUp(self): + from geonode.br.management.commands.utils.utils import get_db_schema + + self.get_db_schema = get_db_schema + + def test_returns_public_when_no_options(self): + result = self.get_db_schema({}) + self.assertEqual(result, "public") + + def test_returns_public_when_empty_options(self): + result = self.get_db_schema({"OPTIONS": {}}) + self.assertEqual(result, "public") + + def test_returns_public_when_no_search_path(self): + result = self.get_db_schema({"OPTIONS": {"connect_timeout": 5}}) + self.assertEqual(result, "public") + + def test_returns_schema_from_search_path(self): + db_settings = {"OPTIONS": {"options": "-c search_path=my_schema,public"}} + result = self.get_db_schema(db_settings) + self.assertEqual(result, "my_schema") + + def test_returns_first_schema_from_search_path(self): + db_settings = {"OPTIONS": {"options": "-c search_path=schema1,schema2,public"}} + result = self.get_db_schema(db_settings) + self.assertEqual(result, "schema1") diff --git a/geonode/settings.py b/geonode/settings.py index 1150c4247d7..544cbdf225d 100644 --- a/geonode/settings.py +++ b/geonode/settings.py @@ -131,6 +131,8 @@ GEONODE_DB_CONN_MAX_AGE = int(os.getenv("GEONODE_DB_CONN_MAX_AGE", 0)) GEONODE_DB_CONN_TOUT = int(os.getenv("GEONODE_DB_CONN_TOUT", 5)) +GEONODE_DATABASE_SCHEMA = os.getenv("GEONODE_DATABASE_SCHEMA", "public") + _db_conf = dj_database_url.parse(DATABASE_URL, conn_max_age=GEONODE_DB_CONN_MAX_AGE) if "CONN_TOUT" in _db_conf: @@ -143,6 +145,8 @@ "connect_timeout": GEONODE_DB_CONN_TOUT, } ) + if GEONODE_DATABASE_SCHEMA != "public": + _db_conf["OPTIONS"]["options"] = f"-c search_path={GEONODE_DATABASE_SCHEMA},public" DATABASES = {"default": _db_conf} @@ -152,6 +156,7 @@ "postgis://\ geonode_data:geonode_data@localhost:5432/geonode_data", ) + GEONODE_GEODATABASE_SCHEMA = os.getenv("GEONODE_GEODATABASE_SCHEMA", "public") DATABASES[os.getenv("DEFAULT_BACKEND_DATASTORE")] = dj_database_url.parse( GEODATABASE_URL, conn_max_age=GEONODE_DB_CONN_MAX_AGE ) @@ -159,12 +164,15 @@ if "CONN_TOUT" in DATABASES["default"]: _geo_db["CONN_TOUT"] = DATABASES["default"]["CONN_TOUT"] if "postgresql" in GEODATABASE_URL or "postgis" in GEODATABASE_URL: - _geo_db["OPTIONS"] = DATABASES["default"]["OPTIONS"] if "OPTIONS" in DATABASES["default"] else {} + _geo_db["OPTIONS"] = dict(DATABASES["default"]["OPTIONS"]) if "OPTIONS" in DATABASES["default"] else {} + _geo_db["OPTIONS"].pop("options", None) _geo_db["OPTIONS"].update( { "connect_timeout": GEONODE_DB_CONN_TOUT, } ) + if GEONODE_GEODATABASE_SCHEMA != "public": + _geo_db["OPTIONS"]["options"] = f"-c search_path={GEONODE_GEODATABASE_SCHEMA},public" DATABASES[os.getenv("DEFAULT_BACKEND_DATASTORE")] = _geo_db diff --git a/geonode/utils.py b/geonode/utils.py index e750a1bd02b..259c226aef5 100755 --- a/geonode/utils.py +++ b/geonode/utils.py @@ -1810,3 +1810,16 @@ def get_allowed_extensions(): for val in _type["formats"]: allowed_extention.append(val["required_ext"][0]) return list(set(allowed_extention)) + + +def get_db_schema(db_settings): + """Extract the configured schema name from a Django database settings dict. + + Returns the first schema listed in the PostgreSQL ``search_path`` option, + or ``"public"`` if none is configured. + """ + options_str = db_settings.get("OPTIONS", {}).get("options", "") + for part in options_str.split(): + if "search_path=" in part: + return part.split("=")[1].split(",")[0] + return "public"