summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/parser/newparser.lua1
-rw-r--r--script/vm/runner.lua49
-rw-r--r--test/type_inference/init.lua26
3 files changed, 71 insertions, 5 deletions
diff --git a/script/parser/newparser.lua b/script/parser/newparser.lua
index 864ffcff..4bddd7e5 100644
--- a/script/parser/newparser.lua
+++ b/script/parser/newparser.lua
@@ -117,6 +117,7 @@ local Specials = {
['xpcall'] = true,
['pairs'] = true,
['ipairs'] = true,
+ ['assert'] = true,
}
local UnarySymbol = {
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
index 5c92dcbe..721e8c7f 100644
--- a/script/vm/runner.lua
+++ b/script/vm/runner.lua
@@ -81,17 +81,17 @@ function mt:_compileNarrowByFilter(filter, pos)
if not loc or not exp then
return
end
- if exp.type == 'nil' then
+ if guide.isLiteral(exp) then
if filter.op.type == '==' then
self.steps[#self.steps+1] = {
type = 'remove',
- name = 'nil',
+ name = exp.type,
pos = pos,
order = 2,
}
self.steps[#self.steps+1] = {
type = 'as',
- name = 'nil',
+ name = exp.type,
pos = pos,
order = 4,
}
@@ -99,13 +99,13 @@ function mt:_compileNarrowByFilter(filter, pos)
if filter.op.type == '~=' then
self.steps[#self.steps+1] = {
type = 'as',
- name = 'nil',
+ name = exp.type,
pos = pos,
order = 2,
}
self.steps[#self.steps+1] = {
type = 'remove',
- name = 'nil',
+ name = exp.type,
pos = pos,
order = 4,
}
@@ -248,6 +248,42 @@ function mt:_preCompile()
end)
end
+---@param loc parser.object
+---@param node vm.node
+---@return vm.node
+local function checkAssert(loc, node)
+ local parent = loc.parent
+ if parent.type == 'binary' then
+ if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then
+ local exp
+ for i = 1, 2 do
+ if parent[i] == loc then
+ exp = parent[i % 2 + 1]
+ end
+ end
+ if exp and guide.isLiteral(exp) then
+ local callargs = parent.parent
+ if callargs.type == 'callargs'
+ and callargs.parent.node.special == 'assert'
+ and callargs[1] == parent then
+ if parent.op.type == '~=' then
+ node:remove(exp.type)
+ end
+ if parent.op.type == '==' then
+ node = vm.compileNode(exp)
+ end
+ end
+ end
+ end
+ end
+ if parent.type == 'callargs'
+ and parent.parent.node.special == 'assert'
+ and parent[1] == loc then
+ node:setTruly()
+ end
+ return node
+end
+
---@param callback fun(src: parser.object, node: vm.node)
function mt:launch(callback)
local node = vm.getNode(self.loc):copy()
@@ -267,6 +303,9 @@ function mt:launch(callback)
node:remove(step.name)
elseif step.type == 'object' then
node = callback(step.object, node) or node
+ if step.object.type == 'getlocal' then
+ node = checkAssert(step.object, node)
+ end
elseif step.type == 'save' then
-- nothing to do
elseif step.type == 'load' then
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index a9ea81a7..357dd04b 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -1919,3 +1919,29 @@ local <?x?> = t[1]
TEST 'unknown' [[
local <?x?> = y and z
]]
+
+TEST 'integer' [[
+---@type integer?
+local x
+
+assert(x)
+
+print(<?x?>)
+]]
+
+TEST 'integer' [[
+---@type integer?
+local x
+
+assert(x ~= nil)
+
+print(<?x?>)
+]]
+
+TEST 'integer' [[
+local x
+
+assert(x == 1)
+
+print(<?x?>)
+]]