Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions geonode/base/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,39 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#########################################################################
import logging

from django.apps import AppConfig
from django.utils.translation import gettext_noop as _

from geonode.notifications_helper import NotificationsAppConfigBase

logger = logging.getLogger(__name__)


def create_geonode_db_schema(sender, using, **kwargs):
"""Create the configured PostgreSQL schema before migrations run."""
import re
from django.conf import settings
from django.db import connections
from geonode.utils import get_db_schema

db_config = settings.DATABASES.get(using, {})
schema = get_db_schema(db_config)

if not schema or schema == "public":
return

if not re.match(r"^[A-Za-z_][A-Za-z0-9_$]*$", schema):
logger.warning(f"Skipping schema creation for '{schema}' on database '{using}': invalid schema name")
return

try:
with connections[using].cursor() as cursor:
cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{schema}"')
except Exception as e:
logger.warning(f"Could not create schema '{schema}' on database '{using}': {e}")


class BaseAppConfig(NotificationsAppConfigBase, AppConfig):
name = "geonode.base"
Expand All @@ -45,8 +73,10 @@ class BaseAppConfig(NotificationsAppConfigBase, AppConfig):

def ready(self):
"""Finalize setup"""
from django.db.models import signals
from geonode.base.signals import connect_signals

connect_signals()
signals.pre_migrate.connect(create_geonode_db_schema)

super(BaseAppConfig, self).ready()
1 change: 1 addition & 0 deletions geonode/br/management/commands/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def dump_geoserver_vector_data(self, config, settings, target_folder):
datastore["HOST"],
datastore["PASSWORD"],
gs_data_folder,
utils.get_db_schema(datastore),
)

def dump_geoserver_externals(self, config, settings, target_folder):
Expand Down
10 changes: 8 additions & 2 deletions geonode/br/management/commands/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,10 @@ def execute_restore(self, **options):
db_host = settings.DATABASES["default"]["HOST"]
db_passwd = settings.DATABASES["default"]["PASSWORD"]

utils.truncate_tables(db_name, db_user, db_port, db_host, db_passwd)
utils.truncate_tables(
db_name, db_user, db_port, db_host, db_passwd,
utils.get_db_schema(settings.DATABASES["default"]),
)
except Exception:
logger.info("Error while truncating tables, trying external task")

Expand Down Expand Up @@ -730,7 +733,10 @@ def restore_geoserver_vector_data(self, config, settings, target_folder, soft_re
ogc_db_port = datastore["PORT"]

if not soft_reset:
utils.remove_existing_tables(ogc_db_name, ogc_db_user, ogc_db_port, ogc_db_host, ogc_db_passwd)
utils.remove_existing_tables(
ogc_db_name, ogc_db_user, ogc_db_port, ogc_db_host, ogc_db_passwd,
utils.get_db_schema(datastore),
)

utils.restore_db(
config,
Expand Down
23 changes: 15 additions & 8 deletions geonode/br/management/commands/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ def load_settings(config):
sys.path.append(os.path.join(os.path.dirname(__file__), "lib"))


def get_db_schema(db_settings):
"""Extract the configured schema name from a Django database settings dict."""
from geonode.utils import get_db_schema as _get_db_schema

return _get_db_schema(db_settings)


def get_db_conn(db_name, db_user, db_port, db_host, db_passwd):
"""Get db conn (GeoNode)"""
db_host = db_host if db_host is not None else "localhost"
Expand All @@ -174,8 +181,8 @@ def get_db_conn(db_name, db_user, db_port, db_host, db_passwd):
return conn


def get_tables(db_user, db_passwd, db_name, db_host="localhost", db_port=5432):
select = f"SELECT tablename FROM pg_tables WHERE tableowner = '{db_user}' and schemaname = 'public'"
def get_tables(db_user, db_passwd, db_name, db_host="localhost", db_port=5432, db_schema="public"):
select = f"SELECT tablename FROM pg_tables WHERE tableowner = '{db_user}' and schemaname = '{db_schema}'"
logger.info(f"Retrieving table list from DB {db_name}@{db_host}: {select}")

try:
Expand All @@ -194,13 +201,13 @@ def get_tables(db_user, db_passwd, db_name, db_host="localhost", db_port=5432):
conn.close()


def truncate_tables(db_name, db_user, db_port, db_host, db_passwd):
def truncate_tables(db_name, db_user, db_port, db_host, db_passwd, db_schema="public"):
"""HARD Truncate all DB Tables"""
db_host = db_host if db_host is not None else "localhost"
db_port = db_port if db_port is not None else 5432

logger.info(f"Truncating the tables in DB {db_name} @{db_host}:{db_port} for user {db_user}")
pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port)
pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port, db_schema)
logger.info(f"Tables found: {pg_tables}")

conn = get_db_conn(db_name, db_user, db_port, db_host, db_passwd)
Expand All @@ -227,13 +234,13 @@ def truncate_tables(db_name, db_user, db_port, db_host, db_passwd):
conn.close()


def dump_db(config, db_name, db_user, db_port, db_host, db_passwd, target_folder):
def dump_db(config, db_name, db_user, db_port, db_host, db_passwd, target_folder, db_schema="public"):
"""Dump Full DB into target folder"""
db_host = db_host if db_host is not None else "localhost"
db_port = db_port if db_port is not None else 5432

