diff options
-rw-r--r-- | changelog.md | 13 | ||||
-rw-r--r-- | script/vm/compiler.lua | 79 | ||||
-rw-r--r-- | script/vm/def.lua | 125 | ||||
-rw-r--r-- | test/definition/luadoc.lua | 24 | ||||
-rw-r--r-- | test/type_inference/init.lua | 12 |
5 files changed, 129 insertions, 124 deletions
diff --git a/changelog.md b/changelog.md index 92967c83..d3dec956 100644 --- a/changelog.md +++ b/changelog.md @@ -17,6 +17,18 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`. } ``` * `CHG` [#1177] re-support for symlinks, users need to maintain the correctness of symlinks themselves +* `CHG` [#1561] infer definitions and types across chain expression + ```lua + ---@class myClass + local myClass = {} + + myClass.a.b.c.e.f.g = 1 + + ---@type myClass + local class + + print(class.a.b.c.e.f.g) --> infered as integer + ``` * `FIX` [#1567] * `FIX` [#1593] * `FIX` [#1606] @@ -26,6 +38,7 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`. [#1458]: https://github.com/sumneko/lua-language-server/issues/1458 [#1557]: https://github.com/sumneko/lua-language-server/issues/1557 [#1558]: https://github.com/sumneko/lua-language-server/issues/1558 +[#1561]: https://github.com/sumneko/lua-language-server/issues/1561 [#1567]: https://github.com/sumneko/lua-language-server/issues/1567 [#1593]: https://github.com/sumneko/lua-language-server/issues/1593 [#1606]: https://github.com/sumneko/lua-language-server/issues/1606 diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 05ad0634..61838775 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1818,6 +1818,84 @@ local function compileByGlobal(source) end end +local nodeSwitch;nodeSwitch = util.switch() + : case 'field' + : case 'method' + : call(function (source, lastKey, pushResult) + return nodeSwitch(source.parent.type, source.parent, lastKey, pushResult) + end) + : case 'getfield' + : case 'setfield' + : case 'getmethod' + : case 'setmethod' + : case 'getindex' + : case 'setindex' + : call(function (source, lastKey, pushResult) + local parentNode = vm.compileNode(source.node) + local uri = guide.getUri(source) + local key = guide.getKeyName(source) + if type(key) ~= 'string' then + return + end + if lastKey then + key = key .. vm.ID_SPLITE .. lastKey + end + for pn in parentNode:eachObject() do + searchFieldSwitch(pn.type, uri, pn, key, false, pushResult) + end + return key, source.node + end) + : case 'tableindex' + : case 'tablefield' + : call(function (source, lastKey, pushResult) + if lastKey then + return + end + local key = guide.getKeyName(source) + if type(key) ~= 'string' then + return + end + local uri = guide.getUri(source) + local parentNode = vm.compileNode(source.node) + for pn in parentNode:eachObject() do + searchFieldSwitch(pn.type, uri, pn, key, false, pushResult) + end + end) + : case 'doc.see.field' + : call(function (source, lastKey, pushResult) + if lastKey then + return + end + local parentNode = vm.compileNode(source.parent.name) + local uri = guide.getUri(source) + for pn in parentNode:eachObject() do + searchFieldSwitch(pn.type, uri, pn, source[1], false, pushResult) + end + end) + +function vm.compileByNodeChain(source, pushResult) + local lastKey + local src = source + while true do + local key, node = nodeSwitch(src.type, src, lastKey, pushResult) + if not key then + break + end + src = node + lastKey = key + end +end + +---@param source vm.object +local function compileByParentNode(source) + if not vm.getNode(source):isEmpty() then + return + end + vm.compileByNodeChain(source, function (result) + vm.setNode(source, vm.compileNode(result)) + end) +end + ---@param source vm.object ---@return vm.node function vm.compileNode(source) @@ -1846,6 +1924,7 @@ function vm.compileNode(source) LOCK[source] = true compileByGlobal(source) compileByNode(source) + compileByParentNode(source) matchCall(source) LOCK[source] = nil diff --git a/script/vm/def.lua b/script/vm/def.lua index f557f221..7ce8ad7a 100644 --- a/script/vm/def.lua +++ b/script/vm/def.lua @@ -20,110 +20,6 @@ simpleSwitch = util.switch() end end) -local searchFieldSwitch = util.switch() - : case 'table' - : call(function (suri, obj, key, pushResult) - for _, field in ipairs(obj) do - if field.type == 'tablefield' - or field.type == 'tableindex' then - if guide.getKeyName(field) == key then - pushResult(field) - end - end - end - end) - : case 'global' - ---@param obj vm.global - ---@param key string - : call(function (suri, obj, key, pushResult) - if obj.cate == 'variable' then - local newGlobal = vm.getGlobal('variable', obj.name, key) - if newGlobal then - for _, set in ipairs(newGlobal:getSets(suri)) do - pushResult(set) - end - end - end - if obj.cate == 'type' then - vm.getClassFields(suri, obj, key, false, pushResult) - end - end) - : case 'local' - : call(function (suri, obj, key, pushResult) - local sources = vm.getLocalSourcesSets(obj, key) - if sources then - for _, src in ipairs(sources) do - pushResult(src) - end - end - end) - : case 'doc.type.table' - : call(function (suri, obj, key, pushResult) - for _, field in ipairs(obj.fields) do - local fieldKey = field.name - if fieldKey.type == 'doc.field.name' then - if fieldKey[1] == key then - pushResult(field) - end - end - end - end) - -local nodeSwitch;nodeSwitch = util.switch() - : case 'field' - : case 'method' - : call(function (source, lastKey, pushResult) - return nodeSwitch(source.parent.type, source.parent, lastKey, pushResult) - end) - : case 'getfield' - : case 'setfield' - : case 'getmethod' - : case 'setmethod' - : case 'getindex' - : case 'setindex' - : call(function (source, lastKey, pushResult) - local parentNode = vm.compileNode(source.node) - local uri = guide.getUri(source) - local key = guide.getKeyName(source) - if type(key) ~= 'string' then - return - end - if lastKey then - key = key .. vm.ID_SPLITE .. lastKey - end - for pn in parentNode:eachObject() do - searchFieldSwitch(pn.type, uri, pn, key, pushResult) - end - return key, source.node - end) - : case 'tableindex' - : case 'tablefield' - : call(function (source, lastKey, pushResult) - if lastKey then - return - end - local key = guide.getKeyName(source) - if type(key) ~= 'string' then - return - end - local uri = guide.getUri(source) - local parentNode = vm.compileNode(source.node) - for pn in parentNode:eachObject() do - searchFieldSwitch(pn.type, uri, pn, key, pushResult) - end - end) - : case 'doc.see.field' - : call(function (source, lastKey, pushResult) - if lastKey then - return - end - local parentNode = vm.compileNode(source.parent.name) - local uri = guide.getUri(source) - for pn in parentNode:eachObject() do - searchFieldSwitch(pn.type, uri, pn, source[1], pushResult) - end - end) - ---@param source parser.object ---@param pushResult fun(src: parser.object) local function searchBySimple(source, pushResult) @@ -142,25 +38,6 @@ local function searchByLocalID(source, pushResult) end end ----@param source parser.object ----@param pushResult fun(src: parser.object) -local function searchByParentNode(source, pushResult) - local lastKey - local src = source - while true do - local key, node = nodeSwitch(src.type, src, lastKey, pushResult) - if not key then - break - end - src = node - if lastKey then - lastKey = key .. vm.ID_SPLITE .. lastKey - else - lastKey = key - end - end -end - local function searchByNode(source, pushResult) local node = vm.compileNode(source) local suri = guide.getUri(source) @@ -200,7 +77,7 @@ function vm.getDefs(source) searchBySimple(source, pushResult) searchByLocalID(source, pushResult) - searchByParentNode(source, pushResult) + vm.compileByNodeChain(source, pushResult) searchByNode(source, pushResult) return results diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua index 2da10f93..c14e1242 100644 --- a/test/definition/luadoc.lua +++ b/test/definition/luadoc.lua @@ -932,4 +932,28 @@ local <!b!> b.<!<?__index?>!> = b ]] +TEST [[ +---@class myClass +local myClass = { nested = {} } + +function myClass.nested.<!fn!>() end + +---@type myClass +local class + +class.nested.<?fn?>() +]] + +TEST [[ +---@class myClass +local myClass = { has = { nested = {} } } + +function myClass.has.nested.<!fn!>() end + +---@type myClass +local class + +class.has.nested.<?fn?>() +]] + config.set(nil, 'Lua.type.castNumberToInteger', true) diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 6f74e1af..f0948b49 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -3791,3 +3791,15 @@ TEST 'A|B' [[ ---@type A|B local <?t?> ]] + +TEST 'function' [[ +---@class myClass +local myClass = { has = { nested = {} } } + +function myClass.has.nested.fn() end + +---@type myClass +local class + +class.has.nested.<?fn?>() +]] |