Skip to content

Commit

Permalink
⚡ (sql) avoid executing sql.schema on crud
Browse files Browse the repository at this point in the history
follow up to d791f87
  • Loading branch information
kkharji committed Aug 26, 2021
1 parent b688269 commit e977540
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 37 deletions.
53 changes: 28 additions & 25 deletions lua/sql.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ DB.__index = DB
---@field where table: key and value
---@field values table: key and value to updated.

---return now date
---@todo: decide whether using os.time and epoch time would be better.
---@return string osdate
local created = function()
return os.date "%Y-%m-%d %H:%M:%S"
---Get a table schema, or execute a given function to get it
---@param schema table|nil
---@param self SQLDatabase
local get_schema = function(tbl_name, self)
local schema = self.tbl_schemas[tbl_name]
if schema then
return schema
end
self.tbl_schemas[tbl_name] = self:schema(tbl_name)
return self.tbl_schemas[tbl_name]
end

---Creates a new sql.nvim object, without creating a connection to uri.
Expand Down Expand Up @@ -66,12 +71,13 @@ function DB:open(uri, opts, noconn)
closed = noconn and true or false,
sqlite_opts = opts,
modified = false,
created = not noconn and created() or nil,
created = not noconn and os.date "%Y-%m-%d %H:%M:%S" or nil,
tbl_schemas = {},
}, self)
else
if self.closed or self.closed == nil then
self.conn = clib.connect(self.uri, self.sqlite_opts)
self.created = created()
self.created = os.date "%Y-%m-%d %H:%M:%S"
self.closed = false
end
return self
Expand Down Expand Up @@ -261,6 +267,7 @@ end
---@usage `db:drop("todos")` drop table.
---@return boolean
function DB:drop(tbl_name)
self.tbl_schemas[tbl_name] = nil
return self:eval(P.drop(tbl_name))
end

