summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2022-10-11 21:08:12 +0800
committer最萌小汐 <sumneko@hotmail.com>2022-10-11 21:08:12 +0800
commit896a1affdeed8304688611674df88140a5cc180e (patch)
treeb9336e9c8d4ea52cf4b6f1aa54d341e1edad305f
parente1912470260bac17e7efc565cb59719beb810d01 (diff)
downloadlua-language-server-896a1affdeed8304688611674df88140a5cc180e.zip
infer definitions and types across chain exp
resolve #1561
-rw-r--r--changelog.md13
-rw-r--r--script/vm/compiler.lua79
-rw-r--r--script/vm/def.lua125
-rw-r--r--test/definition/luadoc.lua24
-rw-r--r--test/type_inference/init.lua12
5 files changed, 129 insertions, 124 deletions
diff --git a/changelog.md b/changelog.md
index 92967c83..d3dec956 100644
--- a/changelog.md
+++ b/changelog.md
@@ -17,6 +17,18 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
}
```
* `CHG` [#1177] re-support for symlinks, users need to maintain the correctness of symlinks themselves
+* `CHG` [#1561] infer definitions and types across chain expression
+ ```lua
+ ---@class myClass
+ local myClass = {}
+
+ myClass.a.b.c.e.f.g = 1
+
+ ---@type myClass
+ local class
+
+ print(class.a.b.c.e.f.g) --> infered as integer
+ ```
* `FIX` [#1567]
* `FIX` [#1593]
* `FIX` [#1606]
@@ -26,6 +38,7 @@ server will generate `doc.json` and `doc.md` in `LOGPATH`.
[#1458]: https://github.com/sumneko/lua-language-server/issues/1458
[#1557]: https://github.com/sumneko/lua-language-server/issues/1557
[#1558]: https://github.com/sumneko/lua-language-server/issues/1558
+[#1561]: https://github.com/sumneko/lua-language-server/issues/1561
[#1567]: https://github.com/sumneko/lua-language-server/issues/1567
[#1593]: https://github.com/sumneko/lua-language-server/issues/1593
[#1606]: https://github.com/sumneko/lua-language-server/issues/1606
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 05ad0634..61838775 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -1818,6 +1818,84 @@ local function compileByGlobal(source)
end
end
+local nodeSwitch;nodeSwitch = util.switch()
+ : case 'field'
+ : case 'method'
+ : call(function (source, lastKey, pushResult)
+ return nodeSwitch(source.parent.type, source.parent, lastKey, pushResult)
+ end)
+ : case 'getfield'
+ : case 'setfield'
+ : case 'getmethod'
+ : case 'setmethod'
+ : case 'getindex'
+ : case 'setindex'
+ : call(function (source, lastKey, pushResult)
+ local parentNode = vm.compileNode(source.node)
+ local uri = guide.getUri(source)
+ local key = guide.getKeyName(source)
+ if type(key) ~= 'string' then
+ return
+ end
+ if lastKey then
+ key = key .. vm.ID_SPLITE .. lastKey
+ end
+ for pn in parentNode:eachObject() do
+ searchFieldSwitch(pn.type, uri, pn, key, false, pushResult)
+ end
+ return key, source.node
+ end)
+ : case 'tableindex'
+ : case 'tablefield'
+ : call(function (source, lastKey, pushResult)
+ if lastKey then
+ return
+ end
+ local key = guide.getKeyName(source)
+ if type(key) ~= 'string' then
+ return
+ end
+ local uri = guide.getUri(source)
+ local parentNode = vm.compileNode(source.node)
+ for pn in parentNode:eachObject() do
+ searchFieldSwitch(pn.type, uri, pn, key, false, pushResult)
+ end
+ end)
+ : case 'doc.see.field'
+ : call(function (source, lastKey, pushResult)
+ if lastKey then
+ return
+ end
+ local parentNode = vm.compileNode(source.parent.name)
+ local uri = guide.getUri(source)
+ for pn in parentNode:eachObject() do
+ searchFieldSwitch(pn.type, uri, pn, source[1], false, pushResult)
+ end
+ end)
+
+function vm.compileByNodeChain(source, pushResult)
+ local lastKey
+ local src = source
+ while true do
+ local key, node = nodeSwitch(src.type, src, lastKey, pushResult)
+ if not key then
+ break
+ end
+ src = node
+ lastKey = key
+ end
+end
+
+---@param source vm.object
+local function compileByParentNode(source)
+ if not vm.getNode(source):isEmpty() then
+ return
+ end
+ vm.compileByNodeChain(source, function (result)
+ vm.setNode(source, vm.compileNode(result))
+ end)
+end
+
---@param source vm.object
---@return vm.node
function vm.compileNode(source)
@@ -1846,6 +1924,7 @@ function vm.compileNode(source)
LOCK[source] = true
compileByGlobal(source)
compileByNode(source)
+ compileByParentNode(source)
matchCall(source)
LOCK[source] = nil
diff --git a/script/vm/def.lua b/script/vm/def.lua
index f557f221..7ce8ad7a 100644
--- a/script/vm/def.lua
+++ b/script/vm/def.lua
@@ -20,110 +20,6 @@ simpleSwitch = util.switch()
end
end)
-local searchFieldSwitch = util.switch()
- : case 'table'
- : call(function (suri, obj, key, pushResult)
- for _, field in ipairs(obj) do
- if field.type == 'tablefield'
- or field.type == 'tableindex' then
- if guide.getKeyName(field) == key then
- pushResult(field)
- end
- end
- end
- end)
- : case 'global'
- ---@param obj vm.global
- ---@param key string
- : call(function (suri, obj, key, pushResult)
- if obj.cate == 'variable' then
- local newGlobal = vm.getGlobal('variable', obj.name, key)
- if newGlobal then
- for _, set in ipairs(newGlobal:getSets(suri)) do
- pushResult(set)
- end
- end
- end
- if obj.cate == 'type' then
- vm.getClassFields(suri, obj, key, false, pushResult)
- end
- end)
- : case 'local'
- : call(function (suri, obj, key, pushResult)
- local sources = vm.getLocalSourcesSets(obj, key)
- if sources then
- for _, src in ipairs(sources) do
- pushResult(src)
- end
- end
- end)
- : case 'doc.type.table'
- : call(function (suri, obj, key, pushResult)
- for _, field in ipairs(obj.fields) do
- local fieldKey = field.name
- if fieldKey.type == 'doc.field.name' then
- if fieldKey[1] == key then
- pushResult(field)
- end
- end
- end
- end)
-
-local nodeSwitch;nodeSwitch = util.switch()
- : case 'field'
- : case 'method'
- : call(function (source, lastKey, pushResult)
- return nodeSwitch(source.parent.type, source.parent, lastKey, pushResult)
- end)
- : case 'getfield'
- : case 'setfield'
- : case 'getmethod'
- : case 'setmethod'
- : case 'getindex'
- : case 'setindex'
- : call(function (source, lastKey, pushResult)
- local parentNode = vm.compileNode(source.node)
- local uri = guide.getUri(source)
- local key = guide.getKeyName(source)
- if type(key) ~= 'string' then
- return
- end
- if lastKey then
- key = key .. vm.ID_SPLITE .. lastKey
- end
- for pn in parentNode:eachObject() do
- searchFieldSwitch(pn.type, uri, pn, key, pushResult)
- end
- return key, source.node
- end)
- : case 'tableindex'
- : case 'tablefield'
- : call(function (source, lastKey, pushResult)
- if lastKey then
- return
- end
- local key = guide.getKeyName(source)
- if type(key) ~= 'string' then
- return
- end
- local uri = guide.getUri(source)
- local parentNode = vm.compileNode(source.node)
- for pn in parentNode:eachObject() do
- searchFieldSwitch(pn.type, uri, pn, key, pushResult)
- end
- end)
- : case 'doc.see.field'
- : call(function (source, lastKey, pushResult)
- if lastKey then
- return
- end
- local parentNode = vm.compileNode(source.parent.name)
- local uri = guide.getUri(source)
- for pn in parentNode:eachObject() do
- searchFieldSwitch(pn.type, uri, pn, source[1], pushResult)
- end
- end)
-
---@param source parser.object
---@param pushResult fun(src: parser.object)
local function searchBySimple(source, pushResult)
@@ -142,25 +38,6 @@ local function searchByLocalID(source, pushResult)
end
end
----@param source parser.object
----@param pushResult fun(src: parser.object)
-local function searchByParentNode(source, pushResult)
- local lastKey
- local src = source
- while true do
- local key, node = nodeSwitch(src.type, src, lastKey, pushResult)
- if not key then
- break
- end
- src = node
- if lastKey then
- lastKey = key .. vm.ID_SPLITE .. lastKey
- else
- lastKey = key
- end
- end
-end
-
local function searchByNode(source, pushResult)
local node = vm.compileNode(source)
local suri = guide.getUri(source)
@@ -200,7 +77,7 @@ function vm.getDefs(source)
searchBySimple(source, pushResult)
searchByLocalID(source, pushResult)
- searchByParentNode(source, pushResult)
+ vm.compileByNodeChain(source, pushResult)
searchByNode(source, pushResult)
return results
diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua
index 2da10f93..c14e1242 100644
--- a/test/definition/luadoc.lua
+++ b/test/definition/luadoc.lua
@@ -932,4 +932,28 @@ local <!b!>
b.<!<?__index?>!> = b
]]
+TEST [[
+---@class myClass
+local myClass = { nested = {} }
+
+function myClass.nested.<!fn!>() end
+
+---@type myClass
+local class
+
+class.nested.<?fn?>()
+]]
+
+TEST [[
+---@class myClass
+local myClass = { has = { nested = {} } }
+
+function myClass.has.nested.<!fn!>() end
+
+---@type myClass
+local class
+
+class.has.nested.<?fn?>()
+]]
+
config.set(nil, 'Lua.type.castNumberToInteger', true)
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 6f74e1af..f0948b49 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -3791,3 +3791,15 @@ TEST 'A|B' [[
---@type A|B
local <?t?>
]]
+
+TEST 'function' [[
+---@class myClass
+local myClass = { has = { nested = {} } }
+
+function myClass.has.nested.fn() end
+
+---@type myClass
+local class
+
+class.has.nested.<?fn?>()
+]]