summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/vm/compiler.lua39
-rw-r--r--script/vm/global.lua39
-rw-r--r--script/vm/variable.lua4
-rw-r--r--test/type_inference/init.lua14
4 files changed, 66 insertions, 30 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index f3366958..18b2905e 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -9,25 +9,32 @@ local vm = require 'vm.vm'
---@class parser.object
---@field _compiledNodes boolean
---@field _node vm.node
+---@field package _hasBindedDocs? boolean
---@field cindex integer
---@field func parser.object
-- 该函数有副作用,会给source绑定node!
---@param source parser.object
---@return boolean
-local function bindDocs(source)
+function vm.bindDocs(source)
local docs = source.bindDocs
if not docs then
return false
end
+ if source._hasBindedDocs ~= nil then
+ return source._hasBindedDocs
+ end
+ source._hasBindedDocs = false
for i = #docs, 1, -1 do
local doc = docs[i]
if doc.type == 'doc.type' then
vm.setNode(source, vm.compileNode(doc))
+ source._hasBindedDocs = true
return true
end
if doc.type == 'doc.class' then
vm.setNode(source, vm.compileNode(doc))
+ source._hasBindedDocs = true
return true
end
if doc.type == 'doc.param' then
@@ -36,23 +43,28 @@ local function bindDocs(source)
node:addOptional()
end
vm.setNode(source, node)
+ source._hasBindedDocs = true
return true
end
if doc.type == 'doc.module' then
local name = doc.module
if not name then
+ source._hasBindedDocs = true
return true
end
local uri = rpath.findUrisByRequireName(guide.getUri(source), name)[1]
if not uri then
+ source._hasBindedDocs = true
return true
end
local state = files.getState(uri)
local ast = state and state.ast
if not ast then
+ source._hasBindedDocs = true
return true
end
vm.setNode(source, vm.compileNode(ast))
+ source._hasBindedDocs = true
return true
end
end
@@ -90,7 +102,7 @@ local function searchFieldByLocalID(source, key, pushResult)
local hasMarkDoc = {}
for _, src in ipairs(fields) do
if src.bindDocs then
- if bindDocs(src) then
+ if vm.bindDocs(src) then
local skey = guide.getKeyName(src)
if skey then
hasMarkDoc[skey] = true
@@ -170,7 +182,7 @@ local searchFieldSwitch = util.switch()
hasFiled = true
pushResult(field)
end
- if key == nil then
+ if key == vm.ANY then
pushResult(field)
end
end
@@ -695,7 +707,7 @@ function vm.compileByParentNode(source, key, pushResult)
pushResult(res)
end
end
- if #docedResults == 0 or key == nil then
+ if #docedResults == 0 or key == vm.ANY then
for _, res in ipairs(commonResults) do
pushResult(res)
end
@@ -970,7 +982,7 @@ local function compileLocal(source)
local hasMarkDoc
if source.bindDocs then
- hasMarkDoc = bindDocs(source)
+ hasMarkDoc = vm.bindDocs(source)
end
local hasMarkParam
if not hasMarkDoc then
@@ -1035,7 +1047,7 @@ local function compileLocal(source)
-- for x = ... do
if source.parent.type == 'loop' then
if source.parent.loc == source then
- if bindDocs(source) then
+ if vm.bindDocs(source) then
return
end
vm.setNode(source, vm.declareGlobal('type', 'integer'))
@@ -1199,7 +1211,7 @@ local compilerSwitch = util.switch()
end)
: case 'setlocal'
: call(function (source)
- if bindDocs(source) then
+ if vm.bindDocs(source) then
return
end
local locNode = vm.compileNode(source.node)
@@ -1236,7 +1248,7 @@ local compilerSwitch = util.switch()
: case 'getmethod'
: case 'getindex'
: call(function (source)
- if bindDocs(source) then
+ if vm.bindDocs(source) then
return
end
if guide.isGet(source) and bindAs(source) then
@@ -1278,7 +1290,7 @@ local compilerSwitch = util.switch()
end)
: case 'setglobal'
: call(function (source)
- if bindDocs(source) then
+ if vm.bindDocs(source) then
return
end
if source.node[1] ~= '_ENV' then
@@ -1307,7 +1319,7 @@ local compilerSwitch = util.switch()
: call(function (source)
local hasMarkDoc
if source.bindDocs then
- hasMarkDoc = bindDocs(source)
+ hasMarkDoc = vm.bindDocs(source)
end
if not hasMarkDoc then
@@ -1875,9 +1887,10 @@ function vm.compileNode(source)
---@cast source parser.object
vm.setNode(source, vm.createNode(), true)
- vm.compileByGlobal(source)
- vm.compileByVariable(source)
- compileByNode(source)
+ if not vm.compileByGlobal(source) then
+ vm.compileByVariable(source)
+ compileByNode(source)
+ end
compileByParentNode(source)
matchCall(source)
diff --git a/script/vm/global.lua b/script/vm/global.lua
index 5ffcdb34..0a90829a 100644
--- a/script/vm/global.lua
+++ b/script/vm/global.lua
@@ -540,31 +540,38 @@ function vm.getEnums(source)
end
---@param source parser.object
+---@return boolean
function vm.compileByGlobal(source)
local global = vm.getGlobalNode(source)
if not global then
- return
+ return false
end
- if global.cate == 'variable' then
+ if global.cate == 'type' then
vm.setNode(source, global)
- if guide.isAssign(source) then
- if source.value then
- vm.setNode(source, vm.compileNode(source.value))
- end
- return
+ return false
+ end
+ vm.setNode(source, global)
+ if guide.isAssign(source) then
+ if vm.bindDocs(source) then
+ return true
end
- local node = vm.traceNode(source)
- if node then
- vm.setNode(source, node, true)
+ if source.value then
+ vm.setNode(source, vm.compileNode(source.value))
end
- return
+ return true
end
- local globalBase = vm.getGlobalBase(source)
- if not globalBase then
- return
+ local node = vm.traceNode(source)
+ if node then
+ vm.setNode(source, node, true)
+ else
+ local globalBase = vm.getGlobalBase(source)
+ if not globalBase then
+ return false
+ end
+ local globalNode = vm.compileNode(globalBase)
+ vm.setNode(source, globalNode, true)
end
- local globalNode = vm.compileNode(globalBase)
- vm.setNode(source, globalNode, true)
+ return true
end
---@param source parser.object
diff --git a/script/vm/variable.lua b/script/vm/variable.lua
index 539d8507..150ad18b 100644
--- a/script/vm/variable.lua
+++ b/script/vm/variable.lua
@@ -344,12 +344,14 @@ function vm.getVariableFields(source, includeGets)
end
---@param source parser.object
+---@return boolean
function vm.compileByVariable(source)
local variable = vm.getVariableNode(source)
if not variable then
- return
+ return false
end
vm.setNode(source, variable)
+ return true
end
---@param source parser.object
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index f4a4449d..ed29d3d1 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -4175,3 +4175,17 @@ if xxx == <?t?> then
print(t)
end
]]
+
+TEST 'V' [[
+---@class V
+X = 1
+
+print(<?X?>)
+]]
+
+TEST 'V' [[
+---@class V
+X.Y = 1
+
+print(X.<?Y?>)
+]]