summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/diagnostics/undefined-doc-name.lua2
-rw-r--r--script/core/semantic-tokens.lua7
-rw-r--r--script/parser/guide.lua24
-rw-r--r--script/vm/compiler.lua23
-rw-r--r--script/vm/global.lua3
-rw-r--r--test/type_inference/init.lua38
6 files changed, 85 insertions, 12 deletions
diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua
index bacd4288..3c8ed469 100644
--- a/script/core/diagnostics/undefined-doc-name.lua
+++ b/script/core/diagnostics/undefined-doc-name.lua
@@ -32,7 +32,7 @@ return function (uri, callback)
return
end
local name = source[1]
- if name == '...' or name == '_' then
+ if name == '...' or name == '_' or name == 'self' then
return
end
if #vm.getDocSets(uri, name) > 0
diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua
index 14d6ddc8..7b9adad9 100644
--- a/script/core/semantic-tokens.lua
+++ b/script/core/semantic-tokens.lua
@@ -437,6 +437,13 @@ local Care = util.switch()
type = define.TokenTypes.type,
modifieres = define.TokenModifiers.modification,
}
+ elseif source[1] == 'self' then
+ results[#results+1] = {
+ start = source.start,
+ finish = source.finish,
+ type = define.TokenTypes.type,
+ modifieres = define.TokenModifiers.readonly,
+ }
else
results[#results+1] = {
start = source.start,
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index b22c55f0..f7dcc116 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -1307,7 +1307,6 @@ function m.isBlockType(source)
return blockTypes[source.type] == true
end
-
---@param source parser.object
---@return parser.object?
function m.getSelfNode(source)
@@ -1331,18 +1330,23 @@ function m.getSelfNode(source)
return getmethod.node
end
if args.type == 'funcargs' then
- local func = args.parent
- if func.type ~= 'function' then
- return nil
- end
- local setmethod = func.parent
- if setmethod.type ~= 'setmethod' then
- return nil
- end
- return setmethod.node
+ return m.getFunctionSelfNode(args.parent)
end
return nil
end
+---@param func parser.object
+---@return parser.object?
+function m.getFunctionSelfNode(func)
+ if func.type ~= 'function' then
+ return nil
+ end
+ local parent = func.parent
+ if parent.type == 'setmethod'
+ or parent.type == 'setfield' then
+ return parent.node
+ end
+ return nil
+end
return m
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index dcd32e77..2e481edf 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1564,6 +1564,29 @@ local compilerSwitch = util.switch()
: call(function (source)
vm.setNode(source, source)
end)
+ : case 'doc.type.name'
+ : call(function (source)
+ if source[1] == 'self' then
+ local state = guide.getDocState(source)
+ if state.type == 'doc.return'
+ or state.type == 'doc.param' then
+ local func = state.bindSource
+ if func.type == 'function' then
+ local node = guide.getFunctionSelfNode(func)
+ if node then
+ vm.setNode(source, vm.compileNode(node))
+ return
+ end
+ end
+ elseif state.type == 'doc.field' then
+ local class = state.class
+ if class then
+ vm.setNode(source, vm.compileNode(class))
+ return
+ end
+ end
+ end
+ end)
: case 'doc.generic.name'
: call(function (source)
vm.setNode(source, source)
diff --git a/script/vm/global.lua b/script/vm/global.lua
index e94f9239..027b096a 100644
--- a/script/vm/global.lua
+++ b/script/vm/global.lua
@@ -404,6 +404,9 @@ local compilerGlobalSwitch = util.switch()
if name == '_' then
return
end
+ if name == 'self' then
+ return
+ end
local type = vm.declareGlobal('type', name, uri)
type:addGet(uri, source)
source._globalNode = type
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 7f8e8e4f..a1ee533d 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -21,7 +21,8 @@ local function getSource(pos)
or source.type == 'field'
or source.type == 'method'
or source.type == 'function'
- or source.type == 'table' then
+ or source.type == 'table'
+ or source.type == 'doc.type.name' then
result = source
end
end)
@@ -3891,3 +3892,38 @@ print(<?t?>)
]]
config.set(nil, 'Lua.runtime.special', nil)
+
+TEST 'A' [[
+---@class A
+local mt
+
+---@return <?self?>
+function mt:init()
+end
+]]
+
+TEST 'A' [[
+---@class A
+local mt
+
+---@return self
+function mt:init()
+end
+
+local <?o?> = mt:init()
+]]
+
+TEST 'A' [[
+---@class A
+---@field x <?self?>
+]]
+
+TEST 'A' [[
+---@class A
+---@field x self
+
+---@type A
+local o
+
+print(o.<?x?>)
+]]