summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2021-02-26 15:15:05 +0800
committer最萌小汐 <sumneko@hotmail.com>2021-02-26 15:15:05 +0800
commitcbc8e42f84fdb8468d2f31797fc8a2d6344c53d4 (patch)
treec8a65bb1760fd97babcfccfadf9f3cf07b661d63
parent3227cfa5ea12cbc471cb656606379854eb12f375 (diff)
downloadlua-language-server-cbc8e42f84fdb8468d2f31797fc8a2d6344c53d4.zip
improve generic across `fun(arg: T):T`
-rw-r--r--script/parser/guide.lua43
-rw-r--r--script/parser/luadoc.lua3
-rw-r--r--test/type_inference/init.lua8
3 files changed, 40 insertions, 14 deletions
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index c657ef4b..91acb457 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -1687,26 +1687,43 @@ local function stepRefOfGeneric(status, typeUnit, args, mode)
if typeName == typeUnit then
goto CONTINUE
end
- local doc = m.getDocState(typeName)
- if doc.type ~= 'doc.param' then
+ local docArg = m.getParentType(typeName, 'doc.type.arg')
+ or m.getParentType(typeName, 'doc.param')
+ if not docArg then
goto CONTINUE
end
+ local doc = m.getDocState(docArg)
if not doc.bindSources then
goto CONTINUE
end
- local crossTable = stepRefOfGenericCrossTable(status, doc, typeName)
- local paramName = doc.param[1]
- for _, source in ipairs(doc.bindSources) do
- if source.type == 'local'
- and source[1] == paramName
- and source.parent.type == 'funcargs' then
- for index, arg in ipairs(source.parent) do
- if arg == source then
- appendValidGenericType(results, status, typeName, crossTable(args[index]))
+ local crossTable = stepRefOfGenericCrossTable(status, docArg, typeName)
+
+ -- find out param index
+ local genericIndex
+ if docArg.type == 'doc.param' then
+ local paramName = docArg.param[1]
+ for _, source in ipairs(doc.bindSources) do
+ if source.type == 'local'
+ and source[1] == paramName
+ and source.parent.type == 'funcargs' then
+ for index, arg in ipairs(source.parent) do
+ if arg == source then
+ genericIndex = index
+ break
+ end
end
end
end
+ elseif docArg.type == 'doc.type.arg' then
+ for index, arg in ipairs(docArg.parent.args) do
+ if arg == docArg then
+ genericIndex = index
+ break
+ end
+ end
end
+
+ appendValidGenericType(results, status, typeName, crossTable(args[genericIndex]))
::CONTINUE::
end
return results
@@ -2037,13 +2054,13 @@ local function checkSameSimpleAndMergeFunctionReturnsByDoc(status, results, sour
return true
end
-local function checkSameSimpleAndMergeDocFunctionReturn(status, results, docFunc, index)
+local function checkSameSimpleAndMergeDocFunctionReturn(status, results, docFunc, index, args)
if docFunc.type ~= 'doc.type.function' then
return
end
local rtn = docFunc.returns[index]
if rtn then
- local types = m.checkSameSimpleByDocType(status, rtn)
+ local types = m.checkSameSimpleByDocType(status, rtn, args)
if types then
for _, res in ipairs(types) do
results[#results+1] = res
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index 13816943..f30d47b4 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -998,7 +998,8 @@ local function bindGeneric(binded)
generics[name] = {}
end
elseif doc.type == 'doc.param'
- or doc.type == 'doc.return' then
+ or doc.type == 'doc.return'
+ or doc.type == 'doc.type' then
guide.eachSourceType(doc, 'doc.type.name', function (src)
local name = src[1]
if generics[name] then
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 4eee872f..883385f8 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -427,6 +427,14 @@ local t
local k, <?v?> = next(t)
]]
+TEST 'boolean' [[
+---@generic K
+---@type fun(arg: K):K
+local f
+
+local <?r?> = f(true)
+]]
+
TEST 'string' [[
---@generic T: table, K, V
---@param t T