summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/vm/compiler.lua64
-rw-r--r--script/vm/runner.lua112
-rw-r--r--test/crossfile/references.lua8
-rw-r--r--test/tclient/tests/recursive-runner.lua42
4 files changed, 169 insertions, 57 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index ef0f7157..70a46677 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -976,7 +976,7 @@ end
---@param source parser.object
local function compileLocal(source)
- vm.setNode(source, source)
+ local myNode = vm.setNode(source, source)
local hasMarkDoc
if source.bindDocs then
@@ -988,7 +988,7 @@ local function compileLocal(source)
if selfNode then
hasMarkParam = true
vm.setNode(source, vm.compileNode(selfNode))
- vm.getNode(source):remove 'function'
+ myNode:remove 'function'
end
end
local hasMarkValue
@@ -1062,7 +1062,7 @@ local function compileLocal(source)
end
end
- vm.getNode(source):setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue)
+ myNode:setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue)
end
---@param source parser.object
@@ -1192,15 +1192,40 @@ local compilerSwitch = util.switch()
---@async
---@param source parser.object
: call(function (source)
- compileLocal(source)
- local refs = source.ref
- if not refs then
- return
- end
-
- local hasMark = vm.getNode(source):getData 'hasDefined'
-
- vm.launchRunner(source, function (src, node)
+ vm.launchRunner(source, function ()
+ local myNode = vm.getNode(source)
+ ---@cast myNode -?
+ myNode:setData('resolving', true)
+
+ if source.ref then
+ for _, ref in ipairs(source.ref) do
+ if ref.type == 'getlocal'
+ or ref.type == 'setlocal' then
+ vm.setNode(ref, myNode, true)
+ end
+ end
+ end
+ compileLocal(source)
+
+ myNode:setData('hasResolved', true)
+ end, function ()
+ local myNode = vm.getNode(source)
+ ---@cast myNode -?
+ myNode:setData('resolving', nil)
+ local hasMark = vm.getNode(source):getData 'hasDefined'
+ if source.ref and not hasMark then
+ local parentFunc = guide.getParentFunction(source)
+ for _, ref in ipairs(source.ref) do
+ if ref.type == 'setlocal'
+ and guide.getParentFunction(ref) == parentFunc then
+ local refNode = vm.getNode(ref)
+ if refNode then
+ vm.setNode(source, refNode)
+ end
+ end
+ end
+ end
+ end, function (src, node)
if src.type == 'setlocal' then
if src.bindDocs then
for _, doc in ipairs(src.bindDocs) do
@@ -1212,7 +1237,7 @@ local compilerSwitch = util.switch()
end
if src.value then
if src.value.type == 'table' then
- vm.setNode(src, vm.createNode(src.value))
+ vm.setNode(src, vm.createNode(src.value), true)
vm.setNode(src, node:copy():asTable())
else
vm.setNode(src, vm.compileNode(src.value), true)
@@ -1231,18 +1256,7 @@ local compilerSwitch = util.switch()
end
end)
- if not hasMark then
- local parentFunc = guide.getParentFunction(source)
- for _, ref in ipairs(source.ref) do
- if ref.type == 'setlocal'
- and guide.getParentFunction(ref) == parentFunc then
- local refNode = vm.getNode(ref)
- if refNode then
- vm.setNode(source, refNode)
- end
- end
- end
- end
+ vm.waitResolveRunner(source)
end)
: case 'setlocal'
: call(function (source)
diff --git a/script/vm/runner.lua b/script/vm/runner.lua
index 341126c6..13370ac7 100644
--- a/script/vm/runner.lua
+++ b/script/vm/runner.lua
@@ -345,7 +345,7 @@ function mt:lookIntoBlock(block, topNode)
return topNode
end
----@alias runner.info { source?: parser.object, loc: parser.object }
+---@alias runner.info { target?: parser.object, loc: parser.object }
---@type thread?
local masterRunner = nil
@@ -364,25 +364,55 @@ local runnerList = nil
---@param info runner.info
local function waitResolve(info)
while true do
- if not info.source then
+ if not info.target then
break
end
- if info.source.node == info.loc then
+ if info.target.node == info.loc then
break
end
- local node = vm.getNode(info.source)
+ local node = vm.getNode(info.target)
if node and node:getData('hasResolved') then
break
end
coroutine.yield()
end
- info.source = nil
+ info.target = nil
+end
+
+local function resolveDeadLock()
+ if not runnerList then
+ return
+ end
+
+ ---@type runner.info[]
+ local infos = {}
+ for runner in runnerList:pairs() do
+ local info = runnerInfo[runner]
+ infos[#infos+1] = info
+ end
+
+ table.sort(infos, function (a, b)
+ local uriA = guide.getUri(a.loc)
+ local uriB = guide.getUri(b.loc)
+ if uriA ~= uriB then
+ return uriA < uriB
+ end
+ return a.loc.start < b.loc.start
+ end)
+
+ local firstTarget = infos[1].target
+ ---@cast firstTarget -?
+ local firstNode = vm.setNode(firstTarget, vm.getNode(firstTarget):copy(), true)
+ firstNode:setData('hasResolved', true)
+ firstNode:setData('resolvedByDeadLock', true)
end
---@async
---@param loc parser.object
+---@param start fun()
+---@param finish fun()
---@param callback vm.runner.callback
-function vm.launchRunner(loc, callback)
+function vm.launchRunner(loc, start, finish, callback)
local locNode = vm.getNode(loc)
if not locNode then
return
@@ -393,10 +423,10 @@ function vm.launchRunner(loc, callback)
if not runnerList or runnerList:getSize() == 0 then
return
end
- local allWaiting = true
+ local deadLock = true
for runner in runnerList:pairs() do
local info = runnerInfo[runner]
- local waitingSource = info.source
+ local waitingSource = info.target
if coroutine.status(runner) == 'suspended' then
local suc, err = coroutine.resume(runner)
if not suc then
@@ -404,25 +434,33 @@ function vm.launchRunner(loc, callback)
end
else
runnerList:pop(runner)
+ deadLock = false
end
- if not waitingSource or waitingSource ~= info.source then
- allWaiting = false
+ if not waitingSource or waitingSource ~= info.target then
+ deadLock = false
end
end
if runnerList:getSize() == 0 then
return
end
- if allWaiting or i == 10000 then
+ if deadLock then
+ resolveDeadLock()
+ end
+ if i == 10000 then
local lines = {}
lines[#lines+1] = 'Dead lock:'
- lines[#lines+1] = guide.getUri(loc)
for runner in runnerList:pairs() do
local info = runnerInfo[runner]
- lines[#lines+1] = string.format('Runner `%s` at %d waiting for `%s` at %d'
- , loc[1]
- , loc.start
- , info.source and info.source[1] or ''
- , info.source and info.source.start or 0
+ lines[#lines+1] = '==============='
+ lines[#lines+1] = string.format('Runner `%s` at %d(%s)'
+ , info.loc[1]
+ , info.loc.start
+ , guide.getUri(info.loc)
+ )
+ lines[#lines+1] = string.format('Waiting `%s` at %d(%s)'
+ , info.target[1]
+ , info.target.start
+ , guide.getUri(info.target)
)
end
local msg = table.concat(lines, '\n')
@@ -432,8 +470,14 @@ function vm.launchRunner(loc, callback)
end
local function launch()
+ start()
+ if not loc.ref then
+ finish()
+ return
+ end
local main = guide.getParentBlock(loc)
if not main then
+ finish()
return
end
local self = setmetatable({
@@ -451,6 +495,8 @@ function vm.launchRunner(loc, callback)
self:lookIntoBlock(main, locNode:copy())
locNode:setData('runner', nil)
+
+ finish()
end
local co = coroutine.create(launch)
@@ -474,25 +520,35 @@ end
---@async
---@param source parser.object
function vm.waitResolveRunner(source)
- local running = coroutine.running()
- if not masterRunner or running == masterRunner then
+ local myNode = vm.getNode(source)
+ if myNode and myNode:getData('hasResolved') then
return
end
- local loc = source.node
- local locNode = vm.getNode(loc)
- if not locNode then
- return
- end
- local runner = locNode:getData('runner')
- if not runner or runner == running then
+ local running = coroutine.running()
+ if not masterRunner or running == masterRunner then
return
end
local info = runnerInfo[running]
- if info.loc == loc then
+
+ local targetLoc
+ if source.type == 'getlocal' then
+ targetLoc = source.node
+ elseif source.type == 'local'
+ or source.type == 'self' then
+ targetLoc = source
+ info.target = info.target or source
+ else
+ error('Unknown source type: ' .. source.type)
+ end
+
+ local targetNode = vm.getNode(targetLoc)
+ if not targetNode then
+ -- Wait for compiling local by `compiler`
return
end
+
waitResolve(info)
end
@@ -505,5 +561,5 @@ function vm.storeWaitingRunner(source)
local running = coroutine.running()
local info = runnerInfo[running]
- info.source = source
+ info.target = source
end
diff --git a/test/crossfile/references.lua b/test/crossfile/references.lua
index 22b13f42..0b7beb82 100644
--- a/test/crossfile/references.lua
+++ b/test/crossfile/references.lua
@@ -361,15 +361,15 @@ TEST {
{
path = 'a.lua',
content = [[
- local <~t~> = require 'b'
- return <!t!>
+ local <~x~> = require 'b'
+ return <!x!>
]]
},
{
path = 'b.lua',
content = [[
- local t = require 'a'
- return t
+ local y = require 'a'
+ return y
]]
},
}
diff --git a/test/tclient/tests/recursive-runner.lua b/test/tclient/tests/recursive-runner.lua
index 3e3b5bba..5ecae705 100644
--- a/test/tclient/tests/recursive-runner.lua
+++ b/test/tclient/tests/recursive-runner.lua
@@ -118,5 +118,47 @@ y = x
assert(hover6.contents.value:find 'number')
assert(hover7.contents.value:find 'number')
+ client:notify('textDocument/didOpen', {
+ textDocument = {
+ uri = 'file://test.lua',
+ languageId = 'lua',
+ version = 2,
+ text = [[
+---@meta
+
+---@class vector3
+---@operator add(vector3): vector3
+---@operator sub(vector3): vector3
+---@operator mul(vector3): number
+---@operator mul(number): vector3
+---@operator div(number): vector3
+local mt
+
+---@return vector3
+function mt:normalize() end
+
+---@param target vector3
+function Walk(target)
+ local moveSpeed = 1.0
+ local deltalTime = 2.0
+
+ ---@type vector3
+ local curPos
+ local targetDirVec = (target - curPos):normalize()
+ local stepMove = targetDirVec * (moveSpeed * deltalTime)
+ local nextPos = curPos + stepMove
+
+ curPos = nextPos
+end
+]]
+ }
+ })
+
+ local hover1 = client:awaitRequest('textDocument/hover', {
+ textDocument = { uri = 'file://test.lua' },
+ position = { line = 20, character = 11 },
+ })
+ assert(hover1.contents.value:find 'vector3')
+
config.set(nil, 'Lua.diagnostics.enable', true)
end)