diff options
-rw-r--r-- | script/vm/compiler.lua | 39 | ||||
-rw-r--r-- | script/vm/global.lua | 39 | ||||
-rw-r--r-- | script/vm/variable.lua | 4 | ||||
-rw-r--r-- | test/type_inference/init.lua | 14 |
4 files changed, 66 insertions, 30 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index f3366958..18b2905e 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -9,25 +9,32 @@ local vm = require 'vm.vm' ---@class parser.object ---@field _compiledNodes boolean ---@field _node vm.node +---@field package _hasBindedDocs? boolean ---@field cindex integer ---@field func parser.object -- 该函数有副作用,会给source绑定node! ---@param source parser.object ---@return boolean -local function bindDocs(source) +function vm.bindDocs(source) local docs = source.bindDocs if not docs then return false end + if source._hasBindedDocs ~= nil then + return source._hasBindedDocs + end + source._hasBindedDocs = false for i = #docs, 1, -1 do local doc = docs[i] if doc.type == 'doc.type' then vm.setNode(source, vm.compileNode(doc)) + source._hasBindedDocs = true return true end if doc.type == 'doc.class' then vm.setNode(source, vm.compileNode(doc)) + source._hasBindedDocs = true return true end if doc.type == 'doc.param' then @@ -36,23 +43,28 @@ local function bindDocs(source) node:addOptional() end vm.setNode(source, node) + source._hasBindedDocs = true return true end if doc.type == 'doc.module' then local name = doc.module if not name then + source._hasBindedDocs = true return true end local uri = rpath.findUrisByRequireName(guide.getUri(source), name)[1] if not uri then + source._hasBindedDocs = true return true end local state = files.getState(uri) local ast = state and state.ast if not ast then + source._hasBindedDocs = true return true end vm.setNode(source, vm.compileNode(ast)) + source._hasBindedDocs = true return true end end @@ -90,7 +102,7 @@ local function searchFieldByLocalID(source, key, pushResult) local hasMarkDoc = {} for _, src in ipairs(fields) do if src.bindDocs then - if bindDocs(src) then + if vm.bindDocs(src) then local skey = guide.getKeyName(src) if skey then hasMarkDoc[skey] = true @@ -170,7 +182,7 @@ local searchFieldSwitch = util.switch() hasFiled = true pushResult(field) end - if key == nil then + if key == vm.ANY then pushResult(field) end end @@ -695,7 +707,7 @@ function vm.compileByParentNode(source, key, pushResult) pushResult(res) end end - if #docedResults == 0 or key == nil then + if #docedResults == 0 or key == vm.ANY then for _, res in ipairs(commonResults) do pushResult(res) end @@ -970,7 +982,7 @@ local function compileLocal(source) local hasMarkDoc if source.bindDocs then - hasMarkDoc = bindDocs(source) + hasMarkDoc = vm.bindDocs(source) end local hasMarkParam if not hasMarkDoc then @@ -1035,7 +1047,7 @@ local function compileLocal(source) -- for x = ... do if source.parent.type == 'loop' then if source.parent.loc == source then - if bindDocs(source) then + if vm.bindDocs(source) then return end vm.setNode(source, vm.declareGlobal('type', 'integer')) @@ -1199,7 +1211,7 @@ local compilerSwitch = util.switch() end) : case 'setlocal' : call(function (source) - if bindDocs(source) then + if vm.bindDocs(source) then return end local locNode = vm.compileNode(source.node) @@ -1236,7 +1248,7 @@ local compilerSwitch = util.switch() : case 'getmethod' : case 'getindex' : call(function (source) - if bindDocs(source) then + if vm.bindDocs(source) then return end if guide.isGet(source) and bindAs(source) then @@ -1278,7 +1290,7 @@ local compilerSwitch = util.switch() end) : case 'setglobal' : call(function (source) - if bindDocs(source) then + if vm.bindDocs(source) then return end if source.node[1] ~= '_ENV' then @@ -1307,7 +1319,7 @@ local compilerSwitch = util.switch() : call(function (source) local hasMarkDoc if source.bindDocs then - hasMarkDoc = bindDocs(source) + hasMarkDoc = vm.bindDocs(source) end if not hasMarkDoc then @@ -1875,9 +1887,10 @@ function vm.compileNode(source) ---@cast source parser.object vm.setNode(source, vm.createNode(), true) - vm.compileByGlobal(source) - vm.compileByVariable(source) - compileByNode(source) + if not vm.compileByGlobal(source) then + vm.compileByVariable(source) + compileByNode(source) + end compileByParentNode(source) matchCall(source) diff --git a/script/vm/global.lua b/script/vm/global.lua index 5ffcdb34..0a90829a 100644 --- a/script/vm/global.lua +++ b/script/vm/global.lua @@ -540,31 +540,38 @@ function vm.getEnums(source) end ---@param source parser.object +---@return boolean function vm.compileByGlobal(source) local global = vm.getGlobalNode(source) if not global then - return + return false end - if global.cate == 'variable' then + if global.cate == 'type' then vm.setNode(source, global) - if guide.isAssign(source) then - if source.value then - vm.setNode(source, vm.compileNode(source.value)) - end - return + return false + end + vm.setNode(source, global) + if guide.isAssign(source) then + if vm.bindDocs(source) then + return true end - local node = vm.traceNode(source) - if node then - vm.setNode(source, node, true) + if source.value then + vm.setNode(source, vm.compileNode(source.value)) end - return + return true end - local globalBase = vm.getGlobalBase(source) - if not globalBase then - return + local node = vm.traceNode(source) + if node then + vm.setNode(source, node, true) + else + local globalBase = vm.getGlobalBase(source) + if not globalBase then + return false + end + local globalNode = vm.compileNode(globalBase) + vm.setNode(source, globalNode, true) end - local globalNode = vm.compileNode(globalBase) - vm.setNode(source, globalNode, true) + return true end ---@param source parser.object diff --git a/script/vm/variable.lua b/script/vm/variable.lua index 539d8507..150ad18b 100644 --- a/script/vm/variable.lua +++ b/script/vm/variable.lua @@ -344,12 +344,14 @@ function vm.getVariableFields(source, includeGets) end ---@param source parser.object +---@return boolean function vm.compileByVariable(source) local variable = vm.getVariableNode(source) if not variable then - return + return false end vm.setNode(source, variable) + return true end ---@param source parser.object diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index f4a4449d..ed29d3d1 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -4175,3 +4175,17 @@ if xxx == <?t?> then print(t) end ]] + +TEST 'V' [[ +---@class V +X = 1 + +print(<?X?>) +]] + +TEST 'V' [[ +---@class V +X.Y = 1 + +print(X.<?Y?>) +]] |