diff options
-rw-r--r-- | script/core/diagnostics/assign-type-mismatch.lua | 2 | ||||
-rw-r--r-- | script/core/diagnostics/cast-local-type.lua | 2 | ||||
-rw-r--r-- | script/core/diagnostics/cast-type-mismatch.lua | 2 | ||||
-rw-r--r-- | script/core/signature.lua | 2 | ||||
-rw-r--r-- | script/parser/compile.lua | 27 | ||||
-rw-r--r-- | script/parser/guide.lua | 3 | ||||
-rw-r--r-- | script/vm/compiler.lua | 147 | ||||
-rw-r--r-- | script/vm/infer.lua | 10 | ||||
-rw-r--r-- | script/vm/init.lua | 2 | ||||
-rw-r--r-- | script/vm/node.lua | 18 | ||||
-rw-r--r-- | script/vm/runner.lua | 565 | ||||
-rw-r--r-- | script/vm/tracer.lua | 541 | ||||
-rw-r--r-- | test/tclient/tests/recursive-runner.lua | 3 | ||||
-rw-r--r-- | test/type_inference/init.lua | 158 |
14 files changed, 793 insertions, 689 deletions
diff --git a/script/core/diagnostics/assign-type-mismatch.lua b/script/core/diagnostics/assign-type-mismatch.lua index 566f4a27..8472e87c 100644 --- a/script/core/diagnostics/assign-type-mismatch.lua +++ b/script/core/diagnostics/assign-type-mismatch.lua @@ -61,7 +61,7 @@ return function (uri, callback) await.delay() if source.type == 'setlocal' then local locNode = vm.compileNode(source.node) - if not locNode:getData 'hasDefined' then + if not locNode.hasDefined then return end end diff --git a/script/core/diagnostics/cast-local-type.lua b/script/core/diagnostics/cast-local-type.lua index 1b1c8432..1998b915 100644 --- a/script/core/diagnostics/cast-local-type.lua +++ b/script/core/diagnostics/cast-local-type.lua @@ -18,7 +18,7 @@ return function (uri, callback) end await.delay() local locNode = vm.compileNode(loc) - if not locNode:getData 'hasDefined' then + if not locNode.hasDefined then return end for _, ref in ipairs(loc.ref) do diff --git a/script/core/diagnostics/cast-type-mismatch.lua b/script/core/diagnostics/cast-type-mismatch.lua index c0483459..b2d2bdf3 100644 --- a/script/core/diagnostics/cast-type-mismatch.lua +++ b/script/core/diagnostics/cast-type-mismatch.lua @@ -22,7 +22,7 @@ return function (uri, callback) local loc = defs[1] if loc then local defNode = vm.compileNode(loc) - if defNode:getData 'hasDefined' then + if defNode.hasDefined then for _, cast in ipairs(doc.casts) do if not cast.mode and cast.extends then local refNode = vm.compileNode(cast.extends) diff --git a/script/core/signature.lua b/script/core/signature.lua index 3465fda2..63b0cd0d 100644 --- a/script/core/signature.lua +++ b/script/core/signature.lua @@ -134,7 +134,7 @@ local function makeSignatures(text, call, pos) local signs = {} local node = vm.compileNode(func) ---@type vm.node - node = node:getData 'originNode' or node + node = node.originNode or node local mark = {} for src in node:eachObject() do if (src.type == 'function' and not vm.isVarargFunctionWithOverloads(src)) diff --git a/script/parser/compile.lua b/script/parser/compile.lua index b8040382..17b9b051 100644 --- a/script/parser/compile.lua +++ b/script/parser/compile.lua @@ -2232,6 +2232,7 @@ local function parseFunction(isLocal, isAction) type = 'function', start = funcLeft, finish = funcRight, + bstart = funcRight, keyword = { [1] = funcLeft, [2] = funcRight, @@ -2262,6 +2263,7 @@ local function parseFunction(isLocal, isAction) end func.name = simple func.finish = simple.finish + func.bstart = simple.finish if not isAction then simple.parent = func pushError { @@ -2302,6 +2304,7 @@ local function parseFunction(isLocal, isAction) if Tokens[Index + 1] == ')' then local parenRight = getPosition(Tokens[Index], 'right') func.finish = parenRight + func.bstart = parenRight if params then params.finish = parenRight end @@ -2309,6 +2312,7 @@ local function parseFunction(isLocal, isAction) skipSpace(true) else func.finish = lastRightPosition() + func.bstart = func.finish if params then params.finish = func.finish end @@ -2963,6 +2967,7 @@ local function parseDo() type = 'do', start = doLeft, finish = doRight, + bstart = doRight, keyword = { [1] = doLeft, [2] = doRight, @@ -3145,6 +3150,7 @@ local function parseIfBlock(parent) parent = parent, start = ifLeft, finish = ifRight, + bstart = ifRight, keyword = { [1] = ifLeft, [2] = ifRight, @@ -3155,7 +3161,8 @@ local function parseIfBlock(parent) if filter then ifblock.filter = filter ifblock.finish = filter.finish - filter.parent = ifblock + ifblock.bstart = ifblock.finish + filter.parent = ifblock else missExp() end @@ -3164,6 +3171,7 @@ local function parseIfBlock(parent) if thenToken == 'then' or thenToken == 'do' then ifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right') + ifblock.bstart = ifblock.finish ifblock.keyword[3] = getPosition(Tokens[Index], 'left') ifblock.keyword[4] = ifblock.finish if thenToken == 'do' then @@ -3203,6 +3211,7 @@ local function parseElseIfBlock(parent) parent = parent, start = ifLeft, finish = ifRight, + bstart = ifRight, keyword = { [1] = ifLeft, [2] = ifRight, @@ -3214,6 +3223,7 @@ local function parseElseIfBlock(parent) if filter then elseifblock.filter = filter elseifblock.finish = filter.finish + elseifblock.bstart = elseifblock.finish filter.parent = elseifblock else missExp() @@ -3223,6 +3233,7 @@ local function parseElseIfBlock(parent) if thenToken == 'then' or thenToken == 'do' then elseifblock.finish = getPosition(Tokens[Index] + #thenToken - 1, 'right') + elseifblock.bstart = elseifblock.finish elseifblock.keyword[3] = getPosition(Tokens[Index], 'left') elseifblock.keyword[4] = elseifblock.finish if thenToken == 'do' then @@ -3262,6 +3273,7 @@ local function parseElseBlock(parent) parent = parent, start = ifLeft, finish = ifRight, + bstart = ifRight, keyword = { [1] = ifLeft, [2] = ifRight, @@ -3337,6 +3349,7 @@ local function parseFor() finish = getPosition(Tokens[Index] + 2, 'right'), keyword = {}, } + action.bstart = action.finish action.keyword[1] = action.start action.keyword[2] = action.finish Index = Index + 2 @@ -3366,6 +3379,7 @@ local function parseFor() local loc = createLocal(name) loc.parent = action action.finish = name.finish + action.bstart = action.finish action.loc = loc end if expList then @@ -3375,12 +3389,14 @@ local function parseFor() value.parent = expList action.init = value action.finish = expList[#expList].finish + action.bstart = action.finish end local max = expList[2] if max then max.parent = expList action.max = max action.finish = max.finish + action.bstart = action.finish else pushError { type = 'MISS_LOOP_MAX', @@ -3393,6 +3409,7 @@ local function parseFor() step.parent = expList action.step = step action.finish = step.finish + action.bstart = action.finish end else pushError { @@ -3414,7 +3431,8 @@ local function parseFor() local exps = parseExpList() - action.finish = inRight + action.finish = inRight + action.bstart = action.finish action.keyword[3] = inLeft action.keyword[4] = inRight @@ -3435,6 +3453,7 @@ local function parseFor() local lastExp = exps[#exps] if lastExp then action.finish = lastExp.finish + action.bstart = action.finish end action.exps = exps @@ -3468,6 +3487,7 @@ local function parseFor() local left = getPosition(Tokens[Index], 'left') local right = getPosition(Tokens[Index] + #doToken - 1, 'right') action.finish = left + action.bstart = action.finish action.keyword[#action.keyword+1] = left action.keyword[#action.keyword+1] = right if doToken == 'then' then @@ -3518,6 +3538,7 @@ local function parseWhile() finish = getPosition(Tokens[Index] + 4, 'right'), keyword = {}, } + action.bstart = action.finish action.keyword[1] = action.start action.keyword[2] = action.finish Index = Index + 2 @@ -3542,6 +3563,7 @@ local function parseWhile() local left = getPosition(Tokens[Index], 'left') local right = getPosition(Tokens[Index] + #doToken - 1, 'right') action.finish = left + action.bstart = action.finish action.keyword[#action.keyword+1] = left action.keyword[#action.keyword+1] = right if doToken == 'then' then @@ -3594,6 +3616,7 @@ local function parseRepeat() finish = getPosition(Tokens[Index] + 5, 'right'), keyword = {}, } + action.bstart = action.finish action.keyword[1] = action.start action.keyword[2] = action.finish Index = Index + 2 diff --git a/script/parser/guide.lua b/script/parser/guide.lua index 147e6237..f27a2af7 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -21,6 +21,7 @@ local type = type ---@field finish integer ---@field range integer ---@field effect integer +---@field bstart integer ---@field attrs string[] ---@field specials parser.object[] ---@field labels parser.object[] @@ -139,7 +140,7 @@ local childMap = { ['getfield'] = {'node', 'field'}, ['list'] = {'#'}, ['binary'] = {1, 2}, - ['unary'] = {1}, + ['unary'] = { 1 }, ['doc'] = {'#'}, ['doc.class'] = {'class', '#extends', '#signs', 'comment'}, diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 168ad536..446c357e 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -538,7 +538,7 @@ local function matchCall(source) if needRemove then local newNode = myNode:copy() newNode:removeNode(needRemove) - newNode:setData('originNode', myNode) + newNode.originNode = myNode vm.setNode(source, newNode, true) end end @@ -836,10 +836,13 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex) for i = fixIndex + 1, myIndex - 1 do args[#args+1] = call.args[i] end - fn = generic:resolve(guide.getUri(call), args) + local resolvedNode = generic:resolve(guide.getUri(call), args) + vm.setNode(arg, resolvedNode) + goto CONTINUE end end vm.setNode(arg, fn) + ::CONTINUE:: end end end @@ -907,9 +910,10 @@ end ---@param source parser.object ---@param target parser.object +---@return boolean local function compileForVars(source, target) if not source.exps then - return + return false end -- for k, v in pairs(t) do --> for k, v in iterator, status, initValue do @@ -940,9 +944,11 @@ local function compileForVars(source, target) local node = getReturn(source._iterator, i, source._iterArgs) node:removeOptional() vm.setNode(loc, node) + return true end end end + return false end ---@param source parser.object @@ -972,17 +978,6 @@ local function compileLocal(source) vm.setNode(source, vm.compileNode(source.value)) end end - if not hasMarkValue and not hasMarkValue then - if source.ref then - for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' - and ref.value - and ref.value.type == 'function' then - vm.setNode(source, vm.compileNode(ref.value)) - end - end - end - end -- function x.y(self, ...) --> function x:y(...) if source[1] == 'self' and not hasMarkDoc @@ -1021,6 +1016,7 @@ local function compileLocal(source) -- for x in ... do if source.parent.type == 'in' then compileForVars(source.parent, source) + hasMarkDoc = true end -- for x = ... do @@ -1030,10 +1026,33 @@ local function compileLocal(source) return end vm.setNode(source, vm.declareGlobal('type', 'integer')) + hasMarkDoc = true end end - myNode:setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue) + if not hasMarkDoc + and not hasMarkValue + and source.ref then + local firstSet + local myFunction = guide.getParentFunction(source) + for _, ref in ipairs(source.ref) do + if ref.type == 'setlocal' then + firstSet = ref + break + end + if ref.type == 'getlocal' then + if guide.getParentFunction(ref) == myFunction then + break + end + end + end + if firstSet + and guide.getBlock(firstSet) == guide.getBlock(source) then + vm.setNode(source, vm.compileNode(firstSet)) + end + end + + myNode.hasDefined = hasMarkDoc or hasMarkParam or hasMarkValue end ---@param source parser.object @@ -1163,75 +1182,27 @@ local compilerSwitch = util.switch() ---@async ---@param source parser.object : call(function (source) - vm.launchRunner(source, function () - local myNode = vm.getNode(source) - ---@cast myNode -? - myNode:setData('resolving', true) - - if source.ref then - for _, ref in ipairs(source.ref) do - if ref.type == 'getlocal' - or ref.type == 'setlocal' then - vm.setNode(ref, myNode, true) - end - end - end - compileLocal(source) - - myNode.resolved = true - end, function () - local myNode = vm.getNode(source) - ---@cast myNode -? - myNode:setData('resolving', nil) - local hasMark = vm.getNode(source):getData 'hasDefined' - if source.ref and not hasMark then - local parentFunc = guide.getParentFunction(source) - for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' - and guide.getParentFunction(ref) == parentFunc then - local refNode = vm.getNode(ref) - if refNode then - vm.setNode(source, refNode) - end - end - end - end - end, function (src, node) - if src.type == 'setlocal' then - if src.bindDocs then - for _, doc in ipairs(src.bindDocs) do - if doc.type == 'doc.type' then - vm.setNode(src, vm.compileNode(doc), true) - return vm.getNode(src) - end - end - end - if src.value then - if src.value.type == 'table' then - vm.setNode(src, vm.createNode(src.value), true) - vm.setNode(src, node:copy():asTable()) - else - vm.setNode(src, vm.compileNode(src.value), true) - end - else - vm.setNode(src, node, true) - end - return vm.getNode(src) - elseif src.type == 'getlocal' then - if bindAs(src) then - return - end - vm.setNode(src, node, true) - node.resolved = true - matchCall(src) - end - end) - - vm.waitResolveRunner(source) + compileLocal(source) end) : case 'setlocal' : call(function (source) - vm.compileNode(source.node) + if bindDocs(source) then + return + end + local locNode = vm.compileNode(source.node) + if not source.value then + vm.setNode(source, locNode) + return + end + local valueNode = vm.compileNode(source.value) + vm.setNode(source, valueNode) + if locNode.hasDefined + and guide.isLiteral(source.value) then + vm.setNode(source, locNode) + vm.getNode(source):narrow(guide.getUri(source), source.value.type) + else + vm.setNode(source, valueNode) + end end) : case 'getlocal' ---@async @@ -1239,8 +1210,11 @@ local compilerSwitch = util.switch() if bindAs(source) then return end - vm.compileNode(source.node) - vm.waitResolveRunner(source) + local node = vm.traceNode(source) + if not node then + return + end + vm.setNode(source, node, true) end) : case 'setfield' : case 'setmethod' @@ -1921,13 +1895,6 @@ function vm.compileNode(source) end end - if source.type == 'getlocal' then - ---@cast source parser.object - vm.storeWaitingRunner(source) - ---@diagnostic disable-next-line: await-in-sync - vm.waitResolveRunner(source) - end - local cache = vm.getNode(source) if cache ~= nil then return cache diff --git a/script/vm/infer.lua b/script/vm/infer.lua index b9dfb29a..99cf622e 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -432,10 +432,14 @@ function mt:view(uri, default) end if self.node:isOptional() then - if max > 1 then - view = '(' .. view .. ')?' + if #array == 0 then + view = 'nil' else - view = view .. '?' + if max > 1 then + view = '(' .. view .. ')?' + else + view = view .. '?' + end end end diff --git a/script/vm/init.lua b/script/vm/init.lua index 7b69a7eb..9c8ebe55 100644 --- a/script/vm/init.lua +++ b/script/vm/init.lua @@ -11,7 +11,7 @@ require 'vm.field' require 'vm.doc' require 'vm.type' require 'vm.library' -require 'vm.runner' +require 'vm.tracer' require 'vm.infer' require 'vm.generic' require 'vm.sign' diff --git a/script/vm/node.lua b/script/vm/node.lua index 2e408128..65d752df 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -20,7 +20,8 @@ mt.id = 0 mt.type = 'vm.node' mt.optional = nil mt.data = nil -mt.resolved = nil +mt.hasDefined = nil +mt.originNode = nil ---@param node vm.node | vm.node.object ---@return vm.node @@ -70,21 +71,6 @@ function mt:get(n) return self[n] end -function mt:setData(k, v) - if not self.data then - self.data = {} - end - self.data[k] = v -end - ----@return any -function mt:getData(k) - if not self.data then - return nil - end - return self.data[k] -end - function mt:addOptional() self.optional = true end diff --git a/script/vm/runner.lua b/script/vm/runner.lua deleted file mode 100644 index 8e264521..00000000 --- a/script/vm/runner.lua +++ /dev/null @@ -1,565 +0,0 @@ ----@class vm -local vm = require 'vm.vm' -local guide = require 'parser.guide' -local linked = require 'linked-table' - ----@alias vm.runner.callback fun(src: parser.object, node?: vm.node) - ----@class vm.runner ----@field _loc parser.object ----@field _casts parser.object[] ----@field _callback vm.runner.callback ----@field _mark table ----@field _has table<parser.object, true> ----@field _main parser.object ----@field _uri uri -local mt = {} -mt.__index = mt -mt._index = 1 - ----@return parser.object[] -function mt:_getCasts() - local root = guide.getRoot(self._loc) - if not root._casts then - root._casts = {} - local docs = root.docs - for _, doc in ipairs(docs) do - if doc.type == 'doc.cast' and doc.loc then - root._casts[#root._casts+1] = doc - end - end - end - return root._casts -end - ----@param obj parser.object -function mt:_markHas(obj) - while true do - if self._has[obj] then - return - end - self._has[obj] = true - if obj == self._main then - return - end - obj = obj.parent - end -end - -function mt:collect() - local startPos = self._loc.start - local finishPos = 0 - - for _, ref in ipairs(self._loc.ref) do - if ref.type == 'getlocal' - or ref.type == 'setlocal' then - self:_markHas(ref) - if ref.finish > finishPos then - finishPos = ref.finish - end - end - end - - local casts = self:_getCasts() - for _, cast in ipairs(casts) do - if cast.loc[1] == self._loc[1] - and cast.start > startPos - and cast.finish < finishPos - and guide.getLocal(self._loc, self._loc[1], cast.start) == self._loc then - self._casts[#self._casts+1] = cast - end - end -end - ----@param pos integer ----@param topNode vm.node ----@return vm.node -function mt:_fastWardCasts(pos, topNode) - for i = self._index, #self._casts do - local action = self._casts[i] - if action.start > pos then - self._index = i - return topNode - end - topNode = topNode:copy() - for _, cast in ipairs(action.casts) do - if cast.mode == '+' then - if cast.optional then - topNode:addOptional() - end - if cast.extends then - topNode:merge(vm.compileNode(cast.extends)) - end - elseif cast.mode == '-' then - if cast.optional then - topNode:removeOptional() - end - if cast.extends then - topNode:removeNode(vm.compileNode(cast.extends)) - end - else - if cast.extends then - topNode:clear() - topNode:merge(vm.compileNode(cast.extends)) - end - end - end - end - self._index = self._index + 1 - return topNode -end - ----@param action parser.object ----@param topNode vm.node ----@param outNode? vm.node ----@return vm.node topNode ----@return vm.node outNode -function mt:_lookIntoChild(action, topNode, outNode) - if not self._has[action] - or self._mark[action] then - return topNode, topNode or outNode - end - self._mark[action] = true - topNode = self:_fastWardCasts(action.start, topNode) - if action.type == 'getlocal' then - if action.node == self._loc then - self._callback(action, topNode) - if outNode then - topNode = topNode:copy():setTruthy() - outNode = outNode:copy():setFalsy() - end - end - elseif action.type == 'function' then - self:lookIntoBlock(action, topNode:copy()) - elseif action.type == 'unary' then - if not action[1] then - goto RETURN - end - if action.op.type == 'not' then - outNode = outNode or topNode:copy() - outNode, topNode = self:_lookIntoChild(action[1], topNode, outNode) - outNode = outNode:copy() - end - elseif action.type == 'binary' then - if not action[1] or not action[2] then - goto RETURN - end - if action.op.type == 'and' then - topNode = self:_lookIntoChild(action[1], topNode, topNode:copy()) - topNode = self:_lookIntoChild(action[2], topNode, topNode:copy()) - elseif action.op.type == 'or' then - outNode = outNode or topNode:copy() - local topNode1, outNode1 = self:_lookIntoChild(action[1], topNode, outNode) - local topNode2, outNode2 = self:_lookIntoChild(action[2], outNode1, outNode1:copy()) - topNode = vm.createNode(topNode1, topNode2) - outNode = outNode2:copy() - elseif action.op.type == '==' - or action.op.type == '~=' then - local handler, checker - for i = 1, 2 do - if guide.isLiteral(action[i]) then - checker = action[i] - handler = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1` - end - end - if not handler then - goto RETURN - end - if handler.type == 'getlocal' - and handler.node == self._loc then - -- if x == y then - topNode = self:_lookIntoChild(handler, topNode, outNode) - local checkerNode = vm.compileNode(checker) - local checkerName = vm.getNodeName(checker) - if checkerName then - topNode = topNode:copy() - if action.op.type == '==' then - topNode:narrow(self._uri, checkerName) - if outNode then - outNode:removeNode(checkerNode) - end - else - topNode:removeNode(checkerNode) - if outNode then - outNode:narrow(self._uri, checkerName) - end - end - end - elseif handler.type == 'call' - and checker.type == 'string' - and handler.node.special == 'type' - and handler.args - and handler.args[1] - and handler.args[1].type == 'getlocal' - and handler.args[1].node == self._loc then - -- if type(x) == 'string' then - self:_lookIntoChild(handler, topNode:copy()) - if action.op.type == '==' then - topNode:narrow(self._uri, checker[1]) - if outNode then - outNode:remove(checker[1]) - end - else - topNode:remove(checker[1]) - if outNode then - outNode:narrow(self._uri, checker[1]) - end - end - elseif handler.type == 'getlocal' - and checker.type == 'string' then - local nodeValue = vm.getObjectValue(handler.node) - if nodeValue - and nodeValue.type == 'select' - and nodeValue.sindex == 1 then - local call = nodeValue.vararg - if call - and call.type == 'call' - and call.node.special == 'type' - and call.args - and call.args[1] - and call.args[1].type == 'getlocal' - and call.args[1].node == self._loc then - -- `local tp = type(x);if tp == 'string' then` - if action.op.type == '==' then - topNode:narrow(self._uri, checker[1]) - if outNode then - outNode:remove(checker[1]) - end - else - topNode:remove(checker[1]) - if outNode then - outNode:narrow(self._uri, checker[1]) - end - end - end - end - end - end - elseif action.type == 'loop' - or action.type == 'in' - or action.type == 'repeat' - or action.type == 'for' then - topNode = self:lookIntoBlock(action, topNode:copy()) - elseif action.type == 'while' then - local blockNode, mainNode - if action.filter then - blockNode, mainNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy()) - else - blockNode = topNode:copy() - mainNode = topNode:copy() - end - blockNode = self:lookIntoBlock(action, blockNode:copy()) - topNode = mainNode:merge(blockNode) - if action.filter then - -- look into filter again - guide.eachSource(action.filter, function (src) - self._mark[src] = nil - end) - blockNode, topNode = self:_lookIntoChild(action.filter, topNode:copy(), topNode:copy()) - end - elseif action.type == 'if' then - local hasElse - local mainNode = topNode:copy() - local blockNodes = {} - for _, subBlock in ipairs(action) do - local blockNode = mainNode:copy() - if subBlock.filter then - blockNode, mainNode = self:_lookIntoChild(subBlock.filter, blockNode, mainNode) - else - hasElse = true - mainNode:clear() - end - blockNode = self:lookIntoBlock(subBlock, blockNode:copy()) - local neverReturn = subBlock.hasReturn - or subBlock.hasGoTo - or subBlock.hasBreak - or subBlock.hasError - if not neverReturn then - blockNodes[#blockNodes+1] = blockNode - end - end - if not hasElse and not topNode:hasKnownType() then - mainNode:merge(vm.declareGlobal('type', 'unknown')) - end - for _, blockNode in ipairs(blockNodes) do - mainNode:merge(blockNode) - end - topNode = mainNode - elseif action.type == 'call' then - if action.node.special == 'assert' and action.args and action.args[1] then - topNode = self:_lookIntoChild(action.args[1], topNode, topNode:copy()) - end - elseif action.type == 'paren' then - topNode, outNode = self:_lookIntoChild(action.exp, topNode, outNode) - elseif action.type == 'setlocal' then - if action.node == self._loc then - if action.value then - self:_lookIntoChild(action.value, topNode) - end - topNode = self._callback(action, topNode) - end - elseif action.type == 'local' then - if action.value - and action.ref - and action.value.type == 'select' then - local index = action.value.sindex - local call = action.value.vararg - if index == 1 - and call.type == 'call' - and call.node - and call.node.special == 'type' - and call.args then - local getLoc = call.args[1] - if getLoc - and getLoc.type == 'getlocal' - and getLoc.node == self._loc then - for _, ref in ipairs(action.ref) do - self:_markHas(ref) - end - end - end - end - end - ::RETURN:: - guide.eachChild(action, function (src) - if self._has[src] then - self:_lookIntoChild(src, topNode) - end - end) - return topNode, outNode or topNode -end - ----@param block parser.object ----@param topNode vm.node ----@return vm.node topNode -function mt:lookIntoBlock(block, topNode) - if not self._has[block] then - return topNode - end - for _, action in ipairs(block) do - if self._has[action] then - topNode = self:_lookIntoChild(action, topNode) - end - end - topNode = self:_fastWardCasts(block.finish, topNode) - return topNode -end - ----@alias runner.info { target?: parser.object, loc: parser.object } - ----@type thread? -local masterRunner = nil ----@type table<thread, runner.info> -local runnerInfo = setmetatable({}, { - __mode = 'k', - __index = function (self, k) - self[k] = {} - return self[k] - end -}) ----@type linked-table? -local runnerList = nil - ----@async ----@param info runner.info -local function waitResolve(info) - while true do - if not info.target then - break - end - if info.target.node == info.loc then - break - end - local node = vm.getNode(info.target) - if node and node.resolved then - break - end - coroutine.yield() - end - info.target = nil -end - -local function resolveDeadLock() - if not runnerList then - return - end - - ---@type runner.info[] - local infos = {} - for runner in runnerList:pairs() do - local info = runnerInfo[runner] - infos[#infos+1] = info - end - - table.sort(infos, function (a, b) - local uriA = guide.getUri(a.loc) - local uriB = guide.getUri(b.loc) - if uriA ~= uriB then - return uriA < uriB - end - return a.loc.start < b.loc.start - end) - - local firstTarget = infos[1].target - ---@cast firstTarget -? - local firstNode = vm.setNode(firstTarget, vm.getNode(firstTarget):copy(), true) - firstNode.resolved = true - firstNode:setData('resolvedByDeadLock', true) -end - ----@async ----@param loc parser.object ----@param start fun() ----@param finish fun() ----@param callback vm.runner.callback -function vm.launchRunner(loc, start, finish, callback) - local locNode = vm.getNode(loc) - if not locNode then - return - end - - local function resumeMaster() - for i = 1, 10010 do - if not runnerList or runnerList:getSize() == 0 then - return - end - local deadLock = true - for runner in runnerList:pairs() do - local info = runnerInfo[runner] - local waitingSource = info.target - if coroutine.status(runner) == 'suspended' then - local suc, err = coroutine.resume(runner) - if not suc then - log.error(debug.traceback(runner, err)) - end - else - runnerList:pop(runner) - deadLock = false - end - if not waitingSource or waitingSource ~= info.target then - deadLock = false - end - end - if runnerList:getSize() == 0 then - return - end - if deadLock then - resolveDeadLock() - end - if i == 10000 then - local lines = {} - lines[#lines+1] = 'Dead lock:' - for runner in runnerList:pairs() do - local info = runnerInfo[runner] - lines[#lines+1] = '===============' - lines[#lines+1] = string.format('Runner `%s` at %d(%s)' - , info.loc[1] - , info.loc.start - , guide.getUri(info.loc) - ) - lines[#lines+1] = string.format('Waiting `%s` at %d(%s)' - , info.target[1] - , info.target.start - , guide.getUri(info.target) - ) - end - local msg = table.concat(lines, '\n') - log.error(msg) - end - end - end - - local function launch() - start() - if not loc.ref then - finish() - return - end - local main = guide.getParentBlock(loc) - if not main then - finish() - return - end - local self = setmetatable({ - _loc = loc, - _casts = {}, - _mark = {}, - _has = {}, - _main = main, - _uri = guide.getUri(loc), - _callback = callback, - }, mt) - - self:collect() - - self:lookIntoBlock(main, locNode:copy()) - - locNode:setData('runner', nil) - - finish() - end - - local co = coroutine.create(launch) - locNode:setData('runner', co) - local info = runnerInfo[co] - info.loc = loc - - if not runnerList then - runnerList = linked() - end - runnerList:pushTail(co) - - if not masterRunner then - masterRunner = coroutine.running() - resumeMaster() - masterRunner = nil - return - end -end - ----@async ----@param source parser.object -function vm.waitResolveRunner(source) - local myNode = vm.getNode(source) - if myNode and myNode.resolved then - return - end - - local running = coroutine.running() - if not masterRunner or running == masterRunner then - return - end - - local info = runnerInfo[running] - - local targetLoc - if source.type == 'getlocal' then - targetLoc = source.node - elseif source.type == 'local' - or source.type == 'self' then - targetLoc = source - info.target = info.target or source - else - error('Unknown source type: ' .. source.type) - end - - local targetNode = vm.getNode(targetLoc) - if not targetNode then - -- Wait for compiling local by `compiler` - return - end - - waitResolve(info) -end - ----@param source parser.object -function vm.storeWaitingRunner(source) - local sourceNode = vm.getNode(source) - if sourceNode and sourceNode.resolved then - return - end - - local running = coroutine.running() - local info = runnerInfo[running] - info.target = source -end diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua new file mode 100644 index 00000000..21a2619f --- /dev/null +++ b/script/vm/tracer.lua @@ -0,0 +1,541 @@ +---@class vm +local vm = require 'vm.vm' +local guide = require 'parser.guide' +local util = require 'utility' + +---@class parser.object +---@field package _tracer? vm.tracer +---@field package _casts? parser.object[] + +---@class vm.tracer +---@field source parser.object +---@field assigns parser.object[] +---@field assignMap table<parser.object, true> +---@field careMap table<parser.object, true> +---@field mark table<parser.object, true> +---@field casts parser.object[] +---@field nodes table<parser.object, vm.node|false> +---@field main parser.object +---@field uri uri +---@field castIndex integer? +local mt = {} +mt.__index = mt + +---@return parser.object[] +function mt:getCasts() + local root = guide.getRoot(self.source) + if not root._casts then + root._casts = {} + local docs = root.docs + for _, doc in ipairs(docs) do + if doc.type == 'doc.cast' and doc.loc then + root._casts[#root._casts+1] = doc + end + end + end + return root._casts +end + +---@param obj parser.object +function mt:collectAssign(obj) + while true do + local block = guide.getParentBlock(obj) + if not block then + return + end + obj = block + if self.assignMap[obj] then + return + end + if obj == self.main then + return + end + self.assignMap[obj] = true + self.assigns[#self.assigns+1] = obj + end +end + +---@param obj parser.object +function mt:collectCare(obj) + while true do + if self.careMap[obj] then + return + end + if obj == self.main then + return + end + self.careMap[obj] = true + obj = obj.parent + end +end + +function mt:collectLocal() + local startPos = self.source.start + local finishPos = 0 + + self.assigns[#self.assigns+1] = self.source + self.assignMap[self.source] = true + + for _, obj in ipairs(self.source.ref) do + if obj.type == 'setlocal' then + self.assigns[#self.assigns+1] = obj + self.assignMap[obj] = true + self:collectCare(obj) + if obj.finish > finishPos then + finishPos = obj.finish + end + end + if obj.type == 'getlocal' then + self:collectCare(obj) + if obj.finish > finishPos then + finishPos = obj.finish + end + end + end + + local casts = self:getCasts() + for _, cast in ipairs(casts) do + if cast.loc[1] == self.source[1] + and cast.start > startPos + and cast.finish < finishPos + and guide.getLocal(self.source, self.source[1], cast.start) == self.source then + self.casts[#self.casts+1] = cast + end + end +end + +---@param start integer +---@param finish integer +---@return parser.object? +function mt:getLastAssign(start, finish) + local assign + for _, obj in ipairs(self.assigns) do + if obj.start < start then + goto CONTINUE + end + if (obj.range or obj.start) >= finish then + break + end + local objBlock = guide.getParentBlock(obj) + if not objBlock then + break + end + if objBlock.start <= finish + and objBlock.finish >= finish then + assign = obj + end + ::CONTINUE:: + end + return assign +end + +---@param pos integer +function mt:resetCastsIndex(pos) + for i = 1, #self.casts do + local cast = self.casts[i] + if cast.start > pos then + self.castIndex = i + return + end + end + self.castIndex = nil +end + +---@param pos integer +---@param node vm.node +---@return vm.node +function mt:fastWardCasts(pos, node) + if not self.castIndex then + return node + end + for i = self.castIndex, #self.casts do + local action = self.casts[i] + if action.start > pos then + return node + end + node = node:copy() + for _, cast in ipairs(action.casts) do + if cast.mode == '+' then + if cast.optional then + node:addOptional() + end + if cast.extends then + node:merge(vm.compileNode(cast.extends)) + end + elseif cast.mode == '-' then + if cast.optional then + node:removeOptional() + end + if cast.extends then + node:removeNode(vm.compileNode(cast.extends)) + end + else + if cast.extends then + node:clear() + node:merge(vm.compileNode(cast.extends)) + end + end + end + end + self.castIndex = self.castIndex + 1 + return node +end + +---@param action parser.object +---@param topNode vm.node +---@param outNode? vm.node +---@return vm.node topNode +---@return vm.node outNode +function mt:lookIntoChild(action, topNode, outNode) + if not self.careMap[action] + or self.mark[action] then + return topNode, outNode or topNode + end + self.mark[action] = true + topNode = self:fastWardCasts(action.start, topNode) + if action.type == 'getlocal' then + if action.node == self.source then + self.nodes[action] = topNode + if outNode then + topNode = topNode:copy():setTruthy() + outNode = outNode:copy():setFalsy() + end + end + elseif action.type == 'function' then + self:lookIntoBlock(action, action.args.finish, topNode:copy()) + elseif action.type == 'unary' then + if not action[1] then + goto RETURN + end + if action.op.type == 'not' then + outNode = outNode or topNode:copy() + outNode, topNode = self:lookIntoChild(action[1], topNode, outNode) + outNode = outNode:copy() + end + elseif action.type == 'binary' then + if not action[1] or not action[2] then + goto RETURN + end + if action.op.type == 'and' then + topNode = self:lookIntoChild(action[1], topNode, topNode:copy()) + topNode = self:lookIntoChild(action[2], topNode, topNode:copy()) + elseif action.op.type == 'or' then + outNode = outNode or topNode:copy() + local topNode1, outNode1 = self:lookIntoChild(action[1], topNode, outNode) + local topNode2, outNode2 = self:lookIntoChild(action[2], outNode1, outNode1:copy()) + topNode = vm.createNode(topNode1, topNode2) + outNode = outNode2:copy() + elseif action.op.type == '==' + or action.op.type == '~=' then + local handler, checker + for i = 1, 2 do + if guide.isLiteral(action[i]) then + checker = action[i] + handler = action[3-i] -- Copilot tells me use `3-i` instead of `i%2+1` + end + end + if not handler then + goto RETURN + end + if handler.type == 'getlocal' + and handler.node == self.source then + -- if x == y then + topNode = self:lookIntoChild(handler, topNode, outNode) + local checkerNode = vm.compileNode(checker) + local checkerName = vm.getNodeName(checker) + if checkerName then + topNode = topNode:copy() + if action.op.type == '==' then + topNode:narrow(self.uri, checkerName) + if outNode then + outNode:removeNode(checkerNode) + end + else + topNode:removeNode(checkerNode) + if outNode then + outNode:narrow(self.uri, checkerName) + end + end + end + elseif handler.type == 'call' + and checker.type == 'string' + and handler.node.special == 'type' + and handler.args + and handler.args[1] + and handler.args[1].type == 'getlocal' + and handler.args[1].node == self.source then + -- if type(x) == 'string' then + self:lookIntoChild(handler, topNode:copy()) + if action.op.type == '==' then + topNode:narrow(self.uri, checker[1]) + if outNode then + outNode:remove(checker[1]) + end + else + topNode:remove(checker[1]) + if outNode then + outNode:narrow(self.uri, checker[1]) + end + end + elseif handler.type == 'getlocal' + and checker.type == 'string' then + local nodeValue = vm.getObjectValue(handler.node) + if nodeValue + and nodeValue.type == 'select' + and nodeValue.sindex == 1 then + local call = nodeValue.vararg + if call + and call.type == 'call' + and call.node.special == 'type' + and call.args + and call.args[1] + and call.args[1].type == 'getlocal' + and call.args[1].node == self.source then + -- `local tp = type(x);if tp == 'string' then` + if action.op.type == '==' then + topNode:narrow(self.uri, checker[1]) + if outNode then + outNode:remove(checker[1]) + end + else + topNode:remove(checker[1]) + if outNode then + outNode:narrow(self.uri, checker[1]) + end + end + end + end + end + end + elseif action.type == 'loop' + or action.type == 'in' + or action.type == 'repeat' + or action.type == 'for' + or action.type == 'do' then + if action[1] then + self:lookIntoBlock(action, action.bstart, topNode:copy()) + local lastAssign = self:getLastAssign(action.start, action.finish) + if lastAssign then + self:getNode(lastAssign) + end + if self.nodes[action] then + topNode = self.nodes[action]:copy() + end + end + elseif action.type == 'while' then + local blockNode, mainNode + if action.filter then + blockNode, mainNode = self:lookIntoChild(action.filter, topNode:copy(), topNode:copy()) + else + blockNode = topNode:copy() + mainNode = topNode:copy() + end + if action[1] then + self:lookIntoBlock(action, action.bstart, blockNode:copy()) + local lastAssign = self:getLastAssign(action.start, action.finish) + if lastAssign then + self:getNode(lastAssign) + end + if self.nodes[action] then + topNode = mainNode:merge(self.nodes[action]) + end + end + if action.filter then + -- look into filter again + guide.eachSource(action.filter, function (src) + self.mark[src] = nil + end) + blockNode, topNode = self:lookIntoChild(action.filter, topNode:copy(), topNode:copy()) + end + elseif action.type == 'if' then + local hasElse + local mainNode = topNode:copy() + local blockNodes = {} + for _, subBlock in ipairs(action) do + self:resetCastsIndex(subBlock.start) + local blockNode = mainNode:copy() + if subBlock.filter then + blockNode, mainNode = self:lookIntoChild(subBlock.filter, blockNode, mainNode) + else + hasElse = true + mainNode:clear() + end + local mergedNode + if subBlock[1] then + self:lookIntoBlock(subBlock, subBlock.bstart, blockNode:copy()) + local neverReturn = subBlock.hasReturn + or subBlock.hasGoTo + or subBlock.hasBreak + or subBlock.hasError + if neverReturn then + mergedNode = true + else + local lastAssign = self:getLastAssign(subBlock.start, subBlock.finish) + if lastAssign then + self:getNode(lastAssign) + end + if self.nodes[subBlock] then + blockNodes[#blockNodes+1] = self.nodes[subBlock] + mergedNode = true + end + end + end + if not mergedNode then + blockNodes[#blockNodes+1] = blockNode + end + end + if not hasElse and not topNode:hasKnownType() then + mainNode:merge(vm.declareGlobal('type', 'unknown')) + end + for _, blockNode in ipairs(blockNodes) do + mainNode:merge(blockNode) + end + topNode = mainNode + elseif action.type == 'call' then + if action.node.special == 'assert' and action.args and action.args[1] then + topNode = self:lookIntoChild(action.args[1], topNode, topNode:copy()) + end + elseif action.type == 'paren' then + topNode, outNode = self:lookIntoChild(action.exp, topNode, outNode) + elseif action.type == 'setlocal' then + if action.value then + self:lookIntoChild(action.value, topNode) + end + elseif action.type == 'local' then + if action.value + and action.ref + and action.value.type == 'select' then + local index = action.value.sindex + local call = action.value.vararg + if index == 1 + and call.type == 'call' + and call.node + and call.node.special == 'type' + and call.args then + local getLoc = call.args[1] + if getLoc + and getLoc.type == 'getlocal' + and getLoc.node == self.source then + for _, ref in ipairs(action.ref) do + self:collectCare(ref) + end + end + end + end + end + ::RETURN:: + guide.eachChild(action, function (src) + if self.careMap[src] then + self:lookIntoChild(src, topNode) + end + end) + return topNode, outNode or topNode +end + +---@param block parser.object +---@param start integer +---@param node vm.node +function mt:lookIntoBlock(block, start, node) + self:resetCastsIndex(start) + for _, action in ipairs(block) do + if (action.effect or action.start) < start then + goto CONTINUE + end + if self.careMap[action] then + node = self:lookIntoChild(action, node) + end + if action.finish > start and self.assignMap[action] then + return + end + ::CONTINUE:: + end + self.nodes[block] = node +end + +---@param source parser.object +function mt:calcNode(source) + if source.type == 'getlocal' then + local lastAssign = self:getLastAssign(0, source.start) + if not lastAssign then + lastAssign = source.node + end + self:calcNode(lastAssign) + return + end + if source.type == 'local' + or source.type == 'self' + or source.type == 'setlocal' then + local node = vm.compileNode(source) + self.nodes[source] = node + local parentBlock = guide.getParentBlock(source) + if parentBlock then + self:lookIntoBlock(parentBlock, source.finish, node) + end + return + end +end + +---@param source parser.object +---@return vm.node? +function mt:getNode(source) + local cache = self.nodes[source] + if cache ~= nil then + return cache or nil + end + if source == self.main then + self.nodes[source] = false + return nil + end + self.nodes[source] = false + self:calcNode(source) + return self.nodes[source] or nil +end + +---@class vm.node +---@field package _tracer vm.tracer + +---@param source parser.object +---@return vm.tracer? +local function createTracer(source) + local node = vm.compileNode(source) + local tracer = node._tracer + if tracer then + return tracer + end + local main = guide.getParentBlock(source) + if not main then + return nil + end + tracer = setmetatable({ + source = source, + assigns = {}, + assignMap = {}, + careMap = {}, + mark = {}, + casts = {}, + nodes = {}, + main = main, + uri = guide.getUri(source), + }, mt) + node._tracer = tracer + + tracer:collectLocal() + + return tracer +end + +---@param source parser.object +---@return vm.node? +function vm.traceNode(source) + local loc + if source.type == 'getlocal' + or source.type == 'setlocal' then + loc = source.node + end + local tracer = createTracer(loc) + if not tracer then + return nil + end + local node = tracer:getNode(source) + return node +end diff --git a/test/tclient/tests/recursive-runner.lua b/test/tclient/tests/recursive-runner.lua index ddcdb5d6..e824f23a 100644 --- a/test/tclient/tests/recursive-runner.lua +++ b/test/tclient/tests/recursive-runner.lua @@ -174,8 +174,7 @@ end textDocument = { uri = 'file:///test.lua' }, position = { line = 20, character = 11 }, }) - -- TODO - --assert(hover1.contents.value:find 'vector3') + assert(hover1.contents.value:find 'vector3') config.set(nil, 'Lua.diagnostics.enable', true) end) diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 33521a0d..0b69a34c 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -1590,7 +1590,7 @@ AAA = {} local <?x?> = AAA() ]] -TEST 'string|integer' [[ +TEST 'string' [[ local <?x?> x = '1' x = 1 @@ -1637,7 +1637,7 @@ function A() end ]] -TEST 'unknown' [[ +TEST 'string' [[ local x function A() @@ -1758,6 +1758,26 @@ x = '1' x = 1 ]] +TEST 'integer' [[ +local x +x = true +do + x = 1 +end +print(<?x?>) +]] + +TEST 'boolean' [[ +local x +x = true +function XX() + do + x = 1 + end +end +print(<?x?>) +]] + TEST 'integer?' [[ ---@type integer? local <?x?> @@ -1809,6 +1829,17 @@ end print(<?x?>) ]] +TEST 'nil' [[ +---@type integer? +local x + +if not x then + print(<?x?>) +end + +print(x) +]] + TEST 'integer' [[ ---@type integer? local x @@ -1840,6 +1871,15 @@ if xxx and x then end ]] +TEST 'unknown' [[ +---@type integer? +local x + +if not x and x then + print(<?x?>) +end +]] + TEST 'integer' [[ ---@type integer? local x @@ -2277,7 +2317,7 @@ local x print(<?x?>) ]] -TEST 'unknown?' [[ +TEST 'nil' [[ ---@type string? local x @@ -2351,7 +2391,7 @@ end print(<?t?>) ]] -TEST 'unknown?' [[ +TEST 'nil' [[ ---@type integer? local t @@ -3160,7 +3200,7 @@ local function f() end local x, y, <?z?> = 1, 2, f() ]] -TEST 'function' [[ +TEST 'unknown' [[ local f print(<?f?>) @@ -3168,6 +3208,26 @@ print(<?f?>) function f() end ]] +TEST 'unknown' [[ +local f + +do + print(<?f?>) +end + +function f() end +]] + +TEST 'function' [[ +local f + +function A() + print(<?f?>) +end + +function f() end +]] + TEST 'number' [[ ---@type number|nil local n @@ -4000,3 +4060,91 @@ local m, v local <?r?> = m * v ]] + +TEST 'A|B' [[ +---@class A +---@class B + +---@type A|B +local t + +if x then + ---@cast t A +else + print(<?t?>) +end +]] + +TEST 'A|B' [[ +---@class A +---@class B + +---@type A|B +local t + +if x then + ---@cast t A +elseif <?t?> then +end +]] + +TEST 'A|B' [[ +---@class A +---@class B + +---@type A|B +local t + +if x then + ---@cast t A + print(t) +elseif <?t?> then +end +]] + +TEST 'A|B' [[ +---@class A +---@class B + +---@type A|B +local t + +if x then + ---@cast t A + print(t) +elseif <?t?> then + ---@cast t A + print(t) +end +]] + +TEST 'function' [[ +local function x() + print(<?x?>) +end +]] + +TEST 'number' [[ +---@type number? +local x + +do + if not x then + return + end +end + +print(<?x?>) +]] + +TEST 'number' [[ +---@type number[] +local xs + +---@type fun(x): number? +local f + +for _, <?x?> in ipairs(xs) do + x = f(x) +end +]] |