summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/src/matcher/vm.lua204
-rw-r--r--server/src/service.lua56
-rw-r--r--server/test/crossfile/definition.lua10
-rw-r--r--server/test/crossfile/hover.lua28
-rw-r--r--server/test/hover/init.lua18
5 files changed, 146 insertions, 170 deletions
diff --git a/server/src/matcher/vm.lua b/server/src/matcher/vm.lua
index de55c12e..74a86d59 100644
--- a/server/src/matcher/vm.lua
+++ b/server/src/matcher/vm.lua
@@ -455,6 +455,77 @@ function mt:callSetMetaTable(func, values)
self:checkMetaIndex(values[1], values[2])
end
+function mt:getRequire(strValue, destVM)
+ -- 取出对方的主函数
+ local main = destVM.results.main
+ -- 获取主函数返回值,注意不能修改对方的环境
+ local mainValue
+ if main.returns then
+ mainValue = main.returns[1]
+ else
+ mainValue = self:createValue('boolean', {
+ type = 'name',
+ start = 0,
+ finish = 0,
+ [1] = '',
+ uri = destVM.uri,
+ }, true)
+ end
+
+ -- 支持 require 'xxx' 的转到定义
+ local strSource = strValue.source
+ self.results.sources[strSource] = strValue
+ strValue.uri = destVM.uri
+
+ return mainValue
+end
+
+function mt:getLoadFile(strValue, destVM)
+ -- 取出对方的主函数
+ local main = destVM.results.main
+ -- loadfile 的返回值就是对方的主函数
+ local mainValue = main
+
+ -- 支持 loadfile 'xxx.lua' 的转到定义
+ local strSource = strValue.source
+ self.results.sources[strSource] = strValue
+ strValue.uri = destVM.uri
+
+ return mainValue
+end
+
+function mt:tryRequireOne(strValue, mode)
+ if not self.lsp or not self.lsp.workspace then
+ return nil
+ end
+ local str = strValue.value
+ if type(str) == 'string' then
+ local uri
+ if mode == 'require' then
+ uri = self.lsp.workspace:searchPath(self.uri, str)
+ elseif mode == 'loadfile' then
+ uri = self.lsp.workspace:loadPath(self.uri, str)
+ elseif mode == 'dofile' then
+ uri = self.lsp.workspace:loadPath(self.uri, str)
+ end
+ -- 如果取不到VM(不编译),则做个标记,之后再取一次
+ local destVM = self.lsp:getVM(uri)
+ if destVM then
+ if mode == 'require' then
+ return self:getRequire(strValue, destVM)
+ elseif mode == 'loadfile' then
+ return self:getLoadFile(strValue, destVM)
+ elseif mode == 'dofile' then
+ return self:getRequire(strValue, destVM)
+ end
+ else
+ self.lsp:needCompile(uri)
+ self.lsp:needCompile(self.uri)
+ end
+ end
+ return nil
+end
+
function mt:callRequire(func, values)
if not values[1] then
values[1] = self:createValue('any')
@@ -469,12 +540,8 @@ function mt:callRequire(func, values)
self:setFunctionReturn(func, 1, value)
return
else
- local requireValue = self:createValue('boolean', nil, true)
+ local requireValue = self:tryRequireOne(values[1], 'require') or self:createValue('boolean')
self:setFunctionReturn(func, 1, requireValue)
- self.requires[requireValue] = {
- mode = 'require',
- str = values[1],
- }
end
end
@@ -486,12 +553,8 @@ function mt:callLoadFile(func, values)
if type(str) ~= 'string' then
return
end
- local requireValue = self:buildFunction()
+ local requireValue = self:tryRequireOne(values[1], 'loadfile')
self:setFunctionReturn(func, 1, requireValue)
- self.requires[requireValue] = {
- mode = 'loadfile',
- str = values[1],
- }
end
function mt:callDoFile(func, values)
@@ -502,12 +565,8 @@ function mt:callDoFile(func, values)
if type(str) ~= 'string' then
return
end
- local requireValue = self:createValue('any')
+ local requireValue = self:tryRequireOne(values[1], 'dofile')
self:setFunctionReturn(func, 1, requireValue)
- self.requires[requireValue] = {
- mode = 'dofile',
- str = values[1],
- }
end
function mt:call(func, values)
@@ -1242,117 +1301,6 @@ function mt:createEnvironment()
gValue.child = envValue.child
end
-function mt:mergeRequire(value, strValue, destVM)
- -- 取出对方的主函数
- local main = destVM.results.main
- -- 获取主函数返回值,注意不能修改对方的环境
- local mainValue
- if not main.returns then
- mainValue = self:createValue('nil', {
- type = 'name',
- start = 0,
- finish = 0,
- [1] = '',
- uri = destVM.uri,
- })
- else
- mainValue = main.returns[1]
- end
- self:mergeValue(value, mainValue)
-
- -- 支持 require 'xxx' 的转到定义
- local strSource = strValue.source
- self.results.sources[strSource] = strValue
- strValue.uri = destVM.uri
-end
-
-function mt:mergeLoadFile(value, strValue, destVM)
- -- 取出对方的主函数
- local main = destVM.results.main
- -- loadfile 的返回值就是对方的主函数
- local mainValue = main
- self:mergeValue(value, mainValue)
-
- -- 支持 loadfile 'xxx.lua' 的转到定义
- local strSource = strValue.source
- self.results.sources[strSource] = strValue
- strValue.uri = destVM.uri
-end
-
-function mt:loadRequires()
- if not self.lsp or not self.lsp.workspace then
- return
- end
- local copy = {}
- for k, v in pairs(self.requires) do
- self.requires[k] = nil
- copy[k] = v
- end
- for value, data in pairs(copy) do
- local strValue = data.str
- local mode = data.mode
- local str = strValue.value
- if type(str) == 'string' then
- local uri
- if mode == 'require' then
- uri = self.lsp.workspace:searchPath(self.uri, str)
- elseif mode == 'loadfile' then
- uri = self.lsp.workspace:loadPath(self.uri, str)
- elseif mode == 'dofile' then
- uri = self.lsp.workspace:loadPath(self.uri, str)
- elseif mode == '' then
- end
- -- 如果循环require,这里会返回nil
- -- 会当场编译VM
- local destVM = self.lsp:loadVM(uri)
- if destVM then
- if mode == 'require' then
- self:mergeRequire(value, strValue, destVM)
- elseif mode == 'loadfile' then
- self:mergeLoadFile(value, strValue, destVM)
- elseif mode == 'dofile' then
- self:mergeRequire(value, strValue, destVM)
- end
- end
- end
- end
-end
-
-function mt:tryLoadRequires()
- if not self.lsp or not self.lsp.workspace then
- return
- end
- for value, data in pairs(self.requires) do
- local strValue = data.str
- local mode = data.mode
- local str = strValue.value
- if type(str) == 'string' then
- local uri
- if mode == 'require' then
- uri = self.lsp.workspace:searchPath(self.uri, str)
- elseif mode == 'loadfile' then
- uri = self.lsp.workspace:loadPath(self.uri, str)
- elseif mode == 'dofile' then
- uri = self.lsp.workspace:loadPath(self.uri, str)
- end
- -- 如果取不到VM(不编译),则做个标记,之后再取一次
- local destVM = self.lsp:getVM(uri)
- if destVM then
- if mode == 'require' then
- self:mergeRequire(value, strValue, destVM)
- elseif mode == 'loadfile' then
- self:mergeLoadFile(value, strValue, destVM)
- elseif mode == 'dofile' then
- self:mergeRequire(value, strValue, destVM)
- end
- self.requires[value] = nil
- else
- self.lsp:needRequires(self.uri)
- end
- end
- end
-end
-
local function compile(ast, lsp, uri)
local vm = setmetatable({
scope = env {
@@ -1372,7 +1320,6 @@ local function compile(ast, lsp, uri)
},
libraryValue = {},
libraryChild = {},
- requires = {},
lsp = lsp,
uri = uri,
}, mt)
@@ -1383,9 +1330,6 @@ local function compile(ast, lsp, uri)
-- 执行代码
vm:doActions(ast)
- -- 合并
- vm:tryLoadRequires()
-
return vm
end
diff --git a/server/src/service.lua b/server/src/service.lua
index 12a6ed25..0c167570 100644
--- a/server/src/service.lua
+++ b/server/src/service.lua
@@ -131,13 +131,13 @@ function mt:clearDiagnostics(uri)
})
end
-function mt:_buildTextCache()
+function mt:compileAll()
if not next(self._needCompile) then
return
end
local list = {}
- for uri in pairs(self._needCompile) do
- list[#list+1] = uri
+ for i, uri in ipairs(self._needCompile) do
+ list[i] = uri
end
local size = 0
@@ -167,6 +167,14 @@ function mt:read(mode)
return self._input(mode)
end
+function mt:needCompile(uri)
+ if self._needCompile[uri] then
+ return
+ end
+ self._needCompile[uri] = true
+ self._needCompile[#self._needCompile+1] = uri
+end
+
function mt:saveText(uri, version, text)
local obj = self._file[uri]
if obj then
@@ -175,13 +183,13 @@ function mt:saveText(uri, version, text)
end
obj.version = version
obj.text = text
- self._needCompile[uri] = true
+ self:needCompile(uri)
else
self._file[uri] = {
version = version,
text = text,
}
- self._needCompile[uri] = true
+ self:needCompile(uri)
end
end
@@ -198,7 +206,7 @@ function mt:readText(uri, path)
version = -1,
text = text,
}
- self._needCompile[uri] = true
+ self:needCompile(uri)
end
function mt:open(uri)
@@ -215,7 +223,7 @@ end
function mt:reCompile()
for uri in pairs(self._opening) do
- self._needCompile[uri] = true
+ self:needCompile(uri)
end
end
@@ -237,6 +245,11 @@ function mt:compileVM(uri)
return nil
end
self._needCompile[uri] = nil
+ for i, u in ipairs(self._needCompile) do
+ if u == uri then
+ table.remove(self._needCompile, i)
+ end
+ end
local ast = parser:ast(obj.text)
obj.vm = matcher.vm(ast, self, uri)
@@ -261,31 +274,6 @@ function mt:getVM(uri)
return obj.vm
end
-function mt:needRequires(uri)
- self._needRequire[uri] = true
-end
-
-function mt:_loadRequires()
- if not self.workspace then
- return
- end
- if not next(self._needRequire) then
- return
- end
- local copy = {}
- for uri in pairs(self._needRequire) do
- self._needRequire[uri] = nil
- copy[uri] = true
- end
- for uri in pairs(copy) do
- local obj = self._file[uri]
- if obj then
- obj.vm:loadRequires()
- self._needDiagnostics[uri] = true
- end
- end
-end
-
function mt:removeText(uri)
self._file[uri] = nil
end
@@ -302,9 +290,8 @@ function mt:onTick()
rpc:recieve(proto)
end
end
- self:_buildTextCache()
+ self:compileAll()
self:_doDiagnostic()
- self:_loadRequires()
if os.clock() - self._clock >= 600 then
self._clock = os.clock()
@@ -344,7 +331,6 @@ return function ()
local session = setmetatable({
_file = {},
_needCompile = {},
- _needRequire = {},
_needDiagnostics = {},
_opening = {},
_clock = -100,
diff --git a/server/test/crossfile/definition.lua b/server/test/crossfile/definition.lua
index 43261771..11d5f746 100644
--- a/server/test/crossfile/definition.lua
+++ b/server/test/crossfile/definition.lua
@@ -38,14 +38,14 @@ function TEST(data)
local sourceScript, sourceList = catch_target(data[2].content, '?')
local sourceUri = ws:uriEncode(fs.path(data[2].path))
- lsp:saveText(targetUri, 1, targetScript)
lsp:saveText(sourceUri, 1, sourceScript)
- ws:addFile(targetUri)
ws:addFile(sourceUri)
- lsp:compileVM(targetUri)
- lsp:compileVM(sourceUri)
+ lsp:saveText(targetUri, 1, targetScript)
+ ws:addFile(targetUri)
+ lsp:compileAll()
+ lsp:compileAll()
- local sourceVM = lsp:loadVM(sourceUri)
+ local sourceVM = lsp:getVM(sourceUri)
assert(sourceVM)
local sourcePos = (sourceList[1][1] + sourceList[1][2]) // 2
local positions = matcher.definition(sourceVM, sourcePos)
diff --git a/server/test/crossfile/hover.lua b/server/test/crossfile/hover.lua
index 8bcc9411..51ac85bd 100644
--- a/server/test/crossfile/hover.lua
+++ b/server/test/crossfile/hover.lua
@@ -135,3 +135,31 @@ TEST {
label = 'function (a: any, b: any)',
}
}
+
+TEST {
+ {
+ path = 'a.lua',
+ content = [[
+ local mt = {}
+ mt.__index = mt
+
+ function mt:add(a, b)
+ end
+
+ return function ()
+ return setmetatable({}, mt)
+ end
+ ]],
+ },
+ {
+ path = 'b.lua',
+ content = [[
+ local m = require 'a'
+ local obj = m()
+ obj:<?add?>()
+ ]]
+ },
+ hover = {
+ label = 'function mt:add(a: any, b: any)'
+ },
+}
diff --git a/server/test/hover/init.lua b/server/test/hover/init.lua
index 5d0cb068..445e72ba 100644
--- a/server/test/hover/init.lua
+++ b/server/test/hover/init.lua
@@ -198,3 +198,21 @@ end
function x()
-> any
]]
+
+TEST [[
+local mt = {}
+mt.__index = mt
+
+function mt:add(a, b)
+end
+
+local function init()
+ return setmetatable({}, mt)
+end
+
+local t = init()
+t:<?add?>()
+]]
+[[
+function mt:add(a: any, b: any)
+]]