diff --git a/core/codegen/tests/route-format.rs b/core/codegen/tests/route-format.rs index 1c991b84fc..d43e8a15d7 100644 --- a/core/codegen/tests/route-format.rs +++ b/core/codegen/tests/route-format.rs @@ -61,7 +61,7 @@ fn test_formats() { assert_eq!(response.into_string().unwrap(), "plain"); let response = client.put("/").header(ContentType::HTML).dispatch(); - assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.status(), Status::MethodNotAllowed); } // Test custom formats. diff --git a/core/lib/src/router/collider.rs b/core/lib/src/router/collider.rs index df343a85b6..16ffa2f7bb 100644 --- a/core/lib/src/router/collider.rs +++ b/core/lib/src/router/collider.rs @@ -38,12 +38,19 @@ impl Route { /// request query string, though in any position. /// - If no query in route, requests with/without queries match. #[doc(hidden)] - pub fn matches(&self, req: &Request<'_>) -> bool { + pub fn matches_by_method(&self, req: &Request<'_>) -> bool { self.method == req.method() && paths_match(self, req) && queries_match(self, req) && formats_match(self, req) } + + /// Match against any method. + #[doc(hidden)] + pub fn match_any(&self, req: &Request<'_>) -> bool { + paths_match(self, req) && queries_match(self, req) && formats_match(self, req) + } + } fn paths_collide(route: &Route, other: &Route) -> bool { @@ -416,7 +423,7 @@ mod tests { route.format = Some(mt_str.parse::().unwrap()); } - route.matches(&req) + route.matches_by_method(&req) } #[test] @@ -471,7 +478,7 @@ mod tests { let rocket = Rocket::custom(Config::default()); let req = Request::new(&rocket, Get, Origin::parse(a).expect("valid URI")); let route = Route::ranked(0, Get, b.to_string(), dummy); - route.matches(&req) + route.matches_by_method(&req) } #[test] diff --git a/core/lib/src/router/mod.rs b/core/lib/src/router/mod.rs index 8538394395..81b19fc73e 100644 --- a/core/lib/src/router/mod.rs +++ b/core/lib/src/router/mod.rs @@ -17,6 +17,7 @@ pub struct Router { routes: HashMap>, } + impl Router { pub fn new() -> Router { Router { routes: HashMap::new() } @@ -31,15 +32,26 @@ impl Router { entries.insert(i, route); } - pub fn route<'b>(&'b self, req: &Request<'_>) -> Vec<&'b Route> { - // Note that routes are presorted by rank on each `add`. - let matches = self.routes.get(&req.method()).map_or(vec![], |routes| { - routes.iter() - .filter(|r| r.matches(req)) - .collect() - }); + // Param `restrict` will restrict the route matching by the http method of `req` + // With `restrict` == false and `req` method == GET both will be matched: + // - GET hello/world <- + // - POST hello/world <- + // With `restrict` == true and `req` method == GET only the first one will be matched: + // - GET foo/bar <- + // - POST foo/bar + pub fn route<'b>(&'b self, req: &Request<'_>, restrict: bool) -> Vec<&'b Route> { + let mut matches = Vec::new(); + for (_method, routes_vec) in self.routes.iter() { + for _route in routes_vec { + if _route.matches_by_method(req) { + matches.push(_route); + } else if !restrict && _route.match_any(req){ + matches.push(_route); + } + } + } - trace_!("Routing the request: {}", req); + trace_!("Routing(restrict: {}): {}", restrict, req); trace_!("All matches: {:?}", matches); matches } @@ -245,7 +257,7 @@ mod test { fn route<'a>(router: &'a Router, method: Method, uri: &str) -> Option<&'a Route> { let rocket = Rocket::custom(Config::default()); let request = Request::new(&rocket, method, Origin::parse(uri).unwrap()); - let matches = router.route(&request); + let matches = router.route(&request, false); if matches.len() > 0 { Some(matches[0]) } else { @@ -256,7 +268,7 @@ mod test { fn matches<'a>(router: &'a Router, method: Method, uri: &str) -> Vec<&'a Route> { let rocket = Rocket::custom(Config::default()); let request = Request::new(&rocket, method, Origin::parse(uri).unwrap()); - router.route(&request) + router.route(&request, false) } #[test] @@ -294,9 +306,9 @@ mod test { #[test] fn test_err_routing() { let router = router_with_routes(&["/hello"]); - assert!(route(&router, Put, "/hello").is_none()); - assert!(route(&router, Post, "/hello").is_none()); - assert!(route(&router, Options, "/hello").is_none()); + assert!(route(&router, Put, "/hello").is_some()); + assert!(route(&router, Post, "/hello").is_some()); + assert!(route(&router, Options, "/hello").is_some()); assert!(route(&router, Get, "/hell").is_none()); assert!(route(&router, Get, "/hi").is_none()); assert!(route(&router, Get, "/hello/there").is_none()); @@ -304,20 +316,19 @@ mod test { assert!(route(&router, Get, "/hillo").is_none()); let router = router_with_routes(&["/"]); - assert!(route(&router, Put, "/hello").is_none()); - assert!(route(&router, Post, "/hello").is_none()); - assert!(route(&router, Options, "/hello").is_none()); + assert!(route(&router, Put, "/hello").is_some()); + assert!(route(&router, Post, "/hello").is_some()); + assert!(route(&router, Options, "/hello").is_some()); assert!(route(&router, Get, "/hello/there").is_none()); assert!(route(&router, Get, "/hello/i").is_none()); let router = router_with_routes(&["//"]); + assert!(route(&router, Put, "/a/b").is_some()); + assert!(route(&router, Put, "/hello/hi").is_some()); assert!(route(&router, Get, "/a/b/c").is_none()); assert!(route(&router, Get, "/a").is_none()); assert!(route(&router, Get, "/a/").is_none()); assert!(route(&router, Get, "/a/b/c/d").is_none()); - assert!(route(&router, Put, "/hello/hi").is_none()); - assert!(route(&router, Put, "/a/b").is_none()); - assert!(route(&router, Put, "/a/b").is_none()); } macro_rules! assert_ranked_routes { diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index b4ed0687e1..1263f3700a 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -263,10 +263,11 @@ impl Rocket { ) -> impl Future> + 's { async move { // Go through the list of matching routes until we fail or succeed. - let matches = self.router.route(request); - for route in matches { + let method_matches = self.router.route(request, true); + for route in method_matches { // Retrieve and set the requests parameters. info_!("Matched: {}", route); + request.set_route(route); // Dispatch the request to the handler. @@ -283,7 +284,21 @@ impl Rocket { } error_!("No matching routes for {}.", request); - Outcome::Forward(data) + + // Find if a similar route exists + let match_any = self.router.route(request, false); + + for route in match_any { + if &request.method() != &Method::Head // Must pass HEAD requests foward + && &request.method() != &route.method + { + info_!("{}", Paint::yellow("A similar route exists: ").bold()); + info_!(" - {}", Paint::yellow(&route).bold()); + return Outcome::Failure(Status::MethodNotAllowed); + } + } + + Outcome::forward(data) } } diff --git a/core/lib/tests/form_method-issue-45.rs b/core/lib/tests/form_method-issue-45.rs index c60b5224e1..adbbb525cf 100644 --- a/core/lib/tests/form_method-issue-45.rs +++ b/core/lib/tests/form_method-issue-45.rs @@ -37,6 +37,6 @@ mod tests { .body("_method=patch&form_data=Form+data") .dispatch(); - assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.status(), Status::MethodNotAllowed); } } diff --git a/core/lib/tests/return_method_not_allowed_issue_1224.rs b/core/lib/tests/return_method_not_allowed_issue_1224.rs new file mode 100644 index 0000000000..9765f4bb4c --- /dev/null +++ b/core/lib/tests/return_method_not_allowed_issue_1224.rs @@ -0,0 +1,101 @@ +#[macro_use] +extern crate rocket; + +#[get("/index")] +fn get_index() -> &'static str { + "GET index :)" +} + +#[post("/index")] +fn post_index() -> &'static str { + "POST index :)" +} + +#[post("/hello")] +fn post_hello() -> &'static str { + "POST Hello, world!" +} + +mod tests { + use super::*; + use rocket::http::Status; + use rocket::local::blocking::Client; + + #[test] + fn test_http_200_when_same_route_with_diff_meth() { + let rocket = rocket::ignite() + .mount("/", routes![get_index, post_index]); + + let client = Client::tracked(rocket).unwrap(); + + let response = client.post("/index").dispatch(); + + assert_eq!(response.status(), Status::Ok); + } + + #[test] + fn test_http_200_when_head_request() { + let rocket = rocket::ignite().mount("/", routes![get_index]); + + let client = Client::tracked(rocket).unwrap(); + + let response = client.head("/index").dispatch(); + + assert_eq!(response.status(), Status::Ok); + } + + #[test] + fn test_http_200_when_route_is_ok() { + let rocket = rocket::ignite().mount("/", routes![get_index]); + + let client = Client::tracked(rocket).unwrap(); + + let response = client.get("/index").dispatch(); + + assert_eq!(response.status(), Status::Ok); + } + + #[test] + fn test_http_200_with_params() { + let rocket = rocket::ignite().mount("/", routes![get_index]); + + let client = Client::tracked(rocket).unwrap(); + + let response = client.get("/index?say=hi").dispatch(); + + assert_eq!(response.status(), Status::Ok); + } + + #[test] + fn test_http_404_when_route_not_match() { + let rocket = rocket::ignite().mount("/", routes![get_index]); + + let client = Client::tracked(rocket).unwrap(); + + let response = client.get("/abc").dispatch(); + + assert_eq!(response.status(), Status::NotFound); + } + + #[test] + fn test_http_405_when_method_not_match() { + let rocket = rocket::ignite().mount("/", routes![get_index]); + + let client = Client::tracked(rocket).unwrap(); + + let response = client.post("/index").dispatch(); + + assert_eq!(response.status(), Status::MethodNotAllowed); + } + + #[test] + fn test_http_405_with_params() { + let rocket = rocket::ignite().mount("/", routes![post_hello]); + + let client = Client::tracked(rocket).unwrap(); + + let response = client.get("/hello?say=hi").dispatch(); + + assert_eq!(response.status(), Status::MethodNotAllowed); + } +} diff --git a/examples/errors/src/tests.rs b/examples/errors/src/tests.rs index 7d5e6c244a..18dda016f2 100644 --- a/examples/errors/src/tests.rs +++ b/examples/errors/src/tests.rs @@ -46,7 +46,7 @@ fn forced_error_and_default_catcher() { fn test_hello_invalid_age() { let client = Client::tracked(super::rocket()).unwrap(); - for &(name, age) in &[("Ford", -129), ("Trillian", 128)] { + for &(name, age) in &[("Ford", "s"), ("Trillian", "f")] { let request = client.get(format!("/hello/{}/{}", name, age)); let expected = super::not_found(request.inner()); let response = request.dispatch(); diff --git a/examples/handlebars_templates/src/main.rs b/examples/handlebars_templates/src/main.rs index 04d3cafea6..533de533a9 100644 --- a/examples/handlebars_templates/src/main.rs +++ b/examples/handlebars_templates/src/main.rs @@ -47,6 +47,13 @@ fn not_found(req: &Request<'_>) -> Template { Template::render("error/404", &map) } +#[catch(405)] +fn method_not_allowed(req: &Request<'_>) -> Template { + let mut map = std::collections::HashMap::new(); + map.insert("path", req.uri().path()); + Template::render("error/405", &map) +} + use self::handlebars::{Helper, Handlebars, Context, RenderContext, Output, HelperResult, JsonRender}; fn wow_helper( @@ -69,7 +76,7 @@ fn wow_helper( fn rocket() -> rocket::Rocket { rocket::ignite() .mount("/", routes![index, hello, about]) - .register(catchers![not_found]) + .register(catchers![not_found, method_not_allowed]) .attach(Template::custom(|engines| { engines.handlebars.register_helper("wow", Box::new(wow_helper)); })) diff --git a/examples/handlebars_templates/src/tests.rs b/examples/handlebars_templates/src/tests.rs index 998a364a5d..4991af5f8c 100644 --- a/examples/handlebars_templates/src/tests.rs +++ b/examples/handlebars_templates/src/tests.rs @@ -31,9 +31,9 @@ fn test_root() { dispatch!(*method, "/", |client, response| { let mut map = std::collections::HashMap::new(); map.insert("path", "/"); - let expected = Template::show(client.rocket(), "error/404", &map).unwrap(); + let expected = Template::show(client.rocket(), "error/405", &map).unwrap(); - assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.status(), Status::MethodNotAllowed); assert_eq!(response.into_string(), Some(expected)); }); } diff --git a/examples/handlebars_templates/templates/error/405.hbs b/examples/handlebars_templates/templates/error/405.hbs new file mode 100644 index 0000000000..a4df7ca0d4 --- /dev/null +++ b/examples/handlebars_templates/templates/error/405.hbs @@ -0,0 +1,17 @@ + + + + + 405 Method Not Allowed + + +
+