Expand Down Expand Up @@ -294,7 +301,8 @@ end
function DB:insert(tbl_name, rows, schema)
a.is_sqltbl(self, tbl_name, "insert")
local ret_vals = {}
local items = P.pre_insert(rows, schema and schema or self:schema(tbl_name))
schema = schema and schema or get_schema(tbl_name, self)
local items = P.pre_insert(rows, schema)
local last_rowid
clib.wrap_stmts(self.conn, function()
for _, v in ipairs(items) do
Expand Down Expand Up @@ -326,18 +334,11 @@ end
---@usage `db:update("todos", { where = { project = "sql.nvim" }, values = { status = "later" } )` update multiple rows
function DB:update(tbl_name, specs, schema)
a.is_sqltbl(self, tbl_name, "update")
return not specs and false or clib.wrap_stmts(self.conn, function()
specs = u.is_nested(specs) and specs or { specs }
schema = schema and schema or get_schema(tbl_name, self)

local ret_vals = {}
if not specs then
return false
end

specs = u.is_nested(specs) and specs or { specs }
schema = schema and schema or self:schema(tbl_name)

local ret_val = nil

clib.wrap_stmts(self.conn, function()
local ret_val = nil
for _, v in ipairs(specs) do
v.set = v.set and v.set or v.values
if self:select(tbl_name, { where = v.where })[1] then
Expand All @@ -348,15 +349,15 @@ function DB:update(tbl_name, specs, schema)
s:bind_clear()
s:finalize()
a.should_modify(self:status())
ret_val = true
else
local res = self:insert(tbl_name, u.tbl_extend("keep", v.set, v.where))
table.insert(ret_vals, res)
ret_val = self:insert(tbl_name, u.tbl_extend("keep", v.set, v.where))
a.should_modify(self:status())
end
end
self.modified = true
return ret_val
end)

self.modified = true
return true
end

---Delete a {tbl_name} row/rows based on the {specs} given. if no spec was given,
Expand Down Expand Up @@ -404,6 +405,8 @@ function DB:select(tbl_name, spec, schema)
a.is_sqltbl(self, tbl_name, "select")
return clib.wrap_stmts(self.conn, function()
local ret = {}
schema = schema and schema or get_schema(tbl_name, self)

spec = spec or {}
spec.select = spec.keys and spec.keys or spec.select

Expand All @@ -415,7 +418,7 @@ function DB:select(tbl_name, spec, schema)
if stmt.finalize(s) then
self.modified = false
end
return P.post_select(ret, schema and schema or self:schema(tbl_name))
return P.post_select(ret, schema)
end)
end

Expand Down
8 changes: 8 additions & 0 deletions lua/sql/assert.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
local M = {}
local u = require "sql.utils"
local clib = require "sql.defs"

--- Functions for asseting and erroring out :D
Expand All @@ -10,6 +11,7 @@ local errors = {
failed_ops = "operation failed, ERRMSG: %s",
missing_req_key = "(insert) missing a required key: %s",
missing_db_object = "'%s' db object is not set. please set it with `tbl.set_db(db)` and try again.",
outdated_schema = "`%s` does not exists in {`%s`}, schema is outdateset `self.db.tbl_schemas[table_name]` or reload",
}

for key, value in pairs(errors) do
Expand Down Expand Up @@ -60,6 +62,12 @@ M.missing_req_key = function(val, key)
return false
end

M.should_have_column_def = function(column_def, k, schema)
if not column_def then
error(errors.outdated_schema:format(k, u.join(u.keys(schema), ", ")))
end
end

M.should_have_db_object = function(db, name)
assert(db ~= nil, errors.missing_db_object:format(name))
return true
Expand Down
6 changes: 4 additions & 2 deletions lua/sql/parser.lua
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,10 @@ M.pre_insert = function(rows, schema)
rows = u.is_nested(rows) and rows or { rows }
for i, row in ipairs(rows) do
res[i] = u.map(row, function(v, k)
a.missing_req_key(v, schema[k].required)
local is_json = schema[k].type == "luatable" or schema[k].type == "json"
local column_def = schema[k]
a.should_have_column_def(column_def, k, schema)
a.missing_req_key(v, column_def)
local is_json = column_def.type == "luatable" or column_def.type == "json"
return is_json and json.encode(v) or M.sqlvalue(v)
end)
end
Expand Down
22 changes: 12 additions & 10 deletions test/auto/sql_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ local curl = require "plenary.curl"
local eq = assert.are.same
local sql = require "sql"
local u = require "sql.utils"
local luv = require "luv"

describe("sql", function()
local path = "/tmp/db.sqlite3"
vim.loop.fs_unlink(path)
luv.fs_unlink(path)

describe("sqlfunctions:", function()
local s = sql.lib
Expand Down Expand Up @@ -53,7 +54,7 @@ describe("sql", function()
eq(true, db:close(), "should close")
eq(true, db:isclose(), "should close")
eq(true, P.exists(P.new(tmp)), "It should created the file")
vim.loop.fs_unlink(tmp)
luv.fs_unlink(tmp)
end)
it("should accept pargma options", function()
local tmp = "/tmp/db5.db"
Expand All @@ -63,7 +64,7 @@ describe("sql", function()
db:open()
eq("persist", db:eval("pragma journal_mode")[1].journal_mode)
db:close()
vim.loop.fs_unlink(tmp)
luv.fs_unlink(tmp)
end)
end)

Expand Down Expand Up @@ -99,7 +100,7 @@ describe("sql", function()

eq(true, db:close(), "It should close connection successfully.")
eq(true, P.exists(P.new(path)), "File should still exists")
vim.loop.fs_unlink(path)
luv.fs_unlink(path)
end)

it("returns data and time of creation", function()
Expand All @@ -115,7 +116,7 @@ describe("sql", function()
})
eq("persist", db:eval("pragma journal_mode")[1].journal_mode)
db:close()
vim.loop.fs_unlink(tmp)
luv.fs_unlink(tmp)
end)

it("reopen db object.", function()
Expand All @@ -132,7 +133,7 @@ describe("sql", function()

eq(path, db.uri, "uri should be identical to db.uri")
local res = db:eval "select * from todo"
eq(row, res[1], vim.loop.fs_unlink(path), "local row should equal db:eval result.")
eq(row, res[1], luv.fs_unlink(path), "local row should equal db:eval result.")
end)
end)

Expand All @@ -151,7 +152,7 @@ describe("sql", function()
eq(true, db.closed, "should be closed.")
eq("1", res[1].title, "should pass.")

vim.loop.fs_unlink(path)
luv.fs_unlink(path)
end)

it("works without initalizing sql objects. (via uri)", function()
Expand Down Expand Up @@ -373,6 +374,7 @@ describe("sql", function()
it("serialize lua table in sql column", function()
db:eval "drop table test"
db:eval "create table test(id integer, data luatable)"
db.tbl_schemas.test = nil
db:insert("test", { id = 1, data = { "list", "of", "lines" } })

local res = db:eval [[select * from test]]
Expand Down Expand Up @@ -548,7 +550,7 @@ describe("sql", function()
local posts, users

it(".... pre", function()
if vim.loop.fs_stat "/tmp/posts" == nil then
if luv.fs_stat "/tmp/posts" == nil then
curl.get("https://jsonplaceholder.typicode.com/posts", { output = "/tmp/posts" })
curl.get("https://jsonplaceholder.typicode.com/users", { output = "/tmp/users" })
end
Expand Down Expand Up @@ -911,7 +913,7 @@ describe("sql", function()
eq(3, db.st.count(), "should have inserted.")
end)

vim.loop.fs_unlink(testrui)
vim.loop.fs_unlink(testrui2)
luv.fs_unlink(testrui)
luv.fs_unlink(testrui2)
end)
end)

0 comments on commit e977540

Please sign in to comment.