Skip to content

Commit

Permalink
add database as a param to create_*_client
Browse files Browse the repository at this point in the history
- no tests are provided: those api will be tested by section tablespaces

Change-Id: Ice3a85edb7fddf78c617bbbfa3d461feb381c899
  • Loading branch information
s-kipnis committed Nov 16, 2023
1 parent 647b674 commit 9b8a664
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 63 deletions.
82 changes: 26 additions & 56 deletions packages/check-sql/src/ms_sql/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use crate::ms_sql::queries;
use anyhow::Result;
use futures::stream::{self, StreamExt};

use tiberius::{AuthMethod, Config, Query, Row, SqlBrowser};
#[cfg(windows)]
use tiberius::SqlBrowser;
use tiberius::{AuthMethod, Config, Query, Row};
use tokio::net::TcpStream;
use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};

Expand Down Expand Up @@ -92,7 +94,7 @@ impl InstanceEngine {
) -> String {
let mut result = String::new();
let instance_section = Section::new(INSTANCE_SECTION_NAME); // this is important section always present
match self.create_client(ms_sql.auth(), ms_sql.conn()).await {
match self.create_client(ms_sql.auth(), ms_sql.conn(), None).await {
Ok(mut client) => {
for section in sections {
result += &section.to_header();
Expand Down Expand Up @@ -128,6 +130,7 @@ impl InstanceEngine {
&self,
auth: &config::ms_sql::Authentication,
conn: &config::ms_sql::Connection,
database: Option<String>,
) -> Result<Client> {
let client = match auth.auth_type() {
config::ms_sql::AuthType::SqlServer | config::ms_sql::AuthType::Windows => {
Expand All @@ -136,6 +139,7 @@ impl InstanceEngine {
conn.hostname(),
self.port().unwrap_or(defaults::STANDARD_PORT),
credentials,
database,
)
.await?
} else {
Expand All @@ -145,7 +149,7 @@ impl InstanceEngine {

#[cfg(windows)]
config::ms_sql::AuthType::Integrated => {
create_local_instance_client(&self.name, conn.sql_browser_port()).await?
create_local_instance_client(&self.name, conn.sql_browser_port(), None).await?
}

_ => anyhow::bail!("Not supported authorization type"),
Expand Down Expand Up @@ -453,14 +457,14 @@ async fn create_client_from_config(
let client = match auth.auth_type() {
config::ms_sql::AuthType::SqlServer | config::ms_sql::AuthType::Windows => {
if let Some(credentials) = obtain_config_credentials(auth) {
create_remote_client(conn.hostname(), conn.port(), credentials).await?
create_remote_client(conn.hostname(), conn.port(), credentials, None).await?
} else {
anyhow::bail!("Not provided credentials")
}
}

#[cfg(windows)]
config::ms_sql::AuthType::Integrated => create_local_client().await?,
config::ms_sql::AuthType::Integrated => create_local_client(None).await?,

_ => anyhow::bail!("Not supported authorization type"),
};
Expand Down Expand Up @@ -515,12 +519,14 @@ pub async fn create_remote_client(
host: &str,
port: u16,
credentials: Credentials<'_>,
database: Option<String>,
) -> Result<Client> {
match _create_remote_client(
host,
port,
&credentials,
tiberius::EncryptionLevel::Required,
&database,
)
.await
{
Expand All @@ -536,6 +542,7 @@ pub async fn create_remote_client(
port,
&credentials,
tiberius::EncryptionLevel::NotSupported,
&database,
)
.await?)
}
Expand All @@ -555,12 +562,16 @@ pub async fn _create_remote_client(
port: u16,
credentials: &Credentials<'_>,
encryption: tiberius::EncryptionLevel,
database: &Option<String>,
) -> Result<Client> {
let mut config = Config::new();

config.host(host);
config.port(port);
config.encryption(encryption);
if let Some(db) = database {
config.database(db);
}
config.authentication(match credentials {
Credentials::SqlServer { user, password } => AuthMethod::sql_server(user, password),
#[cfg(windows)]
Expand All @@ -582,60 +593,14 @@ pub async fn _create_remote_client(
Ok(Client::connect(config, tcp.compat_write()).await?)
}

/// Create `remote` connection to MS SQL `instance`
///
/// # Arguments
///
/// * `host` - Hostname of MS SQL server
/// * `port` - Port of MS SQL server BROWSER, 1434 - default
/// * `credentials` - defines connection type and credentials itself
/// * `instance_name` - name of the instance to connect to
pub async fn create_remote_instance_client(
instance_name: &str,
host: &str,
sql_browser_port: Option<u16>,
credentials: Credentials<'_>,
) -> anyhow::Result<Client> {
let mut config = Config::new();

config.host(host);
// The default port of SQL Browser
config.port(sql_browser_port.unwrap_or(defaults::SQL_BROWSER_PORT));
config.authentication(match credentials {
Credentials::SqlServer { user, password } => AuthMethod::sql_server(user, password),
#[cfg(windows)]
Credentials::Windows { user, password } => AuthMethod::windows(user, password),
#[cfg(unix)]
Credentials::Windows {
user: _,
password: _,
} => anyhow::bail!("not supported"),
});

// The name of the database server instance.
config.instance_name(instance_name);

// on production, it is not a good idea to do this
config.trust_cert();

// This will create a new `TcpStream` from `async-std`, connected to the
// right port of the named instance.
let tcp = TcpStream::connect_named(&config)
.await
.map_err(|e| anyhow::anyhow!("{} {}", SQL_TCP_ERROR_TAG, e))?;

// And from here on continue the connection process in a normal way.
let s = Client::connect(config, tcp.compat_write())
.await
.map_err(|e| anyhow::anyhow!("{} {}", SQL_LOGIN_ERROR_TAG, e))?;
Ok(s)
}

/// Check `local` (Integrated) connection to MS SQL
#[cfg(windows)]
pub async fn create_local_client() -> Result<Client> {
pub async fn create_local_client(database: Option<String>) -> Result<Client> {
let mut config = Config::new();

if let Some(db) = database {
config.database(db);
}
config.authentication(AuthMethod::Integrated);
config.trust_cert(); // on production, it is not a good idea to do this

Expand All @@ -649,7 +614,7 @@ pub async fn create_local_client() -> Result<Client> {
}

#[cfg(unix)]
pub async fn create_local_client() -> Result<Client> {
pub async fn create_local_client(_database: Option<String>) -> Result<Client> {
anyhow::bail!("not supported");
}

Expand All @@ -663,13 +628,17 @@ pub async fn create_local_client() -> Result<Client> {
pub async fn create_local_instance_client(
instance_name: &str,
sql_browser_port: Option<u16>,
database: Option<String>,
) -> anyhow::Result<Client> {
let mut config = Config::new();

config.host("localhost");
// The default port of SQL Browser
config.port(sql_browser_port.unwrap_or(defaults::SQL_BROWSER_PORT));
config.authentication(AuthMethod::Integrated);
if let Some(db) = database {
config.database(db);
}

// The name of the database server instance.
config.instance_name(instance_name);
Expand Down Expand Up @@ -700,6 +669,7 @@ pub async fn create_local_instance_client(
pub async fn create_local_instance_client(
_instance_name: &str,
_port: Option<u16>,
_database: Option<String>,
) -> anyhow::Result<Client> {
anyhow::bail!("not supported");
}
Expand Down
1 change: 1 addition & 0 deletions packages/check-sql/tests/common/tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub async fn create_remote_client(endpoint: &SqlDbEndpoint) -> Result<Client> {
user: &endpoint.user,
password: &endpoint.pwd,
},
None,
)
.await
}
Expand Down
15 changes: 8 additions & 7 deletions packages/check-sql/tests/test_ms_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn expected_instances() -> Vec<String> {
#[cfg(windows)]
#[tokio::test(flavor = "multi_thread")]
async fn test_local_connection() {
assert!(api::create_local_client().await.is_ok());
assert!(api::create_local_client(None).await.is_ok());
}

fn is_instance_good(i: &InstanceEngine) -> bool {
Expand All @@ -37,7 +37,7 @@ fn is_instance_good(i: &InstanceEngine) -> bool {
#[cfg(windows)]
#[tokio::test(flavor = "multi_thread")]
async fn test_find_all_instances_local() {
let mut client = api::create_local_client().await.unwrap();
let mut client = api::create_local_client(None).await.unwrap();
let instances = api::detect_instance_engines(&mut client).await.unwrap();
let all: Vec<InstanceEngine> = [&instances.0[..], &instances.1[..]].concat();
assert!(all.iter().all(is_instance_good), "{:?}", all);
Expand All @@ -50,7 +50,7 @@ async fn test_find_all_instances_local() {
#[cfg(windows)]
#[tokio::test(flavor = "multi_thread")]
async fn test_validate_all_instances_local() {
let mut client = api::create_local_client().await.unwrap();
let mut client = api::create_local_client(None).await.unwrap();
let instances = api::detect_instance_engines(&mut client).await.unwrap();
let names: Vec<String> = [&instances.0[..], &instances.1[..]]
.concat()
Expand All @@ -59,7 +59,7 @@ async fn test_validate_all_instances_local() {
.collect();

for name in names {
let c = api::create_local_instance_client(&name, None).await;
let c = api::create_local_instance_client(&name, None, None).await;
match c {
Ok(mut c) => assert!(tools::run_get_version(&mut c).await.is_some()),
Err(e) if e.to_string().starts_with(api::SQL_LOGIN_ERROR_TAG) => {
Expand All @@ -84,7 +84,8 @@ async fn test_remote_connection() {
api::Credentials::SqlServer {
user: &endpoint.user,
password: &endpoint.pwd,
}
},
None
)
.await
.is_ok());
Expand Down Expand Up @@ -131,7 +132,7 @@ async fn test_validate_all_instances_remote() {
.unwrap()
.unwrap();
for i in is {
match i.create_client(cfg.auth(), cfg.conn()).await {
match i.create_client(cfg.auth(), cfg.conn(), None).await {
Ok(mut c) => {
assert!(
tools::run_get_version(&mut c).await.is_some()
Expand Down Expand Up @@ -236,7 +237,7 @@ mssql:
.unwrap();

for i in is {
let c = i.create_client(ms_sql.auth(), ms_sql.conn()).await;
let c = i.create_client(ms_sql.auth(), ms_sql.conn(), None).await;
match c {
Ok(mut c) => assert!(
tools::run_get_version(&mut c).await.is_some()
Expand Down

0 comments on commit 9b8a664

Please sign in to comment.