405: Method Not Allowed

+

The request method is not supported for the requested resource.

+
+
+
+ Rocket +
+ + diff --git a/examples/tera_templates/src/main.rs b/examples/tera_templates/src/main.rs index 3f617773ce..1c225a6043 100644 --- a/examples/tera_templates/src/main.rs +++ b/examples/tera_templates/src/main.rs @@ -32,10 +32,17 @@ fn not_found(req: &Request<'_>) -> Template { Template::render("error/404", &map) } +#[catch(405)] +fn method_not_allowed(req: &Request<'_>) -> Template { + let mut map = HashMap::new(); + map.insert("path", req.uri().path()); + Template::render("error/405", &map) +} + #[launch] fn rocket() -> rocket::Rocket { rocket::ignite() .mount("/", routes![index, get]) .attach(Template::fairing()) - .register(catchers![not_found]) + .register(catchers![not_found, method_not_allowed]) } diff --git a/examples/tera_templates/src/tests.rs b/examples/tera_templates/src/tests.rs index 05f386b7f9..7b17ebe36a 100644 --- a/examples/tera_templates/src/tests.rs +++ b/examples/tera_templates/src/tests.rs @@ -30,9 +30,9 @@ fn test_root() { dispatch!(*method, "/", |client, response| { let mut map = std::collections::HashMap::new(); map.insert("path", "/"); - let expected = Template::show(client.rocket(), "error/404", &map).unwrap(); + let expected = Template::show(client.rocket(), "error/405", &map).unwrap(); - assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.status(), Status::MethodNotAllowed); assert_eq!(response.into_string(), Some(expected)); }); } diff --git a/examples/tera_templates/templates/error/405.html.tera b/examples/tera_templates/templates/error/405.html.tera new file mode 100644 index 0000000000..e290b55728 --- /dev/null +++ b/examples/tera_templates/templates/error/405.html.tera @@ -0,0 +1,17 @@ + + + + + 405 Method Not Allowed + + +
+

405: Method Not Allowed

+

The request method is not supported for the requested resource.

+
+
+
+ Rocket +
+ + \ No newline at end of file