summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/generic.lua27
-rw-r--r--script/core/noder.lua111
-rw-r--r--script/parser/luadoc.lua32
-rw-r--r--test/type_inference/init.lua7
4 files changed, 90 insertions, 87 deletions
diff --git a/script/core/generic.lua b/script/core/generic.lua
index 5b52a304..92e97362 100644
--- a/script/core/generic.lua
+++ b/script/core/generic.lua
@@ -195,15 +195,29 @@ local function buildValues(closure)
local protoFunction = closure.proto
local upvalues = closure.upvalues
local params = closure.call.args
+ local args = protoFunction.args
+ local paramMap = {}
+ if params then
+ for i, param in ipairs(params) do
+ local arg = args and args[i]
+ if arg then
+ if arg.type == 'local' then
+ paramMap[arg[1]] = param
+ elseif arg.type == 'doc.type.arg' then
+ paramMap[arg.name[1]] = param
+ end
+ end
+ end
+ end
if protoFunction.type == 'function' then
for _, doc in ipairs(protoFunction.bindDocs) do
if doc.type == 'doc.param' then
+ local name = doc.param[1]
local extends = doc.extends
- local index = extends and extends.paramIndex
- if index then
- local param = params and params[index]
- closure.params[index] = param and createValue(closure, extends, function (road, key, proto)
+ if name and extends then
+ local param = paramMap[name]
+ closure.params[name] = param and createValue(closure, extends, function (road, key, proto)
buildValue(road, key, proto, param, upvalues)
end) or extends
end
@@ -219,9 +233,10 @@ local function buildValues(closure)
end
if protoFunction.type == 'doc.type.function' then
for index, arg in ipairs(protoFunction.args) do
+ local name = arg.name[1]
local extends = arg.extends
- local param = params and params[index]
- closure.params[index] = param and createValue(closure, extends, function (road, key, proto)
+ local param = paramMap[name]
+ closure.params[name] = param and createValue(closure, extends, function (road, key, proto)
buildValue(road, key, proto, param, upvalues)
end) or extends
end
diff --git a/script/core/noder.lua b/script/core/noder.lua
index 814936ad..545ba729 100644
--- a/script/core/noder.lua
+++ b/script/core/noder.lua
@@ -26,6 +26,7 @@ local ANY_FIELD_CHAR = '*'
local INDEX_CHAR = '['
local RETURN_INDEX = SPLIT_CHAR .. '#'
local PARAM_INDEX = SPLIT_CHAR .. '&'
+local PARAM_NAME = SPLIT_CHAR .. '$'
local TABLE_KEY = SPLIT_CHAR .. '<'
local WEAK_TABLE_KEY = SPLIT_CHAR .. '<<'
local STRING_FIELD = SPLIT_CHAR .. STRING_CHAR
@@ -857,7 +858,6 @@ compileNodeMap = util.switch()
: call(function (noders, id, source)
pushForward(noders, id, 'dn:nil')
end)
- -- self -> mt:xx
: case 'local'
: call(function (noders, id, source)
if source[1] ~= 'self' then
@@ -1007,6 +1007,16 @@ compileNodeMap = util.switch()
pushForward(noders, getID(src), id)
end
end
+ if source.bindSources then
+ for _, src in ipairs(source.bindSources) do
+ local paramID = sformat('%s%s%s'
+ , getID(src)
+ , PARAM_NAME
+ , source.param[1]
+ )
+ pushForward(noders, paramID, id)
+ end
+ end
end)
: case 'doc.vararg'
: call(function (noders, id, source)
@@ -1054,8 +1064,8 @@ compileNodeMap = util.switch()
for index, param in ipairs(source.args) do
local paramID = sformat('%s%s%s'
, id
- , PARAM_INDEX
- , index
+ , PARAM_NAME
+ , param.name[1]
)
pushForward(noders, paramID, getID(param.extends))
end
@@ -1130,40 +1140,12 @@ compileNodeMap = util.switch()
end)
: case 'function'
: call(function (noders, id, source)
- local hasDocReturn = {}
+ local hasDocReturn
-- 检查 luadoc
if source.bindDocs then
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.return' then
- for _, rtn in ipairs(doc.returns) do
- local fullID = sformat('%s%s%s'
- , id
- , RETURN_INDEX
- , rtn.returnIndex
- )
- pushForward(noders, fullID, getID(rtn))
- for _, typeUnit in ipairs(rtn.types) do
- pushBackward(noders, getID(typeUnit), fullID, INFO_DEEP_AND_DONT_CROSS)
- end
- hasDocReturn[rtn.returnIndex] = true
- end
- end
- if doc.type == 'doc.param' then
- local paramName = doc.param[1]
- if source.docParamMap then
- local paramIndex = source.docParamMap[paramName]
- local param = source.args[paramIndex]
- if param then
- pushForward(noders, getID(param), getID(doc))
- param.docParam = doc
- local paramID = sformat('%s%s%s'
- , id
- , PARAM_INDEX
- , paramIndex
- )
- pushForward(noders, paramID, getID(doc.extends))
- end
- end
+ hasDocReturn = true
end
if doc.type == 'doc.vararg' then
if source.args then
@@ -1182,26 +1164,32 @@ compileNodeMap = util.switch()
end
end
end
+ if source.args then
+ for _, arg in ipairs(source.args) do
+ if arg.type == 'local' then
+ pushForward(noders, getID(arg), sformat('%s%s%s'
+ , id
+ , PARAM_NAME
+ , arg[1]
+ ))
+ else
+ pushForward(noders, getID(arg), sformat('%s%s%s'
+ , id
+ , PARAM_NAME
+ , '...'
+ ))
+ end
+ end
+ end
-- 检查实体返回值
- if source.returns then
- local returns = {}
+ if source.returns and not hasDocReturn then
for _, rtn in ipairs(source.returns) do
for index, rtnObj in ipairs(rtn) do
- if not hasDocReturn[index] then
- if not returns[index] then
- returns[index] = {}
- end
- returns[index][#returns[index]+1] = rtnObj
- end
- end
- end
- for index, rtnObjs in ipairs(returns) do
- local returnID = sformat('%s%s%s'
- , id
- , RETURN_INDEX
- , index
- )
- for _, rtnObj in ipairs(rtnObjs) do
+ local returnID = sformat('%s%s%s'
+ , id
+ , RETURN_INDEX
+ , index
+ )
pushForward(noders, returnID, getID(rtnObj))
pushBackward(noders, getID(rtnObj), returnID, INFO_DEEP_AND_DONT_CROSS)
end
@@ -1296,6 +1284,25 @@ compileNodeMap = util.switch()
end
end
end)
+ : case 'doc.return'
+ : call(function (noders, id, source)
+ if not source.bindSources then
+ return
+ end
+ for _, rtn in ipairs(source.returns) do
+ for _, src in ipairs(source.bindSources) do
+ local fullID = sformat('%s%s%s'
+ , getID(src)
+ , RETURN_INDEX
+ , rtn.returnIndex
+ )
+ pushForward(noders, fullID, getID(rtn))
+ for _, typeUnit in ipairs(rtn.types) do
+ pushBackward(noders, getID(typeUnit), fullID, INFO_DEEP_AND_DONT_CROSS)
+ end
+ end
+ end
+ end)
: case 'generic.closure'
: call(function (noders, id, source)
for i, rtn in ipairs(source.returns) do
@@ -1436,7 +1443,7 @@ function m.hasField(id)
end
local next2Char = ssub(id, #firstID + 2, #firstID + 2)
if next2Char == RETURN_INDEX
- or next2Char == PARAM_INDEX then
+ or next2Char == PARAM_NAME then
return false
end
return true
@@ -1460,7 +1467,7 @@ function m.isCommonField(field)
if ssub(field, 1, #RETURN_INDEX) == RETURN_INDEX then
return false
end
- if ssub(field, 1, #PARAM_INDEX) == PARAM_INDEX then
+ if ssub(field, 1, #PARAM_NAME) == PARAM_NAME then
return false
end
return true
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index 02b23d16..37447328 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -1265,36 +1265,10 @@ local function bindDocsBetween(sources, binded, bindSources, start, finish)
end
end
-local function bindParamAndReturnIndex(binded)
- local func
- for _, source in ipairs(binded[1].bindSources) do
- if source.type == 'function' then
- func = source
- break
- end
- end
- if not func then
- return
- end
- local paramMap
- if func.args then
- local paramIndex = 0
- paramMap = {}
- for _, param in ipairs(func.args) do
- paramIndex = paramIndex + 1
- if param[1] then
- paramMap[param[1]] = paramIndex
- end
- end
- func.docParamMap = paramMap
- end
+local function bindReturnIndex(binded)
local returnIndex = 0
for _, doc in ipairs(binded) do
- if doc.type == 'doc.param' then
- if paramMap and doc.extends then
- doc.extends.paramIndex = paramMap[doc.param[1]]
- end
- elseif doc.type == 'doc.return' then
+ if doc.type == 'doc.return' then
for _, rtn in ipairs(doc.returns) do
returnIndex = returnIndex + 1
rtn.returnIndex = returnIndex
@@ -1340,7 +1314,7 @@ local function bindDoc(sources, binded)
if #bindSources == 0 then
bindDocsBetween(sources, binded, bindSources, guide.positionOf(row + 1, 0), guide.positionOf(row + 2, 0))
end
- bindParamAndReturnIndex(binded)
+ bindReturnIndex(binded)
bindClassAndFields(binded)
end
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 13efbc59..e0c7b5c8 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -871,3 +871,10 @@ for _, a in ipairs(v) do
end
end
]]
+
+--TEST 'number' [[
+-----@param x number
+--local f
+--
+--f = function (<?x?>) end
+--]]