1- //go:build ccm  
2- // +build ccm  
1+ //go:build tc  
2+ // +build tc  
33
44/* 
55 * Licensed to the Apache Software Foundation (ASF) under one 
2828package  gocql
2929
3030import  (
31+ 	"context" 
3132	"fmt" 
3233	"sync" 
3334	"testing" 
3435	"time" 
35- 
36- 	"github.com/gocql/gocql/internal/ccm" 
3736)
3837
3938type  TestHostFilter  struct  {
4039	mu            sync.Mutex 
41- 	allowedHosts  map [string ]ccm. Host 
40+ 	allowedHosts  map [string ]TChost 
4241}
4342
4443func  (f  * TestHostFilter ) Accept (h  * HostInfo ) bool  {
@@ -48,37 +47,27 @@ func (f *TestHostFilter) Accept(h *HostInfo) bool {
4847	return  ok 
4948}
5049
51- func  (f  * TestHostFilter ) SetAllowedHosts (hosts  map [string ]ccm. Host ) {
50+ func  (f  * TestHostFilter ) SetAllowedHosts (hosts  map [string ]TChost ) {
5251	f .mu .Lock ()
5352	defer  f .mu .Unlock ()
5453	f .allowedHosts  =  hosts 
5554}
5655
5756func  TestControlConn_ReconnectRefreshesRing (t  * testing.T ) {
58- 	if  err  :=  ccm .AllUp (); err  !=  nil  {
59- 		t .Fatal (err )
60- 	}
61- 
62- 	allCcmHosts , err  :=  ccm .Status ()
63- 	if  err  !=  nil  {
64- 		t .Fatal (err )
65- 	}
57+ 	ctx  :=  context .Background ()
6658
67- 	if  len (allCcmHosts ) <  2  {
59+ 	if  len (cassNodes ) <  2  {
6860		t .Skip ("this test requires at least 2 nodes" )
6961	}
7062
71- 	allAllowedHosts  :=  map [string ]ccm.Host {}
72- 	var  firstNode  * ccm.Host 
73- 	for  _ , node  :=  range  allCcmHosts  {
74- 		if  firstNode  ==  nil  {
75- 			firstNode  =  & node 
76- 		}
63+ 	allAllowedHosts  :=  map [string ]TChost {}
64+ 	for  _ , node  :=  range  cassNodes  {
7765		allAllowedHosts [node .Addr ] =  node 
7866	}
7967
80- 	allowedHosts  :=  map [string ]ccm.Host {
81- 		firstNode .Addr : * firstNode ,
68+ 	firstNode  :=  cassNodes ["node1" ]
69+ 	allowedHosts  :=  map [string ]TChost {
70+ 		firstNode .Addr : firstNode ,
8271	}
8372
8473	testFilter  :=  & TestHostFilter {allowedHosts : allowedHosts }
@@ -99,9 +88,9 @@ func TestControlConn_ReconnectRefreshesRing(t *testing.T) {
9988	ccHost  :=  controlConnection .host 
10089
10190	var  ccHostName  string 
102- 	for  _ , node  :=  range  allCcmHosts  {
91+ 	for  name , node  :=  range  cassNodes  {
10392		if  node .Addr  ==  ccHost .ConnectAddress ().String () {
104- 			ccHostName  =  node . Name 
93+ 			ccHostName  =  name 
10594			break 
10695		}
10796	}
@@ -110,25 +99,15 @@ func TestControlConn_ReconnectRefreshesRing(t *testing.T) {
11099		t .Fatal ("could not find name of control host" )
111100	}
112101
113- 	if  err  :=  ccm . NodeDown ( ccHostName ); err  !=  nil  {
102+ 	if  err  :=  cassNodes [ ccHostName ]. TC . Stop ( ctx ,  nil ); err  !=  nil  {
114103		t .Fatal ()
115104	}
116105
117- 	defer  func () {
118- 		ccmStatus , err  :=  ccm .Status ()
119- 		if  err  !=  nil  {
120- 			t .Logf ("could not bring nodes back up after test: %v" , err )
121- 			return 
122- 		}
123- 		for  _ , node  :=  range  ccmStatus  {
124- 			if  node .State  ==  ccm .NodeStateDown  {
125- 				err  =  ccm .NodeUp (node .Name )
126- 				if  err  !=  nil  {
127- 					t .Logf ("could not bring node %v back up after test: %v" , node .Name , err )
128- 				}
129- 			}
106+ 	defer  func (ctx  context.Context ) {
107+ 		if  err  :=  restoreCluster (ctx ); err  !=  nil  {
108+ 			t .Fatalf ("couldn't restore a cluster : %v" , err )
130109		}
131- 	}()
110+ 	}(ctx )
132111
133112	assertNodeDown  :=  func () error  {
134113		hosts  :=  session .ring .currentHosts ()
@@ -159,19 +138,19 @@ func TestControlConn_ReconnectRefreshesRing(t *testing.T) {
159138	}
160139
161140	if  assertErr  !=  nil  {
162- 		t .Fatal (err )
141+ 		t .Fatal (assertErr )
163142	}
164143
165144	testFilter .SetAllowedHosts (allAllowedHosts )
166145
167- 	if  err  =   ccm . NodeUp ( ccHostName ); err  !=  nil  {
146+ 	if  err  :=   restoreCluster ( ctx ); err  !=  nil  {
168147		t .Fatal (err )
169148	}
170149
171150	assertNodeUp  :=  func () error  {
172151		hosts  :=  session .ring .currentHosts ()
173- 		if  len (hosts ) !=  len (allCcmHosts ) {
174- 			return  fmt .Errorf ("expected %v hosts in ring but there were %v" , len (allCcmHosts ), len (hosts ))
152+ 		if  len (hosts ) !=  len (cassNodes ) {
153+ 			return  fmt .Errorf ("expected %v hosts in ring but there were %v" , len (ccHostName ), len (hosts ))
175154		}
176155		for  _ , host  :=  range  hosts  {
177156			if  ! host .IsUp () {
@@ -181,8 +160,8 @@ func TestControlConn_ReconnectRefreshesRing(t *testing.T) {
181160		session .pool .mu .RLock ()
182161		poolsLen  :=  len (session .pool .hostConnPools )
183162		session .pool .mu .RUnlock ()
184- 		if  poolsLen  !=  len (allCcmHosts ) {
185- 			return  fmt .Errorf ("expected %v connection pool but there were %v" , len (allCcmHosts ), poolsLen )
163+ 		if  poolsLen  !=  len (cassNodes ) {
164+ 			return  fmt .Errorf ("expected %v connection pool but there were %v" , len (ccHostName ), poolsLen )
186165		}
187166		return  nil 
188167	}
@@ -196,6 +175,6 @@ func TestControlConn_ReconnectRefreshesRing(t *testing.T) {
196175	}
197176
198177	if  assertErr  !=  nil  {
199- 		t .Fatal (err )
178+ 		t .Fatal (assertErr )
200179	}
201180}
0 commit comments