From 6e2d3d42dba22ee9f179546b1e92d0b3313ec1f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Thu, 15 Dec 2022 21:09:14 +0800 Subject: stash --- script/vm/compiler.lua | 28 ++++++++++++---- script/vm/infer.lua | 10 ++++-- script/vm/node.lua | 7 ++-- script/vm/tracer.lua | 78 ++++++++++++++++++++++++++++++++++++++++---- test/type_inference/init.lua | 28 +++++++++++++--- 5 files changed, 127 insertions(+), 24 deletions(-) diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index a7ad4464..eefed5a8 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -972,12 +972,25 @@ local function compileLocal(source) vm.setNode(source, vm.compileNode(source.value)) end end - if not hasMarkValue and not hasMarkValue then - local firstRef = source.ref and source.ref[1] - if firstRef - and guide.isSet(firstRef) - and guide.getBlock(firstRef) == guide.getBlock(source) then - vm.setNode(source, vm.compileNode(firstRef)) + if not hasMarkValue + and not hasMarkValue + and source.ref then + local firstSet + local myFunction = guide.getParentFunction(source) + for _, ref in ipairs(source.ref) do + if ref.type == 'setlocal' then + firstSet = ref + break + end + if ref.type == 'getlocal' then + if guide.getParentFunction(ref) == myFunction then + break + end + end + end + if firstSet + and guide.getBlock(firstSet) == guide.getBlock(source) then + vm.setNode(source, vm.compileNode(firstSet)) end end -- function x.y(self, ...) --> function x:y(...) @@ -1164,6 +1177,9 @@ local compilerSwitch = util.switch() end) : case 'setlocal' : call(function (source) + if bindDocs(source) then + return + end local valueNode = vm.compileNode(source.value) vm.setNode(source, valueNode) end) diff --git a/script/vm/infer.lua b/script/vm/infer.lua index b9dfb29a..99cf622e 100644 --- a/script/vm/infer.lua +++ b/script/vm/infer.lua @@ -432,10 +432,14 @@ function mt:view(uri, default) end if self.node:isOptional() then - if max > 1 then - view = '(' .. view .. ')?' + if #array == 0 then + view = 'nil' else - view = view .. '?' + if max > 1 then + view = '(' .. view .. ')?' + else + view = view .. '?' + end end end diff --git a/script/vm/node.lua b/script/vm/node.lua index d0fd5ffb..2e408128 100644 --- a/script/vm/node.lua +++ b/script/vm/node.lua @@ -188,6 +188,9 @@ end ---@return vm.node function mt:setFalsy() + if self.optional == false then + self.optional = nil + end local hasBoolean for index = #self, 1, -1 do local c = self[index] @@ -226,10 +229,6 @@ function mt:setFalsy() if hasBoolean then self:merge(vm.declareGlobal('type', 'false')) end - if self.optional then - self.optional = nil - self:merge(vm.declareGlobal('type', 'nil')) - end return self end diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua index c2b7ae39..5b3b1b54 100644 --- a/script/vm/tracer.lua +++ b/script/vm/tracer.lua @@ -13,9 +13,11 @@ local util = require 'utility' ---@field assignMap table ---@field careMap table ---@field mark table +---@field casts parser.object[] ---@field nodes table ---@field main parser.object ---@field uri uri +---@field castIndex integer? local mt = {} mt.__index = mt @@ -79,25 +81,27 @@ function mt:collectLocal() self.assigns[#self.assigns+1] = obj self.assignMap[obj] = true self:collectCare(obj) + if obj.finish > finishPos then + finishPos = obj.finish + end end if obj.type == 'getlocal' then self:collectCare(obj) + if obj.finish > finishPos then + finishPos = obj.finish + end end end local casts = self:getCasts() for _, cast in ipairs(casts) do if cast.loc[1] == self.source[1] - and cast.start > startPos + and cast.start > startPos and cast.finish < finishPos and guide.getLocal(self.source, self.source[1], cast.start) == self.source then - self.assigns[#self.assigns+1] = cast + self.casts[#self.casts+1] = cast end end - - table.sort(self.assigns, function (a, b) - return a.start < b.start - end) end ---@param start integer @@ -109,7 +113,7 @@ function mt:getLastAssign(start, finish) if obj.start < start then goto CONTINUE end - if obj.start >= finish then + if (obj.range or obj.start) >= finish then break end local objBlock = guide.getParentBlock(obj) @@ -125,6 +129,58 @@ function mt:getLastAssign(start, finish) return assign end +---@param pos integer +function mt:resetCastsIndex(pos) + for i = 1, #self.casts do + local cast = self.casts[i] + if cast.start > pos then + self.castIndex = i + return + end + end + self.castIndex = nil +end + +---@param pos integer +---@param node vm.node +---@return vm.node +function mt:fastWardCasts(pos, node) + if not self.castIndex then + return node + end + for i = self.castIndex, #self.casts do + local action = self.casts[i] + if action.start > pos then + return node + end + node = node:copy() + for _, cast in ipairs(action.casts) do + if cast.mode == '+' then + if cast.optional then + node:addOptional() + end + if cast.extends then + node:merge(vm.compileNode(cast.extends)) + end + elseif cast.mode == '-' then + if cast.optional then + node:removeOptional() + end + if cast.extends then + node:removeNode(vm.compileNode(cast.extends)) + end + else + if cast.extends then + node:clear() + node:merge(vm.compileNode(cast.extends)) + end + end + end + end + self.castIndex = self.castIndex + 1 + return node +end + ---@param action parser.object ---@param topNode vm.node ---@param outNode? vm.node @@ -136,6 +192,7 @@ function mt:lookIntoChild(action, topNode, outNode) return topNode, outNode or topNode end self.mark[action] = true + topNode = self:fastWardCasts(action.start, topNode) if action.type == 'getlocal' then if action.node == self.source then self.nodes[action] = topNode @@ -306,13 +363,18 @@ function mt:lookIntoChild(action, topNode, outNode) or subBlock.hasBreak or subBlock.hasError if not neverReturn then + local ok local lastAssign = self:getLastAssign(subBlock.start, subBlock.finish) if lastAssign then local node = self:getNode(lastAssign) if node then blockNodes[#blockNodes+1] = node + ok = true end end + if not ok then + blockNodes[#blockNodes+1] = blockNode + end end end if not hasElse and not topNode:hasKnownType() then @@ -367,6 +429,7 @@ end ---@param start integer ---@param node vm.node function mt:lookIntoBlock(block, start, node) + self:resetCastsIndex(start) for _, action in ipairs(block) do if action.start < start then goto CONTINUE @@ -436,6 +499,7 @@ local function createTracer(source) assignMap = {}, careMap = {}, mark = {}, + casts = {}, nodes = {}, main = main, uri = guide.getUri(source), diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 8ae65e48..0eb43ed3 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -1637,7 +1637,7 @@ function A() end ]] -TEST 'unknown' [[ +TEST 'string' [[ local x function A() @@ -2317,7 +2317,7 @@ local x print() ]] -TEST 'unknown?' [[ +TEST 'nil' [[ ---@type string? local x @@ -2391,7 +2391,7 @@ end print() ]] -TEST 'unknown?' [[ +TEST 'nil' [[ ---@type integer? local t @@ -3200,7 +3200,7 @@ local function f() end local x, y, = 1, 2, f() ]] -TEST 'function' [[ +TEST 'unknown' [[ local f print() @@ -3208,6 +3208,26 @@ print() function f() end ]] +TEST 'unknown' [[ +local f + +do + print() +end + +function f() end +]] + +TEST 'function' [[ +local f + +function A() + print() +end + +function f() end +]] + TEST 'number' [[ ---@type number|nil local n -- cgit v1.2.3