diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 9f51608b968afb4d6388e194ef6712346113aa49..8f060d7f9a0e107a410d3ffe71722f49059f7bc8 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -1031,7 +1031,7 @@ static inline int netlink_compare(struct rhashtable_compare_arg *arg,
 	const struct netlink_compare_arg *x = arg->key;
 	const struct netlink_sock *nlk = ptr;
 
-	return nlk->rhash_portid != x->portid ||
+	return nlk->portid != x->portid ||
 	       !net_eq(sock_net(&nlk->sk), read_pnet(&x->pnet));
 }
 
@@ -1057,7 +1057,7 @@ static int __netlink_insert(struct netlink_table *table, struct sock *sk)
 {
 	struct netlink_compare_arg arg;
 
-	netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->rhash_portid);
+	netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->portid);
 	return rhashtable_lookup_insert_key(&table->hash, &arg,
 					    &nlk_sk(sk)->node,
 					    netlink_rhashtable_params);
@@ -1110,8 +1110,8 @@ static int netlink_insert(struct sock *sk, u32 portid)
 
 	lock_sock(sk);
 
-	err = -EBUSY;
-	if (nlk_sk(sk)->portid)
+	err = nlk_sk(sk)->portid == portid ? 0 : -EBUSY;
+	if (nlk_sk(sk)->bound)
 		goto err;
 
 	err = -ENOMEM;
@@ -1119,7 +1119,7 @@ static int netlink_insert(struct sock *sk, u32 portid)
 	    unlikely(atomic_read(&table->hash.nelems) >= UINT_MAX))
 		goto err;
 
-	nlk_sk(sk)->rhash_portid = portid;
+	nlk_sk(sk)->portid = portid;
 	sock_hold(sk);
 
 	err = __netlink_insert(table, sk);
@@ -1135,7 +1135,9 @@ static int netlink_insert(struct sock *sk, u32 portid)
 		goto err;
 	}
 
-	nlk_sk(sk)->portid = portid;
+	/* We need to ensure that the socket is hashed and visible. */
+	smp_wmb();
+	nlk_sk(sk)->bound = portid;
 
 err:
 	release_sock(sk);
@@ -1521,6 +1523,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
 	struct sockaddr_nl *nladdr = (struct sockaddr_nl *)addr;
 	int err;
 	long unsigned int groups = nladdr->nl_groups;
+	bool bound;
 
 	if (addr_len < sizeof(struct sockaddr_nl))
 		return -EINVAL;
@@ -1537,9 +1540,14 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
 			return err;
 	}
 
-	if (nlk->portid)
+	bound = nlk->bound;
+	if (bound) {
+		/* Ensure nlk->portid is up-to-date. */
+		smp_rmb();
+
 		if (nladdr->nl_pid != nlk->portid)
 			return -EINVAL;
+	}
 
 	if (nlk->netlink_bind && groups) {
 		int group;
@@ -1555,7 +1563,10 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
 		}
 	}
 
-	if (!nlk->portid) {
+	/* No need for barriers here as we return to user-space without
+	 * using any of the bound attributes.
+	 */
+	if (!bound) {
 		err = nladdr->nl_pid ?
 			netlink_insert(sk, nladdr->nl_pid) :
 			netlink_autobind(sock);
@@ -1603,7 +1614,10 @@ static int netlink_connect(struct socket *sock, struct sockaddr *addr,
 	    !netlink_allowed(sock, NL_CFG_F_NONROOT_SEND))
 		return -EPERM;
 
-	if (!nlk->portid)
+	/* No need for barriers here as we return to user-space without
+	 * using any of the bound attributes.
+	 */
+	if (!nlk->bound)
 		err = netlink_autobind(sock);
 
 	if (err == 0) {
@@ -2444,10 +2458,13 @@ static int netlink_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 		dst_group = nlk->dst_group;
 	}
 
-	if (!nlk->portid) {
+	if (!nlk->bound) {
 		err = netlink_autobind(sock);
 		if (err)
 			goto out;
+	} else {
+		/* Ensure nlk is hashed and visible. */
+		smp_rmb();
 	}
 
 	/* It's a really convoluted way for userland to ask for mmaped
@@ -3273,7 +3290,7 @@ static inline u32 netlink_hash(const void *data, u32 len, u32 seed)
 	const struct netlink_sock *nlk = data;
 	struct netlink_compare_arg arg;
 
-	netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->rhash_portid);
+	netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->portid);
 	return jhash2((u32 *)&arg, netlink_compare_arg_len / sizeof(u32), seed);
 }
 
diff --git a/net/netlink/af_netlink.h b/net/netlink/af_netlink.h
index 80b2b7526dfd26641542f4c5868fdd5b830056a0..14437d9b1965dcf3d3f085e4aba1f804bdc6f652 100644
--- a/net/netlink/af_netlink.h
+++ b/net/netlink/af_netlink.h
@@ -25,7 +25,6 @@ struct netlink_ring {
 struct netlink_sock {
 	/* struct sock has to be the first member of netlink_sock */
 	struct sock		sk;
-	u32			rhash_portid;
 	u32			portid;
 	u32			dst_portid;
 	u32			dst_group;
@@ -36,6 +35,7 @@ struct netlink_sock {
 	unsigned long		state;
 	size_t			max_recvmsg_len;
 	wait_queue_head_t	wait;
+	bool			bound;
 	bool			cb_running;
 	struct netlink_callback	cb;
 	struct mutex		*cb_mutex;