summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/implementation.lua167
-rw-r--r--script/core/type-definition.lua3
-rw-r--r--script/provider/provider.lua108
-rw-r--r--script/vm/function.lua2
-rw-r--r--test.lua1
-rw-r--r--test/implementation/init.lua81
6 files changed, 311 insertions, 51 deletions
diff --git a/script/core/implementation.lua b/script/core/implementation.lua
new file mode 100644
index 00000000..da4d4c8b
--- /dev/null
+++ b/script/core/implementation.lua
@@ -0,0 +1,167 @@
+local workspace = require 'workspace'
+local files = require 'files'
+local vm = require 'vm'
+local findSource = require 'core.find-source'
+local guide = require 'parser.guide'
+local rpath = require 'workspace.require-path'
+local jumpSource = require 'core.jump-source'
+
+local function sortResults(results)
+ -- 先按照顺序排序
+ table.sort(results, function (a, b)
+ local u1 = guide.getUri(a.target)
+ local u2 = guide.getUri(b.target)
+ if u1 == u2 then
+ return a.target.start < b.target.start
+ else
+ return u1 < u2
+ end
+ end)
+ -- 如果2个结果处于嵌套状态,则取范围小的那个
+ local lf, lu
+ for i = #results, 1, -1 do
+ local res = results[i].target
+ local f = res.finish
+ local uri = guide.getUri(res)
+ if lf and f > lf and uri == lu then
+ table.remove(results, i)
+ else
+ lu = uri
+ lf = f
+ end
+ end
+end
+
+local accept = {
+ ['local'] = true,
+ ['setlocal'] = true,
+ ['getlocal'] = true,
+ ['label'] = true,
+ ['goto'] = true,
+ ['field'] = true,
+ ['method'] = true,
+ ['setglobal'] = true,
+ ['getglobal'] = true,
+ ['string'] = true,
+ ['boolean'] = true,
+ ['number'] = true,
+ ['integer'] = true,
+ ['...'] = true,
+
+ ['doc.type.name'] = true,
+ ['doc.class.name'] = true,
+ ['doc.extends.name'] = true,
+ ['doc.alias.name'] = true,
+ ['doc.cast.name'] = true,
+ ['doc.enum.name'] = true,
+ ['doc.field.name'] = true,
+}
+
+local function convertIndex(source)
+ if not source then
+ return
+ end
+ if source.type == 'string'
+ or source.type == 'boolean'
+ or source.type == 'number'
+ or source.type == 'integer' then
+ local parent = source.parent
+ if not parent then
+ return
+ end
+ if parent.type == 'setindex'
+ or parent.type == 'getindex'
+ or parent.type == 'tableindex' then
+ return parent
+ end
+ end
+ return source
+end
+
+return function (uri, offset)
+ local ast = files.getState(uri)
+ if not ast then
+ return nil
+ end
+
+ local source = convertIndex(findSource(ast, offset, accept))
+ if not source then
+ return nil
+ end
+
+ local results = {}
+
+ local defs = vm.getDefs(source)
+
+ for _, src in ipairs(defs) do
+ if src.type == 'global' then
+ goto CONTINUE
+ end
+ local root = guide.getRoot(src)
+ if not root then
+ goto CONTINUE
+ end
+ if src.type == 'self' then
+ goto CONTINUE
+ end
+ src = src.field or src.method or src
+ if src.type == 'getindex'
+ or src.type == 'setindex'
+ or src.type == 'tableindex' then
+ src = src.index
+ if not src then
+ goto CONTINUE
+ end
+ if not guide.isLiteral(src) then
+ goto CONTINUE
+ end
+ end
+ if src.type == 'doc.type.function'
+ or src.type == 'doc.type.table'
+ or src.type == 'doc.type.boolean'
+ or src.type == 'doc.type.integer'
+ or src.type == 'doc.type.string' then
+ goto CONTINUE
+ end
+ if src.type == 'doc.class' then
+ goto CONTINUE
+ end
+ if src.type == 'doc.alias' then
+ goto CONTINUE
+ end
+ if src.type == 'doc.enum' then
+ goto CONTINUE
+ end
+ if src.type == 'doc.type.field' then
+ goto CONTINUE
+ end
+ if src.type == 'doc.class.name'
+ or src.type == 'doc.alias.name'
+ or src.type == 'doc.enum.name'
+ or src.type == 'doc.field.name' then
+ goto CONTINUE
+ end
+ if src.type == 'doc.generic.name' then
+ goto CONTINUE
+ end
+ if src.type == 'doc.param' then
+ goto CONTINUE
+ end
+
+ results[#results+1] = {
+ target = src,
+ uri = root.uri,
+ source = source,
+ }
+ ::CONTINUE::
+ end
+
+ if #results == 0 then
+ return nil
+ end
+
+ sortResults(results)
+ jumpSource(results)
+
+ return results
+end
diff --git a/script/core/type-definition.lua b/script/core/type-definition.lua
index 0a821f25..d9939eb0 100644
--- a/script/core/type-definition.lua
+++ b/script/core/type-definition.lua
@@ -52,8 +52,9 @@ local accept = {
['doc.class.name'] = true,
['doc.extends.name'] = true,
['doc.alias.name'] = true,
+ ['doc.cast.name'] = true,
['doc.enum.name'] = true,
- ['doc.see.name'] = true,
+ ['doc.field.name'] = true,
}
local function checkRequire(source, offset)
diff --git a/script/provider/provider.lua b/script/provider/provider.lua
index 69fb3263..15e78b9a 100644
--- a/script/provider/provider.lua
+++ b/script/provider/provider.lua
@@ -368,6 +368,39 @@ m.register 'textDocument/hover' {
end
}
+local function convertDefinitionResult(state, result)
+ local response = {}
+ for i, info in ipairs(result) do
+ ---@type uri
+ local targetUri = info.uri
+ if targetUri then
+ local targetState = files.getState(targetUri)
+ if targetState then
+ if client.getAbility 'textDocument.definition.linkSupport' then
+ response[i] = converter.locationLink(targetUri
+ , converter.packRange(targetState, info.target.start, info.target.finish)
+ , converter.packRange(targetState, info.target.start, info.target.finish)
+ , converter.packRange(state, info.source.start, info.source.finish)
+ )
+ else
+ response[i] = converter.location(targetUri
+ , converter.packRange(targetState, info.target.start, info.target.finish)
+ )
+ end
+ else
+ response[i] = converter.location(
+ targetUri,
+ converter.range(
+ converter.position(guide.rowColOf(info.target.start)),
+ converter.position(guide.rowColOf(info.target.finish))
+ )
+ )
+ end
+ end
+ end
+ return response
+end
+
m.register 'textDocument/definition' {
capability = {
definitionProvider = true,
@@ -388,35 +421,7 @@ m.register 'textDocument/definition' {
if not result then
return nil
end
- local response = {}
- for i, info in ipairs(result) do
- ---@type uri
- local targetUri = info.uri
- if targetUri then
- local targetState = files.getState(targetUri)
- if targetState then
- if client.getAbility 'textDocument.definition.linkSupport' then
- response[i] = converter.locationLink(targetUri
- , converter.packRange(targetState, info.target.start, info.target.finish)
- , converter.packRange(targetState, info.target.start, info.target.finish)
- , converter.packRange(state, info.source.start, info.source.finish)
- )
- else
- response[i] = converter.location(targetUri
- , converter.packRange(targetState, info.target.start, info.target.finish)
- )
- end
- else
- response[i] = converter.location(
- targetUri,
- converter.range(
- converter.position(guide.rowColOf(info.target.start)),
- converter.position(guide.rowColOf(info.target.finish))
- )
- )
- end
- end
- end
+ local response = convertDefinitionResult(state, result)
return response
end
}
@@ -441,27 +446,32 @@ m.register 'textDocument/typeDefinition' {
if not result then
return nil
end
- local response = {}
- for i, info in ipairs(result) do
- ---@type uri
- local targetUri = info.uri
- if targetUri then
- local targetState = files.getState(targetUri)
- if targetState then
- if client.getAbility 'textDocument.typeDefinition.linkSupport' then
- response[i] = converter.locationLink(targetUri
- , converter.packRange(targetState, info.target.start, info.target.finish)
- , converter.packRange(targetState, info.target.start, info.target.finish)
- , converter.packRange(state, info.source.start, info.source.finish)
- )
- else
- response[i] = converter.location(targetUri
- , converter.packRange(targetState, info.target.start, info.target.finish)
- )
- end
- end
- end
+ local response = convertDefinitionResult(state, result)
+ return response
+ end
+}
+
+m.register 'textDocument/implementation' {
+ capability = {
+ implementationProvider = true,
+ },
+ abortByFileUpdate = true,
+ ---@async
+ function (params)
+ local uri = files.getRealUri(params.textDocument.uri)
+ workspace.awaitReady(uri)
+ local _ <close> = progress.create(uri, lang.script.WINDOW_PROCESSING_TYPE_DEFINITION, 0.5)
+ local state = files.getState(uri)
+ if not state then
+ return
+ end
+ local core = require 'core.implementation'
+ local pos = converter.unpackPosition(state, params.position)
+ local result = core(uri, pos)
+ if not result then
+ return nil
end
+ local response = convertDefinitionResult(state, result)
return response
end
}
diff --git a/script/vm/function.lua b/script/vm/function.lua
index c6df6349..dde8ecb2 100644
--- a/script/vm/function.lua
+++ b/script/vm/function.lua
@@ -337,7 +337,7 @@ function vm.getMatchedFunctions(func, args, mark)
local funcs = {}
local node = vm.compileNode(func)
for n in node:eachObject() do
- if (n.type == 'function' and not vm.isVarargFunctionWithOverloads(n))
+ if n.type == 'function'
or n.type == 'doc.type.function' then
funcs[#funcs+1] = n
end
diff --git a/test.lua b/test.lua
index 6e3a4290..aad95270 100644
--- a/test.lua
+++ b/test.lua
@@ -55,6 +55,7 @@ local function testAll()
test 'basic'
test 'definition'
test 'type_inference'
+ test 'implementation'
test 'references'
test 'hover'
test 'completion'
diff --git a/test/implementation/init.lua b/test/implementation/init.lua
new file mode 100644
index 00000000..678cb23b
--- /dev/null
+++ b/test/implementation/init.lua
@@ -0,0 +1,81 @@
+local core = require 'core.implementation'
+local files = require 'files'
+local vm = require 'vm'
+local catch = require 'catch'
+
+rawset(_G, 'TEST', true)
+
+local function founded(targets, results)
+ if #targets ~= #results then
+ return false
+ end
+ for _, target in ipairs(targets) do
+ for _, result in ipairs(results) do
+ if target[1] == result[1] and target[2] == result[2] then
+ goto NEXT
+ end
+ end
+ do return false end
+ ::NEXT::
+ end
+ return true
+end
+
+---@async
+function TEST(script)
+ local newScript, catched = catch(script, '!?')
+
+ files.setText(TESTURI, newScript)
+
+ local results = core(TESTURI, catched['?'][1][1])
+ if results then
+ local positions = {}
+ for i, result in ipairs(results) do
+ if not vm.isMetaFile(result.uri) then
+ positions[#positions+1] = { result.target.start, result.target.finish }
+ end
+ end
+ assert(founded(catched['!'], positions))
+ else assert(#catched['!'] == 0)
+ end
+
+ files.remove(TESTURI)
+end
+
+TEST [[
+---@class A
+---@field x number
+local M
+
+M.<!x!> = 1
+
+
+print(M.<?x?>)
+]]
+
+TEST [[
+---@class A
+---@field f fun()
+local M
+
+function M.<!f!>() end
+
+
+print(M.<?f?>)
+]]
+
+TEST [[
+---@class A
+local M
+
+function M:<!event!>(name) end
+
+---@class A
+---@field event fun(self, name: 'ev1')
+---@field event fun(self, name: 'ev2')
+
+---@type A
+local m
+
+m:<?event?>('ev1')
+]]