Skip to content

Commit

Permalink
[red-knot] Infer target types for unpacked tuple assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvmanila committed Sep 10, 2024
1 parent b7cef6c commit 9ea561f
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 18 deletions.
32 changes: 22 additions & 10 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,10 @@ where
debug_assert!(self.current_assignment.is_none());
self.visit_expr(&node.value);
self.add_standalone_expression(&node.value);
self.current_assignment = Some(node.into());
self.current_assignment = Some(CurrentAssignment::Assign {
node,
target_index: 0,
});
for target in &node.targets {
self.visit_expr(target);
}
Expand Down Expand Up @@ -699,12 +702,13 @@ where
let symbol = self.add_or_update_symbol(id.clone(), flags);
if flags.contains(SymbolFlags::IS_DEFINED) {
match self.current_assignment {
Some(CurrentAssignment::Assign(assignment)) => {
Some(CurrentAssignment::Assign { node, target_index }) => {
self.add_definition(
symbol,
AssignmentDefinitionNodeRef {
assignment,
assignment: node,
target: name_node,
target_index,
},
);
}
Expand Down Expand Up @@ -851,6 +855,17 @@ where
self.visit_expr(key);
self.visit_expr(value);
}
ast::Expr::Tuple(ast::ExprTuple { elts, ctx, .. }) => {
for (index, element) in elts.iter().enumerate() {
if let Some(CurrentAssignment::Assign { target_index, .. }) =
self.current_assignment.as_mut()
{
*target_index = index;
}
self.visit_expr(element);
}
self.visit_expr_context(ctx);
}
_ => {
walk_expr(self, expr);
}
Expand Down Expand Up @@ -957,7 +972,10 @@ where

#[derive(Copy, Clone, Debug)]
enum CurrentAssignment<'a> {
Assign(&'a ast::StmtAssign),
Assign {
node: &'a ast::StmtAssign,
target_index: usize,
},
AnnAssign(&'a ast::StmtAnnAssign),
AugAssign(&'a ast::StmtAugAssign),
For(&'a ast::StmtFor),
Expand All @@ -969,12 +987,6 @@ enum CurrentAssignment<'a> {
WithItem(&'a ast::WithItem),
}

impl<'a> From<&'a ast::StmtAssign> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtAssign) -> Self {
Self::Assign(value)
}
}

impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtAnnAssign) -> Self {
Self::AnnAssign(value)
Expand Down
22 changes: 16 additions & 6 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ pub(crate) struct ImportFromDefinitionNodeRef<'a> {
pub(crate) struct AssignmentDefinitionNodeRef<'a> {
pub(crate) assignment: &'a ast::StmtAssign,
pub(crate) target: &'a ast::ExprName,
pub(crate) target_index: usize,
}

#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -203,12 +204,15 @@ impl DefinitionNodeRef<'_> {
DefinitionNodeRef::NamedExpression(named) => {
DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named))
}
DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { assignment, target }) => {
DefinitionKind::Assignment(AssignmentDefinitionKind {
assignment: AstNodeRef::new(parsed.clone(), assignment),
target: AstNodeRef::new(parsed, target),
})
}
DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef {
assignment,
target,
target_index,
}) => DefinitionKind::Assignment(AssignmentDefinitionKind {
assignment: AstNodeRef::new(parsed.clone(), assignment),
target: AstNodeRef::new(parsed, target),
target_index,
}),
DefinitionNodeRef::AnnotatedAssignment(assign) => {
DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign))
}
Expand Down Expand Up @@ -276,6 +280,7 @@ impl DefinitionNodeRef<'_> {
Self::Assignment(AssignmentDefinitionNodeRef {
assignment: _,
target,
target_index: _,
}) => target.into(),
Self::AnnotatedAssignment(node) => node.into(),
Self::AugmentedAssignment(node) => node.into(),
Expand Down Expand Up @@ -381,6 +386,7 @@ impl ImportFromDefinitionKind {
pub struct AssignmentDefinitionKind {
assignment: AstNodeRef<ast::StmtAssign>,
target: AstNodeRef<ast::ExprName>,
target_index: usize,
}

impl AssignmentDefinitionKind {
Expand All @@ -391,6 +397,10 @@ impl AssignmentDefinitionKind {
pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
}

pub(crate) fn target_index(&self) -> usize {
self.target_index
}
}

#[derive(Clone, Debug)]
Expand Down
13 changes: 13 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ impl<'db> Type<'db> {
matches!(self, Type::Never)
}

pub const fn as_tuple_type(&self) -> Option<&TupleType<'db>> {
match self {
Type::Tuple(tuple_type) => Some(tuple_type),
_ => None,
}
}

pub const fn into_class_type(self) -> Option<ClassType<'db>> {
match self {
Type::Class(class_type) => Some(class_type),
Expand Down Expand Up @@ -672,3 +679,9 @@ pub struct TupleType<'db> {
#[return_ref]
elements: Box<[Type<'db>]>,
}

impl<'db> TupleType<'db> {
pub fn get(&self, db: &'db dyn Db, index: usize) -> Option<&Type<'db>> {
self.elements(db).get(index)
}
}
35 changes: 33 additions & 2 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ impl<'db> TypeInferenceBuilder<'db> {
DefinitionKind::Assignment(assignment) => {
self.infer_assignment_definition(
assignment.target(),
assignment.target_index(),
assignment.assignment(),
definition,
);
Expand Down Expand Up @@ -957,19 +958,33 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_assignment_definition(
&mut self,
target: &ast::ExprName,
target_index: usize,
assignment: &ast::StmtAssign,
definition: Definition<'db>,
) {
let expression = self.index.expression(assignment.value.as_ref());
let result = infer_expression_types(self.db, expression);
self.extend(result);

let value_ty = self
.types
.expression_ty(assignment.value.scoped_ast_id(self.db, self.scope));

let target_ty = if let Some(tuple_type) = value_ty.as_tuple_type() {
// TODO: when does this happen?
tuple_type
.get(self.db, target_index)
.copied()
.unwrap_or(Type::Unknown)
} else {
value_ty
};

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), value_ty);
self.types.definitions.insert(definition, value_ty);
.insert(target.scoped_ast_id(self.db, self.scope), target_ty);

self.types.definitions.insert(definition, target_ty);
}

fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) {
Expand Down Expand Up @@ -4057,6 +4072,22 @@ mod tests {
Ok(())
}

#[test]
fn unpacked_tuple_assignment() {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
x, y = 1, 2
",
)
.unwrap();

assert_public_ty(&db, "/src/a.py", "x", "Literal[1]");
assert_public_ty(&db, "/src/a.py", "y", "Literal[2]");
}

#[test]
fn list_literal() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down

0 comments on commit 9ea561f

Please sign in to comment.