From 8713c6773120002c0242aced3b237615a3dec863 Mon Sep 17 00:00:00 2001
From: Nicolas Gasull <nicolas@gasull.net>
Date: Tue, 25 Apr 2023 10:55:02 +0200
Subject: [PATCH] Safely return null in functions where applicable. Add support
 for domains and take their nullability into account. Centralize logic in
 `pgTypeToTsType`.

---
 src/lib/sql/types.sql                      |   2 +
 src/lib/types.ts                           |   2 +
 src/server/routes/generators/typescript.ts |   3 +-
 src/server/server.ts                       |   3 +-
 src/server/templates/typescript.ts         | 197 ++++++++++++---------
 test/db/00-init.sql                        |  66 +++++++
 test/lib/types.ts                          |   4 +
 test/server/typegen.ts                     |  88 ++++++++-
 8 files changed, 266 insertions(+), 99 deletions(-)

diff --git a/src/lib/sql/types.sql b/src/lib/sql/types.sql
index d0974012..e1af9dce 100644
--- a/src/lib/sql/types.sql
+++ b/src/lib/sql/types.sql
@@ -3,6 +3,8 @@ select
   t.typname as name,
   n.nspname as schema,
   format_type (t.oid, null) as format,
+  nullif(t.typbasetype, 0) as base_type_id,
+  not (t.typnotnull) as is_nullable,
   coalesce(t_enums.enums, '[]') as enums,
   coalesce(t_attributes.attributes, '[]') as attributes,
   obj_description (t.oid, 'pg_type') as comment
diff --git a/src/lib/types.ts b/src/lib/types.ts
index 9849064d..80f39265 100644
--- a/src/lib/types.ts
+++ b/src/lib/types.ts
@@ -357,6 +357,8 @@ export const postgresTypeSchema = Type.Object({
   name: Type.String(),
   schema: Type.String(),
   format: Type.String(),
+  base_type_id: Type.Optional(Type.Integer()),
+  is_nullable: Type.Boolean(),
   enums: Type.Array(Type.String()),
   attributes: Type.Array(
     Type.Object({
diff --git a/src/server/routes/generators/typescript.ts b/src/server/routes/generators/typescript.ts
index 110b4bf4..b857360e 100644
--- a/src/server/routes/generators/typescript.ts
+++ b/src/server/routes/generators/typescript.ts
@@ -77,8 +77,7 @@ export default async (fastify: FastifyInstance) => {
       functions: functions.filter(
         ({ return_type }) => !['trigger', 'event_trigger'].includes(return_type)
       ),
-      types: types.filter(({ name }) => name[0] !== '_'),
-      arrayTypes: types.filter(({ name }) => name[0] === '_'),
+      types,
     })
   })
 }
diff --git a/src/server/server.ts b/src/server/server.ts
index e0ef5853..04af3d2b 100644
--- a/src/server/server.ts
+++ b/src/server/server.ts
@@ -80,8 +80,7 @@ if (EXPORT_DOCS) {
       functions: functions.filter(
         ({ return_type }) => !['trigger', 'event_trigger'].includes(return_type)
       ),
-      types: types.filter(({ name }) => name[0] !== '_'),
-      arrayTypes: types.filter(({ name }) => name[0] === '_'),
+      types,
     })
   )
 } else {
diff --git a/src/server/templates/typescript.ts b/src/server/templates/typescript.ts
index f18934c4..6121a566 100644
--- a/src/server/templates/typescript.ts
+++ b/src/server/templates/typescript.ts
@@ -15,7 +15,6 @@ export const apply = ({
   materializedViews,
   functions,
   types,
-  arrayTypes,
 }: {
   schemas: PostgresSchema[]
   tables: (PostgresTable & { columns: unknown[] })[]
@@ -23,7 +22,6 @@ export const apply = ({
   materializedViews: (PostgresMaterializedView & { columns: unknown[] })[]
   functions: PostgresFunction[]
   types: PostgresType[]
-  arrayTypes: PostgresType[]
 }): string => {
   let output = `
 export type Json = string | number | boolean | null | { [key: string]: Json } | Json[]
@@ -63,6 +61,15 @@ export interface Database {
       const schemaEnums = types
         .filter((type) => type.schema === schema.name && type.enums.length > 0)
         .sort(({ name: a }, { name: b }) => a.localeCompare(b))
+      const schemaDomainTypes = types
+        .flatMap((type) => {
+          const baseType =
+            type.schema === schema.name &&
+            type.base_type_id &&
+            types.find(({ id }) => id === type.base_type_id)
+          return baseType ? [{ type, baseType }] : []
+        })
+        .sort(({ type: { name: a } }, { type: { name: b } }) => a.localeCompare(b))
       const schemaCompositeTypes = types
         .filter((type) => type.schema === schema.name && type.attributes.length > 0)
         .sort(({ name: a }, { name: b }) => a.localeCompare(b))
@@ -82,8 +89,9 @@ export interface Database {
                             `${JSON.stringify(column.name)}: ${pgTypeToTsType(
                               column.format,
                               types,
-                              schemas
-                            )} ${column.is_nullable ? '| null' : ''}`
+                              schemas,
+                              { nullable: column.is_nullable }
+                            )}`
                         ),
                       ...schemaFunctions
                         .filter((fn) => fn.argument_types === table.name)
@@ -93,7 +101,7 @@ export interface Database {
                               fn.return_type,
                               types,
                               schemas
-                            )} | null`
+                            )}`
                         ),
                     ]}
                   }
