summaryrefslogtreecommitdiff
path: root/server-beta/src
diff options
context:
space:
mode:
Diffstat (limited to 'server-beta/src')
-rw-r--r--server-beta/src/core/definition.lua80
-rw-r--r--server-beta/src/core/engineer.lua3
-rw-r--r--server-beta/src/files.lua5
-rw-r--r--server-beta/src/parser/guide.lua17
-rw-r--r--server-beta/src/workspace/workspace.lua41
5 files changed, 130 insertions, 16 deletions
diff --git a/server-beta/src/core/definition.lua b/server-beta/src/core/definition.lua
index e0589752..2fe726e0 100644
--- a/server-beta/src/core/definition.lua
+++ b/server-beta/src/core/definition.lua
@@ -1,23 +1,73 @@
-local guide = require 'parser.guide'
-local engineer = require 'core.engineer'
+local guide = require 'parser.guide'
+local engineer = require 'core.engineer'
+local workspace = require 'workspace'
+
+local function findDef(searcher, source, callback)
+ searcher:eachDef(source, function (src)
+ if src.type == 'setfield'
+ or src.type == 'getfield'
+ or src.type == 'tablefield' then
+ callback(src.field)
+ elseif src.type == 'setindex'
+ or src.type == 'getindex'
+ or src.type == 'tableindex' then
+ callback(src.index)
+ elseif src.type == 'getmethod'
+ or src.type == 'setmethod' then
+ callback(src.method)
+ else
+ callback(src)
+ end
+ end)
+end
+
+---@param searcher engineer
+local function checkRequire(searcher, source, offset, callback)
+ if source.type ~= 'call' then
+ return
+ end
+ local func = source.node
+ local pathSource = source.args and source.args[1]
+ if not pathSource then
+ return
+ end
+ if not guide.isContain(pathSource, offset) then
+ return
+ end
+ local literal = guide.getLiteral(pathSource)
+ if type(literal) ~= 'string' then
+ return
+ end
+ local name = searcher:getSpecialName(func)
+ if name == 'require' then
+ local result = workspace.findUrisByRequirePath(literal, true)
+ for _, uri in ipairs(result) do
+ callback(uri)
+ end
+ elseif name == 'dofile'
+ or name == 'loadfile' then
+ local result = workspace.findUrisByFilePath(literal, true)
+ for _, uri in ipairs(result) do
+ callback(uri)
+ end
+ end
+end
return function (ast, offset)
local results = {}
local searcher = engineer(ast)
guide.eachSourceContain(ast.ast, offset, function (source)
- searcher:eachDef(source, function (src)
- if src.type == 'setfield'
- or src.type == 'getfield'
- or src.type == 'tablefield' then
- src = src.field
- elseif src.type == 'setindex'
- or src.type == 'getindex'
- or src.type == 'tableindex' then
- src = src.index
- elseif src.type == 'getmethod'
- or src.type == 'setmethod' then
- src = src.method
- end
+ checkRequire(searcher, source, offset, function (uri)
+ results[#results+1] = {
+ uri = uri,
+ source = source,
+ target = {
+ start = 0,
+ finish = 0,
+ }
+ }
+ end)
+ findDef(searcher, source, function (src)
results[#results+1] = {
uri = ast.uri,
source = source,
diff --git a/server-beta/src/core/engineer.lua b/server-beta/src/core/engineer.lua
index 656c29bf..a52bddc2 100644
--- a/server-beta/src/core/engineer.lua
+++ b/server-beta/src/core/engineer.lua
@@ -39,6 +39,9 @@ local specials = {
['rawset'] = true,
['rawget'] = true,
['setmetatable'] = true,
+ ['require'] = true,
+ ['dofile'] = true,
+ ['loadfile'] = true,
}
function mt:getSpecialName(source)
diff --git a/server-beta/src/files.lua b/server-beta/src/files.lua
index 5de2a8fc..1873637d 100644
--- a/server-beta/src/files.lua
+++ b/server-beta/src/files.lua
@@ -96,6 +96,11 @@ function m.removeAll()
end
end
+--- 遍历文件
+function m.eachFile()
+ return pairs(m.fileMap)
+end
+
--- 获取文件语法树
---@param uri string
---@return table ast
diff --git a/server-beta/src/parser/guide.lua b/server-beta/src/parser/guide.lua
index a8f22c6e..af160abb 100644
--- a/server-beta/src/parser/guide.lua
+++ b/server-beta/src/parser/guide.lua
@@ -73,6 +73,19 @@ function m.isLiteral(obj)
or tp == 'table'
end
+--- 获取字面量
+function m.getLiteral(obj)
+ local tp = obj.type
+ if tp == 'boolean' then
+ return obj[1]
+ elseif tp == 'string' then
+ return obj[1]
+ elseif tp == 'number' then
+ return obj[1]
+ end
+ return nil
+end
+
--- 寻找所在函数
function m.getParentFunction(obj)
for _ = 1, 1000 do
@@ -224,7 +237,9 @@ function m.isContain(source, offset)
return source.start <= offset and source.finish >= offset - 1
end
---- 判断offset在source的范围内
+--- 判断offset在source的影响范围内
+---
+--- 主要针对赋值等语句时,key包含value
function m.isInRange(source, offset)
return source.start <= offset and (source.range or source.finish) >= offset - 1
end
diff --git a/server-beta/src/workspace/workspace.lua b/server-beta/src/workspace/workspace.lua
index 7f2998be..6ee1be42 100644
--- a/server-beta/src/workspace/workspace.lua
+++ b/server-beta/src/workspace/workspace.lua
@@ -132,4 +132,45 @@ function m.preload()
log.info('Preload finish.')
end
+--- 查找符合指定file path的所有uri
+---@param path string
+---@param whole boolean
+function m.findUrisByFilePath(path, whole)
+ local results = {}
+ for uri in files.eachFile() do
+ local pathLen = #path
+ local uriLen = #uri
+ for i = uriLen, uriLen - pathLen + 1, -1 do
+ local see = uri:sub(i - pathLen + 1, i)
+ if files.eq(see, path) then
+ results[#results+1] = uri
+ end
+ if not whole then
+ break
+ end
+ end
+ end
+ return results
+end
+
+--- 查找符合指定require path的所有uri
+---@param path string
+---@param whole boolean
+function m.findUrisByRequirePath(path, whole)
+ local results = {}
+ local mark = {}
+ local input = path:gsub('%.', '/')
+ for _, luapath in ipairs(config.config.runtime.path) do
+ local part = luapath:gsub('%?', input)
+ local uris = m.findUrisByFilePath(part, whole)
+ for _, uri in ipairs(uris) do
+ if not mark[uri] then
+ mark[uri] = true
+ results[#results+1] = uri
+ end
+ end
+ end
+ return results
+end
+
return m