-
-
Notifications
You must be signed in to change notification settings - Fork 212
/
oci8.go
466 lines (420 loc) · 14.6 KB
/
oci8.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
package oci8
// #include "oci8.go.h"
import "C"
import (
"database/sql/driver"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"strconv"
"strings"
"time"
"unsafe"
)
// ParseDSN parses a DSN used to connect to Oracle
//
// It expects to receive a string in the form:
//
// [username/[password]@]host[:port][/service_name][?param1=value1&...¶mN=valueN]
//
// Connection timeout can be set in the Oracle files: sqlnet.ora as SQLNET.OUTBOUND_CONNECT_TIMEOUT or tnsnames.ora as CONNECT_TIMEOUT
//
// Supported parameters are:
//
// loc - the time location for reading timestamp (without time zone). Defaults to UTC
// Note that writing a timestamp (without time zone) just truncates the time zone.
//
// isolation - the isolation level that can be set to: READONLY, SERIALIZABLE, or DEFAULT
//
// prefetch_rows - the number of top level rows to be prefetched. Defaults to 0. A 0 means unlimited rows.
//
// prefetch_memory - the max memory for top level rows to be prefetched. Defaults to 4096. A 0 means unlimited memory.
//
// questionph - when true, enables question mark placeholders. Defaults to false. (uses strconv.ParseBool to check for true)
func ParseDSN(dsnString string) (dsn *DSN, err error) {
if dsnString == "" {
return nil, errors.New("empty dsn")
}
const prefix = "oracle://"
if strings.HasPrefix(dsnString, prefix) {
dsnString = dsnString[len(prefix):]
}
dsn = &DSN{
prefetchRows: 0,
prefetchMemory: 4096,
stmtCacheSize: 0,
operationMode: C.OCI_DEFAULT,
timeLocation: time.UTC,
}
authority, dsnString := splitRight(dsnString, "@")
if authority != "" {
dsn.Username, dsn.Password, err = parseAuthority(authority)
if err != nil {
return nil, err
}
}
host, params := splitRight(dsnString, "?")
if host, err = unescape(host, encodeHost); err != nil {
return nil, err
}
dsn.Connect = host
qp, err := ParseQuery(params)
for k, v := range qp {
switch k {
case "loc":
if len(v) > 0 {
if dsn.timeLocation, err = time.LoadLocation(v[0]); err != nil {
return nil, fmt.Errorf("Invalid loc: %v: %v", v[0], err)
}
}
case "isolation":
switch v[0] {
case "READONLY":
dsn.transactionMode = C.OCI_TRANS_READONLY
case "SERIALIZABLE":
dsn.transactionMode = C.OCI_TRANS_SERIALIZABLE
case "DEFAULT":
dsn.transactionMode = C.OCI_TRANS_READWRITE
default:
return nil, fmt.Errorf("Invalid isolation: %v", v[0])
}
case "questionph":
dsn.enableQMPlaceholders, err = strconv.ParseBool(v[0])
if err != nil {
return nil, fmt.Errorf("Invalid questionph: %v", v[0])
}
case "prefetch_rows":
z, err := strconv.ParseUint(v[0], 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid prefetch_rows: %v", v[0])
}
dsn.prefetchRows = C.ub4(z)
case "prefetch_memory":
z, err := strconv.ParseUint(v[0], 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid prefetch_memory: %v", v[0])
}
dsn.prefetchMemory = C.ub4(z)
case "as":
switch v[0] {
case "SYSDBA", "sysdba":
dsn.operationMode = C.OCI_SYSDBA
case "SYSASM", "sysasm":
dsn.operationMode = C.OCI_SYSASM
case "SYSOPER", "sysoper":
dsn.operationMode = C.OCI_SYSOPER
default:
return nil, fmt.Errorf("Invalid as: %v", v[0])
}
case "stmt_cache_size":
z, err := strconv.ParseUint(v[0], 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid stmt_cache_size: %v", v[0])
}
dsn.stmtCacheSize = C.ub4(z)
}
}
return dsn, nil
}
// Commit transaction commit
func (tx *Tx) Commit() error {
tx.conn.inTransaction = false
if rv := C.OCITransCommit(
tx.conn.svc,
tx.conn.errHandle,
0,
); rv != C.OCI_SUCCESS {
return tx.conn.getError(rv)
}
return nil
}
// Rollback transaction rollback
func (tx *Tx) Rollback() error {
tx.conn.inTransaction = false
if rv := C.OCITransRollback(
tx.conn.svc,
tx.conn.errHandle,
0,
); rv != C.OCI_SUCCESS {
return tx.conn.getError(rv)
}
return nil
}
// Open opens a new database connection
func (drv *DriverStruct) Open(dsnString string) (driver.Conn, error) {
var err error
var dsn *DSN
if dsn, err = ParseDSN(dsnString); err != nil {
return nil, err
}
conn := Conn{
operationMode: dsn.operationMode,
stmtCacheSize: dsn.stmtCacheSize,
logger: drv.Logger,
}
if conn.logger == nil {
conn.logger = log.New(ioutil.Discard, "", 0)
}
// environment handle
var envP *C.OCIEnv
envPP := &envP
var result C.sword
charset := C.ub2(0)
if os.Getenv("NLS_LANG") == "" && os.Getenv("NLS_NCHAR") == "" {
charset = defaultCharset
}
result = C.OCIEnvNlsCreate(
envPP, // pointer to a handle to the environment
C.OCI_THREADED, // environment mode: https://docs.oracle.com/cd/B28359_01/appdev.111/b28395/oci16rel001.htm#LNOCI87683
nil, // Specifies the user-defined context for the memory callback routines.
nil, // Specifies the user-defined memory allocation function. If mode is OCI_THREADED, this memory allocation routine must be thread-safe.
nil, // Specifies the user-defined memory re-allocation function. If the mode is OCI_THREADED, this memory allocation routine must be thread safe.
nil, // Specifies the user-defined memory free function. If mode is OCI_THREADED, this memory free routine must be thread-safe.
0, // Specifies the amount of user memory to be allocated for the duration of the environment.
nil, // Returns a pointer to the user memory of size xtramemsz allocated by the call for the user.
charset, // The client-side character set for the current environment handle. If it is 0, the NLS_LANG setting is used.
charset, // The client-side national character set for the current environment handle. If it is 0, NLS_NCHAR setting is used.
)
if result != C.OCI_SUCCESS {
return nil, errors.New("OCIEnvNlsCreate error")
}
conn.env = *envPP
// defer on error handle free
var doneSessionBegin bool
var doneServerAttach bool
var doneLogon bool
defer func(errP *error) {
if *errP != nil {
if doneSessionBegin {
C.OCISessionEnd(
conn.svc,
conn.errHandle,
conn.usrSession,
C.OCI_DEFAULT,
)
}
if doneLogon {
C.OCILogoff(
conn.svc,
conn.errHandle,
)
}
if doneServerAttach {
C.OCIServerDetach(
conn.srv,
conn.errHandle,
C.OCI_DEFAULT,
)
}
if conn.txHandle != nil {
C.OCIHandleFree(unsafe.Pointer(conn.txHandle), C.OCI_HTYPE_TRANS)
conn.txHandle = nil
}
if conn.usrSession != nil {
C.OCIHandleFree(unsafe.Pointer(conn.usrSession), C.OCI_HTYPE_SESSION)
conn.usrSession = nil
}
if conn.svc != nil {
C.OCIHandleFree(unsafe.Pointer(conn.svc), C.OCI_HTYPE_SVCCTX)
conn.svc = nil
}
if conn.srv != nil {
C.OCIHandleFree(unsafe.Pointer(conn.srv), C.OCI_HTYPE_SERVER)
conn.srv = nil
}
if conn.errHandle != nil {
C.OCIHandleFree(unsafe.Pointer(conn.errHandle), C.OCI_HTYPE_ERROR)
conn.errHandle = nil
}
C.OCIHandleFree(unsafe.Pointer(conn.env), C.OCI_HTYPE_ENV)
}
}(&err)
// error handle
var handleTemp unsafe.Pointer
handle := &handleTemp
result = C.OCIHandleAlloc(
unsafe.Pointer(conn.env), // An environment handle
handle, // Returns a handle
C.OCI_HTYPE_ERROR, // type of handle: https://docs.oracle.com/cd/B28359_01/appdev.111/b28395/oci02bas.htm#LNOCI87581
0, // amount of user memory to be allocated
nil, // Returns a pointer to the user memory
)
if result != C.OCI_SUCCESS {
// TODO: error handle not yet allocated, how to get string error from oracle?
err = errors.New("allocate error handle error")
return nil, err
}
conn.errHandle = (*C.OCIError)(*handle)
connectString := cString(dsn.Connect)
defer C.free(unsafe.Pointer(connectString))
username := cString(dsn.Username)
defer C.free(unsafe.Pointer(username))
password := cString(dsn.Password)
defer C.free(unsafe.Pointer(password))
if useOCISessionBegin {
// server handle
handle, _, err = conn.ociHandleAlloc(C.OCI_HTYPE_SERVER, 0)
if err != nil {
return nil, fmt.Errorf("allocate server handle error: %v", err)
}
conn.srv = (*C.OCIServer)(*handle)
if len(dsn.Connect) < 1 {
result = C.OCIServerAttach(
conn.srv, // uninitialized server handle, which gets initialized by this call. Passing in an initialized server handle causes an error.
conn.errHandle, // error handle
nil, // connect string or a service point
0, // length of the database server
C.OCI_DEFAULT, // mode of operation: OCI_DEFAULT or OCI_CPOOL
)
} else {
result = C.OCIServerAttach(
conn.srv, // uninitialized server handle, which gets initialized by this call. Passing in an initialized server handle causes an error.
conn.errHandle, // error handle
connectString, // connect string or a service point
C.sb4(len(dsn.Connect)), // length of the database server
C.OCI_DEFAULT, // mode of operation: OCI_DEFAULT or OCI_CPOOL
)
}
if result != C.OCI_SUCCESS {
err = conn.getError(result)
return nil, conn.getError(result)
}
doneServerAttach = true
// service handle
handle, _, err = conn.ociHandleAlloc(C.OCI_HTYPE_SVCCTX, 0)
if err != nil {
return nil, fmt.Errorf("allocate service handle error: %v", err)
}
conn.svc = (*C.OCISvcCtx)(*handle)
// sets the server context attribute of the service context
err = conn.ociAttrSet(unsafe.Pointer(conn.svc), C.OCI_HTYPE_SVCCTX, unsafe.Pointer(conn.srv), 0, C.OCI_ATTR_SERVER)
if err != nil {
return nil, fmt.Errorf("server context attribute set error: %v", err)
}
// user session handle
handle, _, err = conn.ociHandleAlloc(C.OCI_HTYPE_SESSION, 0)
if err != nil {
return nil, fmt.Errorf("allocate user session handle error: %v", err)
}
conn.usrSession = (*C.OCISession)(*handle)
credentialType := C.ub4(C.OCI_CRED_EXT)
if len(dsn.Username) > 0 {
// specifies a username to use for authentication
err = conn.ociAttrSet(unsafe.Pointer(conn.usrSession), C.OCI_HTYPE_SESSION, unsafe.Pointer(username), C.ub4(len(dsn.Username)), C.OCI_ATTR_USERNAME)
if err != nil {
return nil, fmt.Errorf("username attribute set error: %v", err)
}
// specifies a password to use for authentication
err = conn.ociAttrSet(unsafe.Pointer(conn.usrSession), C.OCI_HTYPE_SESSION, unsafe.Pointer(password), C.ub4(len(dsn.Password)), C.OCI_ATTR_PASSWORD)
if err != nil {
return nil, fmt.Errorf("password attribute set error: %v", err)
}
credentialType = C.OCI_CRED_RDBMS
}
result = C.OCISessionBegin(
conn.svc, // service context
conn.errHandle, // error handle
conn.usrSession, // user session context
credentialType, // type of credentials to use for establishing the user session: OCI_CRED_RDBMS or OCI_CRED_EXT
conn.operationMode, // mode of operation. https://docs.oracle.com/cd/B28359_01/appdev.111/b28395/oci16rel001.htm#LNOCI87690
)
if result != C.OCI_SUCCESS && result != C.OCI_SUCCESS_WITH_INFO {
err = conn.getError(result)
return nil, err
}
doneSessionBegin = true
// sets the authentication context attribute of the service context
err = conn.ociAttrSet(unsafe.Pointer(conn.svc), C.OCI_HTYPE_SVCCTX, unsafe.Pointer(conn.usrSession), 0, C.OCI_ATTR_SESSION)
if err != nil {
return nil, fmt.Errorf("authentication context attribute set error: %v", err)
}
if dsn.stmtCacheSize > 0 {
stmtCacheSize := dsn.stmtCacheSize
err = conn.ociAttrSet(unsafe.Pointer(conn.svc), C.OCI_HTYPE_SVCCTX, unsafe.Pointer(&stmtCacheSize), 0, C.OCI_ATTR_STMTCACHESIZE)
if err != nil {
return nil, fmt.Errorf("stmt cache size attribute set error: %v", err)
}
}
} else {
var svcCtxP *C.OCISvcCtx
svcCtxPP := &svcCtxP
result = C.OCILogon(
conn.env, // environment handle
conn.errHandle, // error handle
svcCtxPP, // service context pointer
username, // user name. Must be in the encoding specified by the charset parameter of a previous call to OCIEnvNlsCreate().
C.ub4(len(dsn.Username)), // length of user name, in number of bytes, regardless of the encoding
password, // user's password. Must be in the encoding specified by the charset parameter of a previous call to OCIEnvNlsCreate().
C.ub4(len(dsn.Password)), // length of password, in number of bytes, regardless of the encoding.
connectString, // name of the database to connect to. Must be in the encoding specified by the charset parameter of a previous call to OCIEnvNlsCreate().
C.ub4(len(dsn.Connect)), // length of dbname, in number of bytes, regardless of the encoding.
)
if result != C.OCI_SUCCESS && result != C.OCI_SUCCESS_WITH_INFO {
err = conn.getError(result)
return nil, err
}
conn.svc = *svcCtxPP
doneLogon = true
}
// Create transaction context.
handle, _, err = conn.ociHandleAlloc(C.OCI_HTYPE_TRANS, 0)
if err != nil {
return nil, fmt.Errorf("allocate transaction handle error: %v", err)
}
conn.txHandle = (*C.OCITrans)(*handle)
// Set transaction context attribute of the service context.
err = conn.ociAttrSet(unsafe.Pointer(conn.svc), C.OCI_HTYPE_SVCCTX, *handle, 0, C.OCI_ATTR_TRANS)
if err != nil {
return nil, fmt.Errorf("service context attribute set error: %v", err)
}
conn.transactionMode = dsn.transactionMode
conn.prefetchRows = dsn.prefetchRows
conn.prefetchMemory = dsn.prefetchMemory
conn.timeLocation = dsn.timeLocation
conn.enableQMPlaceholders = dsn.enableQMPlaceholders
return &conn, nil
}
// GetLastInsertId returns rowid from LastInsertId
func GetLastInsertId(id int64) string {
return *(*string)(unsafe.Pointer(uintptr(id)))
}
// LastInsertId returns last inserted ID
func (result *Result) LastInsertId() (int64, error) {
return int64(uintptr(unsafe.Pointer(&result.rowid))), result.rowidErr
}
// RowsAffected returns rows affected
func (result *Result) RowsAffected() (int64, error) {
return result.rowsAffected, result.rowsAffectedErr
}
// converts "?" characters to :1, :2, ... :n
func placeholders(sql string) string {
n := 0
return phre.ReplaceAllStringFunc(sql, func(string) string {
n++
return ":" + strconv.Itoa(n)
})
}
func timezoneToLocation(hour int64, minute int64) *time.Location {
if minute != 0 || hour > 14 || hour < -12 {
// create location with FixedZone
var name string
if hour < 0 {
name = strconv.FormatInt(hour, 10) + ":"
} else {
name = "+" + strconv.FormatInt(hour, 10) + ":"
}
if minute == 0 {
name += "00"
} else {
if minute < 10 {
name += "0"
}
name += strconv.FormatInt(minute, 10)
}
return time.FixedZone(name, (3600*int(hour))+(60*int(minute)))
}
// use location from timeLocations cache
return timeLocations[12+hour]
}