@@ -117,11 +125,9 @@ export interface Database {
                           output += ':'
                         }
 
-                        output += pgTypeToTsType(column.format, types, schemas)
-
-                        if (column.is_nullable) {
-                          output += '| null'
-                        }
+                        output += pgTypeToTsType(column.format, types, schemas, {
+                          nullable: column.is_nullable,
+                        })
 
                         return output
                       })}
@@ -136,11 +142,9 @@ export interface Database {
                           return `${output}?: never`
                         }
 
-                        output += `?: ${pgTypeToTsType(column.format, types, schemas)}`
-
-                        if (column.is_nullable) {
-                          output += '| null'
-                        }
+                        output += `?: ${pgTypeToTsType(column.format, types, schemas, {
+                          nullable: column.is_nullable,
+                        })}`
 
                         return output
                       })}
@@ -163,8 +167,9 @@ export interface Database {
                           `${JSON.stringify(column.name)}: ${pgTypeToTsType(
                             column.format,
                             types,
-                            schemas
-                          )} ${column.is_nullable ? '| null' : ''}`
+                            schemas,
+                            { nullable: column.is_nullable }
+                          )}`
                       )}
                   }
                   ${
@@ -179,7 +184,9 @@ export interface Database {
                           return `${output}?: never`
                         }
 
-                        output += `?: ${pgTypeToTsType(column.format, types, schemas)} | null`
+                        output += `?: ${pgTypeToTsType(column.format, types, schemas, {
+                          nullable: true,
+                        })}`
 
                         return output
                       })}
@@ -198,7 +205,9 @@ export interface Database {
                           return `${output}?: never`
                         }
 
-                        output += `?: ${pgTypeToTsType(column.format, types, schemas)} | null`
+                        output += `?: ${pgTypeToTsType(column.format, types, schemas, {
+                          nullable: true,
+                        })}`
 
                         return output
                       })}
@@ -239,17 +248,7 @@ export interface Database {
                     }
 
                     const argsNameAndType = inArgs.map(({ name, type_id, has_default }) => {
-                      let type = arrayTypes.find(({ id }) => id === type_id)
-                      if (type) {
-                        // If it's an array type, the name looks like `_int8`.
-                        const elementTypeName = type.name.substring(1)
-                        return {
-                          name,
-                          type: `(${pgTypeToTsType(elementTypeName, types, schemas)})[]`,
-                          has_default,
-                        }
-                      }
-                      type = types.find(({ id }) => id === type_id)
+                      const type = types.find(({ id }) => id === type_id)
                       if (type) {
                         return {
                           name,
@@ -272,19 +271,13 @@ export interface Database {
                     const tableArgs = args.filter(({ mode }) => mode === 'table')
                     if (tableArgs.length > 0) {
                       const argsNameAndType = tableArgs.map(({ name, type_id }) => {
-                        let type = arrayTypes.find(({ id }) => id === type_id)
+                        const type = types.find(({ id }) => id === type_id)
                         if (type) {
-                          // If it's an array type, the name looks like `_int8`.
-                          const elementTypeName = type.name.substring(1)
                           return {
                             name,
-                            type: `(${pgTypeToTsType(elementTypeName, types, schemas)})[]`,
+                            type: pgTypeToTsType(type.name, types, schemas),
                           }
                         }
-                        type = types.find(({ id }) => id === type_id)
-                        if (type) {
-                          return { name, type: pgTypeToTsType(type.name, types, schemas) }
-                        }
                         return { name, type: 'unknown' }
                       })
 
@@ -308,8 +301,9 @@ export interface Database {
                               `${JSON.stringify(column.name)}: ${pgTypeToTsType(
                                 column.format,
                                 types,
-                                schemas
-                              )} ${column.is_nullable ? '| null' : ''}`
+                                schemas,
+                                { nullable: column.is_nullable }
+                              )}`
                           )}
                       }`
                     }
@@ -340,6 +334,21 @@ export interface Database {
                   )
             }
           }
+          DomainTypes: {
+            ${
+              schemaDomainTypes.length === 0
+                ? '[_ in never]: never'
+                : schemaDomainTypes.map(
+                    ({ type: domain_, baseType }) =>
+                      `${JSON.stringify(domain_.name)}: ${pgTypeToTsType(
+                        baseType.name,
+                        types,
+                        schemas,
+                        { nullable: domain_.is_nullable }
+                      )}`
+                  )
+            }
+          }
           CompositeTypes: {
             ${
               schemaCompositeTypes.length === 0
@@ -377,58 +386,72 @@ export interface Database {
 const pgTypeToTsType = (
   pgType: string,
   types: PostgresType[],
-  schemas: PostgresSchema[]
+  schemas: PostgresSchema[],
+  opts: { nullable?: boolean } = {}
 ): string => {
-  if (pgType === 'bool') {
-    return 'boolean'
-  } else if (['int2', 'int4', 'int8', 'float4', 'float8', 'numeric'].includes(pgType)) {
-    return 'number'
-  } else if (
-    [
-      'bytea',
-      'bpchar',
-      'varchar',
-      'date',
-      'text',
-      'citext',
-      'time',
-      'timetz',
-      'timestamp',
-      'timestamptz',
-      'uuid',
-      'vector',
-    ].includes(pgType)
-  ) {
-    return 'string'
-  } else if (['json', 'jsonb'].includes(pgType)) {
-    return 'Json'
-  } else if (pgType === 'void') {
-    return 'undefined'
-  } else if (pgType === 'record') {
-    return 'Record<string, unknown>'
-  } else if (pgType.startsWith('_')) {
-    return `(${pgTypeToTsType(pgType.substring(1), types, schemas)})[]`
-  } else {
-    const enumType = types.find((type) => type.name === pgType && type.enums.length > 0)
-    if (enumType) {
-      if (schemas.some(({ name }) => name === enumType.schema)) {
-        return `Database[${JSON.stringify(enumType.schema)}]['Enums'][${JSON.stringify(
-          enumType.name
-        )}]`
+  const type = types.find((type) => type.name === pgType)
+  const strictTsType = pgTypeToStrictTsType()
+  return strictTsType
+    ? `${strictTsType}${opts.nullable ?? type?.is_nullable ? ' | null' : ''}`
+    : 'unknown'
+
+  function pgTypeToStrictTsType() {
+    if (pgType === 'bool') {
+      return 'boolean'
+    } else if (['int2', 'int4', 'int8', 'float4', 'float8', 'numeric'].includes(pgType)) {
+      return 'number'
+    } else if (
+      [
+        'bytea',
+        'bpchar',
+        'varchar',
+        'date',
+        'text',
+        'citext',
+        'time',
+        'timetz',
+        'timestamp',
+        'timestamptz',
+        'uuid',
+        'vector',
+      ].includes(pgType)
+    ) {
+      return 'string'
+    } else if (['json', 'jsonb'].includes(pgType)) {
+      return 'Json'
+    } else if (pgType === 'void') {
+      return 'undefined'
+    } else if (pgType === 'record') {
+      return 'Record<string, unknown>'
+    } else if (pgType.startsWith('_')) {
+      return `(${pgTypeToTsType(pgType.substring(1), types, schemas)})[]`
+    } else if (type != null) {
+      if (type.base_type_id != null) {
+        if (schemas.some(({ name }) => name === type.schema)) {
+          return `Database[${JSON.stringify(type.schema)}]['DomainTypes'][${JSON.stringify(
+            type.name
+          )}]`
+        }
+        return undefined
+      }
+
+      if (type.enums.length > 0) {
+        if (schemas.some(({ name }) => name === type.schema)) {
+          return `Database[${JSON.stringify(type.schema)}]['Enums'][${JSON.stringify(type.name)}]`
+        }
+        return type.enums.map((variant) => JSON.stringify(variant)).join('|')
       }
-      return enumType.enums.map((variant) => JSON.stringify(variant)).join('|')
-    }
 
-    const compositeType = types.find((type) => type.name === pgType && type.attributes.length > 0)
-    if (compositeType) {
-      if (schemas.some(({ name }) => name === compositeType.schema)) {
-        return `Database[${JSON.stringify(
-          compositeType.schema
-        )}]['CompositeTypes'][${JSON.stringify(compositeType.name)}]`
+      if (type.attributes.length > 0) {
+        if (schemas.some(({ name }) => name === type.schema)) {
+          return `Database[${JSON.stringify(type.schema)}]['CompositeTypes'][${JSON.stringify(
+            type.name
+          )}]`
+        }
+        return undefined
       }
-      return 'unknown'
     }
 
-    return 'unknown'
+    return undefined
   }
 }
diff --git a/test/db/00-init.sql b/test/db/00-init.sql
index 32a3a256..63ce604f 100644
--- a/test/db/00-init.sql
+++ b/test/db/00-init.sql
@@ -91,3 +91,69 @@ stable
 as $$
   select id, name from public.users;
 $$;
+
+
+create domain text_not_null as text not null;
+create domain int_not_null as int not null;
+
+create type composite_with_strict as (
+  a text_not_null,
+  b int_not_null
+);
+create domain strict_composite_with_strict as composite_with_strict not null;
+
+create or replace function public.function_returning_table_of_strict(id int_not_null, name text_not_null)
+returns table (id int_not_null, name text_not_null)
+language sql
+immutable
+as $$
+  select id, name;
+$$;
+
+create or replace function public.function_with_array_of_strict(id int_not_null[], name text_not_null[])
+returns table (id int_not_null[], name text_not_null[])
+language sql
+immutable
+as $$
+  select id, name;
+$$;
+
+create or replace function public.function_with_composite_with_strict(obj composite_with_strict)
+returns composite_with_strict
+language sql
+immutable
+as $$
+  select obj;
+$$;
+
+create or replace function public.function_with_strict_composite_with_strict(obj strict_composite_with_strict)
+returns strict_composite_with_strict
+language sql
+immutable
+as $$
+  select obj;
+$$;
+
+create domain text_array as text[];
+create domain text_array_strict as text_not_null[] not null;
+
+create or replace function public.function_with_domain_array(arr text_array)
+returns text_array
+language sql
+immutable
+as $$
+  select arr;
+$$;
+
+create or replace function public.function_with_domain_array_strict(arr text_array_strict)
+returns text_array_strict
+language sql
+immutable
+as $$
+  select arr;
+$$;
+
+create table public.table_with_domain (
+  name text_not_null,
+  status_code int_not_null
+);
diff --git a/test/lib/types.ts b/test/lib/types.ts
index de547f5b..36f1d1ba 100644
--- a/test/lib/types.ts
+++ b/test/lib/types.ts
@@ -7,6 +7,7 @@ test('list', async () => {
     `
     {
       "attributes": [],
+      "base_type_id": null,
       "comment": null,
       "enums": [
         "ACTIVE",
@@ -14,6 +15,7 @@ test('list', async () => {
       ],
       "format": "user_status",
       "id": Any<Number>,
+      "is_nullable": true,
       "name": "user_status",
       "schema": "public",
     }
@@ -74,10 +76,12 @@ test('composite type attributes', async () => {
           "type_id": 25,
         },
       ],
+      "base_type_id": null,
       "comment": null,
       "enums": [],
       "format": "test_composite",
       "id": Any<Number>,
+      "is_nullable": true,
       "name": "test_composite",
       "schema": "public",
     }
diff --git a/test/server/typegen.ts b/test/server/typegen.ts
index 9e8988e0..10861152 100644
--- a/test/server/typegen.ts
+++ b/test/server/typegen.ts
@@ -57,6 +57,20 @@ test('typegen', async () => {
               status?: Database["public"]["Enums"]["meme_status"] | null
             }
           }
+          table_with_domain: {
+            Row: {
+              name: string
+              status_code: number
+            }
+            Insert: {
+              name: string
+              status_code: number
+            }
+            Update: {
+              name?: string
+              status_code?: number
+            }
+          }
           todos: {
             Row: {
               details: string | null
@@ -144,7 +158,7 @@ test('typegen', async () => {
             Args: {
               "": unknown
             }
-            Returns: string
+            Returns: string | null
           }
           function_returning_row: {
             Args: Record<PropertyKey, never>
@@ -165,23 +179,71 @@ test('typegen', async () => {
           function_returning_table: {
             Args: Record<PropertyKey, never>
             Returns: {
-              id: number
-              name: string
+              id: number | null
+              name: string | null
+            }[]
+          }
+          function_returning_table_of_strict: {
+            Args: {
+              id: Database["public"]["DomainTypes"]["int_not_null"]
+              name: Database["public"]["DomainTypes"]["text_not_null"]
+            }
+            Returns: {
+              id: Database["public"]["DomainTypes"]["int_not_null"]
+              name: Database["public"]["DomainTypes"]["text_not_null"]
             }[]
           }
+          function_with_array_of_strict: {
+            Args: {
+              id: Database["public"]["DomainTypes"]["int_not_null"][] | null
+              name: Database["public"]["DomainTypes"]["text_not_null"][] | null
+            }
+            Returns: {
+              id: Database["public"]["DomainTypes"]["int_not_null"][] | null
+              name: Database["public"]["DomainTypes"]["text_not_null"][] | null
+            }[]
+          }
+          function_with_composite_with_strict: {
+            Args: {
+              obj:
+                | Database["public"]["CompositeTypes"]["composite_with_strict"]
+                | null
+            }
+            Returns:
+              | Database["public"]["CompositeTypes"]["composite_with_strict"]
+              | null
+          }
+          function_with_domain_array: {
+            Args: {
+              arr: Database["public"]["DomainTypes"]["text_array"] | null
+            }
+            Returns: Database["public"]["DomainTypes"]["text_array"] | null
+          }
+          function_with_domain_array_strict: {
+            Args: {
+              arr: Database["public"]["DomainTypes"]["text_array_strict"]
+            }
+            Returns: Database["public"]["DomainTypes"]["text_array_strict"]
+          }
+          function_with_strict_composite_with_strict: {
+            Args: {
+              obj: Database["public"]["DomainTypes"]["strict_composite_with_strict"]
+            }
+            Returns: Database["public"]["DomainTypes"]["strict_composite_with_strict"]
+          }
           postgres_fdw_disconnect: {
             Args: {
-              "": string
+              "": string | null
             }
-            Returns: boolean
+            Returns: boolean | null
           }
           postgres_fdw_disconnect_all: {
             Args: Record<PropertyKey, never>
-            Returns: boolean
+            Returns: boolean | null
           }
           postgres_fdw_get_connections: {
             Args: Record<PropertyKey, never>
-            Returns: Record<string, unknown>[]
+            Returns: (Record<string, unknown> | null)[]
           }
           postgres_fdw_handler: {
             Args: Record<PropertyKey, never>
@@ -192,8 +254,18 @@ test('typegen', async () => {
           meme_status: "new" | "old" | "retired"
           user_status: "ACTIVE" | "INACTIVE"
         }
+        DomainTypes: {
+          int_not_null: number
+          strict_composite_with_strict: Database["public"]["CompositeTypes"]["composite_with_strict"]
+          text_array: (string | null)[] | null
+          text_array_strict: Database["public"]["DomainTypes"]["text_not_null"][]
+          text_not_null: string
+        }
         CompositeTypes: {
-          [_ in never]: never
+          composite_with_strict: {
+            a: Database["public"]["DomainTypes"]["text_not_null"]
+            b: Database["public"]["DomainTypes"]["int_not_null"]
+          }
         }
       }
     }