diff options
-rw-r--r-- | script/parser/newparser.lua | 1 | ||||
-rw-r--r-- | script/vm/runner.lua | 49 | ||||
-rw-r--r-- | test/type_inference/init.lua | 26 |
3 files changed, 71 insertions, 5 deletions
diff --git a/script/parser/newparser.lua b/script/parser/newparser.lua index 864ffcff..4bddd7e5 100644 --- a/script/parser/newparser.lua +++ b/script/parser/newparser.lua @@ -117,6 +117,7 @@ local Specials = { ['xpcall'] = true, ['pairs'] = true, ['ipairs'] = true, + ['assert'] = true, } local UnarySymbol = { diff --git a/script/vm/runner.lua b/script/vm/runner.lua index 5c92dcbe..721e8c7f 100644 --- a/script/vm/runner.lua +++ b/script/vm/runner.lua @@ -81,17 +81,17 @@ function mt:_compileNarrowByFilter(filter, pos) if not loc or not exp then return end - if exp.type == 'nil' then + if guide.isLiteral(exp) then if filter.op.type == '==' then self.steps[#self.steps+1] = { type = 'remove', - name = 'nil', + name = exp.type, pos = pos, order = 2, } self.steps[#self.steps+1] = { type = 'as', - name = 'nil', + name = exp.type, pos = pos, order = 4, } @@ -99,13 +99,13 @@ function mt:_compileNarrowByFilter(filter, pos) if filter.op.type == '~=' then self.steps[#self.steps+1] = { type = 'as', - name = 'nil', + name = exp.type, pos = pos, order = 2, } self.steps[#self.steps+1] = { type = 'remove', - name = 'nil', + name = exp.type, pos = pos, order = 4, } @@ -248,6 +248,42 @@ function mt:_preCompile() end) end +---@param loc parser.object +---@param node vm.node +---@return vm.node +local function checkAssert(loc, node) + local parent = loc.parent + if parent.type == 'binary' then + if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then + local exp + for i = 1, 2 do + if parent[i] == loc then + exp = parent[i % 2 + 1] + end + end + if exp and guide.isLiteral(exp) then + local callargs = parent.parent + if callargs.type == 'callargs' + and callargs.parent.node.special == 'assert' + and callargs[1] == parent then + if parent.op.type == '~=' then + node:remove(exp.type) + end + if parent.op.type == '==' then + node = vm.compileNode(exp) + end + end + end + end + end + if parent.type == 'callargs' + and parent.parent.node.special == 'assert' + and parent[1] == loc then + node:setTruly() + end + return node +end + ---@param callback fun(src: parser.object, node: vm.node) function mt:launch(callback) local node = vm.getNode(self.loc):copy() @@ -267,6 +303,9 @@ function mt:launch(callback) node:remove(step.name) elseif step.type == 'object' then node = callback(step.object, node) or node + if step.object.type == 'getlocal' then + node = checkAssert(step.object, node) + end elseif step.type == 'save' then -- nothing to do elseif step.type == 'load' then diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index a9ea81a7..357dd04b 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -1919,3 +1919,29 @@ local <?x?> = t[1] TEST 'unknown' [[ local <?x?> = y and z ]] + +TEST 'integer' [[ +---@type integer? +local x + +assert(x) + +print(<?x?>) +]] + +TEST 'integer' [[ +---@type integer? +local x + +assert(x ~= nil) + +print(<?x?>) +]] + +TEST 'integer' [[ +local x + +assert(x == 1) + +print(<?x?>) +]] |