diff options
-rw-r--r-- | script/vm/compiler.lua | 64 | ||||
-rw-r--r-- | script/vm/runner.lua | 112 | ||||
-rw-r--r-- | test/crossfile/references.lua | 8 | ||||
-rw-r--r-- | test/tclient/tests/recursive-runner.lua | 42 |
4 files changed, 169 insertions, 57 deletions
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index ef0f7157..70a46677 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -976,7 +976,7 @@ end ---@param source parser.object local function compileLocal(source) - vm.setNode(source, source) + local myNode = vm.setNode(source, source) local hasMarkDoc if source.bindDocs then @@ -988,7 +988,7 @@ local function compileLocal(source) if selfNode then hasMarkParam = true vm.setNode(source, vm.compileNode(selfNode)) - vm.getNode(source):remove 'function' + myNode:remove 'function' end end local hasMarkValue @@ -1062,7 +1062,7 @@ local function compileLocal(source) end end - vm.getNode(source):setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue) + myNode:setData('hasDefined', hasMarkDoc or hasMarkParam or hasMarkValue) end ---@param source parser.object @@ -1192,15 +1192,40 @@ local compilerSwitch = util.switch() ---@async ---@param source parser.object : call(function (source) - compileLocal(source) - local refs = source.ref - if not refs then - return - end - - local hasMark = vm.getNode(source):getData 'hasDefined' - - vm.launchRunner(source, function (src, node) + vm.launchRunner(source, function () + local myNode = vm.getNode(source) + ---@cast myNode -? + myNode:setData('resolving', true) + + if source.ref then + for _, ref in ipairs(source.ref) do + if ref.type == 'getlocal' + or ref.type == 'setlocal' then + vm.setNode(ref, myNode, true) + end + end + end + compileLocal(source) + + myNode:setData('hasResolved', true) + end, function () + local myNode = vm.getNode(source) + ---@cast myNode -? + myNode:setData('resolving', nil) + local hasMark = vm.getNode(source):getData 'hasDefined' + if source.ref and not hasMark then + local parentFunc = guide.getParentFunction(source) + for _, ref in ipairs(source.ref) do + if ref.type == 'setlocal' + and guide.getParentFunction(ref) == parentFunc then + local refNode = vm.getNode(ref) + if refNode then + vm.setNode(source, refNode) + end + end + end + end + end, function (src, node) if src.type == 'setlocal' then if src.bindDocs then for _, doc in ipairs(src.bindDocs) do @@ -1212,7 +1237,7 @@ local compilerSwitch = util.switch() end if src.value then if src.value.type == 'table' then - vm.setNode(src, vm.createNode(src.value)) + vm.setNode(src, vm.createNode(src.value), true) vm.setNode(src, node:copy():asTable()) else vm.setNode(src, vm.compileNode(src.value), true) @@ -1231,18 +1256,7 @@ local compilerSwitch = util.switch() end end) - if not hasMark then - local parentFunc = guide.getParentFunction(source) - for _, ref in ipairs(source.ref) do - if ref.type == 'setlocal' - and guide.getParentFunction(ref) == parentFunc then - local refNode = vm.getNode(ref) - if refNode then - vm.setNode(source, refNode) - end - end - end - end + vm.waitResolveRunner(source) end) : case 'setlocal' : call(function (source) diff --git a/script/vm/runner.lua b/script/vm/runner.lua index 341126c6..13370ac7 100644 --- a/script/vm/runner.lua +++ b/script/vm/runner.lua @@ -345,7 +345,7 @@ function mt:lookIntoBlock(block, topNode) return topNode end ----@alias runner.info { source?: parser.object, loc: parser.object } +---@alias runner.info { target?: parser.object, loc: parser.object } ---@type thread? local masterRunner = nil @@ -364,25 +364,55 @@ local runnerList = nil ---@param info runner.info local function waitResolve(info) while true do - if not info.source then + if not info.target then break end - if info.source.node == info.loc then + if info.target.node == info.loc then break end - local node = vm.getNode(info.source) + local node = vm.getNode(info.target) if node and node:getData('hasResolved') then break end coroutine.yield() end - info.source = nil + info.target = nil +end + +local function resolveDeadLock() + if not runnerList then + return + end + + ---@type runner.info[] + local infos = {} + for runner in runnerList:pairs() do + local info = runnerInfo[runner] + infos[#infos+1] = info + end + + table.sort(infos, function (a, b) + local uriA = guide.getUri(a.loc) + local uriB = guide.getUri(b.loc) + if uriA ~= uriB then + return uriA < uriB + end + return a.loc.start < b.loc.start + end) + + local firstTarget = infos[1].target + ---@cast firstTarget -? + local firstNode = vm.setNode(firstTarget, vm.getNode(firstTarget):copy(), true) + firstNode:setData('hasResolved', true) + firstNode:setData('resolvedByDeadLock', true) end ---@async ---@param loc parser.object +---@param start fun() +---@param finish fun() ---@param callback vm.runner.callback -function vm.launchRunner(loc, callback) +function vm.launchRunner(loc, start, finish, callback) local locNode = vm.getNode(loc) if not locNode then return @@ -393,10 +423,10 @@ function vm.launchRunner(loc, callback) if not runnerList or runnerList:getSize() == 0 then return end - local allWaiting = true + local deadLock = true for runner in runnerList:pairs() do local info = runnerInfo[runner] - local waitingSource = info.source + local waitingSource = info.target if coroutine.status(runner) == 'suspended' then local suc, err = coroutine.resume(runner) if not suc then @@ -404,25 +434,33 @@ function vm.launchRunner(loc, callback) end else runnerList:pop(runner) + deadLock = false end - if not waitingSource or waitingSource ~= info.source then - allWaiting = false + if not waitingSource or waitingSource ~= info.target then + deadLock = false end end if runnerList:getSize() == 0 then return end - if allWaiting or i == 10000 then + if deadLock then + resolveDeadLock() + end + if i == 10000 then local lines = {} lines[#lines+1] = 'Dead lock:' - lines[#lines+1] = guide.getUri(loc) for runner in runnerList:pairs() do local info = runnerInfo[runner] - lines[#lines+1] = string.format('Runner `%s` at %d waiting for `%s` at %d' - , loc[1] - , loc.start - , info.source and info.source[1] or '' - , info.source and info.source.start or 0 + lines[#lines+1] = '===============' + lines[#lines+1] = string.format('Runner `%s` at %d(%s)' + , info.loc[1] + , info.loc.start + , guide.getUri(info.loc) + ) + lines[#lines+1] = string.format('Waiting `%s` at %d(%s)' + , info.target[1] + , info.target.start + , guide.getUri(info.target) ) end local msg = table.concat(lines, '\n') @@ -432,8 +470,14 @@ function vm.launchRunner(loc, callback) end local function launch() + start() + if not loc.ref then + finish() + return + end local main = guide.getParentBlock(loc) if not main then + finish() return end local self = setmetatable({ @@ -451,6 +495,8 @@ function vm.launchRunner(loc, callback) self:lookIntoBlock(main, locNode:copy()) locNode:setData('runner', nil) + + finish() end local co = coroutine.create(launch) @@ -474,25 +520,35 @@ end ---@async ---@param source parser.object function vm.waitResolveRunner(source) - local running = coroutine.running() - if not masterRunner or running == masterRunner then + local myNode = vm.getNode(source) + if myNode and myNode:getData('hasResolved') then return end - local loc = source.node - local locNode = vm.getNode(loc) - if not locNode then - return - end - local runner = locNode:getData('runner') - if not runner or runner == running then + local running = coroutine.running() + if not masterRunner or running == masterRunner then return end local info = runnerInfo[running] - if info.loc == loc then + + local targetLoc + if source.type == 'getlocal' then + targetLoc = source.node + elseif source.type == 'local' + or source.type == 'self' then + targetLoc = source + info.target = info.target or source + else + error('Unknown source type: ' .. source.type) + end + + local targetNode = vm.getNode(targetLoc) + if not targetNode then + -- Wait for compiling local by `compiler` return end + waitResolve(info) end @@ -505,5 +561,5 @@ function vm.storeWaitingRunner(source) local running = coroutine.running() local info = runnerInfo[running] - info.source = source + info.target = source end diff --git a/test/crossfile/references.lua b/test/crossfile/references.lua index 22b13f42..0b7beb82 100644 --- a/test/crossfile/references.lua +++ b/test/crossfile/references.lua @@ -361,15 +361,15 @@ TEST { { path = 'a.lua', content = [[ - local <~t~> = require 'b' - return <!t!> + local <~x~> = require 'b' + return <!x!> ]] }, { path = 'b.lua', content = [[ - local t = require 'a' - return t + local y = require 'a' + return y ]] }, } diff --git a/test/tclient/tests/recursive-runner.lua b/test/tclient/tests/recursive-runner.lua index 3e3b5bba..5ecae705 100644 --- a/test/tclient/tests/recursive-runner.lua +++ b/test/tclient/tests/recursive-runner.lua @@ -118,5 +118,47 @@ y = x assert(hover6.contents.value:find 'number') assert(hover7.contents.value:find 'number') + client:notify('textDocument/didOpen', { + textDocument = { + uri = 'file://test.lua', + languageId = 'lua', + version = 2, + text = [[ +---@meta + +---@class vector3 +---@operator add(vector3): vector3 +---@operator sub(vector3): vector3 +---@operator mul(vector3): number +---@operator mul(number): vector3 +---@operator div(number): vector3 +local mt + +---@return vector3 +function mt:normalize() end + +---@param target vector3 +function Walk(target) + local moveSpeed = 1.0 + local deltalTime = 2.0 + + ---@type vector3 + local curPos + local targetDirVec = (target - curPos):normalize() + local stepMove = targetDirVec * (moveSpeed * deltalTime) + local nextPos = curPos + stepMove + + curPos = nextPos +end +]] + } + }) + + local hover1 = client:awaitRequest('textDocument/hover', { + textDocument = { uri = 'file://test.lua' }, + position = { line = 20, character = 11 }, + }) + assert(hover1.contents.value:find 'vector3') + config.set(nil, 'Lua.diagnostics.enable', true) end) |