diff options
author | 最萌小汐 <sumneko@hotmail.com> | 2022-06-14 21:01:25 +0800 |
---|---|---|
committer | 最萌小汐 <sumneko@hotmail.com> | 2022-06-14 21:01:25 +0800 |
commit | d1567d7e47a7c418d7fa072e41e78b2c8cbf29a5 (patch) | |
tree | 89abe580e7b61aae80b916ee83cfaf0d71f6ade9 | |
parent | 11085b8e4172f054191d3cb9cffac3cbce1803f5 (diff) | |
parent | 329cd44aedec15f1048742d98baacdd71d434a6e (diff) | |
download | lua-language-server-d1567d7e47a7c418d7fa072e41e78b2c8cbf29a5.zip |
Merge branch 'runner'
-rw-r--r-- | script/core/completion/completion.lua | 4 | ||||
-rw-r--r-- | script/core/diagnostics/need-check-nil.lua | 2 | ||||
-rw-r--r-- | script/core/hover/description.lua | 9 | ||||
-rw-r--r-- | script/vm/compiler.lua | 3 | ||||
-rw-r--r-- | script/vm/infer.lua | 5 | ||||
-rw-r--r-- | script/vm/node.lua | 30 | ||||
-rw-r--r-- | script/vm/runner-bak.lua | 444 | ||||
-rw-r--r-- | script/vm/runner.lua | 590 | ||||
-rw-r--r-- | script/vm/sign.lua | 4 | ||||
-rw-r--r-- | test/type_inference/init.lua | 116 |
10 files changed, 818 insertions, 389 deletions
diff --git a/script/core/completion/completion.lua b/script/core/completion/completion.lua index d7c210c6..285f5bf2 100644 --- a/script/core/completion/completion.lua +++ b/script/core/completion/completion.lua @@ -1132,7 +1132,7 @@ local function checkTypingEnum(state, position, defs, str, results) or def.type == 'doc.type.integer' or def.type == 'doc.type.boolean' then enums[#enums+1] = { - label = vm.viewObject(def), + label = vm.viewObject(def, state.uri), description = def.comment and def.comment.text, kind = define.CompletionItemKind.EnumMember, } @@ -1427,7 +1427,7 @@ local function tryCallArg(state, position, results) or src.type == 'doc.type.integer' or src.type == 'doc.type.boolean' then enums[#enums+1] = { - label = vm.viewObject(src), + label = vm.viewObject(src, state.uri), description = src.comment, kind = define.CompletionItemKind.EnumMember, } diff --git a/script/core/diagnostics/need-check-nil.lua b/script/core/diagnostics/need-check-nil.lua index 98fdfd08..56cb1eae 100644 --- a/script/core/diagnostics/need-check-nil.lua +++ b/script/core/diagnostics/need-check-nil.lua @@ -28,7 +28,7 @@ return function (uri, callback) return end local node = vm.compileNode(src) - if node:hasFalsy() then + if node:hasFalsy() and not vm.getInfer(src):hasType(uri, 'any') then callback { start = src.start, finish = src.finish, diff --git a/script/core/hover/description.lua b/script/core/hover/description.lua index 3fef1a21..c96aaae3 100644 --- a/script/core/hover/description.lua +++ b/script/core/hover/description.lua @@ -144,7 +144,7 @@ local function tryDocModule(source) return collectRequire('require', source.module, guide.getUri(source)) end -local function buildEnumChunk(docType, name) +local function buildEnumChunk(docType, name, uri) if not docType then return nil end @@ -175,7 +175,7 @@ local function buildEnumChunk(docType, name) (enum.default and '->') or (enum.additional and '+>') or ' |', - vm.viewObject(enum) + vm.viewObject(enum, uri) ) if enum.comment then local first = true @@ -199,6 +199,7 @@ local function getBindEnums(source, docGroup) return end + local uri = guide.getUri(source) local mark = {} local chunks = {} local returnIndex = 0 @@ -209,7 +210,7 @@ local function getBindEnums(source, docGroup) goto CONTINUE end mark[name] = true - chunks[#chunks+1] = buildEnumChunk(doc.extends, name) + chunks[#chunks+1] = buildEnumChunk(doc.extends, name, uri) elseif doc.type == 'doc.return' then for _, rtn in ipairs(doc.returns) do returnIndex = returnIndex + 1 @@ -218,7 +219,7 @@ local function getBindEnums(source, docGroup) goto CONTINUE end mark[name] = true - chunks[#chunks+1] = buildEnumChunk(rtn, name) + chunks[#chunks+1] = buildEnumChunk(rtn, name, uri) end end ::CONTINUE:: diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 9b871553..0cbe2b7a 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -998,8 +998,7 @@ local compilerSwitch = util.switch() local hasMark = vm.getNode(source):getData 'hasDefined' - local runner = vm.createRunner(source) - runner:launch(function (src, node) + vm.launchRunner(source, function (src, node) if src.type == 'setlocal' then if src.bindDocs then for _, doc in ipairs(src.bindDocs) do diff --git a/script/vm/infer.lua b/script/vm/infer.lua index 712cff10..9bcd3963 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -417,7 +417,8 @@ function mt:viewClass() end ---@param source parser.object +---@param uri uri ---@return string? -function vm.viewObject(source) - return viewNodeSwitch(source.type, source, {}, guide.getUri(source)) +function vm.viewObject(source, uri) + return viewNodeSwitch(source.type, source, {}, uri) end diff --git a/script/vm/node.lua b/script/vm/node.lua index 6a86350a..5086f66d 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -2,6 +2,7 @@ local files = require 'files' ---@class vm local vm = require 'vm.vm' local ws = require 'workspace.workspace' +local guide = require 'parser.guide' ---@type table<vm.object, vm.node> vm.nodeCache = {} @@ -105,6 +106,18 @@ function mt:hasFalsy() return false end +function mt:hasKnownType() + for _, c in ipairs(self) do + if c.type == 'global' and c.cate == 'type' then + return true + end + if guide.isLiteral(c) then + return true + end + end + return false +end + ---@return boolean function mt:isNullable() if self.optional then @@ -152,6 +165,7 @@ function mt:setTruthy() if hasBoolean then self[#self+1] = vm.declareGlobal('type', 'true') end + return self end ---@return vm.node @@ -174,12 +188,19 @@ function mt:setFalsy() hasBoolean = true table.remove(self, index) self[c] = nil + goto CONTINUE + end + if (c.type == 'global' and c.cate == 'type') then + table.remove(self, index) + self[c] = nil + goto CONTINUE end ::CONTINUE:: end if hasBoolean then self[#self+1] = vm.declareGlobal('type', 'false') end + return self end ---@param name string @@ -200,6 +221,7 @@ function mt:remove(name) self[c] = nil end end + return self end ---@param node vm.node @@ -207,6 +229,14 @@ function mt:removeNode(node) for _, c in ipairs(node) do if c.type == 'global' and c.cate == 'type' then self:remove(c.name) + elseif c.type == 'nil' then + self:remove 'nil' + elseif c.type == 'boolean' then + if c[1] == true then + self:remove 'true' + else + self:remove 'false' + end end end end diff --git a/script/vm/runner-bak.lua b/script/vm/runner-bak.lua new file mode 100644 index 00000000..5535f0ca --- /dev/null +++ b/script/vm/runner-bak.lua @@ -0,0 +1,444 @@ +---@class vm +local vm = require 'vm.vm' +local guide = require 'parser.guide' + +---@class vm.runner-bak +---@field loc parser.object +---@field mainBlock parser.object +---@field blocks table<parser.object, true> +---@field steps vm.runner.step[] +local mt = {} +mt.__index = mt +mt.index = 1 + +---@class parser.object +---@field _casts parser.object[] + +---@class vm.runner.step +---@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast' +---@field pos integer +---@field order? integer +---@field node? vm.node +---@field object? parser.object +---@field name? string +---@field cast? parser.object +---@field tag? string +---@field copy? boolean +---@field new? boolean +---@field ref1? vm.runner.step +---@field ref2? vm.runner.step + +---@param filter parser.object +---@param outStep vm.runner.step +---@param blockStep vm.runner.step +function mt:_compileNarrowByFilter(filter, outStep, blockStep) + if not filter then + return + end + if filter.type == 'paren' then + if filter.exp then + self:_compileNarrowByFilter(filter.exp, outStep, blockStep) + end + return + end + if filter.type == 'unary' then + if not filter.op + or not filter[1] then + return + end + if filter.op.type == 'not' then + local exp = filter[1] + if exp.type == 'getlocal' and exp.node == self.loc then + self.steps[#self.steps+1] = { + type = 'falsy', + pos = filter.finish, + new = true, + } + self.steps[#self.steps+1] = { + type = 'truthy', + pos = filter.finish, + ref1 = outStep, + } + end + end + elseif filter.type == 'binary' then + if not filter.op + or not filter[1] + or not filter[2] then + return + end + if filter.op.type == 'and' then + local dummyStep = { + type = 'save', + copy = true, + ref1 = outStep, + pos = filter.start - 1, + } + self.steps[#self.steps+1] = dummyStep + self:_compileNarrowByFilter(filter[1], dummyStep, blockStep) + self:_compileNarrowByFilter(filter[2], dummyStep, blockStep) + end + if filter.op.type == 'or' then + self:_compileNarrowByFilter(filter[1], outStep, blockStep) + local dummyStep = { + type = 'push', + copy = true, + ref1 = outStep, + pos = filter.op.finish, + } + self.steps[#self.steps+1] = dummyStep + self:_compileNarrowByFilter(filter[2], outStep, dummyStep) + self.steps[#self.steps+1] = { + type = 'push', + tag = 'or reset', + ref1 = blockStep, + pos = filter.finish, + } + end + if filter.op.type == '==' + or filter.op.type == '~=' then + local loc, exp + for i = 1, 2 do + loc = filter[i] + if loc.type == 'getlocal' and loc.node == self.loc then + exp = filter[i % 2 + 1] + break + end + end + if not loc or not exp then + return + end + if guide.isLiteral(exp) then + if filter.op.type == '==' then + self.steps[#self.steps+1] = { + type = 'remove', + name = exp.type, + pos = filter.finish, + ref1 = outStep, + } + self.steps[#self.steps+1] = { + type = 'as', + name = exp.type, + pos = filter.finish, + new = true, + } + end + if filter.op.type == '~=' then + self.steps[#self.steps+1] = { + type = 'as', + name = exp.type, + pos = filter.finish, + ref1 = outStep, + } + self.steps[#self.steps+1] = { + type = 'remove', + name = exp.type, + pos = filter.finish, + new = true, + } + end + end + end + else + if filter.type == 'getlocal' and filter.node == self.loc then + self.steps[#self.steps+1] = { + type = 'truthy', + pos = filter.finish, + new = true, + } + self.steps[#self.steps+1] = { + type = 'falsy', + pos = filter.finish, + ref1 = outStep, + } + end + end +end + +---@param block parser.object +function mt:_compileBlock(block) + if self.blocks[block] then + return + end + self.blocks[block] = true + if block == self.mainBlock then + return + end + + local parentBlock = guide.getParentBlock(block) + self:_compileBlock(parentBlock) + + if block.type == 'if' then + ---@type vm.runner.step[] + local finals = {} + for i, childBlock in ipairs(block) do + local blockStep = { + type = 'save', + tag = 'block', + copy = true, + pos = childBlock.start, + } + local outStep = { + type = 'save', + tag = 'out', + copy = true, + pos = childBlock.start, + } + self.steps[#self.steps+1] = blockStep + self.steps[#self.steps+1] = outStep + self.steps[#self.steps+1] = { + type = 'push', + ref1 = blockStep, + pos = childBlock.start, + } + self:_compileNarrowByFilter(childBlock.filter, outStep, blockStep) + if not childBlock.hasReturn + and not childBlock.hasGoTo + and not childBlock.hasBreak then + local finalStep = { + type = 'save', + pos = childBlock.finish, + tag = 'final #' .. i, + } + finals[#finals+1] = finalStep + self.steps[#self.steps+1] = finalStep + end + self.steps[#self.steps+1] = { + type = 'push', + tag = 'reset child', + ref1 = outStep, + pos = childBlock.finish, + } + end + self.steps[#self.steps+1] = { + type = 'push', + tag = 'reset if', + pos = block.finish, + copy = true, + } + for _, final in ipairs(finals) do + self.steps[#self.steps+1] = { + type = 'merge', + ref2 = final, + pos = block.finish, + } + end + end + + if block.type == 'function' + or block.type == 'while' + or block.type == 'loop' + or block.type == 'in' + or block.type == 'repeat' + or block.type == 'for' then + local savePoint = { + type = 'save', + copy = true, + pos = block.start, + } + self.steps[#self.steps+1] = { + type = 'push', + copy = true, + pos = block.start, + } + self.steps[#self.steps+1] = savePoint + self.steps[#self.steps+1] = { + type = 'push', + pos = block.finish, + ref1 = savePoint, + } + end +end + +---@return parser.object[] +function mt:_getCasts() + local root = guide.getRoot(self.loc) + if not root._casts then + root._casts = {} + local docs = root.docs + for _, doc in ipairs(docs) do + if doc.type == 'doc.cast' and doc.loc then + root._casts[#root._casts+1] = doc + end + end + end + return root._casts +end + +function mt:_preCompile() + local startPos = self.loc.start + local finishPos = 0 + + for _, ref in ipairs(self.loc.ref) do + self.steps[#self.steps+1] = { + type = 'object', + object = ref, + pos = ref.range or ref.start, + } + if ref.start > finishPos then + finishPos = ref.start + end + local block = guide.getParentBlock(ref) + self:_compileBlock(block) + end + + for i, step in ipairs(self.steps) do + if step.type ~= 'object' then + step.order = i + end + end + + local casts = self:_getCasts() + for _, cast in ipairs(casts) do + if cast.loc[1] == self.loc[1] + and cast.start > startPos + and cast.finish < finishPos + and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then + self.steps[#self.steps+1] = { + type = 'cast', + cast = cast, + pos = cast.start, + } + end + end + + table.sort(self.steps, function (a, b) + if a.pos == b.pos then + return (a.order or 0) < (b.order or 0) + else + return a.pos < b.pos + end + end) +end + +---@param loc parser.object +---@param node vm.node +---@return vm.node +local function checkAssert(loc, node) + local parent = loc.parent + if parent.type == 'binary' then + if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then + local exp + for i = 1, 2 do + if parent[i] == loc then + exp = parent[i % 2 + 1] + end + end + if exp and guide.isLiteral(exp) then + local callargs = parent.parent + if callargs.type == 'callargs' + and callargs.parent.node.special == 'assert' + and callargs[1] == parent then + if parent.op.type == '~=' then + node:remove(exp.type) + end + if parent.op.type == '==' then + node = vm.compileNode(exp) + end + end + end + end + end + if parent.type == 'callargs' + and parent.parent.node.special == 'assert' + and parent[1] == loc then + node:setTruthy() + end + return node +end + +---@param callback fun(src: parser.object, node: vm.node) +function mt:launch(callback) + local topNode = vm.getNode(self.loc):copy() + for _, step in ipairs(self.steps) do + local node = step.ref1 and step.ref1.node or topNode + if step.type == 'truthy' then + if step.new then + node = node:copy() + topNode = node + end + node:setTruthy() + elseif step.type == 'falsy' then + if step.new then + node = node:copy() + topNode = node + end + node:setFalsy() + elseif step.type == 'as' then + if step.new then + topNode = vm.createNode(vm.getGlobal('type', step.name)) + else + node:clear() + node:merge(vm.getGlobal('type', step.name)) + end + elseif step.type == 'add' then + if step.new then + node = node:copy() + topNode = node + end + node:merge(vm.getGlobal('type', step.name)) + elseif step.type == 'remove' then + if step.new then + node = node:copy() + topNode = node + end + node:remove(step.name) + elseif step.type == 'object' then + topNode = callback(step.object, node) or node + if step.object.type == 'getlocal' then + topNode = checkAssert(step.object, node) + end + elseif step.type == 'save' then + if step.copy then + node = node:copy() + end + step.node = node + elseif step.type == 'push' then + if step.copy then + node = node:copy() + end + topNode = node + elseif step.type == 'merge' then + node:merge(step.ref2.node) + elseif step.type == 'cast' then + topNode = node:copy() + for _, cast in ipairs(step.cast.casts) do + if cast.mode == '+' then + if cast.optional then + topNode:addOptional() + end + if cast.extends then + topNode:merge(vm.compileNode(cast.extends)) + end + elseif cast.mode == '-' then + if cast.optional then + topNode:removeOptional() + end + if cast.extends then + topNode:removeNode(vm.compileNode(cast.extends)) + end + else + if cast.extends then + topNode:clear() + topNode:merge(vm.compileNode(cast.extends)) + end + end + end + end + end +end + +---@param loc parser.object +---@return vm.runner +function vm.launchRunner(loc) + local self = setmetatable({ + loc = loc, + mainBlock = guide.getParentBlock(loc), + blocks = {}, + steps = {}, + }, mt) + + self:_preCompile() + + return self +end diff --git a/script/vm/runner.lua b/script/vm/runner.lua index 9fe0f172..a9c38a87 100644 --- a/script/vm/runner.lua +++ b/script/vm/runner.lua @@ -2,257 +2,19 @@ local vm = require 'vm.vm' local guide = require 'parser.guide' +---@alias vm.runner.callback fun(src: parser.object, node: vm.node) + ---@class vm.runner ----@field loc parser.object ----@field mainBlock parser.object ----@field blocks table<parser.object, true> ----@field steps vm.runner.step[] +---@field _loc parser.object +---@field _objs parser.object[] +---@field _callback vm.runner.callback local mt = {} mt.__index = mt -mt.index = 1 - ----@class parser.object ----@field _casts parser.object[] - ----@class vm.runner.step ----@field type 'truthy' | 'falsy' | 'as' | 'add' | 'remove' | 'object' | 'save' | 'push' | 'merge' | 'cast' ----@field pos integer ----@field order? integer ----@field node? vm.node ----@field object? parser.object ----@field name? string ----@field cast? parser.object ----@field tag? string ----@field copy? boolean ----@field new? boolean ----@field ref1? vm.runner.step ----@field ref2? vm.runner.step - ----@param filter parser.object ----@param outStep vm.runner.step ----@param blockStep vm.runner.step -function mt:_compileNarrowByFilter(filter, outStep, blockStep) - if not filter then - return - end - if filter.type == 'paren' then - if filter.exp then - self:_compileNarrowByFilter(filter.exp, outStep, blockStep) - end - return - end - if filter.type == 'unary' then - if not filter.op - or not filter[1] then - return - end - if filter.op.type == 'not' then - local exp = filter[1] - if exp.type == 'getlocal' and exp.node == self.loc then - self.steps[#self.steps+1] = { - type = 'falsy', - pos = filter.finish, - new = true, - } - self.steps[#self.steps+1] = { - type = 'truthy', - pos = filter.finish, - ref1 = outStep, - } - end - end - elseif filter.type == 'binary' then - if not filter.op - or not filter[1] - or not filter[2] then - return - end - if filter.op.type == 'and' then - local dummyStep = { - type = 'save', - copy = true, - ref1 = outStep, - pos = filter.start - 1, - } - self.steps[#self.steps+1] = dummyStep - self:_compileNarrowByFilter(filter[1], dummyStep, blockStep) - self:_compileNarrowByFilter(filter[2], dummyStep, blockStep) - end - if filter.op.type == 'or' then - self:_compileNarrowByFilter(filter[1], outStep, blockStep) - local dummyStep = { - type = 'push', - copy = true, - ref1 = outStep, - pos = filter.op.finish, - } - self.steps[#self.steps+1] = dummyStep - self:_compileNarrowByFilter(filter[2], outStep, dummyStep) - self.steps[#self.steps+1] = { - type = 'push', - tag = 'or reset', - ref1 = blockStep, - pos = filter.finish, - } - end - if filter.op.type == '==' - or filter.op.type == '~=' then - local loc, exp - for i = 1, 2 do - loc = filter[i] - if loc.type == 'getlocal' and loc.node == self.loc then - exp = filter[i % 2 + 1] - break - end - end - if not loc or not exp then - return - end - if guide.isLiteral(exp) then - if filter.op.type == '==' then - self.steps[#self.steps+1] = { - type = 'remove', - name = exp.type, - pos = filter.finish, - ref1 = outStep, - } - self.steps[#self.steps+1] = { - type = 'as', - name = exp.type, - pos = filter.finish, - new = true, - } - end - if filter.op.type == '~=' then - self.steps[#self.steps+1] = { - type = 'as', - name = exp.type, - pos = filter.finish, - ref1 = outStep, - } - self.steps[#self.steps+1] = { - type = 'remove', - name = exp.type, - pos = filter.finish, - new = true, - } - end - end - end - else - if filter.type == 'getlocal' and filter.node == self.loc then - self.steps[#self.steps+1] = { - type = 'truthy', - pos = filter.finish, - new = true, - } - self.steps[#self.steps+1] = { - type = 'falsy', - pos = filter.finish, - ref1 = outStep, - } - end - end -end - ----@param block parser.object -function mt:_compileBlock(block) - if self.blocks[block] then - return - end - self.blocks[block] = true - if block == self.mainBlock then - return - end - - local parentBlock = guide.getParentBlock(block) - self:_compileBlock(parentBlock) - - if block.type == 'if' then - ---@type vm.runner.step[] - local finals = {} - for i, childBlock in ipairs(block) do - local blockStep = { - type = 'save', - tag = 'block', - copy = true, - pos = childBlock.start, - } - local outStep = { - type = 'save', - tag = 'out', - copy = true, - pos = childBlock.start, - } - self.steps[#self.steps+1] = blockStep - self.steps[#self.steps+1] = outStep - self.steps[#self.steps+1] = { - type = 'push', - ref1 = blockStep, - pos = childBlock.start, - } - self:_compileNarrowByFilter(childBlock.filter, outStep, blockStep) - if not childBlock.hasReturn - and not childBlock.hasGoTo - and not childBlock.hasBreak then - local finalStep = { - type = 'save', - pos = childBlock.finish, - tag = 'final #' .. i, - } - finals[#finals+1] = finalStep - self.steps[#self.steps+1] = finalStep - end - self.steps[#self.steps+1] = { - type = 'push', - tag = 'reset child', - ref1 = outStep, - pos = childBlock.finish, - } - end - self.steps[#self.steps+1] = { - type = 'push', - tag = 'reset if', - pos = block.finish, - copy = true, - } - for _, final in ipairs(finals) do - self.steps[#self.steps+1] = { - type = 'merge', - ref2 = final, - pos = block.finish, - } - end - end - - if block.type == 'function' - or block.type == 'while' - or block.type == 'loop' - or block.type == 'in' - or block.type == 'repeat' - or block.type == 'for' then - local savePoint = { - type = 'save', - copy = true, - pos = block.start, - } - self.steps[#self.steps+1] = { - type = 'push', - copy = true, - pos = block.start, - } - self.steps[#self.steps+1] = savePoint - self.steps[#self.steps+1] = { - type = 'push', - pos = block.finish, - ref1 = savePoint, - } - end -end +mt._index = 1 ---@return parser.object[] function mt:_getCasts() - local root = guide.getRoot(self.loc) + local root = guide.getRoot(self._loc) if not root._casts then root._casts = {} local docs = root.docs @@ -265,180 +27,256 @@ function mt:_getCasts() return root._casts end -function mt:_preCompile() - local startPos = self.loc.start +function mt:_collect() + local startPos = self._loc.start local finishPos = 0 - for _, ref in ipairs(self.loc.ref) do - self.steps[#self.steps+1] = { - type = 'object', - object = ref, - pos = ref.range or ref.start, - } - if ref.start > finishPos then - finishPos = ref.start + for _, ref in ipairs(self._loc.ref) do + if ref.type == 'getlocal' + or ref.type == 'setlocal' then + self._objs[#self._objs+1] = ref + if ref.start > finishPos then + finishPos = ref.start + end end - local block = guide.getParentBlock(ref) - self:_compileBlock(block) end - for i, step in ipairs(self.steps) do - if step.type ~= 'object' then - step.order = i - end + if #self._objs == 0 then + return end local casts = self:_getCasts() for _, cast in ipairs(casts) do - if cast.loc[1] == self.loc[1] + if cast.loc[1] == self._loc[1] and cast.start > startPos and cast.finish < finishPos - and guide.getLocal(self.loc, self.loc[1], cast.start) == self.loc then - self.steps[#self.steps+1] = { - type = 'cast', - cast = cast, - pos = cast.start, - } + and guide.getLocal(self._loc, self._loc[1], cast.start) == self._loc then + self._objs[#self._objs+1] = cast end end - table.sort(self.steps, function (a, b) - if a.pos == b.pos then - return (a.order or 0) < (b.order or 0) - else - return a.pos < b.pos - end + table.sort(self._objs, function (a, b) + return (a.range or a.finish) < (b.range or b.start) end) end ----@param loc parser.object + +---@param pos integer ---@param node vm.node ---@return vm.node -local function checkAssert(loc, node) - local parent = loc.parent - if parent.type == 'binary' then - if parent.op and (parent.op.type == '~=' or parent.op.type == '==') then - local exp - for i = 1, 2 do - if parent[i] == loc then - exp = parent[i % 2 + 1] - end - end - if exp and guide.isLiteral(exp) then - local callargs = parent.parent - if callargs.type == 'callargs' - and callargs.parent.node.special == 'assert' - and callargs[1] == parent then - if parent.op.type == '~=' then - node:remove(exp.type) - end - if parent.op.type == '==' then - node = vm.compileNode(exp) - end - end - end +---@return parser.object? +function mt:_fastWard(pos, node) + for i = self._index, #self._objs do + local obj = self._objs[i] + if (obj.range or obj.finish) > pos then + self._index = i + return node, obj end - end - if parent.type == 'callargs' - and parent.parent.node.special == 'assert' - and parent[1] == loc then - node:setTruthy() - end - return node -end - ----@param callback fun(src: parser.object, node: vm.node) -function mt:launch(callback) - local topNode = vm.getNode(self.loc):copy() - for _, step in ipairs(self.steps) do - local node = step.ref1 and step.ref1.node or topNode - if step.type == 'truthy' then - if step.new then - node = node:copy() - topNode = node - end - node:setTruthy() - elseif step.type == 'falsy' then - if step.new then - node = node:copy() - topNode = node - end - node:setFalsy() - elseif step.type == 'as' then - if step.new then - topNode = vm.createNode(vm.getGlobal('type', step.name)) - else - node:clear() - node:merge(vm.getGlobal('type', step.name)) - end - elseif step.type == 'add' then - if step.new then - node = node:copy() - topNode = node - end - node:merge(vm.getGlobal('type', step.name)) - elseif step.type == 'remove' then - if step.new then - node = node:copy() - topNode = node - end - node:remove(step.name) - elseif step.type == 'object' then - topNode = callback(step.object, node) or node - if step.object.type == 'getlocal' then - topNode = checkAssert(step.object, node) - end - elseif step.type == 'save' then - if step.copy then - node = node:copy() - end - step.node = node - elseif step.type == 'push' then - if step.copy then - node = node:copy() + if obj.type == 'getlocal' then + self._callback(obj, node) + elseif obj.type == 'setlocal' then + local newNode = self._callback(obj, node) + if newNode then + node = newNode:copy() end - topNode = node - elseif step.type == 'merge' then - node:merge(step.ref2.node) - elseif step.type == 'cast' then - topNode = node:copy() - for _, cast in ipairs(step.cast.casts) do + elseif obj.type == 'doc.cast' then + node = node:copy() + for _, cast in ipairs(obj.casts) do if cast.mode == '+' then if cast.optional then - topNode:addOptional() + node:addOptional() end if cast.extends then - topNode:merge(vm.compileNode(cast.extends)) + node:merge(vm.compileNode(cast.extends)) end elseif cast.mode == '-' then if cast.optional then - topNode:removeOptional() + node:removeOptional() end if cast.extends then - topNode:removeNode(vm.compileNode(cast.extends)) + node:removeNode(vm.compileNode(cast.extends)) end else if cast.extends then - topNode:clear() - topNode:merge(vm.compileNode(cast.extends)) + node:clear() + node:merge(vm.compileNode(cast.extends)) end end end end end + self._index = #self._objs + 1 + return node, nil +end + +---@param action parser.object +---@param topNode vm.node +---@param outNode? vm.node +---@return vm.node +function mt:_lookInto(action, topNode, outNode) + local set + local value = vm.getObjectValue(action) + if value then + set = action + action = value + end + if action.type == 'function' + or action.type == 'loop' + or action.type == 'in' + or action.type == 'repeat' + or action.type == 'for' then + self:_launchBlock(action, topNode:copy()) + elseif action.type == 'while' then + local blockNode, mainNode = self:_lookInto(action.filter, topNode:copy(), topNode:copy()) + self:_fastWard(action.filter.finish, blockNode) + self:_launchBlock(action, blockNode:copy()) + topNode = mainNode + elseif action.type == 'if' then + local hasElse + local mainNode = topNode:copy() + local blockNodes = {} + for _, subBlock in ipairs(action) do + local blockNode = mainNode:copy() + if subBlock.filter then + blockNode, mainNode = self:_lookInto(subBlock.filter, blockNode, mainNode) + self:_fastWard(subBlock.filter.finish, blockNode) + else + hasElse = true + mainNode:clear() + end + blockNode = self:_launchBlock(subBlock, blockNode:copy()) + local neverReturn = subBlock.hasReturn + or subBlock.hasGoTo + or subBlock.hasBreak + if not neverReturn then + blockNodes[#blockNodes+1] = blockNode + end + end + if not hasElse and not topNode:hasKnownType() then + mainNode:merge(vm.declareGlobal('type', 'unknown')) + end + for _, blockNode in ipairs(blockNodes) do + mainNode:merge(blockNode) + end + topNode = mainNode + elseif action.type == 'getlocal' then + if action.node == self._loc then + topNode = self:_fastWard(action.finish, topNode) + topNode = topNode:copy():setTruthy() + if outNode then + outNode:setFalsy() + end + end + elseif action.type == 'unary' then + if not action[1] then + goto RETURN + end + if action.op.type == 'not' then + outNode = outNode or topNode:copy() + outNode, topNode = self:_lookInto(action[1], topNode, outNode) + end + elseif action.type == 'binary' then + if not action[1] or not action[2] then + goto RETURN + end + if action.op.type == 'and' then + topNode = self:_lookInto(action[1], topNode) + topNode = self:_lookInto(action[2], topNode) + elseif action.op.type == 'or' then + outNode = outNode or topNode:copy() + local topNode1, outNode1 = self:_lookInto(action[1], topNode, outNode) + local topNode2, outNode2 = self:_lookInto(action[2], outNode1, outNode1:copy()) + topNode = vm.createNode(topNode1, topNode2) + outNode = outNode2 + elseif action.op.type == '==' + or action.op.type == '~=' then + local loc, checker + for i = 1, 2 do + if action[i].type == 'getlocal' and action[i].node == self._loc then + loc = action[i] + checker = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1` + elseif action[2].type == 'getlocal' and action[2].node == self._loc then + loc = action[3-i] + checker = action[i] + end + end + if loc then + self:_fastWard(loc.finish, topNode) + if guide.isLiteral(checker) then + local checkerNode = vm.compileNode(checker) + if action.op.type == '==' then + topNode = checkerNode + if outNode then + outNode:removeNode(topNode) + end + else + topNode:removeNode(checkerNode) + if outNode then + outNode = checkerNode + end + end + end + end + end + elseif action.type == 'call' then + if action.node.special == 'assert' and action.args and action.args[1] then + topNode = self:_lookInto(action.args[1], topNode) + elseif action.args then + for _, arg in ipairs(action.args) do + self:_lookInto(arg, topNode) + end + end + elseif action.type == 'return' then + for _, rtn in ipairs(action) do + self:_lookInto(rtn, topNode) + end + end + ::RETURN:: + topNode = self:_fastWard(action.finish, topNode) + if set then + topNode = self:_fastWard(set.range or set.finish, topNode) + end + return topNode, outNode +end + +---@param block parser.object +---@param node vm.node +---@return vm.node +function mt:_launchBlock(block, node) + local topNode, top = self:_fastWard(block.start, node) + if not top then + return topNode + end + for _, action in ipairs(block) do + if (action.range or action.finish) < (top.range or top.finish) then + goto CONTINUE + end + topNode = self:_lookInto(action, topNode) + topNode, top = self:_fastWard(action.range or action.finish, topNode) + if not top then + return topNode + end + ::CONTINUE:: + end + -- `x = function () end`: don't touch `x` in the end of function + topNode = self:_fastWard(block.finish - 1, topNode) + return topNode end ---@param loc parser.object ----@return vm.runner -function vm.createRunner(loc) +---@param callback vm.runner.callback +function vm.launchRunner(loc, callback) local self = setmetatable({ - loc = loc, - mainBlock = guide.getParentBlock(loc), - blocks = {}, - steps = {}, + _loc = loc, + _objs = {}, + _callback = callback, }, mt) - self:_preCompile() + self:_collect() + + if #self._objs == 0 then + return + end - return self + self:_launchBlock(guide.getParentBlock(loc), vm.getNode(loc):copy()) end diff --git a/script/vm/sign.lua b/script/vm/sign.lua index fe112bc2..5c17aa3d 100644 --- a/script/vm/sign.lua +++ b/script/vm/sign.lua @@ -111,7 +111,7 @@ function mt:resolve(uri, args, removeGeneric) goto CONTINUE end end - local view = vm.viewObject(obj) + local view = vm.viewObject(obj, uri) if view then knownTypes[view] = true end @@ -130,7 +130,7 @@ function mt:resolve(uri, args, removeGeneric) if argNode:hasFalsy() then goto CONTINUE end - local view = vm.viewObject(n) + local view = vm.viewObject(n, uri) if knownTypes[view] then goto CONTINUE end diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 566e847a..efcc5fea 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -1813,6 +1813,17 @@ TEST 'integer' [[ ---@type integer? local x +if not x then + return +end + +print(<?x?>) +]] + +TEST 'integer' [[ +---@type integer? +local x + if xxx and x then print(<?x?>) end @@ -2153,6 +2164,15 @@ print(<?x?>) ]] TEST 'integer' [[ +---@type integer? +local x + +while x do + print(<?x?>) +end +]] + +TEST 'integer' [[ ---@type fun():integer? local iter @@ -2237,6 +2257,19 @@ local x print(<?x?>) ]] +TEST 'unknown?' [[ +---@type string? +local x + +if x then + return +else + print(<?x?>) +end + +print(x) +]] + TEST 'string' [[ ---@type string? local x @@ -2298,6 +2331,18 @@ end print(<?t?>) ]] +TEST 'unknown?' [[ +---@type integer? +local t + +if t then +else + print(<?t?>) +end + +print(t) +]] + TEST 'table|unknown' [[ local function f() if x then @@ -2375,3 +2420,74 @@ TEST '`1`|`true`' [[ ---@type `1` | `true` local <?x?> ]] + +TEST 'function' [[ +local x + +function x() end + +print(<?x?>) +]] + +TEST 'unknown' [[ +local x + +if x.field == 'haha' then + print(<?x?>) +end +]] + +TEST 'string' [[ +---@type string? +local t + +if not t or xxx then + return +end + +print(<?t?>) +]] + +TEST 'table' [[ +---@type table|nil +local t + +return function () + if not t then + return + end + + print(<?t?>) +end +]] + +TEST 'table' [[ +---@type table|nil +local t + +f(function () + if not t then + return + end + + print(<?t?>) +end) +]] + +TEST 'table' [[ +---@type table? +local t + +t = t or {} + +print(<?t?>) +]] + +TEST 'unknown|nil' [[ +local x + +if x == nil then +end + +print(<?x?>) +]] |