logger.info("Dumping data tables")
pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port)
pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port, db_schema)
logger.info(f"Tables found: {pg_tables}")

include_filter = config.gs_data_datasetname_filter
Expand Down Expand Up @@ -313,9 +320,9 @@ def restore_db(config, db_name, db_user, db_port, db_host, db_passwd, source_fol
logger.error(f"ERR:: {cproc.stderr}")


def remove_existing_tables(db_name, db_user, db_port, db_host, db_passwd):
def remove_existing_tables(db_name, db_user, db_port, db_host, db_passwd, db_schema="public"):
logger.info("Dropping existing GeoServer vector data from DB")
pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port)
pg_tables = get_tables(db_user, db_passwd, db_name, db_host, db_port, db_schema)
bad_tables = []

conn = get_db_conn(db_name, db_user, db_port, db_host, db_passwd)
Expand Down
31 changes: 31 additions & 0 deletions geonode/br/tests/test_restore_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,34 @@ def test_backup_hash_success(self):
finally:
# remove temporary hash file
os.remove(tmp_hash_file)


class GetDbSchemaTests(GeoNodeBaseTestSupport):
"""Tests for the get_db_schema utility function."""

def setUp(self):
from geonode.br.management.commands.utils.utils import get_db_schema

self.get_db_schema = get_db_schema

def test_returns_public_when_no_options(self):
result = self.get_db_schema({})
self.assertEqual(result, "public")

def test_returns_public_when_empty_options(self):
result = self.get_db_schema({"OPTIONS": {}})
self.assertEqual(result, "public")

def test_returns_public_when_no_search_path(self):
result = self.get_db_schema({"OPTIONS": {"connect_timeout": 5}})
self.assertEqual(result, "public")

def test_returns_schema_from_search_path(self):
db_settings = {"OPTIONS": {"options": "-c search_path=my_schema,public"}}
result = self.get_db_schema(db_settings)
self.assertEqual(result, "my_schema")

def test_returns_first_schema_from_search_path(self):
db_settings = {"OPTIONS": {"options": "-c search_path=schema1,schema2,public"}}
result = self.get_db_schema(db_settings)
self.assertEqual(result, "schema1")
10 changes: 9 additions & 1 deletion geonode/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@
GEONODE_DB_CONN_MAX_AGE = int(os.getenv("GEONODE_DB_CONN_MAX_AGE", 0))
GEONODE_DB_CONN_TOUT = int(os.getenv("GEONODE_DB_CONN_TOUT", 5))

GEONODE_DATABASE_SCHEMA = os.getenv("GEONODE_DATABASE_SCHEMA", "public")

_db_conf = dj_database_url.parse(DATABASE_URL, conn_max_age=GEONODE_DB_CONN_MAX_AGE)

if "CONN_TOUT" in _db_conf:
Expand All @@ -143,6 +145,8 @@
"connect_timeout": GEONODE_DB_CONN_TOUT,
}
)
if GEONODE_DATABASE_SCHEMA != "public":
_db_conf["OPTIONS"]["options"] = f"-c search_path={GEONODE_DATABASE_SCHEMA},public"

DATABASES = {"default": _db_conf}

Expand All @@ -152,19 +156,23 @@
"postgis://\
geonode_data:geonode_data@localhost:5432/geonode_data",
)
GEONODE_GEODATABASE_SCHEMA = os.getenv("GEONODE_GEODATABASE_SCHEMA", "public")
DATABASES[os.getenv("DEFAULT_BACKEND_DATASTORE")] = dj_database_url.parse(
GEODATABASE_URL, conn_max_age=GEONODE_DB_CONN_MAX_AGE
)
_geo_db = DATABASES[os.getenv("DEFAULT_BACKEND_DATASTORE")]
if "CONN_TOUT" in DATABASES["default"]:
_geo_db["CONN_TOUT"] = DATABASES["default"]["CONN_TOUT"]
if "postgresql" in GEODATABASE_URL or "postgis" in GEODATABASE_URL:
_geo_db["OPTIONS"] = DATABASES["default"]["OPTIONS"] if "OPTIONS" in DATABASES["default"] else {}
_geo_db["OPTIONS"] = dict(DATABASES["default"]["OPTIONS"]) if "OPTIONS" in DATABASES["default"] else {}
_geo_db["OPTIONS"].pop("options", None)
_geo_db["OPTIONS"].update(
{
"connect_timeout": GEONODE_DB_CONN_TOUT,
}
)
if GEONODE_GEODATABASE_SCHEMA != "public":
_geo_db["OPTIONS"]["options"] = f"-c search_path={GEONODE_GEODATABASE_SCHEMA},public"

DATABASES[os.getenv("DEFAULT_BACKEND_DATASTORE")] = _geo_db

Expand Down
13 changes: 13 additions & 0 deletions geonode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,3 +1810,16 @@ def get_allowed_extensions():
for val in _type["formats"]:
allowed_extention.append(val["required_ext"][0])
return list(set(allowed_extention))


def get_db_schema(db_settings):
"""Extract the configured schema name from a Django database settings dict.

Returns the first schema listed in the PostgreSQL ``search_path`` option,
or ``"public"`` if none is configured.
"""
options_str = db_settings.get("OPTIONS", {}).get("options", "")
for part in options_str.split():
if "search_path=" in part:
return part.split("=")[1].split(",")[0]
return "public"