Skip to content

Commit

Permalink
✨ tbl:last_id and db:last_insert_id + more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kkharji committed Sep 2, 2021
1 parent a322be9 commit 1859c54
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 12 deletions.
5 changes: 4 additions & 1 deletion lua/sqlite/assert.lua
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ M.should_match_pk_type = function(name, kt, pk, key)
return error(errors.no_primary_key:format(name))
end

if knotstr and (pt == "string" or pt == "text") or knotnum and (pt == "number" or pt == "integer") then
if
kt ~= "boolean"
and (knotstr and (pt == "string" or pt == "text") or knotnum and (pt == "number" or pt == "integer"))
then
return error(errors.miss_match_pk_type:format(pk.name, pk.type, kt, name, key))
end

Expand Down
24 changes: 19 additions & 5 deletions lua/sqlite/db.lua
Original file line number Diff line number Diff line change
Expand Up @@ -560,15 +560,25 @@ function sqlite.db:select(tbl_name, spec, schema)

spec = spec or {}
spec.select = spec.keys and spec.keys or spec.select
local select = p.select(tbl_name, spec)
local st = ""

local stmt = s:parse(self.conn, p.select(tbl_name, spec))
s.each(stmt, function()
table.insert(ret, s.kv(stmt))
local stmt = s:parse(self.conn, select, tbl_name)
stmt:each(function()
table.insert(ret, stmt:kv())
end)
s.reset(stmt)
if s.finalize(stmt) then
if tbl_name == "todos_indexer" then
st = stmt:expand()
end

stmt:reset()
if stmt:finalize() then
self.modified = false
end
if tbl_name == "todos_indexer" and spec.id == 3 then
error(st)
end

return p.post_select(ret, schema)
end)
end
Expand Down Expand Up @@ -611,6 +621,10 @@ function sqlite.db:table(tbl_name, opts)
return self:tbl(tbl_name, opts)
end

function sqlite.db:last_insert_rowid()
return tonumber(clib.last_insert_rowid(self.conn))
end

---Sqlite functions sugar wrappers. See `sql/strfun`
sqlite.db.lib = require "sqlite.strfun"

Expand Down
1 change: 1 addition & 0 deletions lua/sqlite/helpers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ M.run = function(func, o)
o.db_schema = o.db:schema(o.name)
end

rawset(o, "last_id", o.db:last_insert_rowid())
--- Run wrapped function
return func()
end
Expand Down
8 changes: 5 additions & 3 deletions lua/sqlite/stmt.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ function sqlstmt:__parse()
assert(
code == flags.ok,
string.format(
"sqlite.lua: sql statement parse, , stmt: `%s`, err: `(`%s`)`",
self.str,
clib.to_str(clib.errmsg(self.conn))
"sqlite.lua\n(parse error): `%s` code == %d\nstatement == '%s'",
clib.to_str(clib.errmsg(self.conn)),
code,
self.str
)
)

self.pstmt = pstmt[0]
end

Expand Down
6 changes: 6 additions & 0 deletions lua/sqlite/tbl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,12 @@ function sqlite.tbl:set_db(db)
self.db = db
end

function sqlite.tbl:last_id()
h.run(function()
self.last_id = self.db:last_insert_rowid()
end, self)
end

sqlite.tbl = setmetatable(sqlite.tbl, {
__call = function(_, ...)
return sqlite.tbl.new(...)
Expand Down
23 changes: 21 additions & 2 deletions lua/sqlite/tbl/indexer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ local sep_query_and_where = function(q, keys)
end
return kv
end

---Print errors to the user
---@param func function
local sc = function(func)
local ok, val = xpcall(func, function(msg)
print(msg)
end)
return ok and val
end

return function(tbl)
local pk = get_primary_key(tbl.tbl_schema)
local extend = tbl_row_extender(tbl, pk)
Expand All @@ -109,7 +119,12 @@ return function(tbl)

if kt == "string" or kt == "number" and pk then
a.should_match_pk_type(tbl.name, kt, pk, arg)
return extend(tbl:where { [pk.name] = arg }, arg)
return extend(
tbl:where {
[pk.name] = arg,
},
arg
)
end

return kt == "table" and tbl:get(sep_query_and_where(arg, tbl_keys))
Expand Down Expand Up @@ -141,7 +156,11 @@ return function(tbl)

if vt == "table" and pk then
a.should_match_pk_type(tbl.name, kt, pk, arg)
return tbl:update { where = { [pk.name] = arg }, set = val }
if arg == 0 or arg == true or arg == "" then
return tbl:insert(val)
else
return tbl:update { where = { [pk.name] = arg }, set = val }
end
end
end

Expand Down
46 changes: 45 additions & 1 deletion test/auto/tbl_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ describe("sqlite.tbl", function()

describe("string_index:", function()
local kv = tbl("kvpair", {
key = { "text", primary = true, required = true, unique = true },
key = { "text", primary = true, required = true, default = "none" },
len = "integer",
}, db)

Expand Down Expand Up @@ -1038,6 +1038,13 @@ describe("sqlite.tbl", function()
}]
)
end)

it("insert with 0 or true to skip the primary key value.", function()
kv[true] = { len = 5 }
eq(5, kv.none.len)
kv[""] = { len = 6 }
eq({ key = "none", len = 6 }, kv:where { len = 6 })
end)
end)

describe("number_index", function()
Expand Down Expand Up @@ -1080,8 +1087,45 @@ describe("sqlite.tbl", function()
limit = 2,
}]
)
t[0] = { name = "x" }
eq("x", t[t.last_id].name)
end)
end)

describe("Relationships", function()
local todos = tbl("todos_indexer", {
id = true,
title = "text",
project = {
reference = "projects.title",
required = true,
on_delete = "cascade",
on_update = "cascade",
},
}, db)

local projects = tbl("projects", {
title = { type = "text", primary = true, required = true, unique = true },
deadline = { "date", default = db.lib.date "now" },
}, db)

it("create new table with default values", function()
projects.neovim = {}
eq("string", type(projects.neovim.deadline))
projects["sqlite"] = {}
--- TODO: if you have sqilte.lua todos[2] return empty table
end)

it("fails if foregin key doesn't exists", function()
eq(
false,
pcall(function()
todos[2].project = "ram"
end)
)
end)
end)

-- vim.loop.fs_unlink(db_path)
end)

Expand Down

0 comments on commit 1859c54

Please sign in to comment.