diff options
-rw-r--r-- | script/core/diagnostics/undefined-doc-name.lua | 2 | ||||
-rw-r--r-- | script/core/semantic-tokens.lua | 7 | ||||
-rw-r--r-- | script/parser/guide.lua | 24 | ||||
-rw-r--r-- | script/vm/compiler.lua | 23 | ||||
-rw-r--r-- | script/vm/global.lua | 3 | ||||
-rw-r--r-- | test/type_inference/init.lua | 38 |
6 files changed, 85 insertions, 12 deletions
diff --git a/script/core/diagnostics/undefined-doc-name.lua b/script/core/diagnostics/undefined-doc-name.lua index bacd4288..3c8ed469 100644 --- a/script/core/diagnostics/undefined-doc-name.lua +++ b/script/core/diagnostics/undefined-doc-name.lua @@ -32,7 +32,7 @@ return function (uri, callback) return end local name = source[1] - if name == '...' or name == '_' then + if name == '...' or name == '_' or name == 'self' then return end if #vm.getDocSets(uri, name) > 0 diff --git a/script/core/semantic-tokens.lua b/script/core/semantic-tokens.lua index 14d6ddc8..7b9adad9 100644 --- a/script/core/semantic-tokens.lua +++ b/script/core/semantic-tokens.lua @@ -437,6 +437,13 @@ local Care = util.switch() type = define.TokenTypes.type, modifieres = define.TokenModifiers.modification, } + elseif source[1] == 'self' then + results[#results+1] = { + start = source.start, + finish = source.finish, + type = define.TokenTypes.type, + modifieres = define.TokenModifiers.readonly, + } else results[#results+1] = { start = source.start, diff --git a/script/parser/guide.lua b/script/parser/guide.lua index b22c55f0..f7dcc116 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -1307,7 +1307,6 @@ function m.isBlockType(source) return blockTypes[source.type] == true end - ---@param source parser.object ---@return parser.object? function m.getSelfNode(source) @@ -1331,18 +1330,23 @@ function m.getSelfNode(source) return getmethod.node end if args.type == 'funcargs' then - local func = args.parent - if func.type ~= 'function' then - return nil - end - local setmethod = func.parent - if setmethod.type ~= 'setmethod' then - return nil - end - return setmethod.node + return m.getFunctionSelfNode(args.parent) end return nil end +---@param func parser.object +---@return parser.object? +function m.getFunctionSelfNode(func) + if func.type ~= 'function' then + return nil + end + local parent = func.parent + if parent.type == 'setmethod' + or parent.type == 'setfield' then + return parent.node + end + return nil +end return m diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index dcd32e77..2e481edf 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1564,6 +1564,29 @@ local compilerSwitch = util.switch() : call(function (source) vm.setNode(source, source) end) + : case 'doc.type.name' + : call(function (source) + if source[1] == 'self' then + local state = guide.getDocState(source) + if state.type == 'doc.return' + or state.type == 'doc.param' then + local func = state.bindSource + if func.type == 'function' then + local node = guide.getFunctionSelfNode(func) + if node then + vm.setNode(source, vm.compileNode(node)) + return + end + end + elseif state.type == 'doc.field' then + local class = state.class + if class then + vm.setNode(source, vm.compileNode(class)) + return + end + end + end + end) : case 'doc.generic.name' : call(function (source) vm.setNode(source, source) diff --git a/script/vm/global.lua b/script/vm/global.lua index e94f9239..027b096a 100644 --- a/script/vm/global.lua +++ b/script/vm/global.lua @@ -404,6 +404,9 @@ local compilerGlobalSwitch = util.switch() if name == '_' then return end + if name == 'self' then + return + end local type = vm.declareGlobal('type', name, uri) type:addGet(uri, source) source._globalNode = type diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 7f8e8e4f..a1ee533d 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -21,7 +21,8 @@ local function getSource(pos) or source.type == 'field' or source.type == 'method' or source.type == 'function' - or source.type == 'table' then + or source.type == 'table' + or source.type == 'doc.type.name' then result = source end end) @@ -3891,3 +3892,38 @@ print(<?t?>) ]] config.set(nil, 'Lua.runtime.special', nil) + +TEST 'A' [[ +---@class A +local mt + +---@return <?self?> +function mt:init() +end +]] + +TEST 'A' [[ +---@class A +local mt + +---@return self +function mt:init() +end + +local <?o?> = mt:init() +]] + +TEST 'A' [[ +---@class A +---@field x <?self?> +]] + +TEST 'A' [[ +---@class A +---@field x self + +---@type A +local o + +print(o.<?x?>) +]] |